The Training Loop
The Training Loop Is the Heart of Deep Learning
Unlike scikit-learn's .fit(), PyTorch requires you to write the
training loop explicitly. This gives you complete control over
every aspect: data loading, loss computation, gradient accumulation,
mixed precision, and logging.
Definition: Stochastic Gradient Descent (SGD)
Stochastic Gradient Descent (SGD)
SGD updates parameters using the gradient of the loss on a mini-batch:
With momentum :
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
Definition: Adam Optimiser
Adam Optimiser
Adam combines momentum with per-parameter adaptive learning rates:
Default: , , .
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
AdamW decouples weight decay from the gradient update, giving better generalisation. Prefer AdamW over Adam for most tasks.
Theorem: SGD Convergence for Convex Objectives
For an -smooth convex function with bounded gradient variance , SGD with learning rate satisfies:
This gives an convergence rate.
SGD converges more slowly than full-batch GD () because of gradient noise, but each step is times cheaper.
Example: The Standard PyTorch Training Loop
Write a complete training loop for an MLP on synthetic regression data.
Setup
import torch
import torch.nn as nn
model = nn.Sequential(nn.Linear(10, 64), nn.ReLU(), nn.Linear(64, 1))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()
Training loop
for epoch in range(100):
model.train()
for x_batch, y_batch in train_loader:
optimizer.zero_grad() # 1. Clear gradients
y_pred = model(x_batch) # 2. Forward pass
loss = loss_fn(y_pred, y_batch) # 3. Compute loss
loss.backward() # 4. Backward pass
optimizer.step() # 5. Update parameters
# Validation
model.eval()
with torch.no_grad():
val_loss = sum(loss_fn(model(xv), yv).item()
for xv, yv in val_loader) / len(val_loader)
Training Dynamics: Loss vs Epoch
Watch how learning rate and optimizer choice affect convergence.
Parameters
Quick Check
What happens if you forget optimizer.zero_grad() in the training loop?
Gradients from successive batches accumulate (add up)
The model does not learn at all
PyTorch raises an error
PyTorch accumulates gradients by default. Without zeroing, each .backward() adds to existing gradients.
Common Mistake: Forgetting model.eval() During Validation
Mistake:
Running validation without calling model.eval(), causing BatchNorm
and Dropout to behave as in training mode.
Correction:
Always call model.eval() before validation and model.train() before
the next training epoch. Also wrap validation in torch.no_grad().
Why This Matters: End-to-End Learning of Communication Systems
The training loop framework applies directly to learning communication systems end-to-end: the transmitter and receiver are neural networks, the channel is a differentiable layer, and the loss is the bit error rate or mutual information. This autoencoder approach (O'Shea & Hoydis, 2017) can discover novel modulation schemes that outperform hand-designed ones.
See full treatment in Chapter 28
Historical Note: Adam: Adaptive Moment Estimation
2014Kingma and Ba introduced Adam in 2014, combining ideas from AdaGrad (per-parameter rates) and RMSProp (exponential moving average of squared gradients). Despite theoretical concerns about non-convergence in some cases (Reddi et al., 2018), Adam and its variant AdamW remain the most widely used optimizers in deep learning.
Epoch
One complete pass through the entire training dataset.
Mini-batch
A subset of training examples used to compute one gradient update. Typical sizes: 32-512.