Skip to content

生成对抗网络 (GAN)

生成对抗网络(Generative Adversarial Networks, GAN)是由Ian Goodfellow在2014年提出的一种深度学习模型。GAN通过两个神经网络的对抗训练来学习数据分布,能够生成高质量的合成数据。

GAN基础概念

什么是GAN?

GAN由两个网络组成:

  • 生成器(Generator):学习生成与真实数据相似的假数据
  • 判别器(Discriminator):学习区分真实数据和生成数据

这两个网络在训练过程中相互对抗,最终达到纳什均衡。

python
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output
import os

# 设置随机种子
tf.random.set_seed(42)
np.random.seed(42)

# 基本的GAN架构示例
def create_generator(latent_dim, output_shape):
    """
    创建生成器网络
    """
    model = keras.Sequential([
        keras.layers.Dense(128, activation='relu', input_shape=(latent_dim,)),
        keras.layers.BatchNormalization(),
        keras.layers.Dense(256, activation='relu'),
        keras.layers.BatchNormalization(),
        keras.layers.Dense(512, activation='relu'),
        keras.layers.BatchNormalization(),
        keras.layers.Dense(np.prod(output_shape), activation='tanh'),
        keras.layers.Reshape(output_shape)
    ])
    return model

def create_discriminator(input_shape):
    """
    创建判别器网络
    """
    model = keras.Sequential([
        keras.layers.Flatten(input_shape=input_shape),
        keras.layers.Dense(512, activation='relu'),
        keras.layers.Dropout(0.3),
        keras.layers.Dense(256, activation='relu'),
        keras.layers.Dropout(0.3),
        keras.layers.Dense(1, activation='sigmoid')
    ])
    return model

# 示例:创建简单的GAN
latent_dim = 100
img_shape = (28, 28, 1)

generator = create_generator(latent_dim, img_shape)
discriminator = create_discriminator(img_shape)

print("生成器结构:")
generator.summary()
print("\n判别器结构:")
discriminator.summary()

DCGAN实现

python
def create_dcgan_generator(latent_dim):
    """
    创建DCGAN生成器
    """
    model = keras.Sequential([
        # 输入层
        keras.layers.Dense(7*7*256, use_bias=False, input_shape=(latent_dim,)),
        keras.layers.BatchNormalization(),
        keras.layers.LeakyReLU(),
        
        # 重塑为7x7x256
        keras.layers.Reshape((7, 7, 256)),
        
        # 上采样到14x14x128
        keras.layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), 
                                   padding='same', use_bias=False),
        keras.layers.BatchNormalization(),
        keras.layers.LeakyReLU(),
        
        # 上采样到14x14x64
        keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), 
                                   padding='same', use_bias=False),
        keras.layers.BatchNormalization(),
        keras.layers.LeakyReLU(),
        
        # 上采样到28x28x1
        keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), 
                                   padding='same', use_bias=False, 
                                   activation='tanh')
    ])
    
    return model

def create_dcgan_discriminator():
    """
    创建DCGAN判别器
    """
    model = keras.Sequential([
        # 28x28x1 -> 14x14x64
        keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                           input_shape=[28, 28, 1]),
        keras.layers.LeakyReLU(),
        keras.layers.Dropout(0.3),
        
        # 14x14x64 -> 7x7x128
        keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
        keras.layers.LeakyReLU(),
        keras.layers.Dropout(0.3),
        
        # 展平并输出
        keras.layers.Flatten(),
        keras.layers.Dense(1)
    ])
    
    return model

# 创建DCGAN模型
dcgan_generator = create_dcgan_generator(100)
dcgan_discriminator = create_dcgan_discriminator()

print("DCGAN生成器:")
dcgan_generator.summary()
print("\nDCGAN判别器:")
dcgan_discriminator.summary()

损失函数

python
# 二元交叉熵损失
cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    """
    判别器损失函数
    """
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    """
    生成器损失函数
    """
    return cross_entropy(tf.ones_like(fake_output), fake_output)

# WGAN损失函数
def wasserstein_discriminator_loss(real_output, fake_output):
    """
    Wasserstein判别器损失
    """
    return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)

def wasserstein_generator_loss(fake_output):
    """
    Wasserstein生成器损失
    """
    return -tf.reduce_mean(fake_output)

