如何使用Tensorflow配置花卉数据集以提高性能?
花卉数据集在创建模型时会给出一定的准确率。如果需要配置模型以提高性能,则定义一个函数,该函数第二次执行缓冲预取,然后进行洗牌。此函数用于训练数据集以提高模型的性能。
阅读更多: 什么是TensorFlow以及Keras如何与TensorFlow一起创建神经网络?
我们将使用花卉数据集,其中包含数千朵花的图像。它包含5个子目录,每个类都有一个子目录。
我们使用Google Colaboratory运行以下代码。Google Colab或Colaboratory帮助通过浏览器运行Python代码,无需任何配置,并可免费访问GPU(图形处理单元)。Colaboratory构建在Jupyter Notebook之上。
print("A function is defined that configures the dataset for perfromance") def configure_for_performance(ds): ds = ds.cache() ds = ds.shuffle(buffer_size=1000) ds = ds.batch(batch_size) ds = ds.prefetch(buffer_size=AUTOTUNE) return ds print("The function is called on training dataset") train_ds = configure_for_performance(train_ds) print("The function is called on validation dataset") val_ds = configure_for_performance(val_ds)
代码来源:https://tensorflowcn.cn/tutorials/load_data/images
输出
A function is defined that configures the dataset for perfromance The function is called on training dataset The function is called on validation dataset
解释
- 需要使用数据集训练模型。
- 模型首先被充分洗牌,然后分批处理,然后这些批次可用。
- 这些功能是使用“tf.data”API添加的。
广告