PyTorch – 如何计算矩阵的QR分解?


**torch.linalg.qr()** 计算矩阵或矩阵批次的 QR 分解。它接受浮点型、双精度型、复数浮点型和复数双精度型数据的矩阵和矩阵批次。

它返回一个命名元组 **(Q, R)。Q** 在矩阵为实数值时是正交的,在矩阵为复数值时是酉的。R 是一个上三角矩阵。

语法

(Q, R) = torch.linalg.qr(mat, mode='reduced')

参数

  • **Mat** – 方阵或方阵批次。

  • **mode** – 它决定 QR 分解的模式。它设置为三种模式之一:**'reduced'**、**'complete'** 和 **'r'**。默认为 'reduced'。这是一个可选参数。

步骤

  • 导入所需的库。在以下所有示例中,所需的 Python 库是 **torch**。确保您已安装它。

import torch
  • 创建一个矩阵或矩阵批次。这里我们定义一个大小为 [3, 2] 的矩阵(一个 2D torch 张量)。

mat = torch.tensor([[1.,12.],[14.,5.],[17.,-8.]])
  • 使用 **torch.linalg.qr(mat)** 计算输入矩阵或矩阵批次的 QR 分解。这里 mat 是输入矩阵。

Q, R = torch.linalg.qr(A)
  • 显示 Q 和 R。

print("Q:
", Q) print("R:
", R)

示例 1

在这个 Python 程序中,我们计算矩阵的 QR 分解。我们没有给出 mode 参数。它默认设置为 '**reduced**'。

# import necessary libraries
import torch

# create a matrix
mat = torch.tensor([[1.,12.],[14.,5.],[17.,-8.]])
print("Matrix:
", mat) # compute QR decomposition Q, R = torch.linalg.qr(mat) # print Q and S matrices print("Q:
",Q) print("R:
",R)

输出

它将产生以下输出:

Matrix:
   tensor([[ 1., 12.],
      [14., 5.],
      [17., -8.]])
Q:
   tensor([[-0.0454, 0.8038],
      [-0.6351, 0.4351],
      [-0.7711, -0.4056]])
R:
   tensor([[-22.0454, 2.4495],
      [ 0.0000, 15.0665]])

示例 2

在这个 Python 程序中,我们计算矩阵的 QR 分解。我们将 mode 设置为 'r'。

# import necessary libraries
import torch

# create a matrix
mat = torch.tensor([[1.,12.],[14.,5.],[17.,-8.]])
print("Matrix:
", mat) # compute QR decomposition Q, R = torch.linalg.qr(mat, mode = 'r') # print Q and S matrices print("Q:
",Q) print("R:
",R)

输出

它将产生以下输出:

Matrix:
   tensor([[ 1., 12.],
      [14., 5.],
      [17., -8.]])
Q:
   tensor([])
R:
   tensor([[-22.0454, 2.4495],
      [ 0.0000, 15.0665]])

示例 3

在这个 Python3 程序中,我们计算矩阵的 QR 分解。我们将 mode 设置为 'complete'。

# import necessary libraries
import torch

# create a matrix
mat = torch.tensor([[1.,12.],[14.,5.],[17.,-8.]])
print("Matrix:
", mat) # compute QR decomposition Q, R = torch.linalg.qr(mat, mode = 'complete') # print Q and S matrices print("Q:
", Q) print("R:
", R)

输出

它将产生以下输出:

Matrix:
   tensor([[ 1., 12.],
      [14., 5.],
      [17., -8.]])
Q:
   tensor([[-0.0454, 0.8038, 0.5931],
      [-0.6351, 0.4351, -0.6383],
      [-0.7711, -0.4056, 0.4907]])
R:
   tensor([[-22.0454, 2.4495],
      [ 0.0000, 15.0665],
      [ 0.0000, 0.0000]])

更新于:2022年1月7日

287 次浏览

启动您的职业生涯

完成课程获得认证

开始学习
广告