# LSGAN损失函数
def lsgan_discriminator_loss(real_output, fake_output):
    """
    LSGAN判别器损失
    """
    real_loss = tf.reduce_mean(tf.square(real_output - 1))
    fake_loss = tf.reduce_mean(tf.square(fake_output))
    return 0.5 * (real_loss + fake_loss)

def lsgan_generator_loss(fake_output):
    """
    LSGAN生成器损失
    """
    return 0.5 * tf.reduce_mean(tf.square(fake_output - 1))

训练循环

python
class GAN:
    def __init__(self, generator, discriminator, latent_dim):
        self.generator = generator
        self.discriminator = discriminator
        self.latent_dim = latent_dim
        
        # 优化器
        self.generator_optimizer = keras.optimizers.Adam(1e-4)
        self.discriminator_optimizer = keras.optimizers.Adam(1e-4)
        
        # 损失跟踪
        self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
        self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")
    
    @tf.function
    def train_step(self, real_images):
        batch_size = tf.shape(real_images)[0]
        
        # 生成随机噪声
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        
        # 训练判别器
        with tf.GradientTape() as disc_tape:
            # 生成假图像
            fake_images = self.generator(random_latent_vectors, training=True)
            
            # 判别器预测
            real_predictions = self.discriminator(real_images, training=True)
            fake_predictions = self.discriminator(fake_images, training=True)
            
            # 计算判别器损失
            disc_loss = discriminator_loss(real_predictions, fake_predictions)
        
        # 计算判别器梯度并更新
        disc_gradients = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
        self.discriminator_optimizer.apply_gradients(
            zip(disc_gradients, self.discriminator.trainable_variables)
        )
        
        # 训练生成器
        with tf.GradientTape() as gen_tape:
            # 生成假图像
            fake_images = self.generator(random_latent_vectors, training=True)
            
            # 判别器预测
            fake_predictions = self.discriminator(fake_images, training=True)
            
            # 计算生成器损失
            gen_loss = generator_loss(fake_predictions)
        
        # 计算生成器梯度并更新
        gen_gradients = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        self.generator_optimizer.apply_gradients(
            zip(gen_gradients, self.generator.trainable_variables)
        )
        
        # 更新损失跟踪
        self.gen_loss_tracker.update_state(gen_loss)
        self.disc_loss_tracker.update_state(disc_loss)
        
        return {
            "generator_loss": self.gen_loss_tracker.result(),
            "discriminator_loss": self.disc_loss_tracker.result(),
        }

# 创建GAN实例
gan = GAN(dcgan_generator, dcgan_discriminator, latent_dim=100)

数据准备和训练

python
def prepare_mnist_data():
    """
    准备MNIST数据
    """
    (x_train, _), (_, _) = keras.datasets.mnist.load_data()
    
    # 归一化到[-1, 1]
    x_train = x_train.astype('float32')
    x_train = (x_train - 127.5) / 127.5
    
    # 添加通道维度
    x_train = np.expand_dims(x_train, axis=-1)
    
    return x_train

def create_dataset(data, batch_size=256, buffer_size=60000):
    """
    创建训练数据集
    """
    dataset = tf.data.Dataset.from_tensor_slices(data)
    dataset = dataset.shuffle(buffer_size).batch(batch_size)
    return dataset

# 准备数据
train_images = prepare_mnist_data()
train_dataset = create_dataset(train_images)

print(f"训练数据形状: {train_images.shape}")

def train_gan(gan, dataset, epochs=100):
    """
    训练GAN模型
    """
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        
        # 训练一个epoch
        for batch in dataset:
            losses = gan.train_step(batch)
        
        # 每10个epoch生成样本
        if (epoch + 1) % 10 == 0:
            generate_and_save_images(gan.generator, epoch + 1)
            print(f"Generator Loss: {losses['generator_loss']:.4f}")
            print(f"Discriminator Loss: {losses['discriminator_loss']:.4f}")

