生成对抗网络 (GAN)
生成对抗网络(Generative Adversarial Networks, GAN)是由Ian Goodfellow在2014年提出的一种深度学习模型。GAN通过两个神经网络的对抗训练来学习数据分布,能够生成高质量的合成数据。
GAN基础概念
什么是GAN?
GAN由两个网络组成:
- 生成器(Generator):学习生成与真实数据相似的假数据
- 判别器(Discriminator):学习区分真实数据和生成数据
这两个网络在训练过程中相互对抗,最终达到纳什均衡。
python
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output
import os
# 设置随机种子
tf.random.set_seed(42)
np.random.seed(42)
# 基本的GAN架构示例
def create_generator(latent_dim, output_shape):
"""
创建生成器网络
"""
model = keras.Sequential([
keras.layers.Dense(128, activation='relu', input_shape=(latent_dim,)),
keras.layers.BatchNormalization(),
keras.layers.Dense(256, activation='relu'),
keras.layers.BatchNormalization(),
keras.layers.Dense(512, activation='relu'),
keras.layers.BatchNormalization(),
keras.layers.Dense(np.prod(output_shape), activation='tanh'),
keras.layers.Reshape(output_shape)
])
return model
def create_discriminator(input_shape):
"""
创建判别器网络
"""
model = keras.Sequential([
keras.layers.Flatten(input_shape=input_shape),
keras.layers.Dense(512, activation='relu'),
keras.layers.Dropout(0.3),
keras.layers.Dense(256, activation='relu'),
keras.layers.Dropout(0.3),
keras.layers.Dense(1, activation='sigmoid')
])
return model
# 示例:创建简单的GAN
latent_dim = 100
img_shape = (28, 28, 1)
generator = create_generator(latent_dim, img_shape)
discriminator = create_discriminator(img_shape)
print("生成器结构:")
generator.summary()
print("\n判别器结构:")
discriminator.summary()DCGAN实现
python
def create_dcgan_generator(latent_dim):
"""
创建DCGAN生成器
"""
model = keras.Sequential([
# 输入层
keras.layers.Dense(7*7*256, use_bias=False, input_shape=(latent_dim,)),
keras.layers.BatchNormalization(),
keras.layers.LeakyReLU(),
# 重塑为7x7x256
keras.layers.Reshape((7, 7, 256)),
# 上采样到14x14x128
keras.layers.Conv2DTranspose(128, (5, 5), strides=(1, 1),
padding='same', use_bias=False),
keras.layers.BatchNormalization(),
keras.layers.LeakyReLU(),
# 上采样到14x14x64
keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2),
padding='same', use_bias=False),
keras.layers.BatchNormalization(),
keras.layers.LeakyReLU(),
# 上采样到28x28x1
keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2),
padding='same', use_bias=False,
activation='tanh')
])
return model
def create_dcgan_discriminator():
"""
创建DCGAN判别器
"""
model = keras.Sequential([
# 28x28x1 -> 14x14x64
keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
input_shape=[28, 28, 1]),
keras.layers.LeakyReLU(),
keras.layers.Dropout(0.3),
# 14x14x64 -> 7x7x128
keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
keras.layers.LeakyReLU(),
keras.layers.Dropout(0.3),
# 展平并输出
keras.layers.Flatten(),
keras.layers.Dense(1)
])
return model
# 创建DCGAN模型
dcgan_generator = create_dcgan_generator(100)
dcgan_discriminator = create_dcgan_discriminator()
print("DCGAN生成器:")
dcgan_generator.summary()
print("\nDCGAN判别器:")
dcgan_discriminator.summary()损失函数
python
# 二元交叉熵损失
cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real_output, fake_output):
"""
判别器损失函数
"""
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
def generator_loss(fake_output):
"""
生成器损失函数
"""
return cross_entropy(tf.ones_like(fake_output), fake_output)
# WGAN损失函数
def wasserstein_discriminator_loss(real_output, fake_output):
"""
Wasserstein判别器损失
"""
return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)
def wasserstein_generator_loss(fake_output):
"""
Wasserstein生成器损失
"""
return -tf.reduce_mean(fake_output)
# LSGAN损失函数
def lsgan_discriminator_loss(real_output, fake_output):
"""
LSGAN判别器损失
"""
real_loss = tf.reduce_mean(tf.square(real_output - 1))
fake_loss = tf.reduce_mean(tf.square(fake_output))
return 0.5 * (real_loss + fake_loss)
def lsgan_generator_loss(fake_output):
"""
LSGAN生成器损失
"""
return 0.5 * tf.reduce_mean(tf.square(fake_output - 1))训练循环
python
class GAN:
def __init__(self, generator, discriminator, latent_dim):
self.generator = generator
self.discriminator = discriminator
self.latent_dim = latent_dim
# 优化器
self.generator_optimizer = keras.optimizers.Adam(1e-4)
self.discriminator_optimizer = keras.optimizers.Adam(1e-4)
# 损失跟踪
self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")
@tf.function
def train_step(self, real_images):
batch_size = tf.shape(real_images)[0]
# 生成随机噪声
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
# 训练判别器
with tf.GradientTape() as disc_tape:
# 生成假图像
fake_images = self.generator(random_latent_vectors, training=True)
# 判别器预测
real_predictions = self.discriminator(real_images, training=True)
fake_predictions = self.discriminator(fake_images, training=True)
# 计算判别器损失
disc_loss = discriminator_loss(real_predictions, fake_predictions)
# 计算判别器梯度并更新
disc_gradients = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
self.discriminator_optimizer.apply_gradients(
zip(disc_gradients, self.discriminator.trainable_variables)
)
# 训练生成器
with tf.GradientTape() as gen_tape:
# 生成假图像
fake_images = self.generator(random_latent_vectors, training=True)
# 判别器预测
fake_predictions = self.discriminator(fake_images, training=True)
# 计算生成器损失
gen_loss = generator_loss(fake_predictions)
# 计算生成器梯度并更新
gen_gradients = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
self.generator_optimizer.apply_gradients(
zip(gen_gradients, self.generator.trainable_variables)
)
# 更新损失跟踪
self.gen_loss_tracker.update_state(gen_loss)
self.disc_loss_tracker.update_state(disc_loss)
return {
"generator_loss": self.gen_loss_tracker.result(),
"discriminator_loss": self.disc_loss_tracker.result(),
}
# 创建GAN实例
gan = GAN(dcgan_generator, dcgan_discriminator, latent_dim=100)数据准备和训练
python
def prepare_mnist_data():
"""
准备MNIST数据
"""
(x_train, _), (_, _) = keras.datasets.mnist.load_data()
# 归一化到[-1, 1]
x_train = x_train.astype('float32')
x_train = (x_train - 127.5) / 127.5
# 添加通道维度
x_train = np.expand_dims(x_train, axis=-1)
return x_train
def create_dataset(data, batch_size=256, buffer_size=60000):
"""
创建训练数据集
"""
dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.shuffle(buffer_size).batch(batch_size)
return dataset
# 准备数据
train_images = prepare_mnist_data()
train_dataset = create_dataset(train_images)
print(f"训练数据形状: {train_images.shape}")
def train_gan(gan, dataset, epochs=100):
"""
训练GAN模型
"""
for epoch in range(epochs):
print(f"Epoch {epoch + 1}/{epochs}")
# 训练一个epoch
for batch in dataset:
losses = gan.train_step(batch)
# 每10个epoch生成样本
if (epoch + 1) % 10 == 0:
generate_and_save_images(gan.generator, epoch + 1)
print(f"Generator Loss: {losses['generator_loss']:.4f}")
print(f"Discriminator Loss: {losses['discriminator_loss']:.4f}")
def generate_and_save_images(generator, epoch, num_examples=16):
"""
生成并保存图像
"""
# 生成随机噪声
noise = tf.random.normal([num_examples, 100])
# 生成图像
generated_images = generator(noise, training=False)
# 可视化
fig = plt.figure(figsize=(4, 4))
for i in range(generated_images.shape[0]):
plt.subplot(4, 4, i+1)
plt.imshow(generated_images[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.axis('off')
plt.suptitle(f'Epoch {epoch}')
plt.tight_layout()
plt.show()
# 开始训练(示例,实际训练需要更多epoch)
# train_gan(gan, train_dataset, epochs=50)条件GAN (cGAN)
python
def create_conditional_generator(latent_dim, num_classes, img_shape):
"""
创建条件生成器
"""
# 噪声输入
noise_input = keras.layers.Input(shape=(latent_dim,))
# 标签输入
label_input = keras.layers.Input(shape=(1,))
label_embedding = keras.layers.Embedding(num_classes, 50)(label_input)
label_embedding = keras.layers.Flatten()(label_embedding)
# 合并噪声和标签
merged_input = keras.layers.Concatenate()([noise_input, label_embedding])
# 生成器网络
x = keras.layers.Dense(7*7*256, use_bias=False)(merged_input)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.LeakyReLU()(x)
x = keras.layers.Reshape((7, 7, 256))(x)
x = keras.layers.Conv2DTranspose(128, (5, 5), strides=(1, 1),
padding='same', use_bias=False)(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.LeakyReLU()(x)
x = keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2),
padding='same', use_bias=False)(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.LeakyReLU()(x)
output = keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2),
padding='same', use_bias=False,
activation='tanh')(x)
model = keras.Model([noise_input, label_input], output)
return model
def create_conditional_discriminator(img_shape, num_classes):
"""
创建条件判别器
"""
# 图像输入
img_input = keras.layers.Input(shape=img_shape)
# 标签输入
label_input = keras.layers.Input(shape=(1,))
label_embedding = keras.layers.Embedding(num_classes, 50)(label_input)
label_embedding = keras.layers.Flatten()(label_embedding)
label_embedding = keras.layers.Dense(np.prod(img_shape))(label_embedding)
label_embedding = keras.layers.Reshape(img_shape)(label_embedding)
# 合并图像和标签
merged_input = keras.layers.Concatenate()([img_input, label_embedding])
# 判别器网络
x = keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same')(merged_input)
x = keras.layers.LeakyReLU()(x)
x = keras.layers.Dropout(0.3)(x)
x = keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')(x)
x = keras.layers.LeakyReLU()(x)
x = keras.layers.Dropout(0.3)(x)
x = keras.layers.Flatten()(x)
output = keras.layers.Dense(1)(x)
model = keras.Model([img_input, label_input], output)
return model
# 创建条件GAN
cond_generator = create_conditional_generator(100, 10, (28, 28, 1))
cond_discriminator = create_conditional_discriminator((28, 28, 1), 10)CycleGAN实现
python
def create_cyclegan_generator():
"""
创建CycleGAN生成器(ResNet架构)
"""
def residual_block(x, filters):
shortcut = x
x = keras.layers.Conv2D(filters, 3, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2D(filters, 3, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Add()([shortcut, x])
x = keras.layers.ReLU()(x)
return x
inputs = keras.layers.Input(shape=(256, 256, 3))
# 编码器
x = keras.layers.Conv2D(64, 7, padding='same')(inputs)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
# 下采样
x = keras.layers.Conv2D(128, 3, strides=2, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2D(256, 3, strides=2, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
# 残差块
for _ in range(9):
x = residual_block(x, 256)
# 上采样
x = keras.layers.Conv2DTranspose(128, 3, strides=2, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2DTranspose(64, 3, strides=2, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
outputs = keras.layers.Conv2D(3, 7, padding='same', activation='tanh')(x)
model = keras.Model(inputs, outputs)
return model
def create_cyclegan_discriminator():
"""
创建CycleGAN判别器(PatchGAN)
"""
inputs = keras.layers.Input(shape=(256, 256, 3))
x = keras.layers.Conv2D(64, 4, strides=2, padding='same')(inputs)
x = keras.layers.LeakyReLU(0.2)(x)
x = keras.layers.Conv2D(128, 4, strides=2, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.LeakyReLU(0.2)(x)
x = keras.layers.Conv2D(256, 4, strides=2, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.LeakyReLU(0.2)(x)
x = keras.layers.Conv2D(512, 4, strides=1, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.LeakyReLU(0.2)(x)
outputs = keras.layers.Conv2D(1, 4, strides=1, padding='same')(x)
model = keras.Model(inputs, outputs)
return model
# CycleGAN损失函数
def cycle_consistency_loss(real_image, cycled_image, lambda_cycle=10.0):
"""
循环一致性损失
"""
loss = tf.reduce_mean(tf.abs(real_image - cycled_image))
return lambda_cycle * loss
def identity_loss(real_image, same_image, lambda_identity=0.5):
"""
身份损失
"""
loss = tf.reduce_mean(tf.abs(real_image - same_image))
return lambda_identity * lossStyleGAN基础
python
def create_stylegan_mapping_network(latent_dim=512, num_layers=8):
"""
创建StyleGAN映射网络
"""
model = keras.Sequential()
for _ in range(num_layers):
model.add(keras.layers.Dense(latent_dim, activation='relu'))
model.build(input_shape=(None, latent_dim))
return model
def adaptive_instance_normalization(content_features, style_features):
"""
自适应实例归一化
"""
# 计算内容特征的均值和标准差
content_mean = tf.reduce_mean(content_features, axis=[1, 2], keepdims=True)
content_std = tf.math.reduce_std(content_features, axis=[1, 2], keepdims=True)
# 计算风格特征的均值和标准差
style_mean = tf.reduce_mean(style_features, axis=[1, 2], keepdims=True)
style_std = tf.math.reduce_std(style_features, axis=[1, 2], keepdims=True)
# 归一化内容特征
normalized_content = (content_features - content_mean) / (content_std + 1e-8)
# 应用风格统计
stylized_features = normalized_content * style_std + style_mean
return stylized_features
class StyleGANGenerator(keras.Model):
"""
简化的StyleGAN生成器
"""
def __init__(self, latent_dim=512, num_layers=6):
super(StyleGANGenerator, self).__init__()
self.latent_dim = latent_dim
self.num_layers = num_layers
# 映射网络
self.mapping_network = create_stylegan_mapping_network(latent_dim)
# 常数输入
self.constant_input = self.add_weight(
shape=(1, 4, 4, 512),
initializer='random_normal',
trainable=True,
name='constant_input'
)
# 生成器层
self.conv_layers = []
self.adain_layers = []
channels = [512, 512, 256, 128, 64, 32]
for i in range(num_layers):
self.conv_layers.append(
keras.layers.Conv2D(channels[i], 3, padding='same', activation='relu')
)
self.adain_layers.append(
keras.layers.Dense(channels[i] * 2) # 用于生成均值和标准差
)
def call(self, latent_codes):
batch_size = tf.shape(latent_codes)[0]
# 映射网络
w = self.mapping_network(latent_codes)
# 从常数开始
x = tf.tile(self.constant_input, [batch_size, 1, 1, 1])
# 逐层生成
for i in range(self.num_layers):
# 卷积
x = self.conv_layers[i](x)
# AdaIN
style_params = self.adain_layers[i](w)
style_mean, style_std = tf.split(style_params, 2, axis=-1)
# 应用AdaIN
x_mean = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
x_std = tf.math.reduce_std(x, axis=[1, 2], keepdims=True)
x = (x - x_mean) / (x_std + 1e-8)
x = x * tf.expand_dims(tf.expand_dims(style_std, 1), 1) + \
tf.expand_dims(tf.expand_dims(style_mean, 1), 1)
# 上采样(除了最后一层)
if i < self.num_layers - 1:
x = keras.layers.UpSampling2D()(x)
return x评估指标
python
def calculate_fid_score(real_images, generated_images, model_name='inception_v3'):
"""
计算FID分数(Fréchet Inception Distance)
"""
# 加载预训练的Inception模型
inception_model = keras.applications.InceptionV3(
include_top=False,
pooling='avg',
input_shape=(299, 299, 3)
)
def preprocess_images(images):
# 调整图像大小到299x299
images = tf.image.resize(images, [299, 299])
# 预处理
images = keras.applications.inception_v3.preprocess_input(images)
return images
# 预处理图像
real_images = preprocess_images(real_images)
generated_images = preprocess_images(generated_images)
# 提取特征
real_features = inception_model.predict(real_images)
generated_features = inception_model.predict(generated_images)
# 计算均值和协方差
mu_real = np.mean(real_features, axis=0)
sigma_real = np.cov(real_features, rowvar=False)
mu_gen = np.mean(generated_features, axis=0)
sigma_gen = np.cov(generated_features, rowvar=False)
# 计算FID
diff = mu_real - mu_gen
covmean = scipy.linalg.sqrtm(sigma_real.dot(sigma_gen))
if np.iscomplexobj(covmean):
covmean = covmean.real
fid = diff.dot(diff) + np.trace(sigma_real + sigma_gen - 2 * covmean)
return fid
def calculate_inception_score(generated_images, num_splits=10):
"""
计算Inception Score
"""
# 加载Inception模型
inception_model = keras.applications.InceptionV3(
include_top=True,
input_shape=(299, 299, 3)
)
# 预处理图像
images = tf.image.resize(generated_images, [299, 299])
images = keras.applications.inception_v3.preprocess_input(images)
# 获取预测
predictions = inception_model.predict(images)
# 计算IS
scores = []
for i in range(num_splits):
part = predictions[i * len(predictions) // num_splits:
(i + 1) * len(predictions) // num_splits]
# 计算KL散度
py = np.mean(part, axis=0)
kl_div = part * (np.log(part + 1e-8) - np.log(py + 1e-8))
kl_div = np.mean(np.sum(kl_div, axis=1))
scores.append(np.exp(kl_div))
return np.mean(scores), np.std(scores)训练技巧和最佳实践
python
# 1. 渐进式训练
class ProgressiveGAN:
def __init__(self):
self.current_resolution = 4
self.max_resolution = 256
def grow_network(self):
"""
逐步增加网络分辨率
"""
if self.current_resolution < self.max_resolution:
self.current_resolution *= 2
# 添加新的层到生成器和判别器
def fade_in_new_layers(self, alpha):
"""
淡入新层
"""
# 使用alpha混合旧输出和新输出
pass
# 2. 谱归一化
def spectral_normalization(layer):
"""
谱归一化装饰器
"""
return keras.utils.get_custom_objects()['SpectralNormalization'](layer)
# 3. 自注意力机制
class SelfAttention(keras.layers.Layer):
def __init__(self, channels):
super(SelfAttention, self).__init__()
self.channels = channels
self.query_conv = keras.layers.Conv2D(channels // 8, 1)
self.key_conv = keras.layers.Conv2D(channels // 8, 1)
self.value_conv = keras.layers.Conv2D(channels, 1)
self.gamma = self.add_weight(shape=(), initializer='zeros', trainable=True)
def call(self, x):
batch_size, height, width, channels = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
# 计算query, key, value
query = self.query_conv(x)
key = self.key_conv(x)
value = self.value_conv(x)
# 重塑为矩阵形式
query = tf.reshape(query, [batch_size, -1, channels // 8])
key = tf.reshape(key, [batch_size, -1, channels // 8])
value = tf.reshape(value, [batch_size, -1, channels])
# 计算注意力
attention = tf.nn.softmax(tf.matmul(query, key, transpose_b=True))
out = tf.matmul(attention, value)
out = tf.reshape(out, [batch_size, height, width, channels])
# 残差连接
out = self.gamma * out + x
return out
# 4. 梯度惩罚(WGAN-GP)
def gradient_penalty(discriminator, real_images, fake_images, batch_size):
"""
计算梯度惩罚
"""
alpha = tf.random.uniform([batch_size, 1, 1, 1], 0., 1.)
interpolated = alpha * real_images + (1 - alpha) * fake_images
with tf.GradientTape() as tape:
tape.watch(interpolated)
pred = discriminator(interpolated, training=True)
gradients = tape.gradient(pred, interpolated)
gradients_norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
gradient_penalty = tf.reduce_mean((gradients_norm - 1.) ** 2)
return gradient_penalty实际应用示例
python
def image_to_image_translation():
"""
图像到图像翻译示例
"""
# 创建Pix2Pix模型
def create_pix2pix_generator():
# U-Net架构
inputs = keras.layers.Input(shape=(256, 256, 3))
# 编码器
down_stack = [
keras.layers.Conv2D(64, 4, strides=2, padding='same', use_bias=False),
keras.layers.Conv2D(128, 4, strides=2, padding='same', use_bias=False),
keras.layers.Conv2D(256, 4, strides=2, padding='same', use_bias=False),
keras.layers.Conv2D(512, 4, strides=2, padding='same', use_bias=False),
]
# 解码器
up_stack = [
keras.layers.Conv2DTranspose(256, 4, strides=2, padding='same', use_bias=False),
keras.layers.Conv2DTranspose(128, 4, strides=2, padding='same', use_bias=False),
keras.layers.Conv2DTranspose(64, 4, strides=2, padding='same', use_bias=False),
]
x = inputs
# 下采样
skips = []
for down in down_stack:
x = down(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.LeakyReLU()(x)
skips.append(x)
skips = reversed(skips[:-1])
# 上采样
for up, skip in zip(up_stack, skips):
x = up(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Concatenate()([x, skip])
# 最后一层
last = keras.layers.Conv2DTranspose(3, 4, strides=2, padding='same', activation='tanh')
x = last(x)
return keras.Model(inputs=inputs, outputs=x)
return create_pix2pix_generator()
def super_resolution_gan():
"""
超分辨率GAN示例
"""
def create_srgan_generator():
def residual_block(x):
shortcut = x
x = keras.layers.Conv2D(64, 3, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.PReLU()(x)
x = keras.layers.Conv2D(64, 3, padding='same')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Add()([shortcut, x])
return x
inputs = keras.layers.Input(shape=(None, None, 3))
# 初始卷积
x = keras.layers.Conv2D(64, 9, padding='same')(inputs)
x = keras.layers.PReLU()(x)
# 残差块
for _ in range(16):
x = residual_block(x)
# 上采样
x = keras.layers.Conv2D(256, 3, padding='same')(x)
x = keras.layers.Lambda(lambda x: tf.nn.depth_to_space(x, 2))(x)
x = keras.layers.PReLU()(x)
x = keras.layers.Conv2D(256, 3, padding='same')(x)
x = keras.layers.Lambda(lambda x: tf.nn.depth_to_space(x, 2))(x)
x = keras.layers.PReLU()(x)
# 输出
outputs = keras.layers.Conv2D(3, 9, padding='same', activation='tanh')(x)
return keras.Model(inputs, outputs)
return create_srgan_generator()总结
GAN是深度学习中最具创新性的技术之一,它开创了生成模型的新时代。从基础的GAN到StyleGAN、BigGAN等先进变体,GAN在图像生成、风格迁移、数据增强等领域都有广泛应用。
关键要点:
- 对抗训练:生成器和判别器的博弈过程
- 损失函数设计:不同的损失函数适用于不同场景
- 训练稳定性:需要仔细调节超参数和训练策略
- 评估指标:FID、IS等指标帮助评估生成质量
下一章我们将学习实战项目,将所学的理论知识应用到具体的项目中。