keras.fit() 和 keras.fit_generator()


简介

Keras 中的 fit() 和 fit_generator() 方法使得在 Python 中训练深度神经网络变得非常容易。fit() 方法可以有效地处理和训练批次数据,对于可以加载到内存中的较小数据集特别有用。另一方面,fit_generator() 方法可以动态加载和处理批次数据,更适合无法一次全部加载到内存中的较大型数据集。

Keras 基础

如今,技术在各个方面都得到了改进。因此,在这个先进的技术和未来的环境中,Keras 已成为著名的 Python 库之一。它主要用于创建、评估和训练深度神经网络。这些深度神经网络是一种机器学习算法,模仿人脑的结构和运作方式。图像识别、自然语言处理、语音识别等是它的一些活动。

fit() 方法

这是 Keras 用于模型训练的首选方法,它适用于小型到中型数据集。如果您的数据集太大而无法放入内存,则 fit() 函数需要您在训练前将其完全放入内存。无论如何,使用 fit() 方法都很简单。只需输入您的训练数据、标签、时期数和批次大小即可。然后,Keras 通过在您的数据集上迭代指定数量的时期和批次大小来训练您的模型。fit() 方法非常适合基本训练,但它也存在一些局限性。

fit() 的局限性

Keras 的 fit() 函数通常用于训练深度神经网络 (DNN)。但是,它存在一些重大限制,可能会限制其实用性。fit() 最严重的问题之一是在处理大型数据集时。尽管它是训练 DNN 的强大而通用的工具,但大型数据集会导致 fit() 函数处理时间过长,从而导致您的项目延迟。内存使用是 fit() 的另一个限制。当您的内存有限时,这可能是一个严重的问题。它在定制方面也不是特别灵活。

fit_generator() 方法

fit_generator() 方法与 fit() 方法一样,都用于在数据集上训练神经网络。唯一的区别在于数据处理方式。fit() 一次将整个数据集加载到内存中,而 fit_generator() 则分批处理数据。这种细微的差别可能看起来并不重要,但它带来了许多独特的优势。首先,它使您可以处理大得多的数据集。训练过程中再也不会出现内存不足的情况了!此外,它允许您以更复杂的方式自定义数据处理。

语法

fit() −

model.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_val, y_val))

fit_generator −

model.fit_generator(generator=train_generator, steps_per_epoch=steps_per_epoch,  epochs=epochs, validation_data=val_generator, validation_steps=val_steps)

解释

我们需要模型实例本身,它是训练过程的基础。接下来,您需要输入数据 (x_train) 和标签 (y_train),这些数据将用于训练模型。

batch_size 选项指定在更新模型的内部参数之前将处理的样本数。此值会影响训练过程的速度和效率。此外,时期数指定在训练期间整个训练数据集将通过模型处理的次数。

如果可用,您可以添加可选的 validation_data 来测试模型的性能,并在整个训练过程中跟踪不同数据集上的损失。这使您可以检查模型的泛化能力并进行必要的改进。

或者,您可以使用 train_generator,它可以动态生成批次训练数据。steps_per_epoch 选项指示生成器应为每个时期生成多少批次。类似地,val_generator 可以与 val_steps 参数一起使用,该参数指定验证生成器应为每个时期生成多少批次。

算法

  • 步骤 1 − 导入所有必要的库和模块。

  • 步骤 2 − 接下来,请记住加载或生成所有所需的训练和验证数据。

  • 步骤 3 − 必须使用 Keras API 提供模型架构。

  • 步骤 4 − 设置损失函数、优化器和评估矩阵后,编译模型。

  • 步骤 5 − 要训练模型,请使用上面给出的任何一个模型(即 fit() 或 fit_generator)

方法 1:使用 fit() 进行监督学习。

此代码指示计算机程序使用神经网络学习数据中的模式。该程序有两层:一层查找数据中的模式,另一层预测数据的类别。

示例

# Importing required libraries
import tensorflow as tf
from tensorflow import keras
import numpy as np
# Defining the Sequential model with 2 layers
model = keras.Sequential([
	keras.layers.Dense(units=64, activation='relu', input_dim=100),
	keras.layers.Dense(units=10, activation='softmax')
])
# Compiling the model with required configuration
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
# Creating a dummy dataset
x_train = np.random.random((1000, 100))
y_train = np.random.random((1000, 10))
# Training the model
model.fit(x_train, y_train, epochs=20, batch_size=128)

输出

Epoch 1/20
8/8 [==============================] - 1s 3ms/step - loss: 12.2208 - accuracy: 0.1070
Epoch 2/20
8/8 [==============================] - 0s 3ms/step - loss: 12.6586 - accuracy: 0.1090
...
Epoch 20/20
8/8 [==============================] - 0s 2ms/step - loss: 35170420.0000 - accuracy: 0.1060
<keras.callbacks.History at 0x7f14d04b2c80>

方法 2:使用 Keras.fit_generator() 的示例

在此代码中,我们将使用 Keras 的 ImageDataGenerator 类执行数据增强和重新缩放。数据生成器动态生成批次数据,从而能够使用大规模数据集进行训练。

示例

import numpy as np
from keras.models import Sequential
from keras.layers import Dense
from keras.preprocessing.image import ImageDataGenerator
x_train = np.random.random((1000, 20))
y_train = np.random.randint(2, size=(1000, 1))
x_val = np.random.random((100, 20))
y_val = np.random.randint(2, size=(100, 1))
train_datagen = ImageDataGenerator(rescale=1./255)
val_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow(x_train, y_train, batch_size=32)
val_generator = val_datagen.flow(x_val, y_val, batch_size=32)
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=20))
model.add(Dense(1, activation='sigmoid'))
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])
model.fit_generator(generator=train_generator, steps_per_epoch=30, epochs=10, validation_data=val_generator, validation_steps=3)

输出

Epoch 1/10
30/30 [==============================] - 0s 5ms/step - loss: 0.7105 - accuracy: 0.4969 - val_loss: 0.6967 - val_accuracy: 0.4792
...
Epoch 10/10
30/30 [==============================] - 0s 2ms/step - loss: 0.6932 - accuracy: 0.5031 - val_loss: 0.6931 - val_accuracy: 0.4792

结论

请记住,fit() 方法非常适合小型数据集和简单模型,但它也存在局限性。fit_generator() 方法为大型数据集和更复杂的模型提供了更多灵活性和自定义选项。请根据您的具体需求明智地选择。

更新于: 2023年10月13日

1K+ 浏览量

开启您的 职业生涯

通过完成课程获得认证

立即开始
广告