- PyTorch 教程
- PyTorch - 主页
- PyTorch - 介绍
- PyTorch - 安装
- 神经网络的数学基础知识
- PyTorch - 神经网络基础
- 机器学习的通用工作流程
- 机器学习与深度学习
- 实现第一个神经网络
- 将神经网络应用到功能模块
- PyTorch - 术语
- PyTorch - 加载数据
- PyTorch - 线性回归
- PyTorch - 卷积神经网络
- PyTorch - 循环神经网络
- PyTorch - 数据集
- PyTorch - 卷积神经网络简介
- 从头开始训练卷积神经网络
- PyTorch - 卷积神经网络中的特征提取
- PyTorch - 卷积神经网络可视化
- 使用卷积神经网络处理序列
- PyTorch - 词嵌入
- PyTorch - 递归神经网络
- PyTorch 有用资源
- PyTorch - 快速指南
- PyTorch - 有用资源
- PyTorch - 讨论
PyTorch - 从头开始训练卷积神经网络
在本章中,我们将重点介绍从头开始创建卷积神经网络。这包括使用 torch 创建相应的卷积神经网络或样本神经网络。
步骤 1
使用相应参数创建必要的类。这些参数包括具有随机值的权重。
class Neural_Network(nn.Module): def __init__(self, ): super(Neural_Network, self).__init__() self.inputSize = 2 self.outputSize = 1 self.hiddenSize = 3 # weights self.W1 = torch.randn(self.inputSize, self.hiddenSize) # 3 X 2 tensor self.W2 = torch.randn(self.hiddenSize, self.outputSize) # 3 X 1 tensor
步骤 2
创建带有 sigmoid 函数的前馈模式函数。
def forward(self, X): self.z = torch.matmul(X, self.W1) # 3 X 3 ".dot" does not broadcast in PyTorch self.z2 = self.sigmoid(self.z) # activation function self.z3 = torch.matmul(self.z2, self.W2) o = self.sigmoid(self.z3) # final activation function return o def sigmoid(self, s): return 1 / (1 + torch.exp(-s)) def sigmoidPrime(self, s): # derivative of sigmoid return s * (1 - s) def backward(self, X, y, o): self.o_error = y - o # error in output self.o_delta = self.o_error * self.sigmoidPrime(o) # derivative of sig to error self.z2_error = torch.matmul(self.o_delta, torch.t(self.W2)) self.z2_delta = self.z2_error * self.sigmoidPrime(self.z2) self.W1 + = torch.matmul(torch.t(X), self.z2_delta) self.W2 + = torch.matmul(torch.t(self.z2), self.o_delta)
步骤 3
创建如下所示的训练和预测模型 -
def train(self, X, y): # forward + backward pass for training o = self.forward(X) self.backward(X, y, o) def saveWeights(self, model): # Implement PyTorch internal storage functions torch.save(model, "NN") # you can reload model with all the weights and so forth with: # torch.load("NN") def predict(self): print ("Predicted data based on trained weights: ") print ("Input (scaled): \n" + str(xPredicted)) print ("Output: \n" + str(self.forward(xPredicted)))
广告