nn.Module and Model Definition
Why nn.Module Is the Foundation
Every neural network in PyTorch is an nn.Module. Understanding this
base class β how it registers parameters, composes sub-modules, and
manages device placement β is essential before writing any training code.
This section covers the mechanics that the rest of Part VI builds upon.
Definition: nn.Module
nn.Module
torch.nn.Module is the base class for all neural network components.
A module encapsulates:
- Parameters β learnable tensors registered via
nn.Parameter - Sub-modules β child
nn.Moduleinstances (set as attributes) - Forward logic β the
forward()method defining the computation
import torch
import torch.nn as nn
class LinearModel(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.linear = nn.Linear(in_features, out_features)
def forward(self, x):
return self.linear(x)
Calling model(x) invokes model.forward(x) through __call__,
which also runs registered hooks.
Never call model.forward(x) directly β always use model(x) so
hooks and gradient tracking work correctly.
Definition: nn.Parameter
nn.Parameter
nn.Parameter is a Tensor subclass that, when assigned as a module
attribute, is automatically added to the list of parameters:
class ManualLinear(nn.Module):
def __init__(self, d_in, d_out):
super().__init__()
self.W = nn.Parameter(torch.randn(d_out, d_in))
self.b = nn.Parameter(torch.zeros(d_out))
def forward(self, x):
return x @ self.W.T + self.b
Use model.parameters() to iterate over all parameters and
model.named_parameters() for (name, param) pairs.
Definition: nn.Sequential
nn.Sequential
nn.Sequential chains modules so each output feeds the next:
mlp = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
For named layers, use OrderedDict:
from collections import OrderedDict
mlp = nn.Sequential(OrderedDict([
("fc1", nn.Linear(784, 256)),
("act", nn.ReLU()),
("fc2", nn.Linear(256, 10)),
]))
Definition: Common Activation Functions
Common Activation Functions
Activation functions introduce non-linearity:
GELU (used in transformers):
where is the standard Gaussian CDF.
In PyTorch: nn.ReLU(), nn.Sigmoid(), nn.GELU(), or their
functional forms F.relu(x), torch.sigmoid(x), etc.
ReLU is the default choice for hidden layers. Use GELU for transformer-based architectures and sigmoid/softmax for output layers.
Definition: Weight Initialisation Strategies
Weight Initialisation Strategies
Proper initialisation prevents vanishing/exploding activations. For a layer with fan-in and fan-out :
Kaiming (He) for ReLU networks:
Xavier (Glorot) for tanh/sigmoid:
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
nn.init.xavier_uniform_(layer.weight)
nn.init.zeros_(layer.bias)
Definition: Forward and Backward Hooks
Forward and Backward Hooks
Hooks let you inspect or modify data flowing through modules:
- Forward hook:
module.register_forward_hook(fn)β called afterforward(), receives(module, input, output) - Backward hook:
module.register_full_backward_hook(fn)β called during backward pass, receives(module, grad_input, grad_output)
activations = {}
def save_activation(name):
def hook(module, input, output):
activations[name] = output.detach()
return hook
model.fc1.register_forward_hook(save_activation('fc1'))
Hooks are invaluable for debugging (checking for NaN gradients), feature extraction, and gradient visualization.
Theorem: Universal Approximation Theorem
A feed-forward network with a single hidden layer containing a finite number of neurons can approximate any continuous function on a compact subset of to arbitrary precision, provided the activation function is non-constant, bounded, and continuous (Cybenko, 1989).
More generally, for ReLU networks (Hanin, 2019): a network with width and sufficient depth can approximate any continuous function on .
A single hidden layer can represent any function, but may require exponentially many neurons. Depth provides exponentially more efficient representations.
Theorem: Parameter Count for Fully Connected Networks
An MLP with layer widths has total parameter count:
The first term counts weights and the second counts biases .
Each neuron connects to every neuron in the previous layer (weights) plus one bias. This quadratic scaling motivates architectures like CNNs and transformers that share parameters.
Theorem: Backpropagation via the Chain Rule
For a composition , the gradient with respect to parameters of layer is:
PyTorch's autograd computes this automatically using a dynamic computational graph built during the forward pass.
Backpropagation is just the chain rule applied layer by layer from
output to input. PyTorch records operations on tensors with
requires_grad=True and replays them in reverse during .backward().
Example: Building a 3-Layer MLP for Classification
Build a 3-layer MLP that maps 784-dimensional input (flattened 28x28 image) to 10 class logits, with hidden layers of size 256 and 128.
Define the model
import torch.nn as nn
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 10),
)
def forward(self, x):
return self.net(x.flatten(1))
model = MLP()
print(sum(p.numel() for p in model.parameters()))
# 784*256+256 + 256*128+128 + 128*10+10 = 235,146
Verify with a dummy input
x = torch.randn(4, 1, 28, 28) # batch of 4 images
logits = model(x)
assert logits.shape == (4, 10)
Example: Custom Module with Residual (Skip) Connection
Implement a module where the output is , i.e., a residual block that adds the input to the transformed output.
Implementation
class ResidualBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.block = nn.Sequential(
nn.Linear(dim, dim),
nn.ReLU(),
nn.Linear(dim, dim),
)
def forward(self, x):
return x + self.block(x)
# Stack 4 residual blocks
model = nn.Sequential(*[ResidualBlock(128) for _ in range(4)])
Why it works
The skip connection ensures the gradient flows directly through the identity path, preventing vanishing gradients in deep networks.
Example: Inspecting Parameters and Module Tree
Given a model, enumerate all sub-modules and their parameter shapes.
Using named_modules and named_parameters
model = MLP()
for name, module in model.named_modules():
print(f"{name}: {module.__class__.__name__}")
for name, param in model.named_parameters():
print(f"{name}: {param.shape}")
Count parameters
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters()
if p.requires_grad)
print(f"Total: {total:,} Trainable: {trainable:,}")
Example: Moving Models and Data to GPU
Move a model and its input tensors to GPU for accelerated computation.
Device-agnostic code
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MLP().to(device)
x = torch.randn(32, 784).to(device)
y = model(x) # computed on GPU if available
Common mistake
If model is on GPU but input is on CPU, PyTorch raises a
RuntimeError. Always ensure both are on the same device.
Activation Function Explorer
Compare ReLU, LeakyReLU, GELU, Sigmoid, Tanh, and Swish.
Parameters
MLP Parameter Count Calculator
See how width and depth affect parameter count.
Parameters
Weight Initialisation Comparison
Compare activation distributions through layers with different init strategies.
Parameters
Forward Pass Through an MLP
Watch activations propagate layer by layer through a network.
Parameters
nn.Module Composition Tree
Autograd Computational Graph
Quick Check
What happens if you assign a torch.Tensor (not nn.Parameter) as a module attribute?
It is automatically registered as a parameter
It is ignored by model.parameters() and the optimizer
PyTorch raises an error
Plain tensors are not tracked. Use nn.Parameter or register_buffer for non-learnable tensors.
Quick Check
Which initialisation is most appropriate for a network using ReLU activations?
Xavier (Glorot) initialisation
Kaiming (He) initialisation
All-zeros initialisation
Kaiming accounts for the ReLU zeroing half the distribution.
Common Mistake: Forgetting super().init()
Mistake:
Defining an nn.Module subclass without calling super().__init__() in __init__.
Correction:
Always start with super().__init__(). Without it, PyTorch cannot
register parameters or sub-modules, and .parameters() returns nothing.
Common Mistake: Using Python List Instead of nn.ModuleList
Mistake:
Storing sub-modules in a plain Python list: self.layers = [nn.Linear(64, 64) for _ in range(3)]
Correction:
Use nn.ModuleList: self.layers = nn.ModuleList([nn.Linear(64, 64) for _ in range(3)]).
Plain lists are invisible to .parameters(), .to(device), and .state_dict().
Common Mistake: Calling .forward() Directly
Mistake:
Writing output = model.forward(x) instead of output = model(x).
Correction:
Always use model(x). The __call__ method runs hooks, applies
torch.no_grad() context if set, and handles other internal bookkeeping.
Key Takeaway
nn.Module is a tree: compose complex architectures by nesting modules. Use nn.Sequential for simple chains, nn.ModuleList for indexed access, and nn.ModuleDict for named dynamic architectures.
Key Takeaway
Weight initialisation determines whether gradients flow through deep networks. Use Kaiming for ReLU, Xavier for tanh/sigmoid, and always initialise biases to zero.
Why This Matters: Neural Networks for Channel Estimation
In 5G NR, neural networks can learn the mapping from received pilot signals to channel estimates. An MLP with input dimension equal to the number of pilot subcarriers and output equal to the full channel dimension replaces traditional LS/MMSE estimators, learning non-linear dependencies in the channel structure.
See full treatment in Chapter 33
Historical Note: From Perceptrons to Deep Learning
1958-2012Rosenblatt's Perceptron (1958) was a single-layer linear classifier. Minsky and Papert (1969) showed it could not learn XOR, triggering the first AI winter. Backpropagation (Rumelhart, Hinton, Williams, 1986) enabled training multi-layer networks, but deep networks only became practical with GPU computing (Krizhevsky et al., 2012).
Historical Note: PyTorch: From Lua Torch to Python
2017-presentPyTorch emerged from the Lua-based Torch framework. Released by
Facebook AI Research in 2017, its define-by-run (eager) execution
model and Pythonic API quickly made it the dominant research framework.
PyTorch 2.0 (2023) introduced torch.compile for graph-mode
optimization without sacrificing the eager programming model.
nn.Module
Base class for all neural network components in PyTorch.
Related: nn.Parameter, nn.Sequential
nn.Parameter
A Tensor subclass automatically registered as a learnable parameter when assigned to a Module.
Related: nn.Module
nn.Sequential
A container that chains modules sequentially, passing each output as input to the next.
Related: nn.Module
Autograd
PyTorch's automatic differentiation engine that records operations on tensors and computes gradients via reverse-mode AD.
ReLU
Rectified Linear Unit activation: . The default activation for hidden layers.
Kaiming Initialisation
Weight initialisation that accounts for ReLU activations: .
Activation Function Comparison
| Activation | Formula | Range | Gradient at 0 | Best For |
|---|---|---|---|---|
| ReLU | undefined (0.5 subgrad) | Default hidden layers | ||
| LeakyReLU | 1 | Avoiding dead neurons | ||
| GELU | 0.5 | Transformers | ||
| Sigmoid | 0.25 | Binary output / gates | ||
| Tanh | 1 | Bounded output / RNNs |