Python PyTorch 中的 torch.argmax() 方法
为了找到输入张量中元素最大值的索引,我们可以使用 **torch.argmax()** 函数。它只返回索引,而不是元素值。如果输入张量有多个最大值,则该函数将返回第一个最大元素的索引。我们可以应用 **torch.argmax()** 函数来计算张量跨维度最大值的索引。
语法
torch.argmax(input)
步骤
我们可以使用以下步骤来查找输入张量中所有元素最大值的索引:
导入所需的库。在以下所有示例中,所需的 Python 库为 **torch**。确保您已安装它。
import torch
定义输入张量 **input**。
input = torch.randn(3,4)
计算张量 **input** 中所有元素最大值的索引。
indices = torch.argmax(input)
打印上面计算出的带有索引的张量。
print("Indices:
", indices)示例 1
# Import the required library
import torch
# define an input tensor
input = torch.tensor([0., -1., 2., 8.])
# print above defined tensor
print("Input Tensor:
", input)
# Compute indices of the maximum value
indices = torch.argmax(input)
# print the indices
print("Indices:
", indices)输出
Input Tensor: tensor([ 0., -1., 2., 8.]) Indices: tensor(3)
在上面的 Python 示例中,我们找到了输入 1D 张量中元素最大值的索引。输入张量中的最大值为 8,该元素的索引为 3。
示例 2
在这个程序中,我们计算了相对于不同矩阵范数的条件数。
# Import the required library
import torch
# define an input tensor
input = torch.randn(4,4)
# print above defined tensor
print("Input Tensor:
", input)
# Compute indices of the maximum value
indices = torch.argmax(input)
# print the indices
print("Indices:
", indices)
# Compute indices of the maximum value in dim 0
indices = torch.argmax(input, dim=0)
# print the indices
print("Indices in dim 0:
", indices)
# Compute indices of the maximum value in dim 1
indices = torch.argmax(input, dim=1)
# print the indices
print("Indices in dim 1:
", indices)输出
Input Tensor: tensor([[-1.6729, 1.2613, -1.2882, -0.8133], [ 0.9192, 0.9301, -0.2372, 0.0162], [-0.4669, 0.6604, -0.7982, 0.2621], [ 0.6436, 1.0328, 2.4573, 0.0606]]) Indices: tensor(14) Indices in dim 0: tensor([1, 0, 3, 2]) Indices in dim 1: tensor([1, 1, 1, 2])
在上面的 Python 示例中,我们找到了输入 2D 张量中元素最大值的索引,分别在不同的维度上。我们使用 **torch.randn()** 方法生成了输入张量的元素,因此您可能会注意到获取不同的输入张量和索引。
广告
数据结构
网络
关系型数据库管理系统
操作系统
Java
iOS
HTML
CSS
Android
Python
C 语言编程
C++
C#
MongoDB
MySQL
Javascript
PHP