自编码器是一种用于学习未标记数据高效编码的人工神经网络 (ANN)。它们已成为机器学习和深度学习领域的重要工具。本章提供了一个分步指南,介绍如何在Python编程语言中实现自编码器。我们将使用MNIST数据集作为示例。
pip install numpy matplotlib tensorflow
# Import necessary libraries import numpy as np import matplotlib.pyplot as plt from tensorflow.keras.datasets import mnist from tensorflow.keras.models import Model from tensorflow.keras.layers import Input, Dense, Flatten, Reshape from tensorflow.keras.optimizers import Adam
# Load the dataset (x_train, _), (x_test, _) = mnist.load_data() # Normalize the data x_train = x_train.astype('float32') / 255.0 x_test = x_test.astype('float32') / 255.0 # Reshape the data to include the channel dimension x_train = np.reshape(x_train, (len(x_train), 28, 28, 1)) x_test = np.reshape(x_test, (len(x_test), 28, 28, 1))
# Define the input shape for the autoencoder input_shape = (28, 28, 1) # Define the encoder part of the autoencoder input_img = Input(shape=input_shape) x = Flatten()(input_img) encoded = Dense(64, activation='relu')(x) # Define the decoder part of the autoencoder decoded = Dense(784, activation='sigmoid')(encoded) decoded = Reshape((28, 28, 1))(decoded) # Define the complete autoencoder model autoencoder = Model(input_img, decoded) autoencoder.compile(optimizer=Adam(), loss='binary_crossentropy') # Print the summary of the autoencoder model autoencoder.summary()
# Train the autoencoder autoencoder.fit(x_train, x_train, epochs = 50, # Number of epochs to train batch_size=256, # Batch size for training shuffle=True, validation_data = (x_test, x_test) )
# Predict the reconstructed images from the test set decoded_imgs = autoencoder.predict(x_test) # Number of digits to display n = 10 # Create a figure with a specified size plt.figure(figsize=(20, 4)) # Loop through the first n test images for i in range(n): # Display the original image ax = plt.subplot(2, n, i + 1) # Create a subplot for the original image # Reshape and plot the original image plt.imshow(x_test[i].reshape(28, 28), cmap='gray') plt.title("Original") # Set the title of the plot plt.axis('off') # Display the reconstructed image ax = plt.subplot(2, n, i + 1 + n) plt.imshow(decoded_imgs[i].reshape(28, 28), cmap='gray') plt.title("Reconstructed") plt.axis('off') # Show the figure plt.show()
Model: "functional_1"
层 (类型) | 输出形状 | 参数数量 |
input_layer_3 (InputLayer) | (None, 28, 28, 1) | 0 |
flatten_3 (Flatten) | (None, 784) | 0 |
dense_6 (Dense) | (None, 64) | 50, 240 |
dense_7 (Dense) | (None, 784) | 50, 960 |
reshape_3 (Reshape) | (None, 28, 28, 1) | 0 |
Total params: 101,200 (395.31 KB) Trainable params: 101,200 (395.31 KB) Non-trainable params: 0 (0.00 B)
