Tensorflow 如何配置花卉数据集以提高性能?


花卉数据集在创建模型时会给出一定的准确率。如果需要配置模型以提高性能,则可以使用缓冲预取与 Rescaling 层。通过将 Rescaling 层作为 Keras 模型的一部分,可以在数据集上使用 Keras 模型应用此层。

阅读更多:什么是 TensorFlow 以及 Keras 如何与 TensorFlow 协作创建神经网络?

我们将使用花卉数据集,其中包含数千朵花的图像。它包含 5 个子目录,每个子目录对应一个类别。

我们使用 Google Colaboratory 来运行以下代码。Google Colab 或 Colaboratory 帮助通过浏览器运行 Python 代码,无需任何配置,并可以免费访问 GPU(图形处理单元)。Colaboratory 建立在 Jupyter Notebook 之上。

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

num_classes = 5
print("A sequential model is built")
model = tf.keras.Sequential([
   layers.experimental.preprocessing.Rescaling(1./255),
   layers.Conv2D(32, 3, activation='relu'),
   layers.MaxPooling2D(),
   layers.Conv2D(32, 3, activation='relu'),
   layers.MaxPooling2D(),
   layers.Conv2D(32, 3, activation='relu'),
   layers.MaxPooling2D(),
   layers.Flatten(),
   layers.Dense(128, activation='relu'),
   layers.Dense(num_classes)
])

代码来源:https://tensorflowcn.cn/tutorials/load_data/images

输出

A sequential model is built

解释

  • 使用缓冲预取,以便可以从磁盘生成数据,而不会阻塞 I/O。
  • 这是加载数据时的一个重要步骤。
  • '.cache()' 方法有助于在第一个 epoch 从磁盘加载图像后将其保留在内存中。
  • 这确保了数据集在训练模型时不会成为障碍。
  • 如果数据集太大而无法放入内存,则可以使用相同的方法来创建高性能的磁盘缓存。
  • '.prefetch()' 方法在训练数据时重叠数据预处理和模型执行操作。

更新于: 2021年2月19日

154 次查看

开启你的 职业生涯

通过完成课程获得认证

开始学习
广告