The Attention Mechanism
Beyond Fixed-Size Bottlenecks
Seq2Seq compresses the entire input into a single vector, losing information for long sequences. Attention lets the decoder look back at all encoder states, selecting what is relevant at each step.
Definition: Scaled Dot-Product Attention
Scaled Dot-Product Attention
\mathbf{Q} \in \mathbb{R}^{n \times d_k}\mathbf{K} \in \mathbb{R}^{m \times d_k}\mathbf{V} \in \mathbb{R}^{m \times d_v}\sqrt{d_k}$ scaling prevents softmax saturation.
def scaled_dot_product_attention(Q, K, V, mask=None):
d_k = Q.size(-1)
scores = Q @ K.transpose(-2, -1) / d_k ** 0.5
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
weights = F.softmax(scores, dim=-1)
return weights @ V
Definition: Multi-Head Attention
Multi-Head Attention
Multi-head attention runs attention functions in parallel:
where
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.d_k = d_model // n_heads
self.n_heads = n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, Q, K, V, mask=None):
B, T, _ = Q.shape
Q = self.W_q(Q).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(K).view(B, -1, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(V).view(B, -1, self.n_heads, self.d_k).transpose(1, 2)
attn = scaled_dot_product_attention(Q, K, V, mask)
attn = attn.transpose(1, 2).contiguous().view(B, T, -1)
return self.W_o(attn)
Definition: Positional Encoding
Positional Encoding
Since attention is permutation-equivariant, positional information must be injected explicitly:
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(max_len).unsqueeze(1).float()
div = torch.exp(torch.arange(0, d_model, 2).float()
* (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
Definition: Causal (Autoregressive) Mask
Causal (Autoregressive) Mask
For decoder self-attention, a causal mask prevents attending to future positions:
mask = torch.triu(torch.ones(T, T), diagonal=1).bool()
# scores.masked_fill(mask, -inf) before softmax
Definition: Cross-Attention
Cross-Attention
In encoder-decoder models, cross-attention uses queries from the decoder and keys/values from the encoder:
This allows the decoder to attend to relevant parts of the input.
Theorem: Attention Complexity
Scaled dot-product attention has time and memory complexity for sequence length and dimension . This quadratic scaling is the main computational bottleneck for long sequences.
Every token attends to every other token, creating an attention matrix. Linear attention variants reduce this to but often sacrifice quality.
Theorem: Universality of Transformers
Transformers with sufficient depth and width are universal function approximators for sequence-to-sequence functions (Yun et al., 2020). Moreover, they can approximate any continuous permutation-equivariant function on sets.
The combination of attention (global mixing) and feed-forward layers (local processing) gives transformers sufficient expressiveness to approximate any reasonable sequence transformation.
Theorem: Attention as Soft Dictionary Lookup
Attention can be interpreted as a differentiable dictionary: queries look up keys, and the corresponding values are retrieved with soft weights. In the limit of , attention becomes hard lookup (argmax), selecting the single most relevant key-value pair.
Each query asks "which keys am I most similar to?" and retrieves a weighted combination of the corresponding values.
Example: Self-Attention from Scratch
Implement scaled dot-product self-attention from scratch.
Implementation
def self_attention(x, W_q, W_k, W_v):
Q, K, V = x @ W_q, x @ W_k, x @ W_v
d_k = Q.size(-1)
scores = Q @ K.T / d_k ** 0.5
weights = F.softmax(scores, dim=-1)
return weights @ V
Example: Visualising Attention Weights
Compute and plot the attention weight matrix for a short sequence.
Approach
The attention weight matrix shows which positions attend to which. Plot as a heatmap.
Example: KV-Cache for Efficient Inference
During autoregressive generation, avoid recomputing K and V for all previous tokens.
Key idea
Cache K and V from previous steps. At step , only compute the new query and append the new key/value to the cache.
Attention Weight Heatmap
Visualise attention weights between query and key positions.
Parameters
Positional Encoding Visualiser
See the sinusoidal positional encoding patterns.
Parameters
Attention Complexity Calculator
See how memory and compute scale with sequence length.
Parameters
Attention Pattern Evolution
Watch how attention patterns change during training.
Parameters
Transformer Architecture
Attention Variants
Quick Check
Why divide by sqrt(d_k) in scaled dot-product attention?
To normalise the output magnitude
To prevent softmax from saturating when d_k is large
To make training faster
Without scaling, dot products grow proportional to d_k, pushing softmax into extreme values.
Quick Check
What is the memory complexity of standard self-attention?
The n x n attention weight matrix must be stored.
Quick Check
Why are positional encodings needed in transformers?
To make the model run faster
Because self-attention is permutation-equivariant and cannot distinguish positions
To reduce memory usage
Common Mistake: Forgetting the Causal Mask
Mistake:
Not applying a causal mask in decoder self-attention, allowing the model to see future tokens.
Correction:
Always apply an upper-triangular mask in decoder self-attention for autoregressive generation.
Common Mistake: d_model Not Divisible by n_heads
Mistake:
Setting d_model=256, n_heads=6, so d_k is not an integer.
Correction:
Ensure d_model is divisible by n_heads. Common: d_model=512, n_heads=8 -> d_k=64.
Common Mistake: Omitting Residual Connections in Transformer
Mistake:
Not adding residual connections around attention and FFN sublayers.
Correction:
Always use x + sublayer(LayerNorm(x)) (pre-norm) or LayerNorm(x + sublayer(x)) (post-norm).
Key Takeaway
Attention computes a weighted sum of values based on query-key similarity. The cost is the price for global context. Multi-head attention learns diverse attention patterns in parallel.
Key Takeaway
The transformer = multi-head attention + feed-forward + residual + layer norm. This simple recipe scales to billions of parameters and dominates NLP, vision, and increasingly scientific computing.
Why This Matters: Transformers for MIMO Detection
Attention can model the interactions between transmitted symbols in MIMO systems. Each received antenna's signal attends to all others, learning the interference structure. This has been shown to approach MMSE detection performance with lower complexity than iterative methods.
See full treatment in Chapter 33
Historical Note: Attention Is All You Need
2017Vaswani et al. (2017) introduced the Transformer, replacing recurrence entirely with attention. This paper is the foundation of GPT, BERT, and all modern large language models.
Historical Note: The First Attention Mechanism
2014Bahdanau et al. (2014) introduced attention for machine translation, allowing the decoder to focus on different source words at each step. This additive attention mechanism was the precursor to the dot-product attention used in transformers.
Attention
Mechanism that computes weighted combinations of values based on query-key similarity.
Related: Multi-Head Attention
Multi-Head Attention
Running h parallel attention functions with different learned projections, then concatenating.
Transformer
Architecture built from attention and feed-forward layers with residual connections and layer norm.
Positional Encoding
Signal added to embeddings to inject position information into the permutation-equivariant attention.
Vision Transformer (ViT)
Transformer applied to images by splitting them into patches and treating patches as tokens.
Attention Mechanism Variants
| Variant | Complexity | Key Idea | Use Case |
|---|---|---|---|
| Scaled Dot-Product | Standard softmax attention | Default | |
| Multi-Head | Parallel attention heads | Transformers | |
| Flash Attention | time, memory | IO-aware tiling | Long sequences |
| Linear Attention | Kernel approximation | Very long sequences |