如何在 PyTorch 中通过重塑输入张量来展平它?
可以使用 **torch.flatten()** 方法通过重塑将张量展平为一维张量。此方法支持实值和复值输入张量。它以 torch 张量作为输入,并返回展平为一维的 torch 张量。
它有两个可选参数,**start_dim** 和 **end_dim**。如果传递了这些参数,则仅展平从 start_dim 开始到 end_dim 结束的那些维度。
输入张量中元素的顺序不会改变。此函数可能会返回原始对象、视图或副本。在以下示例中,我们涵盖了使用和不使用 **start_dim** 和 **end_dim** 展平张量的所有方面。
语法
torch.flatten(input, star_dim=0, end_dim=-1)
参数
**input** - 要展平的 torch 张量。
**start_dim** - 要展平的第一个维度。这是一个可选参数。默认设置为 0。
**end_dim** - 要展平的最后一个维度。这是一个可选参数。默认设置为 -1。
步骤
导入所需的库。在以下所有示例中,所需的 Python 库为 **torch**。确保您已安装它。
import torch
创建一个 PyTorch 张量并打印该张量。
t = torch.tensor([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]])
print("Tensor:
", t)使用上面定义的语法展平上述张量,并可选地将值赋给一个新变量。
flatten_t = torch.flatten(t, start_dim=0, end_dim=1)
打印展平后的张量。
print("Flattened Tensor:
", flatten_t)示例 1
在此程序中,我们将张量展平为一维张量。我们还使用 **start_dim** 展平张量。
Import the required library
import torch
# define a torch tensor
t = torch.tensor([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]])
print("Tensor:
", t)
print("Size of Tensor:", t.size())
# flatten the above tensor using start_dims
flatten_t = torch.flatten(t)
flatten_t0 = torch.flatten(t, start_dim=0)
flatten_t1 = torch.flatten(t, start_dim=1)
flatten_t2 = torch.flatten(t, start_dim=2)
# print the flatten tensors
print("Flatten tensor:
", flatten_t)
print("Flatten tensor (start_dim=0):
", flatten_t0)
print("Flatten tensor (start_dim=1):
", flatten_t1)
print("Flatten tensor (start_dim=2):
", flatten_t2)输出
Tensor: tensor([[[ 1, 2, 3], [ 4, 5, 6]], [[ 7, 8, 9], [10, 11, 12]]]) Size of Tensor: torch.Size([2, 2, 3]) Flatten tensor: tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) Flatten tensor (start_dim=0): tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) Flatten tensor (start_dim=1): tensor([[ 1, 2, 3, 4, 5, 6], [ 7, 8, 9, 10, 11, 12]]) Flatten tensor (start_dim=2): tensor([[[ 1, 2, 3], [ 4, 5, 6]], [[ 7, 8, 9], [10, 11, 12]]])
示例 2
在此程序中,我们将张量展平为一维张量。我们还使用 **end_dim** 展平张量。
import torch
t = torch.tensor([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]])
print("Tensor:
", t)
print("Size of Tensor:", t.size())
# flatten the above tensor using end_dims
flatten_t = torch.flatten(t)
flatten_t0 = torch.flatten(t, end_dim=0)
flatten_t1 = torch.flatten(t, end_dim=1)
flatten_t2 = torch.flatten(t, end_dim=2)
# print the flatten tensors
print("Flatten tensor:
", flatten_t)
print("Flatten tensor (end_dim=0):
", flatten_t0)
print("Flatten tensor (end_dim=1):
", flatten_t1)
print("Flatten tensor (end_dim=2):
", flatten_t2)输出
Tensor: tensor([[[ 1, 2, 3], [ 4, 5, 6]], [[ 7, 8, 9], [10, 11, 12]]]) Size of Tensor: torch.Size([2, 2, 3]) Flatten tensor: tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) Flatten tensor (end_dim=0): tensor([[[ 1, 2, 3], [ 4, 5, 6]], [[ 7, 8, 9], [10, 11, 12]]]) Flatten tensor (end_dim=1): tensor([[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9], [10, 11, 12]]) Flatten tensor (end_dim=2): tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
示例 3
在此程序中,我们将张量展平为一维张量。我们还使用 **start_dim** 和 **end_dim** 展平张量。
import torch
t = torch.empty(2,2,3,3).random_(30)
print("Tensor:
", t)
print("Size of Tensor:", t.size())
# flatten the above tensor using end_dims
flatten_t0 = torch.flatten(t, start_dim=2, end_dim=3)
# print the flatten tensors
print("Flatten tensor (start_dim=2,end_dim=3):
", flatten_t0)输出
Tensor: tensor([[[[27., 13., 29.], [ 1., 23., 15.], [15., 7., 19.]], [[ 4., 14., 24.], [ 6., 4., 7.], [ 6., 18., 11.]]], [[[ 0., 27., 3.], [25., 12., 25.], [10., 23., 9.]], [[ 3., 1., 28.], [19., 7., 28.], [23., 14., 21.]]]]) Size of Tensor: torch.Size([2, 2, 3, 3]) Flatten tensor (start_dim=2,end_dim=3): tensor([[[27., 13., 29., 1., 23., 15., 15., 7., 19.], [ 4., 14., 24., 6., 4., 7., 6., 18., 11.]], [[ 0., 27., 3., 25., 12., 25., 10., 23., 9.], [ 3., 1., 28., 19., 7., 28., 23., 14., 21.]]])
广告
数据结构
网络
关系数据库管理系统
操作系统
Java
iOS
HTML
CSS
Android
Python
C 编程
C++
C#
MongoDB
MySQL
Javascript
PHP