PyTorch 中的 backward() 有什么作用?


backward() 方法用于在神经网络的反向传播中计算梯度。

  • 执行此方法时将计算梯度。

  • 这些梯度将存储在相应的变量中。

  • 梯度相对于这些变量计算,而梯度可通过 .grad 进行访问。

  • 如果不调用 backward() 方法来计算梯度,则不会计算梯度。

  • 如果我们使用 .grad 访问梯度,则结果为

我们举几个例子来说明它的工作原理。

示例 1

在此示例中,我们尝试在不调用 backward() 方法的情况下访问梯度。我们注意到所有的梯度都是

# import torch library
import torch

# define three tensor
x = torch.tensor(2., requires_grad = False)
w = torch.tensor(3., requires_grad = True)
b = torch.tensor(1., requires_grad = True)
print("x:", x)
print("w:", w)
print("b:", b)

# define a function of the above defined tensors
y = w * x + b
print("y:", y)

# print the gradient w.r.t above tensors
print("x.grad:", x.grad)
print("w.grad:", w.grad)
print("b.grad:", b.grad)

输出

x: tensor(2.)
w: tensor(3., requires_grad=True)
b: tensor(1., requires_grad=True)
y: tensor(7., grad_fn=<AddBackward0>)
x.grad: None
w.grad: None
b.grad: None

示例 2

在第二个示例中,调用了函数 ybackward() 方法。然后,访问了梯度。对于不需要grad的张量,相对于它们的梯度仍然是。但对于需要梯度的张量,相对于它们的梯度并非无。

# import torch library
import torch

# define three tensors
x = torch.tensor(2., requires_grad = False)
w = torch.tensor(3., requires_grad = True)
b = torch.tensor(1., requires_grad = True)
print("x:", x)
print("w:", w)
print("b:", b)

# define a function y
y = w * x + b
print("y:", y)

# take the backward() for y
y.backward()
# print the gradients w.r.t. above x, w, and b
print("x.grad:", x.grad)
print("w.grad:", w.grad)
print("b.grad:", b.grad)

输出

x: tensor(2.)
w: tensor(3., requires_grad=True)
b: tensor(1., requires_grad=True)
y: tensor(7., grad_fn=<AddBackward0>)
x.grad: None
w.grad: tensor(2.)
b.grad: tensor(1.)

更新于: 06-12-2021

3K+ 次浏览

开始你的 职业生涯

完成课程获得认证

开始学习
广告
© . All rights reserved.