如何在PyTorch中测量目标和输入概率之间的二元交叉熵?
我们应用 **BCELoss()** 方法来计算输入和目标(预测和实际)概率之间的 *二元交叉熵* 损失。 **BCELoss()** 来自 **torch.nn** 模块。它创建一个衡量二元交叉熵损失的标准。它是 **torch.nn** 模块提供的损失函数的一种类型。
损失函数用于通过最小化损失来优化深度神经网络。输入和目标都应该是具有类概率的张量。确保目标在 0 和 1 之间。输入和目标张量都可以具有任意数量的维度。例如,在自动编码器中,**BCELoss()** 用于测量重建误差。
语法
torch.nn.BCELoss()
步骤
要计算二元交叉熵损失,可以按照以下步骤操作:
导入所需的库。在以下所有示例中,所需的Python库是 **torch**。确保您已安装它。
import torch
创建输入和目标张量并打印它们。
input = torch.rand(3, 5) target = torch.randn(3, 5).softmax(dim=1)
创建一个标准来衡量二元交叉熵损失。
bce_loss = nn.BCELoss()
计算二元交叉熵损失并打印它。
output = bce_loss(input, target)
print('Binary Cross Entropy Loss:
', output)**注意** - 在以下示例中,我们使用随机数来生成输入和目标张量。因此,您可能会得到这些张量的不同值。
示例 1
在下面的Python程序中,我们计算输入和目标概率之间的二元交叉熵损失。
import torch
import torch.nn as nn
input = torch.rand(6, requires_grad=True)
target = torch.rand(6)
# create a criterion to measure binary cross entropy
bce_loss = nn.BCELoss()
# compute the binary cross entropy
output = bce_loss(input, target)
output.backward()
print('input:
', input)
print('target:\ n ', target)
print('Binary Cross Entropy Loss:
', output)输出
input: tensor([0.3440, 0.7944, 0.8919, 0.3551, 0.9817, 0.8871], requires_grad=True) target: tensor([0.1639, 0.4745, 0.1537, 0.5444, 0.6933, 0.1129]) Binary Cross Entropy Loss: tensor(1.2200, grad_fn=<BinaryCrossEntropyBackward>)
请注意,输入和目标张量的元素都在 0 和 1 之间。
示例 2
在这个程序中,我们计算输入和目标张量之间的BCE损失。两个张量都是二维的。请注意,对于目标张量,我们使用 **softmax()** 函数使其元素在 0 和 1 之间。
import torch
import torch.nn as nn
input = torch.rand(3, 5, requires_grad=True)
target = torch.randn(3, 5).softmax(dim=1)
loss = nn.BCELoss()
output = loss(input, target)
output.backward()
print("Input:
",input)
print("Target:
",target)
print("Binary Cross Entropy Loss:
",output)输出
Input: tensor([[0.5080, 0.5674, 0.1960, 0.7617, 0.9675], [0.8497, 0.4167, 0.4464, 0.6646, 0.7448], [0.4477, 0.6700, 0.0358, 0.8317, 0.9484]], requires_grad=True) Target: tensor([[0.0821, 0.2900, 0.1864, 0.1480, 0.2935], [0.1719, 0.3426, 0.0729, 0.3616, 0.0510], [0.1284, 0.1542, 0.1338, 0.1779, 0.4057]]) Cross Entropy Loss: tensor(1.0689, grad_fn=<BinaryCrossEntropyBackward>)
请注意,输入和目标张量的元素都在 0 和 1 之间。
广告
数据结构
网络
关系数据库管理系统(RDBMS)
操作系统
Java
iOS
HTML
CSS
Android
Python
C语言编程
C++
C#
MongoDB
MySQL
Javascript
PHP