如何在PyTorch中计算输入中每个元素的海维赛德阶跃函数?


要计算输入张量中每个元素的海维赛德阶跃函数,我们使用 **torch.heaviside()** 方法。它接受两个参数 - 输入和值。它返回一个包含计算出的 **heaviside** 阶跃函数的新张量。

如果 **input=0**,则 **heaviside** 函数的值与值相同。如果输入小于零,则heaviside的值为零。如果输入大于零,则 **heaviside** 的值为1。它接受任何维度的torch张量。它也称为 **单位阶跃函数**。

语法

torch.heaviside(input, values)

步骤

我们可以使用以下步骤计算海维赛德阶跃函数:

  • 导入所需的库。在以下所有示例中,所需的Python库是 **torch**。确保你已经安装了它。

import torch
  • 创建两个张量 - **input** 和 **values**。

input = torch.randn(3,3)
values = torch.tensor([0.5, 0.3, 0.7])
  • 使用 **torch.heaviside(input, values)** 计算上述定义的张量的海维赛德阶跃函数。可以选择将此值赋值给一个新变量。

hssf = torch.heaviside(input, values)
  • 打印上述计算出的海维赛德阶跃函数。

print("Heaviside Step Function:
", hssf)

示例1

在这个Python示例中,我们计算一维张量的海维赛德阶跃函数。

import torch
# define input and values tensors
input = torch.tensor([-1.5, 0, 2.0])
values = torch.tensor([0.5])

# display above defined tensors
print("Input Tensor:
", input) print("Values Tensor:
", values) # compute heaviside step function hssf = torch.heaviside(input, values) print("Heaviside Step Function:
", hssf)

输出

Input Tensor:
   tensor([-1.5000, 0.0000, 2.0000])
Values Tensor:
   tensor([0.5000])
Heaviside Step Function:
   tensor([0.0000, 0.5000, 1.0000])

示例2

在这个示例中,我们计算二维张量的海维赛德阶跃函数。

import torch
# define input and values tensors
input = torch.tensor([[0.2, 0.0, -0.7, -0.2],
   [0.0, 0.6, 0.6, -0.9],
   [0.0, 0.0, 0.0, 0.4],
   [-1.2, 0.0, 0.8, 0.0]])
values = torch.tensor([0.5,0.3, 0.7, 0.8])

# display above defined tensors
print("Input Tensor:
", input) print("Values Tensor:
", values) # compute heaviside step function hssf = torch.heaviside(input, values) print("Heaviside Step Function:
", hssf)

输出

Input Tensor:
   tensor([[ 0.2000, 0.0000, -0.7000, -0.2000],
      [ 0.0000, 0.6000, 0.6000, -0.9000],
      [ 0.0000, 0.0000, 0.0000, 0.4000],
      [-1.2000, 0.0000, 0.8000, 0.0000]])
Values Tensor:
   tensor([5.0000, 0.3000, 0.7000, 0.8000])
Heaviside Step Function:
   tensor([[1.0000, 0.3000, 0.0000, 0.0000],
      [5.0000, 1.0000, 1.0000, 0.0000],
      [5.0000, 0.3000, 0.7000, 1.0000],
      [0.0000, 0.3000, 1.0000, 0.8000]])

更新于:2022年1月27日

310 次浏览

启动您的 职业生涯

完成课程获得认证

开始
广告