Mixed Precision
Why Mixed Precision?
Modern GPUs have specialized hardware (Tensor Cores) that execute half-precision (FP16/BF16) operations 2-8x faster than FP32. Mixed precision training uses lower precision for most computations while keeping a master copy of weights in FP32 for numerical stability.
This approach provides:
- 2x memory reduction for activations and gradients
- 2-3x speedup from Tensor Core acceleration
- Minimal accuracy loss when done correctly
This section covers the floating-point formats, PyTorch's autocast,
loss scaling, and when lower precision is safe vs dangerous.
Definition: IEEE 754 Floating-Point Formats for GPU Computing
IEEE 754 Floating-Point Formats for GPU Computing
Three formats dominate GPU computing:
| Format | Bits | Exponent | Mantissa | Range | |
|---|---|---|---|---|---|
| FP32 | 32 | 8 | 23 | ||
| FP16 | 16 | 5 | 10 | ||
| BF16 | 16 | 8 | 7 |
BFloat16 uses the same exponent range as FP32 (no overflow risk) but fewer mantissa bits (lower precision). It was designed by Google Brain specifically for deep learning.
import torch
x = torch.tensor(1.0)
print(torch.finfo(torch.float32).eps) # 1.19e-07
print(torch.finfo(torch.float16).eps) # 9.77e-04
print(torch.finfo(torch.bfloat16).eps) # 7.81e-03
The key difference: FP16 can overflow at 65504, while BF16 matches FP32's range. Gradient values often exceed 65504, making BF16 safer for training without loss scaling.
Definition: Automatic Mixed Precision (AMP) with torch.autocast
Automatic Mixed Precision (AMP) with torch.autocast
PyTorch's torch.autocast context manager automatically selects
the optimal precision for each operation:
with torch.autocast('cuda', dtype=torch.float16):
output = model(input) # matmuls in FP16
loss = criterion(output, target) # loss in FP32
# Gradients computed in mixed precision
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Operations that benefit from FP16 (matmul, conv) are cast down; operations that need FP32 precision (softmax, layer norm, loss) remain in FP32. This is called the autocast eligibility list.
With BF16, you typically do not need GradScaler because BF16
has the same exponent range as FP32, eliminating overflow risk.
Definition: Gradient Scaling for FP16 Training
Gradient Scaling for FP16 Training
In FP16, small gradient values underflow to zero, causing training
to diverge. torch.amp.GradScaler addresses this by:
- Scaling up the loss before backward:
- Computing gradients (now scaled by )
- Unscaling gradients before optimizer step:
- Skipping the optimizer step if any gradients are
inf/nan - Adapting the scale factor dynamically
scaler = torch.amp.GradScaler()
for batch in loader:
with torch.autocast('cuda'):
loss = model(batch)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Definition: Tensor Cores and Precision Requirements
Tensor Cores and Precision Requirements
Tensor Cores are specialized matrix-multiply-accumulate units on NVIDIA GPUs (Volta and later) that compute:
where are FP16/BF16/TF32 and are FP32. To use Tensor Cores:
- Matrix dimensions must be multiples of 8 (FP16) or 4 (TF32)
- Input tensors must be in the correct dtype
- Use
torch.autocastor explicit casting
A100 Tensor Core throughput: 312 TFLOPS (FP16) vs 19.5 TFLOPS (FP32) -- a 16x difference.
Always pad your dimensions to multiples of 8 for maximum Tensor Core utilization. A hidden size of 768 is better than 700.
Definition: TensorFloat-32 (TF32) Mode
TensorFloat-32 (TF32) Mode
TF32 uses 19 bits (8 exponent + 10 mantissa + 1 sign) and is the default on Ampere+ GPUs for FP32 matrix multiplications. It provides ~10x speedup over true FP32 with minimal precision loss.
# TF32 is enabled by default on Ampere+
torch.backends.cuda.matmul.allow_tf32 = True # default
torch.backends.cudnn.allow_tf32 = True # default
# Disable for full FP32 precision (e.g., numerical validation)
torch.backends.cuda.matmul.allow_tf32 = False
TF32 is transparent: your code uses torch.float32 tensors, but
the hardware internally uses TF32 for matmul. This can cause
surprising numerical differences when comparing CPU vs GPU results.
Theorem: Rounding Error in Mixed-Precision Matrix Multiply
For matrix multiplication with computed in precision (machine epsilon) with accumulation in higher precision, the element-wise error satisfies:
For FP16 () with : relative error . This is why accumulation in FP32 is essential for FP16 matmul.
Each multiply-add introduces a rounding error of order . With such operations per output element, errors accumulate linearly. FP32 accumulation keeps the accumulated result accurate despite FP16 inputs.
Example: Complete AMP Training Loop
Implement a training loop with automatic mixed precision, including gradient scaling for FP16 and measuring the speedup.
Setup
import torch
import torch.nn as nn
import time
device = torch.device('cuda')
model = nn.Sequential(
nn.Linear(1024, 4096), nn.ReLU(),
nn.Linear(4096, 4096), nn.ReLU(),
nn.Linear(4096, 10),
).to(device)
optimizer = torch.optim.Adam(model.parameters())
scaler = torch.amp.GradScaler()
FP32 baseline
x = torch.randn(256, 1024, device=device)
target = torch.randint(10, (256,), device=device)
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(100):
output = model(x)
loss = nn.functional.cross_entropy(output, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
torch.cuda.synchronize()
t_fp32 = time.perf_counter() - t0
AMP version
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(100):
with torch.autocast('cuda'):
output = model(x)
loss = nn.functional.cross_entropy(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
torch.cuda.synchronize()
t_amp = time.perf_counter() - t0
print(f"FP32: {t_fp32:.2f}s, AMP: {t_amp:.2f}s, "
f"Speedup: {t_fp32/t_amp:.2f}x")
Example: Numerical Failures in Low Precision
Demonstrate scenarios where FP16 causes numerical issues and show how to detect and fix them.
Overflow in FP16
import torch
# FP16 max is 65504
x = torch.tensor(300.0, dtype=torch.float16)
print(x * x) # tensor(inf, dtype=torch.float16) -- overflow!
# BF16 handles this
x_bf = torch.tensor(300.0, dtype=torch.bfloat16)
print(x_bf * x_bf) # tensor(90000., dtype=torch.bfloat16)
Precision loss in summation
# Sum of 10000 small values
vals = torch.ones(10000, dtype=torch.float16) * 0.01
print(f"FP16 sum: {vals.sum():.4f}") # May show ~100 (OK here)
# But accumulation of 1 + small fails
x = torch.tensor(1.0, dtype=torch.float16)
for _ in range(10000):
x = x + torch.tensor(1e-4, dtype=torch.float16)
print(f"Expected 2.0, got {x.item():.4f}") # Shows 1.0 (eps too small)
Safe operations outside autocast
# Force critical operations to FP32
with torch.autocast('cuda'):
logits = model(x) # FP16 ok
# Loss computed in FP32 (autocast does this automatically)
loss = nn.functional.cross_entropy(
logits.float(), target) # explicit FP32
Floating-Point Precision Comparison
Compare the representation error of FP32, FP16, and BF16 across different value ranges. Observe where each format loses precision or overflows.
Parameters
Mixed Precision Training
# Code from: ch13/python/mixed_precision.py
# Load from backend supplements endpointQuick Check
What is the main advantage of BFloat16 over Float16 for training?
BF16 is faster on all GPUs
BF16 has more mantissa bits, giving higher precision
BF16 has the same exponent range as FP32, avoiding overflow
BF16 uses less memory than FP16
BF16 uses 8 exponent bits (same as FP32), so gradient values that would overflow FP16 (max 65504) are representable. This eliminates the need for loss scaling.
Common Mistake: Autocast Outside Forward Pass
Mistake:
Using torch.autocast around the entire training loop including
the loss computation, metric calculation, and logging. This can
cause loss values to be in FP16, leading to inf losses that
crash training.
Correction:
Wrap only the forward pass and loss computation in autocast.
Keep metric computation, logging, and validation in FP32.
PyTorch's autocast automatically handles common losses in FP32,
but custom loss functions may not be protected.
Key Takeaway
Use torch.autocast('cuda', dtype=torch.bfloat16) as the default
for mixed precision training on Ampere+ GPUs. BF16 eliminates the
need for GradScaler and handles gradient magnitudes that overflow
FP16. Reserve FP16 + GradScaler for older GPUs without BF16 support.
Why This Matters: Mixed Precision in Digital Signal Processing
In wireless DSP, precision requirements vary by operation. Channel
estimation (computing )
needs FP32 for the matrix inverse, but the subsequent beamforming
multiplication can safely use FP16.
OFDM FFTs are sensitive to precision in the twiddle factors but
tolerant in the data path. Mixed precision strategies mirror the
autocast approach: use high precision where needed, low precision
where safe.
See full treatment in Image Plots
Historical Note: The Mixed Precision Training Paper
21st centuryMicikevicius et al. (2018) at NVIDIA published "Mixed Precision Training," demonstrating that neural networks can be trained in FP16 with FP32 accumulation and loss scaling, achieving identical accuracy with 2x memory savings and significant speedup. Google later introduced BFloat16, which simplified the approach by eliminating the need for loss scaling. Today, mixed precision is the default training mode for most large-scale AI systems.
Float16 (Half Precision)
IEEE 754 half-precision format: 1 sign + 5 exponent + 10 mantissa bits. Range up to 65504, machine epsilon .
Related: BFloat16 (Brain Floating Point)
BFloat16 (Brain Floating Point)
Google's 16-bit format: 1 sign + 8 exponent + 7 mantissa bits. Same range as FP32, lower precision than FP16, but no overflow risk for gradients.
Related: Float16 (Half Precision)
Tensor Core
Specialized hardware units on NVIDIA GPUs (Volta+) that perform mixed-precision matrix-multiply-accumulate operations at up to 16x the throughput of standard FP32 CUDA cores.
Floating-Point Format Comparison
| Property | FP32 | FP16 | BF16 | TF32 |
|---|---|---|---|---|
| Total bits | 32 | 16 | 16 | 19 |
| Exponent bits | 8 | 5 | 8 | 8 |
| Mantissa bits | 23 | 10 | 7 | 10 |
| Max value | ||||
| Machine epsilon | ||||
| Needs GradScaler | No | Yes | No | No |
| A100 TFLOPS | 19.5 | 312 | 312 | 156 |