如何在 PyTorch 中通过重塑输入张量来展平它?


可以使用 **torch.flatten()** 方法通过重塑将张量展平为一维张量。此方法支持实值和复值输入张量。它以 torch 张量作为输入,并返回展平为一维的 torch 张量。

它有两个可选参数,**start_dim** 和 **end_dim**。如果传递了这些参数,则仅展平从 start_dim 开始到 end_dim 结束的那些维度。

输入张量中元素的顺序不会改变。此函数可能会返回原始对象、视图或副本。在以下示例中,我们涵盖了使用和不使用 **start_dim** 和 **end_dim** 展平张量的所有方面。

语法

torch.flatten(input, star_dim=0, end_dim=-1)

参数

  • **input** - 要展平的 torch 张量。

  • **start_dim** - 要展平的第一个维度。这是一个可选参数。默认设置为 0。

  • **end_dim** - 要展平的最后一个维度。这是一个可选参数。默认设置为 -1。

步骤

  • 导入所需的库。在以下所有示例中,所需的 Python 库为 **torch**。确保您已安装它。

import torch
  • 创建一个 PyTorch 张量并打印该张量。

t = torch.tensor([[[1, 2, 3],
   [4, 5, 6]],
   [[7, 8, 9],
   [10, 11, 12]]])
print("Tensor:
", t)
  • 使用上面定义的语法展平上述张量,并可选地将值赋给一个新变量。

flatten_t = torch.flatten(t, start_dim=0, end_dim=1)
  • 打印展平后的张量。

print("Flattened Tensor:
", flatten_t)

示例 1

在此程序中,我们将张量展平为一维张量。我们还使用 **start_dim** 展平张量。

Import the required library
import torch
# define a torch tensor
t = torch.tensor([[[1, 2, 3],
   [4, 5, 6]],
   [[7, 8, 9],
   [10, 11, 12]]])
print("Tensor:
", t) print("Size of Tensor:", t.size()) # flatten the above tensor using start_dims flatten_t = torch.flatten(t) flatten_t0 = torch.flatten(t, start_dim=0) flatten_t1 = torch.flatten(t, start_dim=1) flatten_t2 = torch.flatten(t, start_dim=2) # print the flatten tensors print("Flatten tensor:
", flatten_t) print("Flatten tensor (start_dim=0):
", flatten_t0) print("Flatten tensor (start_dim=1):
", flatten_t1) print("Flatten tensor (start_dim=2):
", flatten_t2)

输出

Tensor:
   tensor([[[ 1, 2, 3],
      [ 4, 5, 6]],
      [[ 7, 8, 9],
      [10, 11, 12]]])
Size of Tensor: torch.Size([2, 2, 3])
Flatten tensor:
   tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
Flatten tensor (start_dim=0):
   tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
Flatten tensor (start_dim=1):
   tensor([[ 1, 2, 3, 4, 5, 6],
      [ 7, 8, 9, 10, 11, 12]])
Flatten tensor (start_dim=2):
   tensor([[[ 1, 2, 3],
      [ 4, 5, 6]],
      [[ 7, 8, 9],
      [10, 11, 12]]])

示例 2

在此程序中,我们将张量展平为一维张量。我们还使用 **end_dim** 展平张量。

import torch
t = torch.tensor([[[1, 2, 3],
   [4, 5, 6]],
   [[7, 8, 9],
   [10, 11, 12]]])
print("Tensor:
", t) print("Size of Tensor:", t.size()) # flatten the above tensor using end_dims flatten_t = torch.flatten(t) flatten_t0 = torch.flatten(t, end_dim=0) flatten_t1 = torch.flatten(t, end_dim=1) flatten_t2 = torch.flatten(t, end_dim=2) # print the flatten tensors print("Flatten tensor:
", flatten_t) print("Flatten tensor (end_dim=0):
", flatten_t0) print("Flatten tensor (end_dim=1):
", flatten_t1) print("Flatten tensor (end_dim=2):
", flatten_t2)

输出

Tensor:
   tensor([[[ 1, 2, 3],
      [ 4, 5, 6]],

      [[ 7, 8, 9],
      [10, 11, 12]]])
Size of Tensor: torch.Size([2, 2, 3])
Flatten tensor:
   tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
Flatten tensor (end_dim=0):
   tensor([[[ 1, 2, 3],
      [ 4, 5, 6]],

      [[ 7, 8, 9],
      [10, 11, 12]]])
Flatten tensor (end_dim=1):
   tensor([[ 1, 2, 3],
      [ 4, 5, 6],

      [ 7, 8, 9],
      [10, 11, 12]])
Flatten tensor (end_dim=2):
   tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])

示例 3

在此程序中,我们将张量展平为一维张量。我们还使用 **start_dim** 和 **end_dim** 展平张量。

import torch
t = torch.empty(2,2,3,3).random_(30)
print("Tensor:
", t) print("Size of Tensor:", t.size()) # flatten the above tensor using end_dims flatten_t0 = torch.flatten(t, start_dim=2, end_dim=3) # print the flatten tensors print("Flatten tensor (start_dim=2,end_dim=3):
", flatten_t0)

输出

Tensor:
   tensor([[[[27., 13., 29.],
      [ 1., 23., 15.],
      [15., 7., 19.]],

      [[ 4., 14., 24.],
      [ 6., 4., 7.],
      [ 6., 18., 11.]]],

      [[[ 0., 27., 3.],
      [25., 12., 25.],
      [10., 23., 9.]],

      [[ 3., 1., 28.],
      [19., 7., 28.],
      [23., 14., 21.]]]])
Size of Tensor: torch.Size([2, 2, 3, 3])
Flatten tensor (start_dim=2,end_dim=3):
   tensor([[[27., 13., 29., 1., 23., 15., 15., 7., 19.],
      [ 4., 14., 24., 6., 4., 7., 6., 18., 11.]],

      [[ 0., 27., 3., 25., 12., 25., 10., 23., 9.],
      [ 3., 1., 28., 19., 7., 28., 23., 14., 21.]]])

更新于: 2022年1月20日

6K+ 阅读量

开启您的 职业生涯

通过完成课程获得认证

立即开始
广告

© . All rights reserved.