PyTorch 中的“with torch no_grad”的作用是什么?


“with torch.no_grad()” 的使用就像一个循环,其中循环内的每一个张量都将 requires_grad 设置为 False。这意味着当前与当前计算图关联的任何带梯度的张量现在都从当前图中分离出来。我们不再能够计算关于此张量的梯度。

一个张量一直从当前图中分离,直到它在循环中。一旦它脱离了循环,如果使用梯度定义了张量,就会再次将它附加到当前图中。

我们来举几个例子,以更好地理解它是如何工作的。

示例 1

在这个示例中,我们创建了一个张量 x,其 requires_grad = true。接下来,我们定义这个张量 x 的函数 y,并将函数置于 torch.no_grad() 循环中。现在 x 在循环中,所以它的 requires_grad 被设置为 False

在循环中,不能针对 x 计算 y 的梯度。所以,y.requires_grad 返回 False

# import torch library
import torch

# define a torch tensor
x = torch.tensor(2., requires_grad = True)
print("x:", x)

# define a function y
with torch.no_grad():
   y = x ** 2
print("y:", y)

# check gradient for Y
print("y.requires_grad:", y.requires_grad)

输出

x: tensor(2., requires_grad=True)
y: tensor(4.)
y.requires_grad: False

示例 2

在此示例中,我们在循环外定义了函数 z。所以,z.requires_grad 返回 True

# import torch library
import torch

# define three tensors
x = torch.tensor(2., requires_grad = False)
w = torch.tensor(3., requires_grad = True)
b = torch.tensor(1., requires_grad = True)

print("x:", x)
print("w:", w)
print("b:", b)

# define a function y
y = w * x + b
print("y:", y)

# define a function z
with torch.no_grad():
   z = w * x + b

print("z:", z)

# check if requires grad is true or not
print("y.requires_grad:", y.requires_grad)
print("z.requires_grad:", z.requires_grad)

输出

x: tensor(2.)
w: tensor(3., requires_grad=True)
b: tensor(1., requires_grad=True)
y: tensor(7., grad_fn=<AddBackward0>)
z: tensor(7.)
y.requires_grad: True
z.requires_grad: False

更新日期:06-12-2021

6K+ 浏览

启动你的 职业生涯

完成课程并获得认证

开始
广告