如何使用TensorFlow保存和加载MNIST数据集的权重?


TensorFlow是由Google提供的机器学习框架。它是一个开源框架,与Python结合使用,可以实现算法、深度学习应用程序等等。它被用于研究和生产环境。它具有优化技术,有助于快速执行复杂的数学运算。这是因为它使用了NumPy和多维数组。这些多维数组也被称为“张量”。

可以使用以下代码行在Windows上安装“tensorflow”包:

pip install tensorflow

张量是TensorFlow中使用的数据结构。它有助于连接数据流图中的边。这个数据流图被称为“数据流图”。张量只不过是多维数组或列表。

当训练时间过长时,模型往往会过拟合,并且在测试数据上的泛化能力较差。因此,训练步骤的数量必须保持平衡。这意味着,必须考虑所有数据案例才能进行有效的训练。这样,模型在测试数据上的泛化能力更好。否则,可以进行正则化。

Keras是一个用Python编写的深度学习API。它是一个高级API,具有高效的接口,可以帮助解决机器学习问题。它运行在TensorFlow框架之上。它的构建是为了帮助快速实验。它提供了开发和封装机器学习解决方案所必需的基本抽象和构建块。

Keras已经存在于TensorFlow包中。可以使用以下代码行访问它。

import tensorflow
from tensorflow import keras

我们使用Google Colaboratory运行以下代码。Google Colab或Colaboratory有助于在浏览器上运行Python代码,无需任何配置,并且可以免费访问GPU(图形处理单元)。Colaboratory构建在Jupyter Notebook之上。以下是代码片段:

示例

!pip install -q pyyaml h5py
import os

import tensorflow as tf
from tensorflow import keras

print("The version of Tensorflow is : ")
print(tf.version.VERSION)
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
print("Splitting training and test data")
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

print("Reshaping the training and test data")
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

代码来源: https://tensorflowcn.cn/tutorials/keras/save_and_load

输出

解释

  • 导入所需的包并为其设置别名。

  • 获取前1000个示例以提高执行速度。

更新于:2021年1月20日

浏览量:122

启动您的职业生涯

完成课程获得认证

开始学习
广告