PyTorch – 如何检测张量是否连续?


连续张量是一种元素以连续顺序存储的张量,它们之间不会留有任何空白。最初创建的张量始终是连续张量。张量可以用不同的维度以连续的方式查看。

转置张量会创建一个遵循非连续顺序的原始张量视图。张量转置是非连续的。

语法

Tensor.is_contiguous()

如果张量连续,则返回True;否则返回False

让我们举几个例子来说明如何使用此函数来检查张量是否连续或不连续。

示例 1

# import torch library
import torch

# define a torch tensor
A = torch.tensor([1. ,2. ,3. ,4. ,5. ,6.])
print(A)

# find a view of the above tensor
B = A.view(-1,3)
print(B)

print("id(A):", id(A))
print("id(A.view):", id(A.view(-1,3)))
# check if A or A.view() are contiguous or not
print(A.is_contiguous()) # True
print(A.view(-1,3).is_contiguous()) # True
print(B.is_contiguous()) # True

输出

tensor([1., 2., 3., 4., 5., 6.])
tensor([[1., 2., 3.],
   [4., 5., 6.]])
id(A): 80673600
id(A.view): 63219712
True
True
True

示例 2

# import torch library
import torch

# create a torch tensor
A = torch.tensor([[1.,2.],[3.,4.],[5.,6.]])
print(A)

# take transpose of the above tensor
B = A.transpose(0,1)
print(B)
print("id(A):", id(A))
print("id(A.transpose):", id(A.transpose(0,1)))

# check if A or A transpose are contiguous or not
print(A.is_contiguous()) # True
print(A.transpose(0,1).is_contiguous()) # False
print(B.is_contiguous()) # False

输出

tensor([[1., 2.],
   [3., 4.],
   [5., 6.]])
tensor([[1., 3., 5.],
   [2., 4., 6.]])
id(A): 63218368
id(A.transpose): 99215808
True
False
False

更新日期:2021 年 12 月 6 日

2K+ 次浏览

开启你的 职业生涯

完成课程以获得认证

开始
广告
© . All rights reserved.