如何在PyTorch中计算两个张量的余弦相似度?


为了计算两个张量之间的余弦相似度,我们使用**torch.nn**模块提供的**CosineSimilarity()**函数。它返回沿**dim**计算的余弦相似度值。

**dim**是此函数的一个可选参数,沿其计算余弦相似度。

  • 对于一维张量,我们只能沿**dim=0**计算余弦相似度。

  • 对于二维张量,我们可以沿**dim=0**或**1**计算余弦相似度。

  • 为了计算余弦相似度,两个张量的尺寸必须相同。两个张量必须是实数值的。余弦相似度常用于文本分析中度量文档相似度。

语法

torch.nn.CosineSimilarity(dim=1)

默认的**dim**设置为1。但是,如果您测量**一维张量**之间的余弦相似度,则我们将**dim**设置为0。

步骤

  • 导入所需的库。在以下所有示例中,所需的Python库是**torch**。确保您已安装它。

import torch
  • 创建两个张量并打印它们。两个张量必须是实数值的。

tensor1 = torch.randn(3,4)
tensor2 = torch.randn(3,4)
  • 定义一个沿维度**dim**测量余弦相似度的方法。

cos = torch.nn.CosineSimilarity(dim=0)
  • 使用上面定义的方法计算余弦相似度。

output = cos(tensor1, tensor2)
  • 打印计算出的包含余弦相似度值的张量。

print("Cosine Similarity:",output)

示例1

下面的Python程序计算两个一维张量之间的**余弦相似度**。

# Import the required library
import torch

# define two input tensors
tensor1 = torch.tensor([0.1, 0.3, 2.3, 0.45])
tensor2 = torch.tensor([0.13, 0.23, 2.33, 0.45])

# print above defined two tensors
print("Tensor 1:
", tensor1) print("Tensor 2:
", tensor2) # define a method to measure cosine similarity cos = torch.nn.CosineSimilarity(dim=0) output = cos(tensor1, tensor2) # display the output tensor print("Cosine Similarity:",output)

输出

Tensor 1:
   tensor([0.1000, 0.3000, 2.3000, 0.4500])
Tensor 2:
   tensor([0.1300, 0.2300, 2.3300, 0.4500])
Cosine Similarity: tensor(0.9995)

示例2

在这个Python程序中,我们沿不同的**dim**计算两个二维张量之间的余弦相似度。

# Import the required library
import torch

# define two input tensors
tensor1 = torch.randn(3,4)
tensor2 = torch.randn(3,4)

# print above defined two tensors
print("Tensor 1:
", tensor1) print("Tensor 2:
", tensor2) # define a method to measure cosine similarity in dim 0 cos0 = torch.nn.CosineSimilarity(dim=0) output0 = cos0(tensor1, tensor2) print("Cosine Similarity in dim 0:
",output0) # define a method to measure cosine similarity in dim 1 cos1 = torch.nn.CosineSimilarity(dim=1) output1 = cos1(tensor1, tensor2) print("Cosine Similarity in dim 1:
",output1)

输出

Tensor 1:
   tensor([[ 0.2714, 1.1430, 1.3997, 0.8788],
      [-2.2268, 1.9799, 1.5682, 0.5850],
      [ 1.2289, 0.5043, -0.1625, 1.1403]])
Tensor 2:
   tensor([[-0.3299, 0.6360, -0.2014, 0.5989],
      [-0.6679, 0.0793, -2.5842, -1.5123],
      [ 1.1110, -0.1212, 0.0324, 1.1277]])
Cosine Similarity in dim 0:
   tensor([ 0.8076, 0.5388, -0.7941, 0.3016])
Cosine Similarity in dim 1:
   tensor([ 0.4553, -0.3140, 0.9258])

更新于:2022年1月20日

12K+浏览量

启动您的职业生涯

通过完成课程获得认证

开始学习
广告