def generate_and_save_images(generator, epoch, num_examples=16):
    """
    生成并保存图像
    """
    # 生成随机噪声
    noise = tf.random.normal([num_examples, 100])
    
    # 生成图像
    generated_images = generator(noise, training=False)
    
    # 可视化
    fig = plt.figure(figsize=(4, 4))
    
    for i in range(generated_images.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(generated_images[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')
    
    plt.suptitle(f'Epoch {epoch}')
    plt.tight_layout()
    plt.show()

# 开始训练(示例,实际训练需要更多epoch)
# train_gan(gan, train_dataset, epochs=50)

条件GAN (cGAN)

python
def create_conditional_generator(latent_dim, num_classes, img_shape):
    """
    创建条件生成器
    """
    # 噪声输入
    noise_input = keras.layers.Input(shape=(latent_dim,))
    
    # 标签输入
    label_input = keras.layers.Input(shape=(1,))
    label_embedding = keras.layers.Embedding(num_classes, 50)(label_input)
    label_embedding = keras.layers.Flatten()(label_embedding)
    
    # 合并噪声和标签
    merged_input = keras.layers.Concatenate()([noise_input, label_embedding])
    
    # 生成器网络
    x = keras.layers.Dense(7*7*256, use_bias=False)(merged_input)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU()(x)
    x = keras.layers.Reshape((7, 7, 256))(x)
    
    x = keras.layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), 
                                   padding='same', use_bias=False)(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU()(x)
    
    x = keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), 
                                   padding='same', use_bias=False)(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU()(x)
    
    output = keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), 
                                        padding='same', use_bias=False, 
                                        activation='tanh')(x)
    
    model = keras.Model([noise_input, label_input], output)
    return model

def create_conditional_discriminator(img_shape, num_classes):
    """
    创建条件判别器
    """
    # 图像输入
    img_input = keras.layers.Input(shape=img_shape)
    
    # 标签输入
    label_input = keras.layers.Input(shape=(1,))
    label_embedding = keras.layers.Embedding(num_classes, 50)(label_input)
    label_embedding = keras.layers.Flatten()(label_embedding)
    label_embedding = keras.layers.Dense(np.prod(img_shape))(label_embedding)
    label_embedding = keras.layers.Reshape(img_shape)(label_embedding)
    
    # 合并图像和标签
    merged_input = keras.layers.Concatenate()([img_input, label_embedding])
    
    # 判别器网络
    x = keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same')(merged_input)
    x = keras.layers.LeakyReLU()(x)
    x = keras.layers.Dropout(0.3)(x)
    
    x = keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')(x)
    x = keras.layers.LeakyReLU()(x)
    x = keras.layers.Dropout(0.3)(x)
    
    x = keras.layers.Flatten()(x)
    output = keras.layers.Dense(1)(x)
    
    model = keras.Model([img_input, label_input], output)
    return model

# 创建条件GAN
cond_generator = create_conditional_generator(100, 10, (28, 28, 1))
cond_discriminator = create_conditional_discriminator((28, 28, 1), 10)

CycleGAN实现

