PyTorch Linear Algebra

Definition:

The torch.linalg Module

PyTorch provides torch.linalg, a namespace that mirrors NumPy's numpy.linalg with identical function signatures:

import torch

A = torch.randn(4, 4, dtype=torch.float64)
U, S, Vh = torch.linalg.svd(A)                # SVD
eigenvalues = torch.linalg.eigvalsh(A @ A.T)   # Hermitian eigenvalues
x = torch.linalg.solve(A, torch.randn(4, dtype=torch.float64))  # solve
n = torch.linalg.norm(A)                       # Frobenius norm

All torch.linalg functions:

  • Support GPU tensors (CUDA via cuSOLVER/cuBLAS).
  • Are differentiable via autograd.
  • Support batched operations on leading dimensions.

Definition:

Batched Linear Algebra

Most torch.linalg functions support batched inputs: if the input has shape (…,m,n)(\ldots, m, n), the operation is applied independently to each matrix in the batch.

# Batch of 100 channel matrices, each 4x4
H = torch.randn(100, 4, 4, dtype=torch.complex128)

# SVD of all 100 matrices in one call
U, S, Vh = torch.linalg.svd(H)
# U: (100, 4, 4), S: (100, 4), Vh: (100, 4, 4)

# Solve 100 systems simultaneously
b = torch.randn(100, 4, 1, dtype=torch.complex128)
x = torch.linalg.solve(H, b)   # (100, 4, 1)

Batching is vastly more efficient than Python loops because the GPU kernel processes all matrices in parallel.

In MIMO wireless, you often have thousands of channel realizations to process. Batched SVD on GPU can handle 10,000 4Γ—44 \times 4 matrices in the time it takes NumPy to do one 4Γ—44 \times 4 SVD on CPU.

Theorem: Differentiating Through SVD

For A=UΞ£VH\mathbf{A} = \mathbf{U}\boldsymbol{\Sigma}\mathbf{V}^H, the gradient of a scalar loss LL w.r.t. A\mathbf{A} involves:

βˆ‚Lβˆ‚A=U(βˆ‚Lβˆ‚Ξ£+F∘(UHβˆ‚Lβˆ‚Uβˆ’βˆ‚Lβˆ‚VVH))VH\frac{\partial L}{\partial \mathbf{A}} = \mathbf{U} \left(\frac{\partial L}{\partial \boldsymbol{\Sigma}} + \mathbf{F} \circ (\mathbf{U}^H \frac{\partial L}{\partial \mathbf{U}} - \frac{\partial L}{\partial \mathbf{V}} \mathbf{V}^H)\right) \mathbf{V}^H

where Fij=(Οƒj2βˆ’Οƒi2)βˆ’1F_{ij} = (\sigma_j^2 - \sigma_i^2)^{-1} for iβ‰ ji \neq j. PyTorch handles this automatically, but the formula reveals a numerical instability when singular values are close or repeated.

Differentiating through eigenvalue decompositions is tricky because the eigenvectors are not unique (sign flips, rotations in degenerate subspaces). The FF matrix has poles where singular values collide.

Example: Batched SVD on GPU

Compute the SVD of 1000 random 8Γ—88 \times 8 complex matrices on GPU and verify the reconstruction Aβ‰ˆUΞ£VH\mathbf{A} \approx \mathbf{U}\boldsymbol{\Sigma}\mathbf{V}^H.

Example: Optimizing Eigenvalues via Autograd

Find a symmetric matrix A\mathbf{A} that maximizes the smallest eigenvalue while keeping tr(A)=1\mathrm{tr}(\mathbf{A}) = 1, using gradient ascent through torch.linalg.eigvalsh.

SVD Low-Rank Approximation

Adjust the rank kk and see how the truncated SVD approximation Ak=βˆ‘i=1kΟƒiuiviT\mathbf{A}_k = \sum_{i=1}^k \sigma_i \mathbf{u}_i \mathbf{v}_i^T converges to the original matrix.

Parameters

NumPy vs. PyTorch Linear Algebra

OperationNumPyPyTorchDifferentiable?
SVDnp.linalg.svdtorch.linalg.svdYes
Eigenvalues (Hermitian)np.linalg.eightorch.linalg.eighYes
Solve Ax=bnp.linalg.solvetorch.linalg.solveYes
Choleskynp.linalg.choleskytorch.linalg.choleskyYes
QRnp.linalg.qrtorch.linalg.qrYes
Normnp.linalg.normtorch.linalg.normYes
Pseudoinversenp.linalg.pinvtorch.linalg.pinvYes
BatchedNoYes (leading dims)Yes

Quick Check

What is the shape of S when calling torch.linalg.svd on a tensor of shape (batch, 4, 6)?

(batch, 4, 6)

(batch, 4)

(batch, 6)

(batch, 4, 4)

Common Mistake: NaN Gradients from Degenerate SVD

Mistake:

Differentiating through SVD when singular values are repeated or near-zero produces NaN or Inf gradients because the backward formula involves (Οƒi2βˆ’Οƒj2)βˆ’1(\sigma_i^2 - \sigma_j^2)^{-1}.

Correction:

Add a small perturbation to the matrix before SVD when differentiating, or use torch.linalg.svd with caution and check for NaN gradients:

A_perturbed = A + 1e-6 * torch.eye(A.shape[-1], device=A.device)
U, S, Vh = torch.linalg.svd(A_perturbed)

Why This Matters: Batched SVD for MIMO Channel Analysis

In MIMO-OFDM systems, each subcarrier has its own channel matrix Hk\mathbf{H}_k. With 1024 subcarriers and 100 time slots, you need SVDs of 102,400 matrices. Batched torch.linalg.svd on GPU handles this in milliseconds, enabling real-time precoding and waterfilling capacity computation.

See full treatment in Chapter 35

Key Takeaway

The torch.linalg module mirrors NumPy's linear algebra API but adds GPU acceleration, autograd support, and batched operations. Batching is the key to performance: process thousands of matrices in a single kernel launch instead of Python loops.

Batched Operation

A linear algebra operation applied independently to each matrix in a batch, indexed by the leading dimensions of the tensor.

PyTorch Linear Algebra Operations

python
Comprehensive examples of torch.linalg: SVD, eigendecomposition, solve, Cholesky, QR, and batched operations with timing benchmarks.
# Code from: ch12/python/pytorch_linalg.py
# Load from backend supplements endpoint