如何在 PyTorch 中对张量挤压和展开?
要挤压一个张量,我们使用 **torch.squeeze()** 方法。它返回一个新张量,该张量包含输入张量的所有维度,但会移除大小 1。例如,如果输入张量的形状为 (M ☓ 1 ☓ N ☓ 1 ☓ P),则挤压后的张量形状为 (M ☓ M ☓ P)。
要展开一个张量,我们使用 **torch.unsqueeze()** 方法。它返回一个新张量,在特定位置插入大小为 1 的新维度。
步骤
导入所需的库。在以下所有 Python 示例中,所需的 Python 库为 **torch**。确保你已经安装它。
创建一个张量并打印它。
计算 **torch.squeeze(input)**。它将挤压(移除)大小 1,并返回一个包含 **input** 张量所有其他维度的张量。
计算 **torch.unsqueeze(input, dim)**。它在给定的 dim 处插入大小为 1 的新维度,并返回该张量。
打印挤压和/或展开的张量。
示例 1
# Python program to squeeze and unsqueeze a tensor
# import necessary library
import torch
# Create a tensor of all one
T = torch.ones(2,1,2) # size 2x1x2
print("Original Tensor T:\n", T )
print("Size of T:", T.size())
# Squeeze the dimension of the tensor
squeezed_T = torch.squeeze(T) # now size 2x2
print("Squeezed_T\n:", squeezed_T )
print("Size of Squeezed_T:", squeezed_T.size())输出
Original Tensor T: tensor([[[1., 1.]], [[1., 1.]]]) Size of T: torch.Size([2, 1, 2]) Squeezed_T : tensor([[1., 1.], [1., 1.]]) Size of Squeezed_T: torch.Size([2, 2])
示例 2
# Python program to squeeze and unsqueeze a tensor
# import necessary library
import torch
# create a tensor
T = torch.Tensor([1,2,3]) # size 3
print("Original Tensor T:\n", T )
print("Size of T:", T.size())
# Squeeze the tensor in dimension o or column dim
unsqueezed_T = torch.unsqueeze(T, dim = 0) # now size 1x3
print("Unsqueezed T\n:", unsqueezed_T )
print("Size of UnSqueezed T:", unsqueezed_T.size())
# Squeeze the tensor in dimension 1 or row dim
unsqueezed_T = torch.unsqueeze(T, dim = 1) # now size 3x1
print("Unsqueezed T\n:", unsqueezed_T )
print("Size of Unsqueezed T:", unsqueezed_T.size())输出
Original Tensor T: tensor([1., 2., 3.]) Size of T: torch.Size([3]) Unsqueezed T : tensor([[1., 2., 3.]]) Size of UnSqueezed T: torch.Size([1, 3]) Unsqueezed T : tensor([[1.], [2.], [3.]]) Size of Unsqueezed T: torch.Size([3, 1])
广告
数据结构
网络
RDBMS
操作系统
Java
iOS
HTML
CSS
安卓
Python
C 编程
C++
C#
MongoDB
MySQL
Javascript
PHP