如何在 PyTorch 中逐元素应用修正线性单元函数?
要对输入张量逐元素应用修正线性单元 (ReLU) 函数,我们使用 **torch.nn.ReLU()**。它将输入张量中所有负元素替换为 0(零),所有非负元素保持不变。它仅支持实值输入张量。**ReLU** 用作神经网络中的激活函数。
语法
relu = torch.nn.ReLU() output = relu(input)
步骤
您可以使用以下步骤逐元素应用修正线性单元 (ReLU) 函数:
导入所需的库。在以下所有示例中,所需的 Python 库为 **torch**。确保您已安装它。
import torch import torch.nn as nn
定义 **输入** 张量并打印它。
input = torch.randn(2,3) print("Input Tensor:
",input)
使用 **torch.nn.ReLU()** 定义 ReLU 函数 **relu**。
relu = torch.nn.ReLU()
将上面定义的 ReLU 函数 **relu** 应用于输入张量。并可以选择将输出分配给一个新变量
output = relu(input)
打印包含 ReLU 函数值的张量。
print("ReLU Tensor:
",output)
让我们看几个例子,以便更好地理解它的工作原理。
示例 1
# Import the required library import torch import torch.nn as nn relu = torch.nn.ReLU() input = torch.tensor([[-1., 8., 1., 13., 9.], [ 0., 1., 0., 5., -5.], [ 3., -5., 8., -1., 5.], [ 0., 3., -1., 13., 12.]]) print("Input Tensor:
",input) print("Size of Input Tensor:
",input.size()) # Compute the rectified linear unit (ReLU) function element-wise output = relu(input) print("ReLU Tensor:
",output) print("Size of ReLU Tensor:
",output.size())
输出
Input Tensor: tensor([[-1., 8., 1., 13., 9.], [ 0., 1., 0., 5., -5.], [ 3., -5., 8., -1., 5.], [ 0., 3., -1., 13., 12.]]) Size of Input Tensor: torch.Size([4, 5]) ReLU Tensor: tensor([[ 0., 8., 1., 13., 9.], [ 0., 1., 0., 5., 0.], [ 3., 0., 8., 0., 5.], [ 0., 3., 0., 13., 12.]]) Size of ReLU Tensor: torch.Size([4, 5])
在上面的示例中,请注意输出张量中输入张量的负元素被替换为零。
示例 2
# Import the required library import torch import torch.nn as nn relu = torch.nn.ReLU(inplace=True) input = torch.randn(4,5) print("Input Tensor:
",input) print("Size of Input Tensor:
",input.size()) # Compute the rectified linear unit (ReLU) function element-wise output = relu(input) print("ReLU Tensor:
",output) print("Size of ReLU Tensor:
",output.size())
输出
Input Tensor: tensor([[ 0.4217, 0.4151, 1.3292, -1.3835, -0.0086], [-0.7693, -1.7736, -0.3401, -0.7179, -0.0196], [ 1.0918, -0.9426, 2.1496, -0.4809, -1.2254], [-0.3198, -0.2231, 1.2043, 1.1222, 0.7905]]) Size of Input Tensor: torch.Size([4, 5]) ReLU Tensor: tensor([[0.4217, 0.4151, 1.3292, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [1.0918, 0.0000, 2.1496, 0.0000, 0.0000], [0.0000, 0.0000, 1.2043, 1.1222, 0.7905]]) Size of ReLU Tensor: torch.Size([4, 5])
广告