Python PyTorch 中的 torch.argmax() 方法


为了找到输入张量中元素最大值的索引,我们可以使用 **torch.argmax()** 函数。它只返回索引,而不是元素值。如果输入张量有多个最大值,则该函数将返回第一个最大元素的索引。我们可以应用 **torch.argmax()** 函数来计算张量跨维度最大值的索引。

语法

torch.argmax(input)

步骤

我们可以使用以下步骤来查找输入张量中所有元素最大值的索引:

  • 导入所需的库。在以下所有示例中,所需的 Python 库为 **torch**。确保您已安装它。

import torch
  • 定义输入张量 **input**。

input = torch.randn(3,4)
  • 计算张量 **input** 中所有元素最大值的索引。

indices = torch.argmax(input)
  • 打印上面计算出的带有索引的张量。

print("Indices:
", indices)

示例 1

# Import the required library
import torch

# define an input tensor
input = torch.tensor([0., -1., 2., 8.])

# print above defined tensor
print("Input Tensor:
", input) # Compute indices of the maximum value indices = torch.argmax(input) # print the indices print("Indices:
", indices)

输出

Input Tensor:
   tensor([ 0., -1., 2., 8.])
Indices:
   tensor(3)

在上面的 Python 示例中,我们找到了输入 1D 张量中元素最大值的索引。输入张量中的最大值为 8,该元素的索引为 3。

示例 2

在这个程序中,我们计算了相对于不同矩阵范数的条件数。

# Import the required library
import torch

# define an input tensor
input = torch.randn(4,4)

# print above defined tensor
print("Input Tensor:
", input) # Compute indices of the maximum value indices = torch.argmax(input) # print the indices print("Indices:
", indices) # Compute indices of the maximum value in dim 0 indices = torch.argmax(input, dim=0) # print the indices print("Indices in dim 0:
", indices) # Compute indices of the maximum value in dim 1 indices = torch.argmax(input, dim=1) # print the indices print("Indices in dim 1:
", indices)

输出

Input Tensor:
   tensor([[-1.6729, 1.2613, -1.2882, -0.8133],
   [ 0.9192, 0.9301, -0.2372, 0.0162],
   [-0.4669, 0.6604, -0.7982, 0.2621],
   [ 0.6436, 1.0328, 2.4573, 0.0606]])
Indices:
   tensor(14)
Indices in dim 0:
   tensor([1, 0, 3, 2])
Indices in dim 1:
   tensor([1, 1, 1, 2])

在上面的 Python 示例中,我们找到了输入 2D 张量中元素最大值的索引,分别在不同的维度上。我们使用 **torch.randn()** 方法生成了输入张量的元素,因此您可能会注意到获取不同的输入张量和索引。

更新于: 2022年1月27日

9K+ 浏览量

开启你的 职业生涯

通过完成课程获得认证

开始学习
广告

© . All rights reserved.