如何在 PyTorch 中归一化张量?
在 TensorFlow 中,PyTorch 的 张量 可以使用 **torch.nn.functional** 模块提供的 **normalize()** 函数进行归一化。这是一个非线性激活函数。
它对给定张量在指定维度上执行 **Lp 归一化**。
它返回原始张量元素的归一化值张量。
一维张量可以在维度 0 上归一化,而二维张量可以在维度 0 和 1 上归一化,即列方向或行方向。
n 维张量可以在维度 (0, 1, 2,..., n-1) 上归一化。
语法
torch.nn.functional.normalize(input, p=2.0, dim = 1)
参数
**输入** – 输入张量
**p** – 范数公式中的幂 (指数) 值
**dim** – 对其元素进行归一化的维度。
步骤
我们可以使用以下步骤来归一化张量:
导入 **torch** 库。确保你已安装它。
import torch from torch.nn.functional import normalize
创建一个张量并打印它。
t = torch.tensor([[1.,2.,3.],[4.,5.,6.]])
print("Tensor:", t)使用不同的 p 值和不同的维度对张量进行归一化。上面定义的张量是一个二维张量,因此我们可以对其进行两个维度的归一化。
t1 = normalize(t, p=1.0, dim = 1) t2 = normalize(t, p=2.0, dim = 0)
打印上面计算出的归一化张量。
print("Normalized tensor:
", t1)
print("Normalized tensor:
", t2)示例 1
# import torch library
import torch
from torch.nn.functional import normalize
# define a torch tensor
t = torch.tensor([1., 2., 3., -2., -5.])
# print the above tensor
print("Tensor:
", t)
# normalize the tensor
t1 = normalize(t, p=1.0, dim = 0)
t2 = normalize(t, p=2.0, dim = 0)
# print normalized tensor
print("Normalized tensor with p=1:
", t1)
print("Normalized tensor with p=2:
", t2)输出
Tensor: tensor([ 1., 2., 3., -2., -5.]) Normalized tensor with p=1: tensor([ 0.0769, 0.1538, 0.2308, -0.1538, -0.3846]) Normalized tensor with p=2: tensor([ 0.1525, 0.3050, 0.4575, -0.3050, -0.7625])
示例 2
# import torch library
import torch
from torch.nn.functional import normalize
# define a 2D tensor
t = torch.tensor([[1.,2.,3.],[4.,5.,6.]])
# print the above tensor
print("Tensor:
", t)
# normalize the tensor
t0 = normalize(t, p=2.0)
# print the normalized tensor
print("Normalized tensor:
", t0)
# normalize the tensor in dim 0 or column-wise
tc = normalize(t, p=2.0, dim = 0)
# print the normalized tensor
print("Column-wise Normalized tensor:
", tc)
# normalize the tensor in dim 1 or row-wise
tr = normalize(t, p=2.0, dim = 1)
# print the normalized tensor
print("Row-wise Normalized tensor:
", tr)输出
Tensor: tensor([[1., 2., 3.], [4., 5., 6.]]) Normalized tensor: tensor([[0.2673, 0.5345, 0.8018], [0.4558, 0.5698, 0.6838]]) Column-wise Normalized tensor: tensor([[0.2425, 0.3714, 0.4472], [0.9701, 0.9285, 0.8944]]) Row-wise Normalized tensor: tensor([[0.2673, 0.5345, 0.8018], [0.4558, 0.5698, 0.6838]])
广告
数据结构
网络
关系型数据库管理系统 (RDBMS)
操作系统
Java
iOS
HTML
CSS
Android
Python
C语言编程
C++
C#
MongoDB
MySQL
Javascript
PHP