Mixed Precision

Why Mixed Precision?

Modern GPUs have specialized hardware (Tensor Cores) that execute half-precision (FP16/BF16) operations 2-8x faster than FP32. Mixed precision training uses lower precision for most computations while keeping a master copy of weights in FP32 for numerical stability.

This approach provides:

  • 2x memory reduction for activations and gradients
  • 2-3x speedup from Tensor Core acceleration
  • Minimal accuracy loss when done correctly

This section covers the floating-point formats, PyTorch's autocast, loss scaling, and when lower precision is safe vs dangerous.

Definition:

IEEE 754 Floating-Point Formats for GPU Computing

Three formats dominate GPU computing:

Format Bits Exponent Mantissa Range Ο΅mach\epsilon_{\text{mach}}
FP32 32 8 23 Β±3.4Γ—1038\pm 3.4 \times 10^{38} 1.19Γ—10βˆ’71.19 \times 10^{-7}
FP16 16 5 10 Β±6.5Γ—104\pm 6.5 \times 10^{4} 9.77Γ—10βˆ’49.77 \times 10^{-4}
BF16 16 8 7 Β±3.4Γ—1038\pm 3.4 \times 10^{38} 7.81Γ—10βˆ’37.81 \times 10^{-3}

BFloat16 uses the same exponent range as FP32 (no overflow risk) but fewer mantissa bits (lower precision). It was designed by Google Brain specifically for deep learning.

import torch
x = torch.tensor(1.0)
print(torch.finfo(torch.float32).eps)   # 1.19e-07
print(torch.finfo(torch.float16).eps)   # 9.77e-04
print(torch.finfo(torch.bfloat16).eps)  # 7.81e-03

The key difference: FP16 can overflow at 65504, while BF16 matches FP32's range. Gradient values often exceed 65504, making BF16 safer for training without loss scaling.

Definition:

Automatic Mixed Precision (AMP) with torch.autocast

PyTorch's torch.autocast context manager automatically selects the optimal precision for each operation:

with torch.autocast('cuda', dtype=torch.float16):
    output = model(input)    # matmuls in FP16
    loss = criterion(output, target)  # loss in FP32

# Gradients computed in mixed precision
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

Operations that benefit from FP16 (matmul, conv) are cast down; operations that need FP32 precision (softmax, layer norm, loss) remain in FP32. This is called the autocast eligibility list.

With BF16, you typically do not need GradScaler because BF16 has the same exponent range as FP32, eliminating overflow risk.

Definition:

Gradient Scaling for FP16 Training

In FP16, small gradient values underflow to zero, causing training to diverge. torch.amp.GradScaler addresses this by:

  1. Scaling up the loss before backward: L~=sβ‹…L\tilde{L} = s \cdot L
  2. Computing gradients (now scaled by ss)
  3. Unscaling gradients before optimizer step: βˆ‡ΞΈ/s\nabla \theta / s
  4. Skipping the optimizer step if any gradients are inf/nan
  5. Adapting the scale factor ss dynamically
scaler = torch.amp.GradScaler()
for batch in loader:
    with torch.autocast('cuda'):
        loss = model(batch)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

Definition:

Tensor Cores and Precision Requirements

Tensor Cores are specialized matrix-multiply-accumulate units on NVIDIA GPUs (Volta and later) that compute:

D=AΓ—B+C\mathbf{D} = \mathbf{A} \times \mathbf{B} + \mathbf{C}

where A,B\mathbf{A}, \mathbf{B} are FP16/BF16/TF32 and C,D\mathbf{C}, \mathbf{D} are FP32. To use Tensor Cores:

  • Matrix dimensions must be multiples of 8 (FP16) or 4 (TF32)
  • Input tensors must be in the correct dtype
  • Use torch.autocast or explicit casting

A100 Tensor Core throughput: 312 TFLOPS (FP16) vs 19.5 TFLOPS (FP32) -- a 16x difference.

Always pad your dimensions to multiples of 8 for maximum Tensor Core utilization. A hidden size of 768 is better than 700.

Definition:

TensorFloat-32 (TF32) Mode

TF32 uses 19 bits (8 exponent + 10 mantissa + 1 sign) and is the default on Ampere+ GPUs for FP32 matrix multiplications. It provides ~10x speedup over true FP32 with minimal precision loss.

# TF32 is enabled by default on Ampere+
torch.backends.cuda.matmul.allow_tf32 = True   # default
torch.backends.cudnn.allow_tf32 = True          # default

# Disable for full FP32 precision (e.g., numerical validation)
torch.backends.cuda.matmul.allow_tf32 = False

TF32 is transparent: your code uses torch.float32 tensors, but the hardware internally uses TF32 for matmul. This can cause surprising numerical differences when comparing CPU vs GPU results.

Theorem: Rounding Error in Mixed-Precision Matrix Multiply

