如何使用Tensorflow配置数据集以提高性能?


花卉数据集可以通过缓冲预取、shuffle 方法和 cache 方法来配置以提高性能。缓冲预取可以用于确保数据可以从磁盘获取,而不会导致 I/O 阻塞。Dataset.cache() 会在第一个 epoch 加载完磁盘上的图像后将其保存在内存中。Dataset.prefetch() 会在训练期间重叠数据预处理和模型执行。

阅读更多: 什么是 TensorFlow 以及 Keras 如何与 TensorFlow 协作创建神经网络?

使用 Keras Sequential API,它有助于构建一个顺序模型,该模型用于处理简单的层堆栈,其中每一层只有一个输入张量和一个输出张量。

我们正在使用 Google Colaboratory 来运行以下代码。Google Colab 或 Colaboratory 帮助在浏览器上运行 Python 代码,无需任何配置,并且可以免费访问 GPU(图形处理单元)。Colaboratory 是建立在 Jupyter Notebook 之上的。

print("Configuring the dataset for better performance")
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

代码来源: https://tensorflowcn.cn/tutorials/images/classification

输出

Configuring the dataset for better performance

解释

  • 可以使用缓冲预取的概念,以便数据可以从磁盘获取,而不会导致 I/O 阻塞。
  • 加载数据时可以使用两种重要的方法。
    • cache() 会在第一个 epoch 加载完磁盘上的图像后将其保存在内存中。
    • 这将确保在训练模型时数据集不会成为瓶颈。
    • 如果数据集太大而无法放入内存,则可以使用此方法创建高性能的磁盘缓存。
    • prefetch() 会在训练期间重叠数据预处理和模型执行。

更新于: 2021年2月20日

826 次浏览

开启你的 职业生涯

通过完成课程获得认证

立即开始
广告