python
def create_cyclegan_generator():
    """
    创建CycleGAN生成器(ResNet架构)
    """
    def residual_block(x, filters):
        shortcut = x
        
        x = keras.layers.Conv2D(filters, 3, padding='same')(x)
        x = keras.layers.BatchNormalization()(x)
        x = keras.layers.ReLU()(x)
        
        x = keras.layers.Conv2D(filters, 3, padding='same')(x)
        x = keras.layers.BatchNormalization()(x)
        
        x = keras.layers.Add()([shortcut, x])
        x = keras.layers.ReLU()(x)
        
        return x
    
    inputs = keras.layers.Input(shape=(256, 256, 3))
    
    # 编码器
    x = keras.layers.Conv2D(64, 7, padding='same')(inputs)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ReLU()(x)
    
    # 下采样
    x = keras.layers.Conv2D(128, 3, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ReLU()(x)
    
    x = keras.layers.Conv2D(256, 3, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ReLU()(x)
    
    # 残差块
    for _ in range(9):
        x = residual_block(x, 256)
    
    # 上采样
    x = keras.layers.Conv2DTranspose(128, 3, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ReLU()(x)
    
    x = keras.layers.Conv2DTranspose(64, 3, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ReLU()(x)
    
    outputs = keras.layers.Conv2D(3, 7, padding='same', activation='tanh')(x)
    
    model = keras.Model(inputs, outputs)
    return model

def create_cyclegan_discriminator():
    """
    创建CycleGAN判别器(PatchGAN)
    """
    inputs = keras.layers.Input(shape=(256, 256, 3))
    
    x = keras.layers.Conv2D(64, 4, strides=2, padding='same')(inputs)
    x = keras.layers.LeakyReLU(0.2)(x)
    
    x = keras.layers.Conv2D(128, 4, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU(0.2)(x)
    
    x = keras.layers.Conv2D(256, 4, strides=2, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU(0.2)(x)
    
    x = keras.layers.Conv2D(512, 4, strides=1, padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU(0.2)(x)
    
    outputs = keras.layers.Conv2D(1, 4, strides=1, padding='same')(x)
    
    model = keras.Model(inputs, outputs)
    return model

# CycleGAN损失函数
def cycle_consistency_loss(real_image, cycled_image, lambda_cycle=10.0):
    """
    循环一致性损失
    """
    loss = tf.reduce_mean(tf.abs(real_image - cycled_image))
    return lambda_cycle * loss

def identity_loss(real_image, same_image, lambda_identity=0.5):
    """
    身份损失
    """
    loss = tf.reduce_mean(tf.abs(real_image - same_image))
    return lambda_identity * loss

StyleGAN基础

python
def create_stylegan_mapping_network(latent_dim=512, num_layers=8):
    """
    创建StyleGAN映射网络
    """
    model = keras.Sequential()
    
    for _ in range(num_layers):
        model.add(keras.layers.Dense(latent_dim, activation='relu'))
    
    model.build(input_shape=(None, latent_dim))
    return model

def adaptive_instance_normalization(content_features, style_features):
    """
    自适应实例归一化
    """
    # 计算内容特征的均值和标准差
    content_mean = tf.reduce_mean(content_features, axis=[1, 2], keepdims=True)
    content_std = tf.math.reduce_std(content_features, axis=[1, 2], keepdims=True)
    
    # 计算风格特征的均值和标准差
    style_mean = tf.reduce_mean(style_features, axis=[1, 2], keepdims=True)
    style_std = tf.math.reduce_std(style_features, axis=[1, 2], keepdims=True)
    
    # 归一化内容特征
    normalized_content = (content_features - content_mean) / (content_std + 1e-8)
    
    # 应用风格统计
    stylized_features = normalized_content * style_std + style_mean
    
    return stylized_features

class StyleGANGenerator(keras.Model):
    """
    简化的StyleGAN生成器
    """
    def __init__(self, latent_dim=512, num_layers=6):
        super(StyleGANGenerator, self).__init__()
        self.latent_dim = latent_dim
        self.num_layers = num_layers
        
        # 映射网络
        self.mapping_network = create_stylegan_mapping_network(latent_dim)
        
        # 常数输入
        self.constant_input = self.add_weight(
            shape=(1, 4, 4, 512),
            initializer='random_normal',
            trainable=True,
            name='constant_input'
        )
        
        # 生成器层
        self.conv_layers = []
        self.adain_layers = []
        
        channels = [512, 512, 256, 128, 64, 32]
        for i in range(num_layers):
            self.conv_layers.append(
                keras.layers.Conv2D(channels[i], 3, padding='same', activation='relu')
            )
            self.adain_layers.append(
                keras.layers.Dense(channels[i] * 2)  # 用于生成均值和标准差
            )
    
    def call(self, latent_codes):
        batch_size = tf.shape(latent_codes)[0]
        
        # 映射网络
        w = self.mapping_network(latent_codes)
        
        # 从常数开始
        x = tf.tile(self.constant_input, [batch_size, 1, 1, 1])
        
        # 逐层生成
        for i in range(self.num_layers):
            # 卷积
            x = self.conv_layers[i](x)
            
            # AdaIN
            style_params = self.adain_layers[i](w)
            style_mean, style_std = tf.split(style_params, 2, axis=-1)
            
            # 应用AdaIN
            x_mean = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
            x_std = tf.math.reduce_std(x, axis=[1, 2], keepdims=True)
            x = (x - x_mean) / (x_std + 1e-8)
            x = x * tf.expand_dims(tf.expand_dims(style_std, 1), 1) + \
                tf.expand_dims(tf.expand_dims(style_mean, 1), 1)
            
            # 上采样(除了最后一层)
            if i < self.num_layers - 1:
                x = keras.layers.UpSampling2D()(x)
        
        return x

评估指标

python
def calculate_fid_score(real_images, generated_images, model_name='inception_v3'):
    """
    计算FID分数(Fréchet Inception Distance)
    """
    # 加载预训练的Inception模型
    inception_model = keras.applications.InceptionV3(
        include_top=False, 
        pooling='avg',
        input_shape=(299, 299, 3)
    )
    
    def preprocess_images(images):
        # 调整图像大小到299x299
        images = tf.image.resize(images, [299, 299])
        # 预处理
        images = keras.applications.inception_v3.preprocess_input(images)
        return images
    
    # 预处理图像
    real_images = preprocess_images(real_images)
    generated_images = preprocess_images(generated_images)
    
    # 提取特征
    real_features = inception_model.predict(real_images)
    generated_features = inception_model.predict(generated_images)
    
    # 计算均值和协方差
    mu_real = np.mean(real_features, axis=0)
    sigma_real = np.cov(real_features, rowvar=False)
    
    mu_gen = np.mean(generated_features, axis=0)
    sigma_gen = np.cov(generated_features, rowvar=False)
    
    # 计算FID
    diff = mu_real - mu_gen
    covmean = scipy.linalg.sqrtm(sigma_real.dot(sigma_gen))
    
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = diff.dot(diff) + np.trace(sigma_real + sigma_gen - 2 * covmean)
    return fid

def calculate_inception_score(generated_images, num_splits=10):
    """
    计算Inception Score
    """
    # 加载Inception模型
    inception_model = keras.applications.InceptionV3(
        include_top=True,
        input_shape=(299, 299, 3)
    )
    
    # 预处理图像
    images = tf.image.resize(generated_images, [299, 299])
    images = keras.applications.inception_v3.preprocess_input(images)
    
    # 获取预测
    predictions = inception_model.predict(images)
    
    # 计算IS
    scores = []
    for i in range(num_splits):
        part = predictions[i * len(predictions) // num_splits:
                         (i + 1) * len(predictions) // num_splits]
        
        # 计算KL散度
        py = np.mean(part, axis=0)
        kl_div = part * (np.log(part + 1e-8) - np.log(py + 1e-8))
        kl_div = np.mean(np.sum(kl_div, axis=1))
        scores.append(np.exp(kl_div))
    
    return np.mean(scores), np.std(scores)

训练技巧和最佳实践

python
# 1. 渐进式训练
class ProgressiveGAN:
    def __init__(self):
        self.current_resolution = 4
        self.max_resolution = 256
        
    def grow_network(self):
        """
        逐步增加网络分辨率
        """
        if self.current_resolution < self.max_resolution:
            self.current_resolution *= 2
            # 添加新的层到生成器和判别器
            
    def fade_in_new_layers(self, alpha):
        """
        淡入新层
        """
        # 使用alpha混合旧输出和新输出
        pass

# 2. 谱归一化
def spectral_normalization(layer):
    """
    谱归一化装饰器
    """
    return keras.utils.get_custom_objects()['SpectralNormalization'](layer)

# 3. 自注意力机制
class SelfAttention(keras.layers.Layer):
    def __init__(self, channels):
        super(SelfAttention, self).__init__()
        self.channels = channels
        self.query_conv = keras.layers.Conv2D(channels // 8, 1)
        self.key_conv = keras.layers.Conv2D(channels // 8, 1)
        self.value_conv = keras.layers.Conv2D(channels, 1)
        self.gamma = self.add_weight(shape=(), initializer='zeros', trainable=True)
        
    def call(self, x):
        batch_size, height, width, channels = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
        
        # 计算query, key, value
        query = self.query_conv(x)
        key = self.key_conv(x)
        value = self.value_conv(x)
        
        # 重塑为矩阵形式
        query = tf.reshape(query, [batch_size, -1, channels // 8])
        key = tf.reshape(key, [batch_size, -1, channels // 8])
        value = tf.reshape(value, [batch_size, -1, channels])
        
        # 计算注意力
        attention = tf.nn.softmax(tf.matmul(query, key, transpose_b=True))
        out = tf.matmul(attention, value)
        out = tf.reshape(out, [batch_size, height, width, channels])
        
        # 残差连接
        out = self.gamma * out + x
        return out

# 4. 梯度惩罚(WGAN-GP)
def gradient_penalty(discriminator, real_images, fake_images, batch_size):
    """
    计算梯度惩罚
    """
    alpha = tf.random.uniform([batch_size, 1, 1, 1], 0., 1.)
    interpolated = alpha * real_images + (1 - alpha) * fake_images
    
    with tf.GradientTape() as tape:
        tape.watch(interpolated)
        pred = discriminator(interpolated, training=True)
    
    gradients = tape.gradient(pred, interpolated)
    gradients_norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
    gradient_penalty = tf.reduce_mean((gradients_norm - 1.) ** 2)
    
    return gradient_penalty

实际应用示例

python
def image_to_image_translation():
    """
    图像到图像翻译示例
    """
    # 创建Pix2Pix模型
    def create_pix2pix_generator():
        # U-Net架构
        inputs = keras.layers.Input(shape=(256, 256, 3))
        
        # 编码器
        down_stack = [
            keras.layers.Conv2D(64, 4, strides=2, padding='same', use_bias=False),
            keras.layers.Conv2D(128, 4, strides=2, padding='same', use_bias=False),
            keras.layers.Conv2D(256, 4, strides=2, padding='same', use_bias=False),
            keras.layers.Conv2D(512, 4, strides=2, padding='same', use_bias=False),
        ]
        
        # 解码器
        up_stack = [
            keras.layers.Conv2DTranspose(256, 4, strides=2, padding='same', use_bias=False),
            keras.layers.Conv2DTranspose(128, 4, strides=2, padding='same', use_bias=False),
            keras.layers.Conv2DTranspose(64, 4, strides=2, padding='same', use_bias=False),
        ]
        
        x = inputs
        
        # 下采样
        skips = []
        for down in down_stack:
            x = down(x)
            x = keras.layers.BatchNormalization()(x)
            x = keras.layers.LeakyReLU()(x)
            skips.append(x)
        
        skips = reversed(skips[:-1])
        
        # 上采样
        for up, skip in zip(up_stack, skips):
            x = up(x)
            x = keras.layers.BatchNormalization()(x)
            x = keras.layers.ReLU()(x)
            x = keras.layers.Concatenate()([x, skip])
        
        # 最后一层
        last = keras.layers.Conv2DTranspose(3, 4, strides=2, padding='same', activation='tanh')
        x = last(x)
        
        return keras.Model(inputs=inputs, outputs=x)
    
    return create_pix2pix_generator()

def super_resolution_gan():
    """
    超分辨率GAN示例
    """
    def create_srgan_generator():
        def residual_block(x):
            shortcut = x
            x = keras.layers.Conv2D(64, 3, padding='same')(x)
            x = keras.layers.BatchNormalization()(x)
            x = keras.layers.PReLU()(x)
            x = keras.layers.Conv2D(64, 3, padding='same')(x)
            x = keras.layers.BatchNormalization()(x)
            x = keras.layers.Add()([shortcut, x])
            return x
        
        inputs = keras.layers.Input(shape=(None, None, 3))
        
        # 初始卷积
        x = keras.layers.Conv2D(64, 9, padding='same')(inputs)
        x = keras.layers.PReLU()(x)
        
        # 残差块
        for _ in range(16):
            x = residual_block(x)
        
        # 上采样
        x = keras.layers.Conv2D(256, 3, padding='same')(x)
        x = keras.layers.Lambda(lambda x: tf.nn.depth_to_space(x, 2))(x)
        x = keras.layers.PReLU()(x)
        
        x = keras.layers.Conv2D(256, 3, padding='same')(x)
        x = keras.layers.Lambda(lambda x: tf.nn.depth_to_space(x, 2))(x)
        x = keras.layers.PReLU()(x)
        
        # 输出
        outputs = keras.layers.Conv2D(3, 9, padding='same', activation='tanh')(x)
        
        return keras.Model(inputs, outputs)
    
    return create_srgan_generator()

总结

GAN是深度学习中最具创新性的技术之一,它开创了生成模型的新时代。从基础的GAN到StyleGAN、BigGAN等先进变体,GAN在图像生成、风格迁移、数据增强等领域都有广泛应用。

关键要点:

  1. 对抗训练:生成器和判别器的博弈过程
  2. 损失函数设计:不同的损失函数适用于不同场景
  3. 训练稳定性:需要仔细调节超参数和训练策略
  4. 评估指标:FID、IS等指标帮助评估生成质量

下一章我们将学习实战项目,将所学的理论知识应用到具体的项目中。

本站内容仅供学习和研究使用。