如何使用 Python 和 Estimator 编译 TensorFlow 模型?


可以使用 Estimator 和 ‘train’ 方法来编译 TensorFlow 模型。

阅读更多: 什么是 TensorFlow,以及 Keras 如何与 TensorFlow 协作创建神经网络?

我们将使用 Keras Sequential API,它有助于构建顺序模型,用于处理简单的层堆栈,其中每一层只有一个输入张量和一个输出张量。

包含至少一层卷积层的神经网络称为卷积神经网络。我们可以使用卷积神经网络构建学习模型。

TensorFlow Text 包含一系列与文本相关的类和运算符,可用于 TensorFlow 2.0。TensorFlow Text 可用于预处理序列建模。

我们使用 Google Colaboratory 来运行以下代码。Google Colab 或 Colaboratory 帮助在浏览器上运行 Python 代码,无需任何配置,并可免费访问 GPU(图形处理单元)。Colaboratory 基于 Jupyter Notebook 构建。

Estimator 是 TensorFlow 中对完整模型的高级表示。它设计用于轻松扩展和异步训练。

该模型使用虹膜数据集进行训练。共有 4 个特征和一个标签。

  • 萼片长度
  • 萼片宽度
  • 花瓣长度
  • 花瓣宽度

示例

print("The model is being trained")
classifier.train(input_fn=lambda: input_fn(train, train_y, training=True), steps=5000)

代码来源 −https://tensorflowcn.cn/tutorials/estimator/premade#first_things_first

