Batched Operations
Why Batching Matters for GPU Performance
A GPU has thousands of cores but each CUDA kernel launch has overhead (~5-50 microseconds). If you launch one kernel per matrix multiply in a loop, the launch overhead dominates and the GPU sits idle most of the time. Batching combines many independent operations into a single kernel launch, keeping the GPU fully utilized.
This section covers torch.bmm, batched linear algebra, and how to
restructure scientific computations for maximum GPU throughput.
Definition: Batched Matrix Multiplication (BMM)
Batched Matrix Multiplication (BMM)
Given two batched tensors and , the batched matrix multiplication computes independent matrix products:
yielding .
A = torch.randn(64, 128, 256) # 64 matrices of shape 128x256
B = torch.randn(64, 256, 64) # 64 matrices of shape 256x64
C = torch.bmm(A, B) # 64 matrices of shape 128x64
# Equivalent using @ operator:
C = A @ B
Under the hood, torch.bmm calls cuBLAS cublasSgemmBatched,
which processes all multiplications in a single kernel launch.
The speedup over a Python loop is typically 10-100x.
Definition: Batched Operations via Einstein Summation
Batched Operations via Einstein Summation
torch.einsum provides a concise notation for batched operations:
# Batched matrix multiply
C = torch.einsum('bij,bjk->bik', A, B)
# Batched dot product
dots = torch.einsum('bi,bi->b', X, Y)
# Batched outer product
outer = torch.einsum('bi,bj->bij', u, v)
# Batched trace
traces = torch.einsum('bii->b', M)
Einstein summation is particularly useful when the batch dimension is not the first axis or when operations involve contractions over multiple indices.
Definition: Batched Singular Value Decomposition
Batched Singular Value Decomposition
PyTorch's torch.linalg.svd natively supports batched input:
A = torch.randn(100, 64, 32) # 100 matrices of shape 64x32
U, S, Vh = torch.linalg.svd(A, full_matrices=False)
# U: (100, 64, 32), S: (100, 32), Vh: (100, 32, 32)
All batched torch.linalg functions (svd, solve, eigh,
cholesky, inv) follow the same convention: batch dimensions
are all dimensions except the last two.
Batched eigh is especially useful for simultaneously diagonalizing
covariance matrices across frequency bins or time slots in MIMO-OFDM.
Theorem: Kernel Launch Overhead and Batching Speedup
Let be the kernel execution time for a single operation and be the kernel launch overhead. For independent operations:
- Sequential (loop):
- Batched:
where is the execution time with batch parallelism. For small (small matrices), the speedup approaches:
For large , the speedup saturates as compute dominates launch overhead.
Batching eliminates kernel launches. The bigger the ratio of launch overhead to compute, the greater the benefit. Small matrix operations (e.g., covariance matrices) benefit enormously.
Theorem: Arithmetic Intensity and the Roofline Model
The arithmetic intensity of an operation is:
For matrix multiplication with , :
An operation is compute-bound if , and memory-bound otherwise. On an A100: FLOP/byte.
Large matrix multiplications have high arithmetic intensity and are compute-bound (GPU cores are the bottleneck). Element-wise operations have and are memory-bound (bandwidth is the bottleneck).
Example: Loop vs Batched Matrix Multiply Benchmark
Compare the speed of computing 1000 independent matrix
multiplications using a Python loop vs torch.bmm.
Setup
import torch
import time
device = torch.device('cuda')
B, M, K, N = 1000, 32, 32, 32
A = torch.randn(B, M, K, device=device)
B_mat = torch.randn(B, K, N, device=device)
Loop approach
torch.cuda.synchronize()
t0 = time.perf_counter()
results = [A[i] @ B_mat[i] for i in range(B)]
C_loop = torch.stack(results)
torch.cuda.synchronize()
t_loop = time.perf_counter() - t0
print(f"Loop: {t_loop*1000:.1f} ms")
Batched approach
torch.cuda.synchronize()
t0 = time.perf_counter()
C_batch = torch.bmm(A, B_mat)
torch.cuda.synchronize()
t_batch = time.perf_counter() - t0
print(f"Batched: {t_batch*1000:.1f} ms")
print(f"Speedup: {t_loop/t_batch:.1f}x")
# Typical result: 20-100x speedup
Example: Batched Cholesky Solve for Multiple Systems
Solve 500 independent linear systems where each is a positive definite covariance matrix (as in MIMO detection across subcarriers).
Generate positive definite matrices
import torch
B, N = 500, 16
device = torch.device('cuda')
# Random PD matrices: A @ A^T + I
A = torch.randn(B, N, N, device=device)
R = A @ A.mT + torch.eye(N, device=device)
b = torch.randn(B, N, 1, device=device)
Batched Cholesky solve
L = torch.linalg.cholesky(R) # (500, 16, 16)
x = torch.cholesky_solve(b, L) # (500, 16, 1)
# Verify: R @ x should equal b
residual = (R @ x - b).norm() / b.norm()
print(f"Relative residual: {residual.item():.2e}")
Batching Speedup Explorer
Compare execution time of loop-based vs batched operations as a function of batch size and matrix dimension. Observe how small matrices benefit most from batching due to kernel launch overhead amortization.
Parameters
Batched Operations
# Code from: ch13/python/batched_operations.py
# Load from backend supplements endpointQuick Check
You need to compute for
. Which is the correct einsum expression?
torch.einsum('bij,bjk->bik', A, B)
torch.einsum('bji,bjk->bik', A, B)
torch.einsum('bij,bkj->bik', A, B)
torch.einsum('ibj,bjk->bik', A, B)
torch.einsum('bji,bjk->bik', A, B)Contracting over j with A's indices swapped (ji instead of ij) implements the transpose.
Common Mistake: Python Loops Over Batch Dimension on GPU
Mistake:
Writing for i in range(B): result[i] = func(A[i]) on CUDA tensors.
Each iteration launches a separate kernel, and the Python loop adds
interpreter overhead. With B=1000 and small matrices, this can be
50-100x slower than a batched call.
Correction:
Use batched operations: result = func(A) where A has a batch
dimension. Most torch.linalg and torch.nn.functional operations
support arbitrary batch dimensions.
Key Takeaway
Always batch independent operations into a single tensor operation.
The speedup from eliminating kernel launch overhead is 10-100x for
small matrices. Use torch.bmm, torch.einsum, or batched
torch.linalg functions instead of Python loops over the batch dimension.
Historical Note: From BLAS to Batched BLAS
21st centuryThe Basic Linear Algebra Subprograms (BLAS) were first published in 1979 by Lawson, Hanson, Kincaid, and Krogh. BLAS Level 3 (matrix-matrix operations) was added in 1990. The batched extension came much later with cuBLAS in CUDA 4.1 (2012), driven by deep learning workloads that require thousands of small matrix multiplications per layer. Today, batched GEMM is the single most important operation in GPU computing.
Batched Matrix Multiply (BMM)
Computing independent matrix products in a single GPU kernel launch via torch.bmm or the @ operator on 3D tensors.
Related: GEMM (General Matrix Multiply)
GEMM (General Matrix Multiply)
The fundamental operation. The batched variant (batched GEMM) is the workhorse of GPU computing.
Related: Batched Matrix Multiply (BMM)
Batched Operations in PyTorch
| Operation | Function | Input Shapes | Output Shape |
|---|---|---|---|
| Matrix multiply | torch.bmm(A, B) | (B, m, k) + (B, k, n) | (B, m, n) |
| Solve Ax=b | torch.linalg.solve(A, b) | (B, n, n) + (B, n, p) | (B, n, p) |
| SVD | torch.linalg.svd(A) | (B, m, n) | U(B,m,k), S(B,k), Vh(B,k,n) |
| Cholesky | torch.linalg.cholesky(A) | (B, n, n) | (B, n, n) |
| Eigenvalues | torch.linalg.eigh(A) | (B, n, n) | vals(B,n), vecs(B,n,n) |
| Inverse | torch.linalg.inv(A) | (B, n, n) | (B, n, n) |