如何在 PyTorch 中连接张量?


我们可以使用 **torch.cat()** 和 **torch.stack()** 连接两个或多个张量。**torch.cat()** 用于连接两个或多个张量,而 **torch.stack()** 用于堆叠张量。我们可以在不同的维度上连接张量,例如 0 维度、-1 维度。

**torch.cat()** 和 **torch.stack()** 都用于连接张量。那么,这两种方法的基本区别是什么呢?

  • **torch.cat()** 沿着现有维度连接一系列张量,因此不会改变张量的维度。

  • **torch.stack()** 沿着新维度堆叠张量,因此会增加维度。

步骤

  • 导入所需的库。在以下所有示例中,所需的 Python 库为 **torch**。请确保您已安装它。

  • 创建两个或多个 PyTorch 张量并打印它们。

  • 使用 **torch.cat()** 或 **torch.stack()** 连接上面创建的张量。提供维度,例如 0、-1,以在特定维度上连接张量。

  • 最后,打印连接或堆叠的张量。

示例 1

# Python program to join tensors in PyTorch # import necessary library import torch # create tensors T1 = torch.Tensor([1,2,3,4]) T2 = torch.Tensor([0,3,4,1]) T3 = torch.Tensor([4,3,2,5]) # print above created tensors print("T1:", T1) print("T2:", T2) print("T3:", T3) # join (concatenate) above tensors using torch.cat() T = torch.cat((T1,T2,T3)) # print final tensor after concatenation print("T:",T)

Explore our latest online courses and learn new skills at your own pace. Enroll and become a certified expert to boost your career.

输出

运行以上 Python 3 代码时,将产生以下输出

T1: tensor([1., 2., 3., 4.])
T2: tensor([0., 3., 4., 1.])
T3: tensor([4., 3., 2., 5.])
T: tensor([1., 2., 3., 4., 0., 3., 4., 1., 4., 3., 2., 5.])

示例 2

# import necessary library import torch # create tensors T1 = torch.Tensor([[1,2],[3,4]]) T2 = torch.Tensor([[0,3],[4,1]]) T3 = torch.Tensor([[4,3],[2,5]]) # print above created tensors print("T1:\n", T1) print("T2:\n", T2) print("T3:\n", T3) print("join(concatenate) tensors in the 0 dimension") T = torch.cat((T1,T2,T3), 0) print("T:\n", T) print("join(concatenate) tensors in the -1 dimension") T = torch.cat((T1,T2,T3), -1) print("T:\n", T)

输出

运行以上 Python 3 代码时,将产生以下输出

T1:
tensor([[1., 2.],
         [3., 4.]])
T2:
tensor([[0., 3.],
         [4., 1.]])
T3:
tensor([[4., 3.],
         [2., 5.]])
join(concatenate) tensors in the 0 dimension
T:
tensor([[1., 2.],
         [3., 4.],
         [0., 3.],
         [4., 1.],
         [4., 3.],
         [2., 5.]])
join(concatenate) tensors in the -1 dimension
T:
tensor([[1., 2., 0., 3., 4., 3.],
         [3., 4., 4., 1., 2., 5.]])

在以上示例中,2D 张量沿 0 和 -1 维度连接。沿 0 维度连接会增加行数,而列数保持不变。

示例 3

# Python program to join tensors in PyTorch # import necessary library import torch # create tensors T1 = torch.Tensor([1,2,3,4]) T2 = torch.Tensor([0,3,4,1]) T3 = torch.Tensor([4,3,2,5]) # print above created tensors print("T1:", T1) print("T2:", T2) print("T3:", T3) # join above tensor using "torch.stack()" print("join(stack) tensors") T = torch.stack((T1,T2,T3)) # print final tensor after join print("T:\n",T) print("join(stack) tensors in the 0 dimension") T = torch.stack((T1,T2,T3), 0) print("T:\n", T) print("join(stack) tensors in the -1 dimension") T = torch.stack((T1,T2,T3), -1) print("T:\n", T)

输出

运行以上 Python 3 代码时,将产生以下输出

T1: tensor([1., 2., 3., 4.])
T2: tensor([0., 3., 4., 1.])
T3: tensor([4., 3., 2., 5.])
join(stack) tensors
T:
tensor([[1., 2., 3., 4.],
         [0., 3., 4., 1.],
         [4., 3., 2., 5.]])
join(stack) tensors in the 0 dimension
T:
tensor([[1., 2., 3., 4.],
         [0., 3., 4., 1.],
         [4., 3., 2., 5.]])
join(stack) tensors in the -1 dimension
T:
tensor([[1., 0., 4.],
         [2., 3., 3.],
         [3., 4., 2.],
         [4., 1., 5.]])

在以上示例中,您可以注意到 1D 张量被堆叠,最终张量为 2D 张量。

示例 4

# import necessary library import torch # create tensors T1 = torch.Tensor([[1,2],[3,4]]) T2 = torch.Tensor([[0,3],[4,1]]) T3 = torch.Tensor([[4,3],[2,5]]) # print above created tensors print("T1:\n", T1) print("T2:\n", T2) print("T3:\n", T3) print("Join (stack)tensors in the 0 dimension") T = torch.stack((T1,T2,T3), 0) print("T:\n", T) print("Join(stack) tensors in the -1 dimension") T = torch.stack((T1,T2,T3), -1) print("T:\n", T)

输出

运行以上 Python 3 代码时,将产生以下输出。

T1:
tensor([[1., 2.],
         [3., 4.]])
T2:
tensor([[0., 3.],
         [4., 1.]])
T3:
tensor([[4., 3.],
         [2., 5.]])
Join (stack)tensors in the 0 dimension
T:
tensor([[[1., 2.],
         [3., 4.]],
         [[0., 3.],
         [4., 1.]],
         [[4., 3.],
         [2., 5.]]])
Join(stack) tensors in the -1 dimension
T:
tensor([[[1., 0., 4.],
         [2., 3., 3.]],
         [[3., 4., 2.],
         [4., 1., 5.]]])

在以上示例中,您可以注意到 2D 张量被连接(堆叠)以创建 3D 张量。

更新于: 2023年9月14日

32K+ 浏览量

开启您的 职业生涯

通过完成课程获得认证

开始学习
广告