如何在 PyTorch 中查找张量的第 k 个元素和前 k 个元素?
PyTorch 提供了一个方法 **torch.kthvalue()** 来查找张量的第 k 个元素。它返回按升序排序的张量中第 k 个元素的值,以及该元素在原始张量中的索引。
**torch.topk()** 方法用于查找前 k 个元素。它返回张量中前 k 个或最大的 k 个元素。
步骤
导入所需的库。在以下所有 Python 示例中,所需的 Python 库是 **torch**。确保您已安装它。
创建一个 PyTorch 张量并打印它。
计算 **torch.kthvalue(input, k)**。它返回两个张量。将这两个张量赋值给两个新的变量 **"value"** 和 **"index"**。这里,input 是一个张量,k 是一个整数。
计算 **torch.topk(input, k)**。它返回两个张量。第一个张量包含前 k 个元素的值,第二个张量包含这些元素在原始张量中的索引。将这两个张量赋值给新的变量 **"values"** 和 **"indices"**。
打印张量中第 k 个元素的值和索引,以及张量中前 k 个元素的值和索引。
示例 1
此 Python 程序演示如何查找张量的第 k 个元素。
# Python program to find k-th element of a tensor
# import necessary library
import torch
# Create a 1D tensor
T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])
print("Original Tensor:\n", T)
# Find the 3rd element in sorted tensor. First it sorts the
# tensor in ascending order then returns the kth element value
# from sorted tensor and the index of element in original tensor
value, index = torch.kthvalue(T, 3)
# print 3rd element with value and index
print("3rd element value:", value)
print("3rd element index:", index)输出
Original Tensor: tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430]) 3rd element value: tensor(2.3340) 3rd element index: tensor(0)
示例 2
以下 Python 程序演示如何查找张量的前 k 个或最大的 k 个元素。
# Python program to find to top k elements of a tensor
# import necessary library
import torch
# Create a 1D tensor
T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])
print("Original Tensor:\n", T)
# Find the top k=2 or 2 largest elements of the tensor
# returns the 2 largest values and their indices in original
# tensor
values, indices = torch.topk(T, 2)
# print top 2 elements with value and index
print("Top 2 element values:", values)
print("Top 2 element indices:", indices)输出
Original Tensor: tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430]) Top 2 element values: tensor([5.0000, 4.4430]) Top 2 element indices: tensor([4, 5])
广告
数据结构
网络
关系数据库管理系统 (RDBMS)
操作系统
Java
iOS
HTML
CSS
Android
Python
C 编程
C++
C#
MongoDB
MySQL
Javascript
PHP