如何使用Tensorflow配置花卉数据集以提高性能?


花卉数据集在创建模型时会给出一定的准确率。如果需要配置模型以提高性能,则定义一个函数,该函数第二次执行缓冲预取,然后进行洗牌。此函数用于训练数据集以提高模型的性能。

阅读更多: 什么是TensorFlow以及Keras如何与TensorFlow一起创建神经网络?

我们将使用花卉数据集,其中包含数千朵花的图像。它包含5个子目录,每个类都有一个子目录。

我们使用Google Colaboratory运行以下代码。Google Colab或Colaboratory帮助通过浏览器运行Python代码,无需任何配置,并可免费访问GPU(图形处理单元)。Colaboratory构建在Jupyter Notebook之上。

print("A function is defined that configures the dataset for perfromance")
def configure_for_performance(ds):
   ds = ds.cache()
   ds = ds.shuffle(buffer_size=1000)
   ds = ds.batch(batch_size)
   ds = ds.prefetch(buffer_size=AUTOTUNE)
   return ds

print("The function is called on training dataset")
train_ds = configure_for_performance(train_ds)
print("The function is called on validation dataset")
val_ds = configure_for_performance(val_ds)

代码来源:https://tensorflowcn.cn/tutorials/load_data/images

输出

A function is defined that configures the dataset for perfromance
The function is called on training dataset
The function is called on validation dataset

解释

  • 需要使用数据集训练模型。
  • 模型首先被充分洗牌,然后分批处理,然后这些批次可用。
  • 这些功能是使用“tf.data”API添加的。

更新于: 2021年2月19日

135次浏览

启动您的职业生涯

通过完成课程获得认证

开始学习
广告