如何使用Python编译Tensorflow模型?


在Tensorflow中创建的模型可以使用`compile`方法进行编译。损失函数使用`SparseCategoricalCrossentropy`方法计算。

了解更多: 什么是TensorFlow以及Keras如何与TensorFlow一起创建神经网络?

我们使用Google Colaboratory运行以下代码。Google Colab或Colaboratory帮助在浏览器上运行Python代码,无需任何配置,并可免费访问GPU(图形处理单元)。Colaboratory构建在Jupyter Notebook之上。

print("The model is being compiled")
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
   metrics=['accuracy'])
print("The architecture of the model")
model.summary()

代码来源:https://tensorflowcn.cn/tutorials/images/classification

输出

The model is being compiled
The architecture of the model
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #  
=================================================================
rescaling_1 (Rescaling)      (None, 180, 180, 3)       0        
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 180, 180, 16)      448      
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 90, 90, 16)        0        
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 90, 90, 32)        4640    
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 45, 45, 32)        0        
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 45, 45, 64)        18496    
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 22, 22, 64)        0        
_________________________________________________________________
flatten_1 (Flatten)          (None, 30976)             0        
_________________________________________________________________
dense_2 (Dense)              (None, 128)               3965056  
_________________________________________________________________
dense_3 (Dense)              (None, 5)                 645      
=================================================================
Total params: 3,989,285
Trainable params: 3,989,285
Non-trainable params: 0
_________________________________________________________________

解释

  • 使用了`optimizers.Adam`优化器和`losses.SparseCategoricalCrossentropy`损失函数。
  • 可以通过传递`metrics`参数来查看每个训练周期的训练和验证准确率。
  • 编译模型后,可以使用`summary`方法显示模型架构的摘要。

更新于: 2021年2月20日

204 次浏览

启动你的职业生涯

完成课程获得认证

开始学习
广告
© . All rights reserved.