训练生成对抗网络 (GAN)



我们探讨了生成对抗网络的架构及其工作原理。在本章中,我们将通过一个实际示例演示如何实现和训练一个 GAN 来生成手写数字,与 MNIST 数据集中的数字相同。我们将使用 Python 以及 TensorFlow 和 Keras 来实现此示例。

训练生成对抗网络的过程

GAN 的训练涉及迭代地优化生成器模型和判别器模型。让我们使用以下步骤了解生成对抗网络 (GAN) 的训练过程

初始化

  • 该过程从两个神经网络开始:生成器网络 (G) 和判别器网络 (D)。
  • 生成器接收一个随机种子或噪声向量作为输入,并生成生成的样本。
  • 判别器接收真实数据样本或生成样本作为输入,并将它们分类为真实或虚假。

生成虚假数据

  • 随机噪声向量被馈送到生成器网络。
  • 生成器处理此噪声并输出生成的样本,这些样本旨在类似于真实数据。

生成器训练

  • 首先,它从输入随机噪声生成虚假数据。
  • 然后,它使用判别器的输出计算生成器的损失。
  • 最后,它更新生成器的权重以最小化损失。

判别器训练

  • 首先,它获取一批真实数据和一批虚假数据。
  • 然后,它计算真实数据和虚假数据的判别器损失。
  • 最后,它更新判别器的权重以最小化损失。

迭代训练

  • 重复步骤 2 到 4。在每次迭代中,生成器和判别器都会交替训练,并试图提高彼此的性能。
  • 这种交替优化将持续进行,直到生成器生成的数据与真实数据相同,并且判别器无法再可靠地区分真实数据和虚假数据。

训练和构建 GAN

在这里,我们将展示使用 Python 和 MNIST 数据集训练和构建 GAN 的分步过程 -

步骤 1:设置环境

在开始之前,我们需要使用必要的库设置 Python 环境。确保您的计算机上安装了 TensorFlow 和 Keras。您可以使用 pip 如下安装它们 -

pip install tensorflow

步骤 2:导入必要的库

我们需要导入必要的库 -

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt

步骤 3:加载和预处理 MNIST 数据集

MNIST 数据集包含 60,000 个训练图像和 10,000 个测试图像的手写数字,每个图像的大小为 28x28 像素。我们将像素值归一化到 [-1, 1] 范围内,以提高训练效率 -

# Load the dataset
(x_train, _), (_, _) = mnist.load_data()

# Normalize the images to [-1, 1]
x_train = (x_train - 127.5) / 127.5
x_train = np.expand_dims(x_train, axis=-1)

# Set batch size and buffer size
BUFFER_SIZE = 60000
BATCH_SIZE = 256

步骤 4:创建生成器和判别器模型

生成器从随机噪声创建虚假图像,而判别器试图区分真实图像和虚假图像。

生成器模型的实现

生成器模型接收一个随机噪声向量作为输入,并将其通过一系列层进行转换以生成虚假图像 -

def build_generator():
   model = models.Sequential()
   model.add(layers.Dense(256, use_bias=False, input_shape=(100,)))
   model.add(layers.BatchNormalization())
   model.add(layers.LeakyReLU())
    
   model.add(layers.Dense(512, use_bias=False))
   model.add(layers.BatchNormalization())
   model.add(layers.LeakyReLU())
    
   model.add(layers.Dense(28 * 28 * 1, use_bias=False, activation='tanh'))
   model.add(layers.Reshape((28, 28, 1)))
    
   return model

generator = build_generator()

判别器模型的实现

判别器模型接收图像作为输入(真实或生成),并输出一个概率值,指示该图像是否为真实图像 -

def build_discriminator():
   model = models.Sequential()
   model.add(layers.Flatten(input_shape=(28, 28, 1)))
   model.add(layers.Dense(512))
   model.add(layers.LeakyReLU())
   model.add(layers.Dropout(0.3))

   model.add(layers.Dense(256))
   model.add(layers.LeakyReLU())
   model.add(layers.Dropout(0.3))

   model.add(layers.Dense(1, activation='sigmoid'))

   return model

discriminator = build_discriminator()

步骤 5:定义损失函数和优化器

在此步骤中,我们将对生成器和判别器都使用二元交叉熵损失。生成器的目标是最大化判别器出错的概率,而判别器的目标是最小化其分类错误。

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def generator_loss(fake_output):
   return cross_entropy(tf.ones_like(fake_output), fake_output)

def discriminator_loss(real_output, fake_output):
   real_loss = cross_entropy(tf.ones_like(real_output), real_output)
   fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
   total_loss = real_loss + fake_loss
   return total_loss

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

步骤 6:定义训练循环

GAN 的训练过程涉及迭代地训练生成器和判别器。在这里,我们将定义一个训练步骤,其中包括生成虚假图像、计算损失以及使用反向传播更新模型权重。

@tf.function
def train_step(images):
   noise = tf.random.normal([BATCH_SIZE, 100])
    
   with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)
        
      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)
        
      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)
    
   gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
   gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
   generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
   discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def train(dataset, epochs):
   for epoch in range(epochs):
      for image_batch in dataset:
         train_step(image_batch)
      print(f'Epoch {epoch+1} completed')

步骤 7:准备数据集并训练 GAN

接下来,我们将通过对 MNIST 图像进行混洗和批处理来准备数据集,然后我们将开始训练过程。

# Prepare the dataset for training
train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

# Train the GAN
EPOCHS = 50
train(train_dataset, EPOCHS)

步骤 8:生成和显示图像

现在,在训练 GAN 之后,我们可以生成和显示生成器创建的新图像。它涉及创建随机噪声,将其馈送到生成器,并显示生成的图像。

def generate_and_save_images(model, epoch, test_input):
   predictions = model(test_input, training=False)
    
   fig = plt.figure(figsize=(7.50, 3.50))
    
   for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i + 1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')
    
   plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
   plt.show()

seed = tf.random.normal([16, 100])
generate_and_save_images(generator, EPOCHS, seed)

实现后,当您运行此代码时,您将获得以下输出 -

Training and Building a GAN

结论

使用 Python 训练 GAN 涉及几个关键步骤,例如设置环境、创建生成器和判别器模型、定义损失函数和优化器以及实现训练循环。通过遵循这些步骤,您可以训练自己的 GAN 并探索生成对抗网络的迷人世界。

在本章中,我们提供了使用 Python 编程语言构建和训练 GAN 的详细指南。我们在示例中使用了 TensorFlow 和 Keras 库以及 MNIST 数据集。

广告