输出

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:Layer dnn is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2. The layer has dtype float32 because its dtype defaults to floatx.
If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.
To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/adagrad.py:83: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpbhg2uvbr/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 1.1140382, step = 0
INFO:tensorflow:global_step/sec: 312.415
INFO:tensorflow:loss = 0.8781501, step = 100 (0.321 sec)
INFO:tensorflow:global_step/sec: 375.535
INFO:tensorflow:loss = 0.80712265, step = 200 (0.266 sec)
INFO:tensorflow:global_step/sec: 372.712
INFO:tensorflow:loss = 0.7615077, step = 300 (0.268 sec)
INFO:tensorflow:global_step/sec: 368.782
INFO:tensorflow:loss = 0.733555, step = 400 (0.271 sec)
INFO:tensorflow:global_step/sec: 372.689
INFO:tensorflow:loss = 0.6983943, step = 500 (0.268 sec)
INFO:tensorflow:global_step/sec: 370.308
INFO:tensorflow:loss = 0.67940104, step = 600 (0.270 sec)
INFO:tensorflow:global_step/sec: 373.374
INFO:tensorflow:loss = 0.65386146, step = 700 (0.268 sec)
INFO:tensorflow:global_step/sec: 368.335
INFO:tensorflow:loss = 0.63730353, step = 800 (0.272 sec)
INFO:tensorflow:global_step/sec: 371.575
INFO:tensorflow:loss = 0.61313766, step = 900 (0.269 sec)
INFO:tensorflow:global_step/sec: 371.975
INFO:tensorflow:loss = 0.6123625, step = 1000 (0.269 sec)
INFO:tensorflow:global_step/sec: 369.615
INFO:tensorflow:loss = 0.5957534, step = 1100 (0.270 sec)
INFO:tensorflow:global_step/sec: 374.054
INFO:tensorflow:loss = 0.57203, step = 1200 (0.267 sec)
INFO:tensorflow:global_step/sec: 369.713
INFO:tensorflow:loss = 0.56556034, step = 1300 (0.270 sec)
INFO:tensorflow:global_step/sec: 366.202
INFO:tensorflow:loss = 0.547443, step = 1400 (0.273 sec)
INFO:tensorflow:global_step/sec: 361.407
INFO:tensorflow:loss = 0.53326523, step = 1500 (0.277 sec)
INFO:tensorflow:global_step/sec: 367.461
INFO:tensorflow:loss = 0.51837724, step = 1600 (0.272 sec)
INFO:tensorflow:global_step/sec: 364.181
INFO:tensorflow:loss = 0.5281174, step = 1700 (0.275 sec)
INFO:tensorflow:global_step/sec: 368.139
INFO:tensorflow:loss = 0.5139683, step = 1800 (0.271 sec)
INFO:tensorflow:global_step/sec: 366.277
INFO:tensorflow:loss = 0.51073176, step = 1900 (0.273 sec)
INFO:tensorflow:global_step/sec: 366.634
INFO:tensorflow:loss = 0.4949246, step = 2000 (0.273 sec)
INFO:tensorflow:global_step/sec: 364.732
INFO:tensorflow:loss = 0.49381495, step = 2100 (0.274 sec)
INFO:tensorflow:global_step/sec: 365.006
INFO:tensorflow:loss = 0.48916715, step = 2200 (0.274 sec)
INFO:tensorflow:global_step/sec: 366.902
INFO:tensorflow:loss = 0.48790723, step = 2300 (0.273 sec)
INFO:tensorflow:global_step/sec: 362.232
INFO:tensorflow:loss = 0.47671652, step = 2400 (0.276 sec)
INFO:tensorflow:global_step/sec: 368.592
INFO:tensorflow:loss = 0.47324088, step = 2500 (0.271 sec)
INFO:tensorflow:global_step/sec: 371.611
INFO:tensorflow:loss = 0.46822113, step = 2600 (0.269 sec)
INFO:tensorflow:global_step/sec: 362.345
INFO:tensorflow:loss = 0.4621966, step = 2700 (0.276 sec)
INFO:tensorflow:global_step/sec: 362.788
INFO:tensorflow:loss = 0.47817266, step = 2800 (0.275 sec)
INFO:tensorflow:global_step/sec: 368.473
INFO:tensorflow:loss = 0.45853442, step = 2900 (0.271 sec)
INFO:tensorflow:global_step/sec: 360.944
INFO:tensorflow:loss = 0.44062576, step = 3000 (0.277 sec)
INFO:tensorflow:global_step/sec: 370.982
INFO:tensorflow:loss = 0.4331399, step = 3100 (0.269 sec)
INFO:tensorflow:global_step/sec: 366.248
INFO:tensorflow:loss = 0.45120597, step = 3200 (0.273 sec)
INFO:tensorflow:global_step/sec: 371.703
INFO:tensorflow:loss = 0.4403404, step = 3300 (0.269 sec)
INFO:tensorflow:global_step/sec: 362.176
INFO:tensorflow:loss = 0.42405623, step = 3400 (0.276 sec)
INFO:tensorflow:global_step/sec: 363.283
INFO:tensorflow:loss = 0.41672814, step = 3500 (0.275 sec)
INFO:tensorflow:global_step/sec: 363.529
INFO:tensorflow:loss = 0.42626005, step = 3600 (0.275 sec)
INFO:tensorflow:global_step/sec: 367.348
INFO:tensorflow:loss = 0.4089098, step = 3700 (0.272 sec)
INFO:tensorflow:global_step/sec: 363.067
INFO:tensorflow:loss = 0.41276374, step = 3800 (0.275 sec)
INFO:tensorflow:global_step/sec: 364.771
INFO:tensorflow:loss = 0.4112524, step = 3900 (0.274 sec)
INFO:tensorflow:global_step/sec: 363.167
INFO:tensorflow:loss = 0.39261794, step = 4000 (0.275 sec)
INFO:tensorflow:global_step/sec: 362.082
INFO:tensorflow:loss = 0.41160905, step = 4100 (0.276 sec)
INFO:tensorflow:global_step/sec: 364.979
INFO:tensorflow:loss = 0.39620766, step = 4200 (0.274 sec)
INFO:tensorflow:global_step/sec: 363.323
INFO:tensorflow:loss = 0.39696264, step = 4300 (0.275 sec)
INFO:tensorflow:global_step/sec: 361.25
INFO:tensorflow:loss = 0.38196522, step = 4400 (0.277 sec)
INFO:tensorflow:global_step/sec: 365.666
INFO:tensorflow:loss = 0.38667366, step = 4500 (0.274 sec)
INFO:tensorflow:global_step/sec: 361.202
INFO:tensorflow:loss = 0.38149032, step = 4600 (0.277 sec)
INFO:tensorflow:global_step/sec: 365.038
INFO:tensorflow:loss = 0.37832782, step = 4700 (0.274 sec)
INFO:tensorflow:global_step/sec: 366.375
INFO:tensorflow:loss = 0.3726803, step = 4800 (0.273 sec)
INFO:tensorflow:global_step/sec: 366.474
INFO:tensorflow:loss = 0.37167495, step = 4900 (0.273 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5000...
INFO:tensorflow:Saving checkpoints for 5000 into /tmp/tmpbhg2uvbr/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5000...
INFO:tensorflow:Loss for final step: 0.36297452.
<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifierV2 at 0x7fc9983ed470>

解释

  • 创建 Estimator 对象后,可以调用以下方法:
  • 训练模型。
  • 评估训练后的模型。
  • 使用此模型进行预测。
  • 再次训练模型。
  • 这是通过调用 Estimator 的 train 方法来完成的。

更新于: 2021年2月22日

93 次查看

启动您的职业生涯

完成课程并获得认证

开始学习
广告