Python – PyTorch clamp() 方法
**torch.clamp()** 用于将输入张量中的所有元素限制在 **[min, max]** 范围内。它接受三个参数:**输入**张量、**min** 值和 **max** 值。小于 min 的值将被替换为 **min**,大于 max 的值将被替换为 **max**。
如果未指定 **min**,则没有下界。如果未指定 **max**,则没有上界。例如,如果我们设置 **min=-0.5** 和 **max=0.4**,则小于 -0.5 的值将被替换为 -0.5,大于 0.4 的值将被替换为 0.4。介于这两个值之间的值将保持不变。它只支持实数值输入。
语法
torch.clamp(input, min=None, max=None)
参数
**input** - 输入张量。
**min** - 下界;可以是数字或张量。
**max** - 上界;可以是数字或张量。
它返回一个新的张量,其中所有元素都被限制在 **[min, max]** 范围内。
步骤
导入所需的库。在以下所有示例中,所需的 Python 库是 **torch**。确保您已安装它。
import torch
创建一个输入张量并打印它。
a = torch.tensor([0.73, 0.35, -0.39, -1.53]) print("input tensor:
", a)
限制输入张量的元素。这里我们使用 **min=-0.5, max=0.5**。
t1 = torch.clamp(a, min=-0.5, max=0.5)
打印 clamp 之后获得的张量。
print(t1)
示例 1
在下面的 Python 程序中,我们限制一维输入张量的元素。请注意当 **min** 或 **max** 为 **None** 时,**clamp()** 方法是如何工作的。
# Import the required library import torch # define a 1D tensor a = torch.tensor([ 0.73, 0.35, -0.39, -1.53]) print("input tensor:
", a) print("clamp the tensor:") print("into range [-0.5, 0.5]:") t1 = torch.clamp(a, min=-0.5, max=0.5) print(t1) print("if min is None:") t2 = torch.clamp(a, max=0.5) print(t2) print("if max is None:") t3 = torch.clamp(a, min=0.5) print(t3) print("if min is greater than max:") t4 = torch.clamp(a, min=0.6, max=.5) print(t4)
输出
input tensor: tensor([ 0.7300, 0.3500, -0.3900, -1.5300]) clamp the tensor: into range [-0.5, 0.5]: tensor([ 0.5000, 0.3500, -0.3900, -0.5000]) if min is None: tensor([ 0.5000, 0.3500, -0.3900, -1.5300]) if max is None: tensor([0.7300, 0.5000, 0.5000, 0.5000]) if min is greater than max: tensor([0.5000, 0.5000, 0.5000, 0.5000])
示例 2
在下面的 Python 程序中,我们限制二维输入张量的元素。请注意当 **min** 或 **max** 为 **None** 时,**clamp()** 方法是如何工作的。
# Import the required library import torch # define a 2D tensor of size [3, 4] a = torch.randn(3,4) print("input tensor:
", a) print("clamp the tensor:") print("into range [-0.6, 0.4]:") t1 = torch.clamp(a, min=-0.6, max=0.4) print(t1) print("if min is None (max=0.4):") t2 = torch.clamp(a, max=0.4) print(t2) print("if max is None (min=-0.6):") t3 = torch.clamp(a, min=-0.6) print(t3) print("if min is greater than max (min=0.6, max=0.4):") t4 = torch.clamp(a, min=0.6, max=0.4) print(t4)
输出
input tensor: tensor([[ 1.2133, 0.2199, -0.0864, -0.1143], [ 0.4205, 1.0258, 0.4022, -1.3172], [ 1.5405, 0.8545, 0.7009, 0.5874]]) clamp the tensor: into range [-0.6, 0.4]: tensor([[ 0.4000, 0.2199, -0.0864, -0.1143], [ 0.4000, 0.4000, 0.4000, -0.6000], [ 0.4000, 0.4000, 0.4000, 0.4000]]) if min is None (max=0.4): tensor([[ 0.4000, 0.2199, -0.0864, -0.1143], [ 0.4000, 0.4000, 0.4000, -1.3172], [ 0.4000, 0.4000, 0.4000, 0.4000]]) if max is None (min=-0.6): tensor([[ 1.2133, 0.2199, -0.0864, -0.1143], [ 0.4205, 1.0258, 0.4022, -0.6000], [ 1.5405, 0.8545, 0.7009, 0.5874]]) if min is greater than max (min=0.6, max=0.4): tensor([[0.4000, 0.4000, 0.4000, 0.4000], [0.4000, 0.4000, 0.4000, 0.4000], [0.4000, 0.4000, 0.4000, 0.4000]])
广告