PyTorch – 如何计算矩阵的QR分解?
**torch.linalg.qr()** 计算矩阵或矩阵批次的 QR 分解。它接受浮点型、双精度型、复数浮点型和复数双精度型数据的矩阵和矩阵批次。
它返回一个命名元组 **(Q, R)。Q** 在矩阵为实数值时是正交的,在矩阵为复数值时是酉的。R 是一个上三角矩阵。
语法
(Q, R) = torch.linalg.qr(mat, mode='reduced')
参数
**Mat** – 方阵或方阵批次。
**mode** – 它决定 QR 分解的模式。它设置为三种模式之一:**'reduced'**、**'complete'** 和 **'r'**。默认为 'reduced'。这是一个可选参数。
步骤
导入所需的库。在以下所有示例中,所需的 Python 库是 **torch**。确保您已安装它。
import torch
创建一个矩阵或矩阵批次。这里我们定义一个大小为 [3, 2] 的矩阵(一个 2D torch 张量)。
mat = torch.tensor([[1.,12.],[14.,5.],[17.,-8.]])
使用 **torch.linalg.qr(mat)** 计算输入矩阵或矩阵批次的 QR 分解。这里 mat 是输入矩阵。
Q, R = torch.linalg.qr(A)
显示 Q 和 R。
print("Q:
", Q)
print("R:
", R)示例 1
在这个 Python 程序中,我们计算矩阵的 QR 分解。我们没有给出 mode 参数。它默认设置为 '**reduced**'。
# import necessary libraries
import torch
# create a matrix
mat = torch.tensor([[1.,12.],[14.,5.],[17.,-8.]])
print("Matrix:
", mat)
# compute QR decomposition
Q, R = torch.linalg.qr(mat)
# print Q and S matrices
print("Q:
",Q)
print("R:
",R)输出
它将产生以下输出:
Matrix: tensor([[ 1., 12.], [14., 5.], [17., -8.]]) Q: tensor([[-0.0454, 0.8038], [-0.6351, 0.4351], [-0.7711, -0.4056]]) R: tensor([[-22.0454, 2.4495], [ 0.0000, 15.0665]])
示例 2
在这个 Python 程序中,我们计算矩阵的 QR 分解。我们将 mode 设置为 'r'。
# import necessary libraries
import torch
# create a matrix
mat = torch.tensor([[1.,12.],[14.,5.],[17.,-8.]])
print("Matrix:
", mat)
# compute QR decomposition
Q, R = torch.linalg.qr(mat, mode = 'r')
# print Q and S matrices
print("Q:
",Q)
print("R:
",R)输出
它将产生以下输出:
Matrix: tensor([[ 1., 12.], [14., 5.], [17., -8.]]) Q: tensor([]) R: tensor([[-22.0454, 2.4495], [ 0.0000, 15.0665]])
示例 3
在这个 Python3 程序中,我们计算矩阵的 QR 分解。我们将 mode 设置为 'complete'。
# import necessary libraries
import torch
# create a matrix
mat = torch.tensor([[1.,12.],[14.,5.],[17.,-8.]])
print("Matrix:
", mat)
# compute QR decomposition
Q, R = torch.linalg.qr(mat, mode = 'complete')
# print Q and S matrices
print("Q:
", Q)
print("R:
", R)输出
它将产生以下输出:
Matrix: tensor([[ 1., 12.], [14., 5.], [17., -8.]]) Q: tensor([[-0.0454, 0.8038, 0.5931], [-0.6351, 0.4351, -0.6383], [-0.7711, -0.4056, 0.4907]]) R: tensor([[-22.0454, 2.4495], [ 0.0000, 15.0665], [ 0.0000, 0.0000]])
广告
数据结构
网络
关系数据库管理系统 (RDBMS)
操作系统
Java
iOS
HTML
CSS
Android
Python
C语言编程
C++
C#
MongoDB
MySQL
Javascript
PHP