Skip to content

PyTorch 生成对抗网络

GAN简介

生成对抗网络(Generative Adversarial Networks, GAN)是一种深度学习架构,由Ian Goodfellow在2014年提出。GAN通过两个神经网络的对抗训练来生成逼真的数据。

python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader

基础GAN实现

1. 生成器网络

python
class Generator(nn.Module):
    def __init__(self, latent_dim=100, img_channels=1, img_size=28):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.img_channels = img_channels
        self.img_size = img_size
        
        # 计算第一层的输出尺寸
        self.init_size = img_size // 4
        self.l1 = nn.Sequential(
            nn.Linear(latent_dim, 128 * self.init_size ** 2)
        )
        
        # 上采样层
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, img_channels, 3, stride=1, padding=1),
            nn.Tanh()
        )
    
    def forward(self, z):
        # 将噪声向量转换为特征图
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        
        # 通过卷积层生成图像
        img = self.conv_blocks(out)
        
        return img

# 测试生成器
latent_dim = 100
generator = Generator(latent_dim=latent_dim, img_channels=1, img_size=28)

# 生成随机噪声
z = torch.randn(4, latent_dim)
fake_imgs = generator(z)
print(f"生成图像形状: {fake_imgs.shape}")

2. 判别器网络

python
class Discriminator(nn.Module):
    def __init__(self, img_channels=1, img_size=28):
        super(Discriminator, self).__init__()
        
        def discriminator_block(in_filters, out_filters, bn=True):
            """判别器基本块"""
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            block.extend([nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)])
            return block
        
        self.model = nn.Sequential(
            *discriminator_block(img_channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )
        
        # 计算卷积输出的尺寸
        ds_size = img_size // 2 ** 4
        self.adv_layer = nn.Sequential(
            nn.Linear(128 * ds_size ** 2, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        
        return validity

# 测试判别器
discriminator = Discriminator(img_channels=1, img_size=28)
validity = discriminator(fake_imgs)
print(f"判别器输出形状: {validity.shape}")

3. 基础GAN训练

python
class BasicGAN:
    def __init__(self, latent_dim=100, img_channels=1, img_size=28, lr=0.0002):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # 初始化网络
        self.generator = Generator(latent_dim, img_channels, img_size).to(self.device)
        self.discriminator = Discriminator(img_channels, img_size).to(self.device)
        
        # 损失函数
        self.adversarial_loss = nn.BCELoss()
        
        # 优化器
        self.optimizer_G = optim.Adam(self.generator.parameters(), lr=lr, betas=(0.5, 0.999))
        self.optimizer_D = optim.Adam(self.discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
        
        self.latent_dim = latent_dim
    
    def train_discriminator(self, real_imgs):
        """训练判别器"""
        batch_size = real_imgs.size(0)
        
        # 真实图像标签
        valid = torch.ones(batch_size, 1, device=self.device, requires_grad=False)
        fake = torch.zeros(batch_size, 1, device=self.device, requires_grad=False)
        
        # 真实图像损失
        real_loss = self.adversarial_loss(self.discriminator(real_imgs), valid)
        
        # 生成假图像
        z = torch.randn(batch_size, self.latent_dim, device=self.device)
        fake_imgs = self.generator(z).detach()
        
        # 假图像损失
        fake_loss = self.adversarial_loss(self.discriminator(fake_imgs), fake)
        
        # 总损失
        d_loss = (real_loss + fake_loss) / 2
        
        return d_loss
    
    def train_generator(self, batch_size):
        """训练生成器"""
        # 生成假图像
        z = torch.randn(batch_size, self.latent_dim, device=self.device)
        fake_imgs = self.generator(z)
        
        # 生成器希望判别器认为假图像是真的
        valid = torch.ones(batch_size, 1, device=self.device, requires_grad=False)
        g_loss = self.adversarial_loss(self.discriminator(fake_imgs), valid)
        
        return g_loss, fake_imgs
    
    def train_epoch(self, dataloader):
        """训练一个epoch"""
        d_losses = []
        g_losses = []
        
        for i, (real_imgs, _) in enumerate(dataloader):
            real_imgs = real_imgs.to(self.device)
            batch_size = real_imgs.size(0)
            
            # 训练判别器
            self.optimizer_D.zero_grad()
            d_loss = self.train_discriminator(real_imgs)
            d_loss.backward()
            self.optimizer_D.step()
            
            # 训练生成器
            self.optimizer_G.zero_grad()
            g_loss, fake_imgs = self.train_generator(batch_size)
            g_loss.backward()
            self.optimizer_G.step()
            
            d_losses.append(d_loss.item())
            g_losses.append(g_loss.item())
            
            if i % 100 == 0:
                print(f'Batch {i}/{len(dataloader)}, D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}')
        
        return np.mean(d_losses), np.mean(g_losses)
    
    def generate_samples(self, num_samples=16):
        """生成样本"""
        self.generator.eval()
        with torch.no_grad():
            z = torch.randn(num_samples, self.latent_dim, device=self.device)
            fake_imgs = self.generator(z)
        self.generator.train()
        return fake_imgs

# 使用示例
# 加载MNIST数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

mnist_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)
dataloader = DataLoader(mnist_dataset, batch_size=64, shuffle=True)

# 创建GAN
gan = BasicGAN(latent_dim=100, img_channels=1, img_size=28)

# 训练
num_epochs = 50
for epoch in range(num_epochs):
    d_loss, g_loss = gan.train_epoch(dataloader)
    print(f'Epoch {epoch+1}/{num_epochs}, D_loss: {d_loss:.4f}, G_loss: {g_loss:.4f}')
    
    # 每10个epoch生成样本
    if (epoch + 1) % 10 == 0:
        samples = gan.generate_samples(16)
        # 可视化样本(这里省略可视化代码)

DCGAN实现

1. DCGAN生成器

python
class DCGANGenerator(nn.Module):
    def __init__(self, latent_dim=100, img_channels=3, feature_maps=64):
        super(DCGANGenerator, self).__init__()
        
        self.main = nn.Sequential(
            # 输入是latent_dim维的噪声向量
            nn.ConvTranspose2d(latent_dim, feature_maps * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(feature_maps * 8),
            nn.ReLU(True),
            
            # 状态尺寸: (feature_maps*8) x 4 x 4
            nn.ConvTranspose2d(feature_maps * 8, feature_maps * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 4),
            nn.ReLU(True),
            
            # 状态尺寸: (feature_maps*4) x 8 x 8
            nn.ConvTranspose2d(feature_maps * 4, feature_maps * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 2),
            nn.ReLU(True),
            
            # 状态尺寸: (feature_maps*2) x 16 x 16
            nn.ConvTranspose2d(feature_maps * 2, feature_maps, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps),
            nn.ReLU(True),
            
            # 状态尺寸: (feature_maps) x 32 x 32
            nn.ConvTranspose2d(feature_maps, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # 输出尺寸: (img_channels) x 64 x 64
        )
    
    def forward(self, input):
        return self.main(input)

class DCGANDiscriminator(nn.Module):
    def __init__(self, img_channels=3, feature_maps=64):
        super(DCGANDiscriminator, self).__init__()
        
        self.main = nn.Sequential(
            # 输入尺寸: (img_channels) x 64 x 64
            nn.Conv2d(img_channels, feature_maps, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 状态尺寸: (feature_maps) x 32 x 32
            nn.Conv2d(feature_maps, feature_maps * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 状态尺寸: (feature_maps*2) x 16 x 16
            nn.Conv2d(feature_maps * 2, feature_maps * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 状态尺寸: (feature_maps*4) x 8 x 8
            nn.Conv2d(feature_maps * 4, feature_maps * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 8),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 状态尺寸: (feature_maps*8) x 4 x 4
            nn.Conv2d(feature_maps * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)

WGAN实现

1. WGAN损失函数

python
class WGAN:
    def __init__(self, generator, discriminator, lr=0.00005, clip_value=0.01):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.generator = generator.to(self.device)
        self.discriminator = discriminator.to(self.device)
        
        # WGAN使用RMSprop优化器
        self.optimizer_G = optim.RMSprop(self.generator.parameters(), lr=lr)
        self.optimizer_D = optim.RMSprop(self.discriminator.parameters(), lr=lr)
        
        self.clip_value = clip_value
        self.latent_dim = 100
    
    def train_discriminator(self, real_imgs, n_critic=5):
        """训练判别器(评论家)"""
        d_losses = []
        
        for _ in range(n_critic):
            self.optimizer_D.zero_grad()
            
            batch_size = real_imgs.size(0)
            
            # 真实图像的损失
            real_validity = self.discriminator(real_imgs)
            
            # 生成假图像
            z = torch.randn(batch_size, self.latent_dim, 1, 1, device=self.device)
            fake_imgs = self.generator(z).detach()
            fake_validity = self.discriminator(fake_imgs)
            
            # WGAN损失
            d_loss = -torch.mean(real_validity) + torch.mean(fake_validity)
            
            d_loss.backward()
            self.optimizer_D.step()
            
            # 权重裁剪
            for p in self.discriminator.parameters():
                p.data.clamp_(-self.clip_value, self.clip_value)
            
            d_losses.append(d_loss.item())
        
        return np.mean(d_losses)
    
    def train_generator(self, batch_size):
        """训练生成器"""
        self.optimizer_G.zero_grad()
        
        # 生成假图像
        z = torch.randn(batch_size, self.latent_dim, 1, 1, device=self.device)
        fake_imgs = self.generator(z)
        
        # WGAN生成器损失
        fake_validity = self.discriminator(fake_imgs)
        g_loss = -torch.mean(fake_validity)
        
        g_loss.backward()
        self.optimizer_G.step()
        
        return g_loss.item(), fake_imgs

2. WGAN-GP实现

python
class WGAN_GP:
    def __init__(self, generator, discriminator, lr=0.0001, lambda_gp=10):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.generator = generator.to(self.device)
        self.discriminator = discriminator.to(self.device)
        
        # WGAN-GP使用Adam优化器
        self.optimizer_G = optim.Adam(self.generator.parameters(), lr=lr, betas=(0.0, 0.9))
        self.optimizer_D = optim.Adam(self.discriminator.parameters(), lr=lr, betas=(0.0, 0.9))
        
        self.lambda_gp = lambda_gp
        self.latent_dim = 100
    
    def gradient_penalty(self, real_imgs, fake_imgs):
        """计算梯度惩罚"""
        batch_size = real_imgs.size(0)
        
        # 随机插值
        alpha = torch.rand(batch_size, 1, 1, 1, device=self.device)
        interpolates = alpha * real_imgs + (1 - alpha) * fake_imgs
        interpolates.requires_grad_(True)
        
        # 计算判别器对插值的输出
        d_interpolates = self.discriminator(interpolates)
        
        # 计算梯度
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=torch.ones_like(d_interpolates),
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]
        
        # 计算梯度惩罚
        gradients = gradients.view(batch_size, -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        
        return gradient_penalty
    
    def train_discriminator(self, real_imgs, n_critic=5):
        """训练判别器"""
        d_losses = []
        
        for _ in range(n_critic):
            self.optimizer_D.zero_grad()
            
            batch_size = real_imgs.size(0)
            
            # 真实图像的损失
            real_validity = self.discriminator(real_imgs)
            
            # 生成假图像
            z = torch.randn(batch_size, self.latent_dim, 1, 1, device=self.device)
            fake_imgs = self.generator(z).detach()
            fake_validity = self.discriminator(fake_imgs)
            
            # 梯度惩罚
            gp = self.gradient_penalty(real_imgs, fake_imgs)
            
            # WGAN-GP损失
            d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + self.lambda_gp * gp
            
            d_loss.backward()
            self.optimizer_D.step()
            
            d_losses.append(d_loss.item())
        
        return np.mean(d_losses)

条件GAN (cGAN)

1. 条件生成器

python
class ConditionalGenerator(nn.Module):
    def __init__(self, latent_dim=100, num_classes=10, img_channels=1, img_size=28):
        super(ConditionalGenerator, self).__init__()
        
        self.label_emb = nn.Embedding(num_classes, num_classes)
        
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *block(latent_dim + num_classes, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, img_channels * img_size * img_size),
            nn.Tanh()
        )
        
        self.img_size = img_size
        self.img_channels = img_channels
    
    def forward(self, noise, labels):
        # 连接噪声和标签嵌入
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), self.img_channels, self.img_size, self.img_size)
        return img

class ConditionalDiscriminator(nn.Module):
    def __init__(self, num_classes=10, img_channels=1, img_size=28):
        super(ConditionalDiscriminator, self).__init__()
        
        self.label_embedding = nn.Embedding(num_classes, img_size * img_size)
        
        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            block.extend([nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)])
            return block
        
        self.model = nn.Sequential(
            *discriminator_block(img_channels + 1, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )
        
        # 计算卷积输出的尺寸
        ds_size = img_size // 2 ** 4
        self.adv_layer = nn.Sequential(
            nn.Linear(128 * ds_size ** 2, 1),
            nn.Sigmoid()
        )
        
        self.img_size = img_size
    
    def forward(self, img, labels):
        # 将标签嵌入为图像
        label_embedding = self.label_embedding(labels)
        label_embedding = label_embedding.view(labels.shape[0], 1, self.img_size, self.img_size)
        
        # 连接图像和标签
        d_in = torch.cat((img, label_embedding), 1)
        
        out = self.model(d_in)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        
        return validity

StyleGAN基础实现

1. 样式生成器

python
class StyleGenerator(nn.Module):
    def __init__(self, latent_dim=512, style_dim=512, img_channels=3):
        super(StyleGenerator, self).__init__()
        
        # 映射网络
        self.mapping = nn.Sequential(
            nn.Linear(latent_dim, style_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(style_dim, style_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(style_dim, style_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(style_dim, style_dim),
            nn.LeakyReLU(0.2),
        )
        
        # 合成网络
        self.synthesis = SynthesisNetwork(style_dim, img_channels)
    
    def forward(self, z, noise=None):
        # 映射到样式空间
        w = self.mapping(z)
        
        # 合成图像
        img = self.synthesis(w, noise)
        
        return img

class SynthesisNetwork(nn.Module):
    def __init__(self, style_dim=512, img_channels=3):
        super(SynthesisNetwork, self).__init__()
        
        # 常数输入
        self.const_input = nn.Parameter(torch.randn(1, 512, 4, 4))
        
        # 样式调制层
        self.style_blocks = nn.ModuleList([
            StyleBlock(512, 512, style_dim),
            StyleBlock(512, 512, style_dim),
            StyleBlock(512, 256, style_dim),
            StyleBlock(256, 128, style_dim),
            StyleBlock(128, 64, style_dim),
        ])
        
        # 输出层
        self.to_rgb = nn.Conv2d(64, img_channels, 1)
    
    def forward(self, w, noise=None):
        x = self.const_input.repeat(w.size(0), 1, 1, 1)
        
        for i, block in enumerate(self.style_blocks):
            x = block(x, w, noise)
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        
        img = torch.tanh(self.to_rgb(x))
        
        return img

class StyleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, style_dim):
        super(StyleBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        
        self.style1 = nn.Linear(style_dim, in_channels)
        self.style2 = nn.Linear(style_dim, out_channels)
        
        self.noise1 = nn.Parameter(torch.zeros(1))
        self.noise2 = nn.Parameter(torch.zeros(1))
    
    def forward(self, x, w, noise=None):
        # 第一个卷积 + 样式调制
        style1 = self.style1(w).unsqueeze(2).unsqueeze(3)
        x = self.conv1(x * style1)
        
        if noise is not None:
            x = x + self.noise1 * noise
        
        x = F.leaky_relu(x, 0.2)
        
        # 第二个卷积 + 样式调制
        style2 = self.style2(w).unsqueeze(2).unsqueeze(3)
        x = self.conv2(x * style2)
        
        if noise is not None:
            x = x + self.noise2 * noise
        
        x = F.leaky_relu(x, 0.2)
        
        return x

训练技巧和优化

1. 渐进式训练

python
class ProgressiveGAN:
    def __init__(self, generator, discriminator):
        self.generator = generator
        self.discriminator = discriminator
        self.current_resolution = 4
        self.max_resolution = 256
        self.fade_in_alpha = 0.0
        
    def grow_network(self):
        """增长网络分辨率"""
        if self.current_resolution < self.max_resolution:
            self.current_resolution *= 2
            self.fade_in_alpha = 0.0
            
            # 添加新的层到生成器和判别器
            self.generator.add_layer()
            self.discriminator.add_layer()
    
    def update_fade_in(self, step, fade_in_steps):
        """更新淡入参数"""
        self.fade_in_alpha = min(1.0, step / fade_in_steps)
    
    def train_step(self, real_imgs, step, fade_in_steps):
        """渐进式训练步骤"""
        self.update_fade_in(step, fade_in_steps)
        
        # 调整图像分辨率
        if real_imgs.size(-1) != self.current_resolution:
            real_imgs = F.interpolate(
                real_imgs, size=self.current_resolution, 
                mode='bilinear', align_corners=False
            )
        
        # 正常的GAN训练步骤
        # ...

2. 谱归一化

python
from torch.nn.utils import spectral_norm

class SpectralNormDiscriminator(nn.Module):
    def __init__(self, img_channels=3, feature_maps=64):
        super(SpectralNormDiscriminator, self).__init__()
        
        self.main = nn.Sequential(
            # 对所有卷积层应用谱归一化
            spectral_norm(nn.Conv2d(img_channels, feature_maps, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            
            spectral_norm(nn.Conv2d(feature_maps, feature_maps * 2, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            
            spectral_norm(nn.Conv2d(feature_maps * 2, feature_maps * 4, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            
            spectral_norm(nn.Conv2d(feature_maps * 4, feature_maps * 8, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            
            spectral_norm(nn.Conv2d(feature_maps * 8, 1, 4, 1, 0)),
        )
    
    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)

3. 自注意力机制

python
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.in_channels = in_channels
        
        self.query = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.key = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.value = nn.Conv2d(in_channels, in_channels, 1)
        
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, x):
        batch_size, channels, height, width = x.size()
        
        # 计算查询、键、值
        proj_query = self.query(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key(x).view(batch_size, -1, width * height)
        proj_value = self.value(x).view(batch_size, -1, width * height)
        
        # 计算注意力
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        
        # 应用注意力
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, channels, height, width)
        
        # 残差连接
        out = self.gamma * out + x
        
        return out

# 在生成器中使用自注意力
class SAGenerator(nn.Module):
    def __init__(self, latent_dim=100, img_channels=3):
        super(SAGenerator, self).__init__()
        
        self.main = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            SelfAttention(256),  # 添加自注意力
            
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(64, img_channels, 4, 2, 1),
            nn.Tanh()
        )
    
    def forward(self, input):
        return self.main(input)

评估指标

1. FID分数

python
import scipy.linalg
from torchvision.models import inception_v3

class FIDCalculator:
    def __init__(self, device):
        self.device = device
        self.inception = inception_v3(pretrained=True, transform_input=False).to(device)
        self.inception.eval()
    
    def get_activations(self, images):
        """获取Inception特征"""
        with torch.no_grad():
            # 调整图像尺寸到299x299
            if images.size(-1) != 299:
                images = F.interpolate(images, size=299, mode='bilinear', align_corners=False)
            
            # 获取特征
            features = self.inception(images)
            
        return features.cpu().numpy()
    
    def calculate_fid(self, real_images, fake_images):
        """计算FID分数"""
        # 获取真实图像和生成图像的特征
        real_features = self.get_activations(real_images)
        fake_features = self.get_activations(fake_images)
        
        # 计算均值和协方差
        mu_real = np.mean(real_features, axis=0)
        sigma_real = np.cov(real_features, rowvar=False)
        
        mu_fake = np.mean(fake_features, axis=0)
        sigma_fake = np.cov(fake_features, rowvar=False)
        
        # 计算FID
        diff = mu_real - mu_fake
        covmean, _ = scipy.linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)
        
        if np.iscomplexobj(covmean):
            covmean = covmean.real
        
        fid = diff.dot(diff) + np.trace(sigma_real + sigma_fake - 2 * covmean)
        
        return fid

# 使用示例
fid_calculator = FIDCalculator(device)
fid_score = fid_calculator.calculate_fid(real_images, fake_images)
print(f"FID分数: {fid_score:.2f}")

2. IS分数

python
def inception_score(images, batch_size=32, splits=10):
    """计算Inception Score"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    inception = inception_v3(pretrained=True, transform_input=False).to(device)
    inception.eval()
    
    def get_pred(x):
        if x.size(-1) != 299:
            x = F.interpolate(x, size=299, mode='bilinear', align_corners=False)
        x = inception(x)
        return F.softmax(x, dim=1).data.cpu().numpy()
    
    # 获取预测
    preds = np.zeros((len(images), 1000))
    
    for i in range(0, len(images), batch_size):
        batch = images[i:i+batch_size].to(device)
        batch_preds = get_pred(batch)
        preds[i:i+batch_size] = batch_preds
    
    # 计算IS
    scores = []
    for i in range(splits):
        part = preds[i * len(preds) // splits:(i + 1) * len(preds) // splits]
        kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
        kl = np.mean(np.sum(kl, 1))
        scores.append(np.exp(kl))
    
    return np.mean(scores), np.std(scores)

# 使用示例
is_mean, is_std = inception_score(fake_images)
print(f"IS分数: {is_mean:.2f} ± {is_std:.2f}")

总结

生成对抗网络是深度学习中的重要技术,本章介绍了:

  1. 基础GAN:生成器、判别器的基本结构和训练过程
  2. DCGAN:深度卷积GAN的实现
  3. WGAN:Wasserstein GAN及其改进版本
  4. 条件GAN:基于条件的生成模型
  5. StyleGAN:样式生成网络的基础实现
  6. 训练技巧:渐进式训练、谱归一化、自注意力等
  7. 评估指标:FID、IS等生成质量评估方法

掌握GAN技术将帮助你在图像生成、数据增强等领域开展创新工作!

本站内容仅供学习和研究使用。