RNN, LSTM, and GRU
Processing Sequential Data
Many signals are sequential: time-series, text, audio, and packets in a network. Recurrent networks process sequences by maintaining a hidden state that summarises all past inputs.
Definition: Vanilla RNN
Vanilla RNN
The simplest recurrent cell:
rnn = nn.RNN(input_size=10, hidden_size=64, num_layers=1, batch_first=True)
output, h_n = rnn(x) # x: (B, T, 10) -> output: (B, T, 64)
Vanilla RNNs suffer from vanishing gradients for long sequences. Use LSTM or GRU for sequences longer than ~20 steps.
Definition: Long Short-Term Memory (LSTM)
Long Short-Term Memory (LSTM)
LSTM adds gating to control information flow:
lstm = nn.LSTM(input_size=10, hidden_size=64, num_layers=2,
batch_first=True, dropout=0.1)
output, (h_n, c_n) = lstm(x)
Definition: Gated Recurrent Unit (GRU)
Gated Recurrent Unit (GRU)
GRU simplifies LSTM by merging the forget and input gates:
GRU has fewer parameters than LSTM (3 vs 4 gates) and often performs comparably.
Definition: Bidirectional RNNs
Bidirectional RNNs
Process the sequence in both directions and concatenate:
lstm = nn.LSTM(10, 64, bidirectional=True, batch_first=True)
# Output: (B, T, 128) — doubled hidden size
Definition: Packed Sequences for Variable Lengths
Packed Sequences for Variable Lengths
For batches with variable-length sequences, use packed sequences:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
packed = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
output, h_n = lstm(packed)
output, _ = pad_packed_sequence(output, batch_first=True)
Without packing, the LSTM wastes computation on padding tokens and the hidden state absorbs padding noise.
Theorem: Vanishing Gradient in Vanilla RNNs
For a vanilla RNN, the gradient of the loss at time with respect to the hidden state at time involves:
If , the product vanishes exponentially as grows. LSTM's cell state path provides a gradient highway that avoids this product.
The cell state in LSTM acts like a conveyor belt: information can flow unchanged through many time steps via the forget gate, without being multiplied by weight matrices at each step.
Theorem: LSTM Parameter Count
An LSTM with input size and hidden size has:
parameters (4 gates, each with input-to-hidden, hidden-to-hidden weights, and bias).
LSTM has 4x the parameters of a vanilla RNN due to its 4 gate computations.
Theorem: Backpropagation Through Time (BPTT)
BPTT unfolds the RNN for time steps and applies standard backpropagation to the unfolded graph. For a loss :
Truncated BPTT limits the inner sum to the last steps for efficiency.
BPTT is expensive because it scales linearly with sequence length. Truncated BPTT trades gradient accuracy for computational efficiency.
Example: Sequence Classification with LSTM
Classify variable-length sequences using an LSTM. Use the final hidden state as the sequence representation.
Implementation
class SeqClassifier(nn.Module):
def __init__(self, input_dim, hidden_dim, n_classes):
super().__init__()
self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, n_classes)
def forward(self, x):
_, (h_n, _) = self.lstm(x)
return self.fc(h_n[-1])
Example: Many-to-Many: Sequence Labelling
Label each time step in a sequence (e.g., slot filling or per-symbol detection).
Implementation
class SeqLabeller(nn.Module):
def __init__(self, input_dim, hidden_dim, n_labels):
super().__init__()
self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, n_labels)
def forward(self, x):
output, _ = self.lstm(x) # (B, T, hidden_dim)
return self.fc(output) # (B, T, n_labels)
Example: LSTM for Time-Series Prediction
Predict the next value in a time series using an LSTM.
Window-based approach
# Create windows: x[t-L:t] -> y[t]
class TSPredictor(nn.Module):
def __init__(self, hidden_dim=32):
super().__init__()
self.lstm = nn.LSTM(1, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, 1)
def forward(self, x):
_, (h_n, _) = self.lstm(x.unsqueeze(-1))
return self.fc(h_n[-1])
LSTM Gate Dynamics
See how LSTM gates open and close over a sequence.
Parameters
Vanishing Gradient: RNN vs LSTM
Compare gradient magnitudes over time steps.
Parameters
RNN/LSTM/GRU Parameter Comparison
Compare parameter counts across recurrent architectures.
Parameters
LSTM Cell Diagram
RNN Unfolding Through Time
Quick Check
What problem does LSTM solve that vanilla RNNs cannot?
Processing variable-length sequences
Learning long-range dependencies (vanishing gradients)
Parallel computation
The cell state provides a gradient highway for long-range dependencies.
Quick Check
How many parameter matrices does a single LSTM cell have (excluding biases)?
4 (one per gate: forget, input, output, cell candidate)
8 (2 per gate x 4 gates)
2
Quick Check
What does bidirectional=True do in nn.LSTM?
Doubles the learning rate
Processes the sequence forward and backward, doubling the output hidden size
Reverses the input sequence
Common Mistake: Not Packing Variable-Length Sequences
Mistake:
Feeding padded sequences directly to LSTM without packing.
Correction:
Use pack_padded_sequence before the LSTM and pad_packed_sequence
after. Otherwise the LSTM processes padding tokens, wasting computation
and corrupting the hidden state.
Common Mistake: Using Wrong Hidden State for Classification
Mistake:
Using output[:, -1, :] instead of h_n[-1] for the last hidden state,
which gives wrong results with packed sequences.
Correction:
Use h_n[-1] from the LSTM output tuple. For bidirectional LSTMs,
concatenate h_n[-2] (forward) and h_n[-1] (backward).
Key Takeaway
LSTM is the default recurrent cell. Use GRU for parameter efficiency. Always pack variable-length sequences. For very long sequences, consider Transformers (Chapter 30) instead.
Key Takeaway
The LSTM cell state acts as a gradient highway, enabling learning of dependencies across hundreds of time steps. Gradient clipping is still recommended for training stability.
Why This Matters: RNNs for Channel Tracking
In time-varying channels, the channel coefficients form a sequence. LSTMs can track fading channels by processing pilot-based estimates sequentially, learning the temporal correlation structure (Jake's model) implicitly from data.
Historical Note: LSTM: A 25-Year Journey
1997-2016Hochreiter and Schmidhuber proposed LSTM in 1997 to address the vanishing gradient problem. It remained a niche technique until Graves applied it to speech recognition (2013) and Google adopted it for machine translation (2016), making it the dominant sequence model until transformers arrived.
Historical Note: GRU: A Simplified Alternative
2014Cho et al. (2014) introduced the GRU as a simpler alternative to LSTM, showing comparable performance with fewer parameters.
Gate (LSTM)
A sigmoid-activated layer that controls information flow: values near 0 block, near 1 pass.
BPTT
Backpropagation Through Time: unfolding the RNN and applying standard backprop.
Cell State
LSTM's long-term memory vector that flows through time with minimal transformation.
Seq2Seq
Sequence-to-sequence model: encoder compresses input to a context vector, decoder generates output.
Recurrent Cell Comparison
| Cell | Gates | Parameters (d=h) | Long-Range | Speed |
|---|---|---|---|---|
| Vanilla RNN | 0 | Poor | Fast | |
| LSTM | 3 + cell | Good | Slowest | |
| GRU | 2 | Good | Medium |