如何使用Tensorflow配置数据集以提高性能?
花卉数据集可以通过缓冲预取、shuffle 方法和 cache 方法来配置以提高性能。缓冲预取可以用于确保数据可以从磁盘获取,而不会导致 I/O 阻塞。Dataset.cache() 会在第一个 epoch 加载完磁盘上的图像后将其保存在内存中。Dataset.prefetch() 会在训练期间重叠数据预处理和模型执行。
阅读更多: 什么是 TensorFlow 以及 Keras 如何与 TensorFlow 协作创建神经网络?
使用 Keras Sequential API,它有助于构建一个顺序模型,该模型用于处理简单的层堆栈,其中每一层只有一个输入张量和一个输出张量。
我们正在使用 Google Colaboratory 来运行以下代码。Google Colab 或 Colaboratory 帮助在浏览器上运行 Python 代码,无需任何配置,并且可以免费访问 GPU(图形处理单元)。Colaboratory 是建立在 Jupyter Notebook 之上的。
print("Configuring the dataset for better performance") AUTOTUNE = tf.data.AUTOTUNE train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE) val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
代码来源: https://tensorflowcn.cn/tutorials/images/classification
输出
Configuring the dataset for better performance
解释
- 可以使用缓冲预取的概念,以便数据可以从磁盘获取,而不会导致 I/O 阻塞。
- 加载数据时可以使用两种重要的方法。
- cache() 会在第一个 epoch 加载完磁盘上的图像后将其保存在内存中。
- 这将确保在训练模型时数据集不会成为瓶颈。
- 如果数据集太大而无法放入内存,则可以使用此方法创建高性能的磁盘缓存。
- prefetch() 会在训练期间重叠数据预处理和模型执行。
广告