如何在PyTorch中测量目标和输入概率之间的二元交叉熵?
我们应用 **BCELoss()** 方法来计算输入和目标(预测和实际)概率之间的 *二元交叉熵* 损失。 **BCELoss()** 来自 **torch.nn** 模块。它创建一个衡量二元交叉熵损失的标准。它是 **torch.nn** 模块提供的损失函数的一种类型。
损失函数用于通过最小化损失来优化深度神经网络。输入和目标都应该是具有类概率的张量。确保目标在 0 和 1 之间。输入和目标张量都可以具有任意数量的维度。例如,在自动编码器中,**BCELoss()** 用于测量重建误差。
语法
torch.nn.BCELoss()
步骤
要计算二元交叉熵损失,可以按照以下步骤操作:
导入所需的库。在以下所有示例中,所需的Python库是 **torch**。确保您已安装它。
import torch
创建输入和目标张量并打印它们。
input = torch.rand(3, 5) target = torch.randn(3, 5).softmax(dim=1)
创建一个标准来衡量二元交叉熵损失。
bce_loss = nn.BCELoss()
计算二元交叉熵损失并打印它。
output = bce_loss(input, target) print('Binary Cross Entropy Loss:
', output)
**注意** - 在以下示例中,我们使用随机数来生成输入和目标张量。因此,您可能会得到这些张量的不同值。
示例 1
在下面的Python程序中,我们计算输入和目标概率之间的二元交叉熵损失。
import torch import torch.nn as nn input = torch.rand(6, requires_grad=True) target = torch.rand(6) # create a criterion to measure binary cross entropy bce_loss = nn.BCELoss() # compute the binary cross entropy output = bce_loss(input, target) output.backward() print('input:
', input) print('target:\ n ', target) print('Binary Cross Entropy Loss:
', output)
输出
input: tensor([0.3440, 0.7944, 0.8919, 0.3551, 0.9817, 0.8871], requires_grad=True) target: tensor([0.1639, 0.4745, 0.1537, 0.5444, 0.6933, 0.1129]) Binary Cross Entropy Loss: tensor(1.2200, grad_fn=<BinaryCrossEntropyBackward>)
请注意,输入和目标张量的元素都在 0 和 1 之间。
示例 2
在这个程序中,我们计算输入和目标张量之间的BCE损失。两个张量都是二维的。请注意,对于目标张量,我们使用 **softmax()** 函数使其元素在 0 和 1 之间。
import torch import torch.nn as nn input = torch.rand(3, 5, requires_grad=True) target = torch.randn(3, 5).softmax(dim=1) loss = nn.BCELoss() output = loss(input, target) output.backward() print("Input:
",input) print("Target:
",target) print("Binary Cross Entropy Loss:
",output)
输出
Input: tensor([[0.5080, 0.5674, 0.1960, 0.7617, 0.9675], [0.8497, 0.4167, 0.4464, 0.6646, 0.7448], [0.4477, 0.6700, 0.0358, 0.8317, 0.9484]], requires_grad=True) Target: tensor([[0.0821, 0.2900, 0.1864, 0.1480, 0.2935], [0.1719, 0.3426, 0.0729, 0.3616, 0.0510], [0.1284, 0.1542, 0.1338, 0.1779, 0.4057]]) Cross Entropy Loss: tensor(1.0689, grad_fn=<BinaryCrossEntropyBackward>)
请注意,输入和目标张量的元素都在 0 和 1 之间。
广告