For matrix multiplication C=AB\mathbf{C} = \mathbf{A}\mathbf{B} with A∈RmΓ—k\mathbf{A} \in \mathbb{R}^{m \times k} computed in precision uu (machine epsilon) with accumulation in higher precision, the element-wise error satisfies:

∣C^ijβˆ’Cijβˆ£β‰€kβ‹…uβ‹…max⁑l∣Ailβˆ£β‹…βˆ£Blj∣+O(u2)|\hat{C}_{ij} - C_{ij}| \le k \cdot u \cdot \max_{l} |A_{il}| \cdot |B_{lj}| + O(u^2)

For FP16 (uβ‰ˆ10βˆ’3u \approx 10^{-3}) with k=1024k = 1024: relative error β‰ˆ1\approx 1. This is why accumulation in FP32 is essential for FP16 matmul.

Each multiply-add introduces a rounding error of order uu. With kk such operations per output element, errors accumulate linearly. FP32 accumulation keeps the accumulated result accurate despite FP16 inputs.

Example: Complete AMP Training Loop

Implement a training loop with automatic mixed precision, including gradient scaling for FP16 and measuring the speedup.

Example: Numerical Failures in Low Precision

Demonstrate scenarios where FP16 causes numerical issues and show how to detect and fix them.

Floating-Point Precision Comparison

Compare the representation error of FP32, FP16, and BF16 across different value ranges. Observe where each format loses precision or overflows.

Parameters

Mixed Precision Training

python
AMP training loop, GradScaler, BF16 vs FP16, and numerical pitfalls.
# Code from: ch13/python/mixed_precision.py
# Load from backend supplements endpoint

Quick Check

What is the main advantage of BFloat16 over Float16 for training?

BF16 is faster on all GPUs

BF16 has more mantissa bits, giving higher precision

BF16 has the same exponent range as FP32, avoiding overflow

BF16 uses less memory than FP16

Common Mistake: Autocast Outside Forward Pass

Mistake:

Using torch.autocast around the entire training loop including the loss computation, metric calculation, and logging. This can cause loss values to be in FP16, leading to inf losses that crash training.

Correction:

Wrap only the forward pass and loss computation in autocast. Keep metric computation, logging, and validation in FP32. PyTorch's autocast automatically handles common losses in FP32, but custom loss functions may not be protected.

Key Takeaway

Use torch.autocast('cuda', dtype=torch.bfloat16) as the default for mixed precision training on Ampere+ GPUs. BF16 eliminates the need for GradScaler and handles gradient magnitudes that overflow FP16. Reserve FP16 + GradScaler for older GPUs without BF16 support.

Why This Matters: Mixed Precision in Digital Signal Processing

In wireless DSP, precision requirements vary by operation. Channel estimation (computing H^=YXβˆ’1\hat{\mathbf{H}} = \mathbf{Y}\mathbf{X}^{-1}) needs FP32 for the matrix inverse, but the subsequent beamforming multiplication W=H^H\mathbf{W} = \hat{\mathbf{H}}^H can safely use FP16. OFDM FFTs are sensitive to precision in the twiddle factors but tolerant in the data path. Mixed precision strategies mirror the autocast approach: use high precision where needed, low precision where safe.

See full treatment in Image Plots

Historical Note: The Mixed Precision Training Paper

21st century

Micikevicius et al. (2018) at NVIDIA published "Mixed Precision Training," demonstrating that neural networks can be trained in FP16 with FP32 accumulation and loss scaling, achieving identical accuracy with 2x memory savings and significant speedup. Google later introduced BFloat16, which simplified the approach by eliminating the need for loss scaling. Today, mixed precision is the default training mode for most large-scale AI systems.

Float16 (Half Precision)

IEEE 754 half-precision format: 1 sign + 5 exponent + 10 mantissa bits. Range up to 65504, machine epsilon β‰ˆ10βˆ’3\approx 10^{-3}.

Related: BFloat16 (Brain Floating Point)

BFloat16 (Brain Floating Point)

Google's 16-bit format: 1 sign + 8 exponent + 7 mantissa bits. Same range as FP32, lower precision than FP16, but no overflow risk for gradients.

Related: Float16 (Half Precision)

Tensor Core

Specialized hardware units on NVIDIA GPUs (Volta+) that perform mixed-precision matrix-multiply-accumulate operations at up to 16x the throughput of standard FP32 CUDA cores.

Floating-Point Format Comparison

PropertyFP32FP16BF16TF32
Total bits32161619
Exponent bits8588
Mantissa bits2310710
Max value3.4Γ—10383.4 \times 10^{38}6.5Γ—1046.5 \times 10^{4}3.4Γ—10383.4 \times 10^{38}3.4Γ—10383.4 \times 10^{38}
Machine epsilon1.2Γ—10βˆ’71.2 \times 10^{-7}9.8Γ—10βˆ’49.8 \times 10^{-4}7.8Γ—10βˆ’37.8 \times 10^{-3}9.8Γ—10βˆ’49.8 \times 10^{-4}
Needs GradScalerNoYesNoNo
A100 TFLOPS19.5312312156