The Attention Mechanism

Beyond Fixed-Size Bottlenecks

Seq2Seq compresses the entire input into a single vector, losing information for long sequences. Attention lets the decoder look back at all encoder states, selecting what is relevant at each step.

Definition:

Scaled Dot-Product Attention

Attention(Q,K,V)=softmax ⁣(QKTdk)V\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\!\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}}\right)\mathbf{V}βˆ’-\mathbf{Q} \in \mathbb{R}^{n \times d_k}:queriesβˆ’: queries -\mathbf{K} \in \mathbb{R}^{m \times d_k}:keysβˆ’: keys -\mathbf{V} \in \mathbb{R}^{m \times d_v}:valuesThe: values The\sqrt{d_k}$ scaling prevents softmax saturation.

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = Q @ K.transpose(-2, -1) / d_k ** 0.5
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    weights = F.softmax(scores, dim=-1)
    return weights @ V

Definition:

Multi-Head Attention

Multi-head attention runs hh attention functions in parallel:

MultiHead(Q,K,V)=Concat(head1,…,headh)WO\text{MultiHead}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)\mathbf{W}^O

where headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(\mathbf{Q}\mathbf{W}_i^Q, \mathbf{K}\mathbf{W}_i^K, \mathbf{V}\mathbf{W}_i^V)

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_k = d_model // n_heads
        self.n_heads = n_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V, mask=None):
        B, T, _ = Q.shape
        Q = self.W_q(Q).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(B, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(B, -1, self.n_heads, self.d_k).transpose(1, 2)
        attn = scaled_dot_product_attention(Q, K, V, mask)
        attn = attn.transpose(1, 2).contiguous().view(B, T, -1)
        return self.W_o(attn)

Definition:

Positional Encoding

Since attention is permutation-equivariant, positional information must be injected explicitly:

PE(pos,2i)=sin⁑(pos/100002i/d),PE(pos,2i+1)=cos⁑(pos/100002i/d)\text{PE}(pos, 2i) = \sin(pos / 10000^{2i/d}), \quad \text{PE}(pos, 2i+1) = \cos(pos / 10000^{2i/d})

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float()
                        * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

Definition:

Causal (Autoregressive) Mask

For decoder self-attention, a causal mask prevents attending to future positions:

mask = torch.triu(torch.ones(T, T), diagonal=1).bool()
# scores.masked_fill(mask, -inf) before softmax

Definition:

Cross-Attention

In encoder-decoder models, cross-attention uses queries from the decoder and keys/values from the encoder:

CrossAttn(Qdec,Kenc,Venc)\text{CrossAttn}(\mathbf{Q}_{\text{dec}}, \mathbf{K}_{\text{enc}}, \mathbf{V}_{\text{enc}})

This allows the decoder to attend to relevant parts of the input.

Theorem: Attention Complexity

Scaled dot-product attention has O(n2d)O(n^2 d) time and O(n2)O(n^2) memory complexity for sequence length nn and dimension dd. This quadratic scaling is the main computational bottleneck for long sequences.

Every token attends to every other token, creating an nΓ—nn \times n attention matrix. Linear attention variants reduce this to O(n)O(n) but often sacrifice quality.

Theorem: Universality of Transformers

Transformers with sufficient depth and width are universal function approximators for sequence-to-sequence functions (Yun et al., 2020). Moreover, they can approximate any continuous permutation-equivariant function on sets.

The combination of attention (global mixing) and feed-forward layers (local processing) gives transformers sufficient expressiveness to approximate any reasonable sequence transformation.

Theorem: Attention as Soft Dictionary Lookup

Attention can be interpreted as a differentiable dictionary: queries look up keys, and the corresponding values are retrieved with soft weights. In the limit T→0T \to 0 of softmax(QKT/T)\text{softmax}(\mathbf{QK}^T/T), attention becomes hard lookup (argmax), selecting the single most relevant key-value pair.

Each query asks "which keys am I most similar to?" and retrieves a weighted combination of the corresponding values.

Example: Self-Attention from Scratch

Implement scaled dot-product self-attention from scratch.

Example: Visualising Attention Weights

Compute and plot the attention weight matrix for a short sequence.

Example: KV-Cache for Efficient Inference

During autoregressive generation, avoid recomputing K and V for all previous tokens.

Attention Weight Heatmap

Visualise attention weights between query and key positions.

Parameters

Positional Encoding Visualiser

See the sinusoidal positional encoding patterns.

Parameters

Attention Complexity Calculator

See how memory and compute scale with sequence length.

Parameters

Attention Pattern Evolution

Watch how attention patterns change during training.

Parameters

Transformer Architecture

Transformer Architecture
The encoder-decoder transformer with multi-head attention, feed-forward, and residual connections.

Attention Variants

Attention Variants
Self-attention, cross-attention, and causal attention patterns.

Quick Check

Why divide by sqrt(d_k) in scaled dot-product attention?

To normalise the output magnitude

To prevent softmax from saturating when d_k is large

To make training faster

Quick Check

What is the memory complexity of standard self-attention?

O(nβ‹…d)O(n \cdot d)

O(n2)O(n^2)

O(nlog⁑n)O(n \log n)

Quick Check

Why are positional encodings needed in transformers?

To make the model run faster

Because self-attention is permutation-equivariant and cannot distinguish positions

To reduce memory usage

Common Mistake: Forgetting the Causal Mask

Mistake:

Not applying a causal mask in decoder self-attention, allowing the model to see future tokens.

Correction:

Always apply an upper-triangular mask in decoder self-attention for autoregressive generation.

Common Mistake: d_model Not Divisible by n_heads

Mistake:

Setting d_model=256, n_heads=6, so d_k is not an integer.

Correction:

Ensure d_model is divisible by n_heads. Common: d_model=512, n_heads=8 -> d_k=64.

Common Mistake: Omitting Residual Connections in Transformer

Mistake:

Not adding residual connections around attention and FFN sublayers.

Correction:

Always use x + sublayer(LayerNorm(x)) (pre-norm) or LayerNorm(x + sublayer(x)) (post-norm).

Key Takeaway

Attention computes a weighted sum of values based on query-key similarity. The O(n2)O(n^2) cost is the price for global context. Multi-head attention learns diverse attention patterns in parallel.

Key Takeaway

The transformer = multi-head attention + feed-forward + residual + layer norm. This simple recipe scales to billions of parameters and dominates NLP, vision, and increasingly scientific computing.

Why This Matters: Transformers for MIMO Detection

Attention can model the interactions between transmitted symbols in MIMO systems. Each received antenna's signal attends to all others, learning the interference structure. This has been shown to approach MMSE detection performance with lower complexity than iterative methods.

See full treatment in Chapter 33

Historical Note: Attention Is All You Need

2017

Vaswani et al. (2017) introduced the Transformer, replacing recurrence entirely with attention. This paper is the foundation of GPT, BERT, and all modern large language models.

Historical Note: The First Attention Mechanism

2014

Bahdanau et al. (2014) introduced attention for machine translation, allowing the decoder to focus on different source words at each step. This additive attention mechanism was the precursor to the dot-product attention used in transformers.

Attention

Mechanism that computes weighted combinations of values based on query-key similarity.

Related: Multi-Head Attention

Multi-Head Attention

Running h parallel attention functions with different learned projections, then concatenating.

Transformer

Architecture built from attention and feed-forward layers with residual connections and layer norm.

Positional Encoding

Signal added to embeddings to inject position information into the permutation-equivariant attention.

Vision Transformer (ViT)

Transformer applied to images by splitting them into patches and treating patches as tokens.

Attention Mechanism Variants

VariantComplexityKey IdeaUse Case
Scaled Dot-ProductO(n2d)O(n^2 d)Standard softmax attentionDefault
Multi-HeadO(hβ‹…n2β‹…d/h)O(h \cdot n^2 \cdot d/h)Parallel attention headsTransformers
Flash AttentionO(n2d)O(n^2 d) time, O(n)O(n) memoryIO-aware tilingLong sequences
Linear AttentionO(nd2)O(n d^2)Kernel approximationVery long sequences