Memory Management on GPU

Why GPU Memory Management Matters

GPU memory is a scarce resource: a typical research GPU has 16-80 GB of VRAM, compared to hundreds of GB of system RAM. When training large models or running big simulations, out-of-memory (OOM) errors are the single most common failure mode. Understanding how PyTorch allocates, caches, and frees GPU memory is essential for writing production code that does not crash at 3 AM.

This section covers the PyTorch caching allocator, memory tracking tools, gradient checkpointing, and practical strategies for reducing peak memory usage.

Definition:

GPU Memory Hierarchy

A modern NVIDIA GPU has multiple levels of memory:

  • Global Memory (VRAM): 16--80 GB, high bandwidth (~2 TB/s on A100), but high latency (~400 cycles). This is where tensors live.
  • Shared Memory (SMEM): 48--228 KB per Streaming Multiprocessor (SM), low latency (~20 cycles), programmer-managed.
  • Registers: ~256 KB per SM, fastest access (1 cycle).
  • L2 Cache: 6--50 MB, hardware-managed.

PyTorch users primarily manage global memory; shared memory is relevant only when writing custom CUDA kernels.

The memory bandwidth of modern GPUs (2+ TB/s on A100) is often the bottleneck, not compute. Many operations are memory-bound: they spend more time moving data than computing on it.

Definition:

PyTorch Caching Memory Allocator

PyTorch does not call cudaMalloc/cudaFree for every tensor creation and deletion. Instead, it maintains a caching allocator that:

  1. Requests large blocks from CUDA and subdivides them
  2. Keeps freed blocks in a cache for reuse
  3. Only returns memory to CUDA when explicitly asked via torch.cuda.empty_cache()

This is why nvidia-smi shows more memory used than torch.cuda.memory_allocated(): the difference is cached but unused memory.

import torch
# Allocated: memory currently used by tensors
alloc = torch.cuda.memory_allocated()
# Reserved: total memory held by the caching allocator
reserved = torch.cuda.memory_reserved()
# Cached but free = reserved - allocated

The caching allocator dramatically reduces allocation overhead. A cudaMalloc call can take 1--10 ms; a cached allocation takes microseconds.

Definition:

GPU Memory Fragmentation

Fragmentation occurs when free memory exists but is split into non-contiguous blocks too small to satisfy a large allocation request.

Symptoms:

  • torch.cuda.memory_allocated() shows plenty of free memory
  • Yet a new allocation raises CUDA out of memory
  • torch.cuda.memory_stats()['active_blocks.all.current'] is high

Mitigation strategies:

  • Pre-allocate tensors and reuse them (avoid dynamic allocation patterns)
  • Use torch.cuda.empty_cache() to return fragmented blocks to CUDA
  • Set max_split_size_mb via PYTORCH_CUDA_ALLOC_CONF to control splitting

Setting PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512 prevents the allocator from splitting blocks smaller than 512 MB, which can reduce fragmentation at the cost of slightly higher peak usage.

Definition:

Memory Tracking and Snapshots

PyTorch provides detailed memory tracking:

# Summary statistics
print(torch.cuda.memory_summary())

# Detailed allocation history
torch.cuda.memory._record_memory_history()
# ... run your code ...
snapshot = torch.cuda.memory._snapshot()
torch.cuda.memory._record_memory_history(enabled=None)

The snapshot contains every allocation and free event with stack traces, enabling you to identify exactly which line of code is responsible for peak memory usage.

Definition:

Gradient Checkpointing

During backpropagation, PyTorch stores all intermediate activations from the forward pass. For a network with LL layers and activation size aa, this costs O(Lβ‹…a)O(L \cdot a) memory.

Gradient checkpointing trades compute for memory by:

  1. Storing activations only at selected checkpoint layers
  2. Recomputing intermediate activations during the backward pass

With L\sqrt{L} checkpoints, memory drops to O(Lβ‹…a)O(\sqrt{L} \cdot a) at the cost of one extra forward pass (~33% more compute).

from torch.utils.checkpoint import checkpoint

class MyModel(nn.Module):
    def forward(self, x):
        x = checkpoint(self.block1, x, use_reentrant=False)
        x = checkpoint(self.block2, x, use_reentrant=False)
        return self.head(x)

Always use use_reentrant=False (the new default in PyTorch 2.x). The legacy use_reentrant=True has subtle bugs with multiple requires_grad inputs.

Definition:

In-Place Operations and Memory Savings

In-place operations modify a tensor without allocating new memory:

x.add_(1)       # in-place add
x.relu_()       # in-place ReLU
x.mul_(0.99)    # in-place multiply (weight decay)

This saves one tensor allocation per operation. However, in-place operations on tensors that require gradients can corrupt the autograd graph. PyTorch raises a RuntimeError if it detects this.

In-place operations are safe in the optimizer step (e.g., SGD parameter updates) because parameter updates happen outside the computation graph.

Theorem: Optimal Checkpointing Memory Bound

For a sequential network with LL layers, each producing activations of size aa, the minimum peak memory with gradient checkpointing is:

Mpeak=O ⁣(Lβ‹…a)M_{\text{peak}} = O\!\left(\sqrt{L} \cdot a\right)

achieved by placing checkpoints every L\sqrt{L} layers. Without checkpointing, Mpeak=O(Lβ‹…a)M_{\text{peak}} = O(L \cdot a).

With kk checkpoints, you store kk activations plus at most L/kL/k activations between consecutive checkpoints during recomputation. The total is k+L/kk + L/k, minimized at k=Lk = \sqrt{L}.

Theorem: Autograd Memory Overhead

