RNN, LSTM, and GRU

Processing Sequential Data

Many signals are sequential: time-series, text, audio, and packets in a network. Recurrent networks process sequences by maintaining a hidden state that summarises all past inputs.

Definition:

Vanilla RNN

The simplest recurrent cell:

ht=tanh(Whhht1+Wxhxt+b)\mathbf{h}_t = \tanh(\mathbf{W}_{hh}\mathbf{h}_{t-1} + \mathbf{W}_{xh}\mathbf{x}_t + \mathbf{b})

rnn = nn.RNN(input_size=10, hidden_size=64, num_layers=1, batch_first=True)
output, h_n = rnn(x)  # x: (B, T, 10) -> output: (B, T, 64)

Vanilla RNNs suffer from vanishing gradients for long sequences. Use LSTM or GRU for sequences longer than ~20 steps.

Definition:

Long Short-Term Memory (LSTM)

LSTM adds gating to control information flow:

ft=σ(Wf[ht1,xt]+bf)(forget gate)\mathbf{f}_t = \sigma(\mathbf{W}_f[\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f) \quad \text{(forget gate)} it=σ(Wi[ht1,xt]+bi)(input gate)\mathbf{i}_t = \sigma(\mathbf{W}_i[\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i) \quad \text{(input gate)} c~t=tanh(Wc[ht1,xt]+bc)(candidate)\tilde{\mathbf{c}}_t = \tanh(\mathbf{W}_c[\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_c) \quad \text{(candidate)} ct=ftct1+itc~t(cell update)\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t \quad \text{(cell update)} ot=σ(Wo[ht1,xt]+bo)(output gate)\mathbf{o}_t = \sigma(\mathbf{W}_o[\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_o) \quad \text{(output gate)} ht=ottanh(ct)\mathbf{h}_t = \mathbf{o}_t \odot \tanh(\mathbf{c}_t)

lstm = nn.LSTM(input_size=10, hidden_size=64, num_layers=2,
               batch_first=True, dropout=0.1)
output, (h_n, c_n) = lstm(x)

Definition:

Gated Recurrent Unit (GRU)

GRU simplifies LSTM by merging the forget and input gates:

zt=σ(Wz[ht1,xt])(update gate)\mathbf{z}_t = \sigma(\mathbf{W}_z[\mathbf{h}_{t-1}, \mathbf{x}_t]) \quad \text{(update gate)} rt=σ(Wr[ht1,xt])(reset gate)\mathbf{r}_t = \sigma(\mathbf{W}_r[\mathbf{h}_{t-1}, \mathbf{x}_t]) \quad \text{(reset gate)} h~t=tanh(W[rtht1,xt])\tilde{\mathbf{h}}_t = \tanh(\mathbf{W}[\mathbf{r}_t \odot \mathbf{h}_{t-1}, \mathbf{x}_t]) ht=(1zt)ht1+zth~t\mathbf{h}_t = (1 - \mathbf{z}_t) \odot \mathbf{h}_{t-1} + \mathbf{z}_t \odot \tilde{\mathbf{h}}_t

GRU has fewer parameters than LSTM (3 vs 4 gates) and often performs comparably.

Definition:

Bidirectional RNNs

Process the sequence in both directions and concatenate:

ht=RNNfwd(xt,ht1),ht=RNNbwd(xt,ht+1)\overrightarrow{\mathbf{h}}_t = \text{RNN}_\text{fwd}(\mathbf{x}_t, \overrightarrow{\mathbf{h}}_{t-1}), \qquad \overleftarrow{\mathbf{h}}_t = \text{RNN}_\text{bwd}(\mathbf{x}_t, \overleftarrow{\mathbf{h}}_{t+1})

ht=[ht;ht]\mathbf{h}_t = [\overrightarrow{\mathbf{h}}_t; \overleftarrow{\mathbf{h}}_t]

lstm = nn.LSTM(10, 64, bidirectional=True, batch_first=True)
# Output: (B, T, 128)  — doubled hidden size

Definition:

Packed Sequences for Variable Lengths

For batches with variable-length sequences, use packed sequences:

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
packed = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
output, h_n = lstm(packed)
output, _ = pad_packed_sequence(output, batch_first=True)

Without packing, the LSTM wastes computation on padding tokens and the hidden state absorbs padding noise.

Theorem: Vanishing Gradient in Vanilla RNNs

For a vanilla RNN, the gradient of the loss at time TT with respect to the hidden state at time tt involves:

hTht=k=t+1Tdiag(tanh())Whh\frac{\partial \mathbf{h}_T}{\partial \mathbf{h}_t} = \prod_{k=t+1}^{T} \text{diag}(\tanh'(\cdot)) \cdot \mathbf{W}_{hh}

If Whh<1/tanhmax=1\|\mathbf{W}_{hh}\| < 1/\|\tanh'\|_{\max} = 1, the product vanishes exponentially as TtT - t grows. LSTM's cell state path ct\mathbf{c}_t provides a gradient highway that avoids this product.

The cell state in LSTM acts like a conveyor belt: information can flow unchanged through many time steps via the forget gate, without being multiplied by weight matrices at each step.

Theorem: LSTM Parameter Count

An LSTM with input size dd and hidden size hh has:

4×(dh+hh+h)=4h(d+h+1)4 \times (d \cdot h + h \cdot h + h) = 4h(d + h + 1)

parameters (4 gates, each with input-to-hidden, hidden-to-hidden weights, and bias).

LSTM has 4x the parameters of a vanilla RNN due to its 4 gate computations.

Theorem: Backpropagation Through Time (BPTT)

BPTT unfolds the RNN for TT time steps and applies standard backpropagation to the unfolded graph. For a loss L=t=1TtL = \sum_{t=1}^{T} \ell_t:

LW=t=1Tk=1ttht(j=k+1thjhj1)hkW\frac{\partial L}{\partial \mathbf{W}} = \sum_{t=1}^{T} \sum_{k=1}^{t} \frac{\partial \ell_t}{\partial \mathbf{h}_t} \left(\prod_{j=k+1}^{t} \frac{\partial \mathbf{h}_j}{\partial \mathbf{h}_{j-1}}\right) \frac{\partial \mathbf{h}_k}{\partial \mathbf{W}}

Truncated BPTT limits the inner sum to the last KK steps for efficiency.

BPTT is expensive because it scales linearly with sequence length. Truncated BPTT trades gradient accuracy for computational efficiency.

Example: Sequence Classification with LSTM

Classify variable-length sequences using an LSTM. Use the final hidden state as the sequence representation.

Example: Many-to-Many: Sequence Labelling

Label each time step in a sequence (e.g., slot filling or per-symbol detection).

Example: LSTM for Time-Series Prediction

Predict the next value in a time series using an LSTM.

LSTM Gate Dynamics

See how LSTM gates open and close over a sequence.

Parameters

Vanishing Gradient: RNN vs LSTM

Compare gradient magnitudes over time steps.

Parameters

RNN/LSTM/GRU Parameter Comparison

Compare parameter counts across recurrent architectures.

Parameters

Hidden State Evolution

Watch the hidden state evolve as the LSTM processes a sequence.

Parameters

LSTM Cell Diagram

LSTM Cell Diagram
The LSTM cell with forget gate, input gate, cell state update, and output gate.

RNN Unfolding Through Time

RNN Unfolding Through Time
An RNN unfolded for T time steps, showing how the same weights are applied at each step.

Quick Check

What problem does LSTM solve that vanilla RNNs cannot?

Processing variable-length sequences

Learning long-range dependencies (vanishing gradients)

Parallel computation

Quick Check

How many parameter matrices does a single LSTM cell have (excluding biases)?

4 (one per gate: forget, input, output, cell candidate)

8 (2 per gate x 4 gates)

2

Quick Check

What does bidirectional=True do in nn.LSTM?

Doubles the learning rate

Processes the sequence forward and backward, doubling the output hidden size

Reverses the input sequence

Common Mistake: Not Packing Variable-Length Sequences

Mistake:

Feeding padded sequences directly to LSTM without packing.

Correction:

Use pack_padded_sequence before the LSTM and pad_packed_sequence after. Otherwise the LSTM processes padding tokens, wasting computation and corrupting the hidden state.

Common Mistake: Forgetting to Reset Hidden State Between Sequences

Mistake:

Not initialising h_0 and c_0 for each new batch, causing the hidden state from the previous batch to leak into the next.

Correction:

Pass h_0=None (default zeros) or explicitly reset between batches. For stateful processing (e.g., streaming), consciously manage state.

Common Mistake: Using Wrong Hidden State for Classification

Mistake:

Using output[:, -1, :] instead of h_n[-1] for the last hidden state, which gives wrong results with packed sequences.

Correction:

Use h_n[-1] from the LSTM output tuple. For bidirectional LSTMs, concatenate h_n[-2] (forward) and h_n[-1] (backward).

Key Takeaway

LSTM is the default recurrent cell. Use GRU for parameter efficiency. Always pack variable-length sequences. For very long sequences, consider Transformers (Chapter 30) instead.

Key Takeaway

The LSTM cell state acts as a gradient highway, enabling learning of dependencies across hundreds of time steps. Gradient clipping is still recommended for training stability.

Why This Matters: RNNs for Channel Tracking

In time-varying channels, the channel coefficients form a sequence. LSTMs can track fading channels by processing pilot-based estimates sequentially, learning the temporal correlation structure (Jake's model) implicitly from data.

Historical Note: LSTM: A 25-Year Journey

1997-2016

Hochreiter and Schmidhuber proposed LSTM in 1997 to address the vanishing gradient problem. It remained a niche technique until Graves applied it to speech recognition (2013) and Google adopted it for machine translation (2016), making it the dominant sequence model until transformers arrived.

Historical Note: GRU: A Simplified Alternative

2014

Cho et al. (2014) introduced the GRU as a simpler alternative to LSTM, showing comparable performance with fewer parameters.

Hidden State

The internal memory vector of an RNN that summarises all past inputs.

Gate (LSTM)

A sigmoid-activated layer that controls information flow: values near 0 block, near 1 pass.

BPTT

Backpropagation Through Time: unfolding the RNN and applying standard backprop.

Cell State

LSTM's long-term memory vector that flows through time with minimal transformation.

Seq2Seq

Sequence-to-sequence model: encoder compresses input to a context vector, decoder generates output.

Recurrent Cell Comparison

CellGatesParameters (d=h)Long-RangeSpeed
Vanilla RNN02h2+h2h^2 + hPoorFast
LSTM3 + cell4(2h2+h)4(2h^2 + h)GoodSlowest
GRU23(2h2+h)3(2h^2 + h)GoodMedium