PyTorch中的Tensor.detach()的作用是什么?


Tensor.detach()用于从当前计算图中分离张量。它将返回一个不需要梯度的张量。

  • 当不需要跟踪张量进行梯度计算时,我们会将张量从当前计算图中分离出来。

  • 当我们需要将张量从GPU传输到CPU时,我们也需要分离一个张量。

语法

Tensor.detach()

它将返回一个新的张量,且requires_grad = True。将不再计算与此张量有关的梯度。

步骤

  • 导入torch库。确保你已安装该库。

import torch
  • 使用requires_grad = True创建PyTorch张量并打印张量。

x = torch.tensor(2.0, requires_grad = True)
print("x:", x)
  • 计算Tensor.detach()并选择性地将此值赋给新变量。

x_detach = x.detach()
  • 在执行.detach()操作后打印张量。

print("Tensor with detach:", x_detach)

样例1

# import torch library
import torch

# create a tensor with requires_gradient=true
x = torch.tensor(2.0, requires_grad = True)

# print the tensor
print("Tensor:", x)

# tensor.detach operation
x_detach = x.detach()
print("Tensor with detach:", x_detach)

输出

Tensor: tensor(2., requires_grad=True)
Tensor with detach: tensor(2.)

请注意,在以上输出中,detach后的张量没有requires_grad = True

样例2

# import torch library
import torch

# define a tensor with requires_grad=true
x = torch.rand(3, requires_grad = True)
print("x:", x)

# apply above tensor to use detach()
y = 3 + x
z = 3 * x.detach()

print("y:", y)
print("z:", z)

输出

x: tensor([0.5656, 0.8402, 0.6661], requires_grad=True)
y: tensor([3.5656, 3.8402, 3.6661], grad_fn=<AddBackward0>)
z: tensor([1.6968, 2.5207, 1.9984])

更新于: 06-Dec-2021

已浏览超过12K次

开启您的 职业

完成课程获得认证

开始吧
广告