PyTorch – torch.linalg.solve() 方法
为了解决具有唯一解的线性方程组,我们可以使用 **torch.linalg.solve()** 方法。此方法接受两个参数:
首先,系数矩阵 **A**,以及
其次,右侧张量 **b**。
其中 **A** 是一个方阵,b 是一个向量。如果 A 可逆,则解是唯一的。我们可以求解多个线性方程组。在这种情况下,A 是一批方阵,b 是一批向量。
语法
torch.linalg.solve(A, b)
参数
**A** – 方阵或方阵批次。它是线性方程组的系数矩阵。
**b** – 向量或向量批次。它是线性系统的右侧张量。
它返回线性方程组解的张量。
**注意** – 此方法假设系数矩阵 A 是可逆的。如果它不可逆,则会引发运行时错误。
步骤
我们可以使用以下步骤来解决线性方程组。
导入所需的库。在以下所有示例中,所需的 Python 库为 **torch**。确保您已安装它。
import torch
为给定的线性方程组定义系数矩阵和右侧张量。
A = torch.tensor([[2., 3.],[1., -2.]]) b = torch.tensor([3., 0.])
使用 torch.linalg.solve(A,b) 计算唯一解。系数矩阵 A 必须可逆。
X = torch.linalg.solve(A, b)
显示解决方案。
print("Solution:
", X)
检查计算出的解是否正确。
print(torch.allclose(A @ X, b)) # True for correct solution
示例 1
请查看以下示例:
# import required library import torch ''' Let's suppose our square system of linear equations is: 2x + 3y = 3 x - 2y = 0 ''' print("Linear equation:") print("2x + 3y = 3") print("x - 2y = 0") # define the coefficient matrix A A = torch.tensor([[2., 3.],[1., -2.]]) # define right hand side tensor b b = torch.tensor([3., 0.]) # Solve the linear equation X = torch.linalg.solve(A, b) # print the solution of above linear equation print("Solution:
", X) # check above solution to be true print(torch.allclose(A @ X, b))
输出
它将产生以下输出:
Linear equation: 2x + 3y = 3 x - 2y = 0 Solution: tensor([0.8571, 0.4286]) True
示例 2
让我们再举一个例子:
# import required library import torch # define the coefficient matrix A for a 3x3 # square system of linear equations A = torch.randn(3,3) # define right hand side tensor b b = torch.randn(3) # Solve the linear equation X = torch.linalg.solve(A, b) # print the solution of above linear equation print("Solution:
", X) # check above solution to be true print(torch.allclose(A @ X, b))
输出
它将产生以下输出:
Solution: tensor([-0.2867, -0.9850, 0.9938]) True
广告