For a computation y=fL∘fLβˆ’1βˆ˜β‹―βˆ˜f1(x)y = f_L \circ f_{L-1} \circ \cdots \circ f_1(x) with gradient computation, the peak memory is:

Mpeak=βˆ‘i=1L∣ai∣⏟savedΒ activations+∣θ∣⏟parameters+βˆ£βˆ‡ΞΈβˆ£βŸgradientsM_{\text{peak}} = \underbrace{\sum_{i=1}^L |a_i|}_{\text{saved activations}} + \underbrace{|\theta|}_{\text{parameters}} + \underbrace{|\nabla \theta|}_{\text{gradients}}

where ∣ai∣|a_i| is the size of activation ai=fi(aiβˆ’1)a_i = f_i(a_{i-1}). The activations typically dominate: for a ResNet-50 with batch size 32, activations use ~6 GB while parameters use only ~100 MB.

Parameters are fixed-size. Activations scale with batch size and spatial resolution, making them the primary memory consumer.

Theorem: Practical Peak Memory Estimation

For a model with PP parameters in FP32, batch size BB, and average activation size AA bytes per sample per layer across LL layers, the approximate peak GPU memory during training is:

Mpeakβ‰ˆ4P⏟params+4P⏟gradients+8P⏟optimizerΒ (Adam)+Bβ‹…Lβ‹…A⏟activationsM_{\text{peak}} \approx \underbrace{4P}_{\text{params}} + \underbrace{4P}_{\text{gradients}} + \underbrace{8P}_{\text{optimizer (Adam)}} + \underbrace{B \cdot L \cdot A}_{\text{activations}}

where the factor 8 for Adam comes from storing both first and second moment estimates (each 4 bytes in FP32).

Adam stores 2 extra copies of the parameters (momentum and variance). The 16P16P bytes for parameters+gradients+optimizer is fixed; activations scale linearly with batch size.

Example: Tracking GPU Memory During Training

Monitor GPU memory allocation through a training step to identify the peak memory usage and which operation causes it.

Example: Applying Gradient Checkpointing to Reduce Memory

Compare peak GPU memory of a deep network with and without gradient checkpointing. Measure the memory savings and compute overhead.

Example: Debugging CUDA Out-of-Memory Errors

Your training script crashes with CUDA out of memory. Systematically diagnose the cause and fix it without simply reducing batch size.

GPU Memory Allocation Timeline

Visualize GPU memory usage through forward pass, backward pass, and optimizer step. Adjust model depth and batch size to see how each factor affects peak memory.

Parameters

GPU Memory Hierarchy

GPU Memory Hierarchy
The GPU memory hierarchy from registers (fastest, smallest) to global memory (slowest, largest). PyTorch tensors reside in global memory; CUDA kernels stage data through shared memory and registers.

GPU Memory Management

python
Memory tracking, gradient checkpointing, and OOM debugging strategies.
# Code from: ch13/python/memory_management.py
# Load from backend supplements endpoint

Quick Check

After deleting a large tensor with del tensor, you notice that nvidia-smi still shows the same GPU memory usage. Why?

Python's garbage collector has not run yet

PyTorch's caching allocator keeps the memory reserved for reuse

The tensor is still referenced somewhere in the computation graph

nvidia-smi has a reporting delay

Common Mistake: Accumulating Loss Without .item()

Mistake:

Writing total_loss += loss in a training loop. This keeps the entire computation graph alive because loss is a tensor with grad_fn. After many iterations, GPU memory grows without bound.

Correction:

Use total_loss += loss.item() to extract the scalar value and discard the graph. Alternatively, loss.detach() removes the gradient connection.

Key Takeaway

Activations dominate training memory, not parameters. For a typical model, activations use 10-100x more memory than the model weights. Gradient checkpointing reduces activation memory from O(L)O(L) to O(L)O(\sqrt{L}) at the cost of ~33% more compute.

Why This Matters: Memory Management in Massive MIMO Simulations

Simulating a massive MIMO system with 64 antennas, 16 users, and 1000 subcarriers requires storing channel matrices of shape (B,64,16,1000)(B, 64, 16, 1000) in complex64, consuming BΓ—8B \times 8 MB per batch. With B=128B = 128, that is already 1 GB for channels alone. Memory management techniques from this section (checkpointing, mixed precision, pre-allocation) are essential for fitting these simulations on a single GPU.

See full treatment in JAX: Functional Numerical Computing

Historical Note: GPU Memory Growth

21st century

The original NVIDIA GeForce 256 (1999) had 32 MB of DDR memory. The A100 (2020) has 80 GB of HBM2e -- a 2500x increase in 21 years. Yet model sizes have grown even faster: GPT-3 (2020) requires 350 GB just for parameters in FP16, far exceeding any single GPU. This gap between model size and GPU memory is the fundamental driver behind memory optimization techniques like gradient checkpointing, mixed precision, and model parallelism.

VRAM (Video RAM)

The dedicated high-bandwidth memory on a GPU, used to store tensors, activations, and model parameters. Typically HBM2/HBM3 on data center GPUs.

Related: Caching Allocator

Caching Allocator

PyTorch's memory management layer that pools GPU memory blocks for reuse, avoiding expensive cudaMalloc/cudaFree calls on every tensor allocation.

Related: VRAM (Video RAM)

Memory Fragmentation

A state where free GPU memory is split into small non-contiguous blocks, preventing large allocations even when total free memory is sufficient.

Gradient Checkpointing

A technique that reduces memory usage by recomputing intermediate activations during the backward pass instead of storing them all from the forward pass.

OOM (Out of Memory)

A CUDA error raised when the GPU cannot allocate the requested memory. The most common error in GPU-accelerated deep learning.