如何在 PyTorch 中执行扩展操作?
Tensor.expand() 属性用于执行扩展操作。它沿着单例维度将张量扩展到新的维度。
扩展张量只会创建一个原始张量的新视图;它不会复制原始张量。
如果将特定维度设置为 -1,则不会沿着此维度扩展张量。
例如,如果我们有一个大小为 (3,1) 的张量,我们可以沿着大小为 1 的维度扩展此张量。
步骤
要扩展张量,可以按照以下步骤操作:
导入 torch 库。确保您已安装它。
import torch
定义至少有一个维度为单例的张量。
t = torch.tensor([[1],[2],[3]])
沿着单例维度扩展张量。沿着非单例维度扩展将引发运行时错误(参见示例 3)。
t_exp = t.expand(3,2)
显示扩展后的张量。
print("Tensor after expand:
", t_exp)
示例 1
以下 Python 程序演示了如何将大小为 (3,1) 的张量扩展为大小为 (3,2) 的张量。它沿着大小为 1 的维度扩展张量。大小为 3 的另一个维度保持不变。
# import required libraries
import torch
# create a tensor
t = torch.tensor([[1],[2],[3]])
# display the tensor
print("Tensor:
", t)
print("Size of Tensor:
", t.size())
# expand the tensor
exp = t.expand(3,2)
print("Tensor after expansion:
", exp)输出
Tensor:
tensor([[1],
[2],
[3]])
Size of Tensor:
torch.Size([3, 1])
Tensor after expansion:
tensor([[1, 1],
[2, 2],
[3, 3]])示例 2
以下 Python 程序将大小为 (1,3) 的张量扩展为大小为 (3,3) 的张量。它沿着大小为 1 的维度扩展张量。
# import required libraries
import torch
# create a tensor
t = torch.tensor([[1,2,3]])
# display the tensor
print("Tensor:
", t)
# size of tensor is [1,3]
print("Size of Tensor:
", t.size())
# expand the tensor
expandedTensor = t.expand(3,-1)
print("Expanded Tensor:
", expandedTensor)
print("Size of expanded tensor:
", expandedTensor.size())输出
Tensor:
tensor([[1, 2, 3]])
Size of Tensor:
torch.Size([1, 3])
Expanded Tensor:
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])
Size of expanded tensor:
torch.Size([3, 3])示例 3
在以下 Python 程序中,我们尝试沿着非单例维度扩展张量,因此它引发了运行时错误。
# import required libraries
import torch
# create a tensor
t = torch.tensor([[1,2,3]])
# display the tensor
print("Tensor:
", t)
# size of tensor is [1,3]
print("Size of Tensor:
", t.size())
t.expand(3,4)输出
Tensor: tensor([[1, 2, 3]]) Size of Tensor: torch.Size([1, 3]) RuntimeError: The expanded size of the tensor (4) must match the existing size (3) at non-singleton dimension 1. Target sizes: [3, 4]. Tensor sizes: [1, 3]
广告
数据结构
网络
关系型数据库管理系统
操作系统
Java
iOS
HTML
CSS
Android
Python
C 编程
C++
C#
MongoDB
MySQL
Javascript
PHP