The GPT Architecture
Definition: GPT Architecture
GPT Architecture
The Generative Pre-trained Transformer (GPT) is a stack of decoder-only transformer blocks. Each block contains:
- Causal multi-head self-attention with mask
- Feed-forward network (two linear layers with GELU activation)
- Residual connections and layer normalization
The forward pass for block :
where .
Modern GPTs use pre-norm (LN before attention/FFN) rather than post-norm, which stabilizes training for deep networks.
Definition: Positional Encoding Strategies
Positional Encoding Strategies
Since self-attention is permutation-equivariant, position must be injected explicitly:
- Learned absolute: added to token embeddings (GPT-1/2)
- Rotary (RoPE): Rotates query/key vectors by position-dependent angles, encoding relative position in the dot product:
- ALiBi: Adds linear bias to attention scores
RoPE is used in LLaMA, Mistral, and most modern open LLMs.
Definition: KV-Cache for Efficient Inference
KV-Cache for Efficient Inference
During autoregressive generation, the key and value projections for all previous tokens are cached:
At step , only the new token's is computed, and attention is .
Memory cost: per sequence.
For a 70B model with , the KV-cache requires ~16 GB, often exceeding the model weights in memory.
Definition: Parameter Count Formula
Parameter Count Formula
For a GPT with layers, dimension , and vocabulary :
The comes from: 4 projection matrices (, each ) plus 2 FFN matrices ( and ), totaling per layer.
Definition: Grouped Query Attention (GQA)
Grouped Query Attention (GQA)
GQA reduces KV-cache size by sharing key/value heads across multiple query heads. With query heads and KV groups:
- Multi-Head Attention (MHA): (no sharing)
- Multi-Query Attention (MQA): (all queries share one KV)
- GQA: (each group of queries shares KV)
LLaMA 2 70B uses GQA with , reducing KV-cache by .
Theorem: FLOPs per Token in Forward Pass
For a GPT with parameters, the approximate FLOPs for a single forward pass on one token is:
For training (forward + backward), the total is approximately FLOPs per token. Training a model on tokens costs:
Each parameter participates in one multiply-add (2 FLOPs) during the forward pass. Backward pass costs roughly 2x forward.
Example: Building a Minimal GPT in PyTorch
Implement a GPT model with configurable depth and width. Count parameters and verify against the formula.
Implementation
import torch
import torch.nn as nn
class GPTBlock(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(
d_model, n_heads, batch_first=True)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.GELU(),
nn.Linear(4 * d_model, d_model),
)
def forward(self, x, mask=None):
h = self.ln1(x)
h, _ = self.attn(h, h, h, attn_mask=mask)
x = x + h
x = x + self.ffn(self.ln2(x))
return x
class GPT(nn.Module):
def __init__(self, vocab_size, d_model, n_heads, n_layers, max_len):
super().__init__()
self.tok_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_len, d_model)
self.blocks = nn.ModuleList([
GPTBlock(d_model, n_heads) for _ in range(n_layers)])
self.ln_f = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, idx):
B, T = idx.shape
tok = self.tok_emb(idx)
pos = self.pos_emb(torch.arange(T, device=idx.device))
x = tok + pos
mask = torch.triu(torch.ones(T, T, device=idx.device),
diagonal=1).bool()
for block in self.blocks:
x = block(x, mask)
return self.head(self.ln_f(x))
model = GPT(50257, d_model=768, n_heads=12, n_layers=12, max_len=1024)
n_params = sum(p.numel() for p in model.parameters())
formula = 12 * 12 * 768**2 + 50257 * 768
print(f"Actual: {n_params:,}")
print(f"Formula: {formula:,}")
Example: KV-Cache Implementation
Implement KV-caching for efficient autoregressive generation.
Key Idea
class CachedAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_qkv = nn.Linear(d_model, 3 * d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, cache=None):
B, T, d = x.shape
qkv = self.W_qkv(x).reshape(B, T, 3, self.n_heads, self.d_k)
q, k, v = qkv.unbind(2)
if cache is not None:
k_prev, v_prev = cache
k = torch.cat([k_prev, k], dim=1)
v = torch.cat([v_prev, v], dim=1)
new_cache = (k, v)
# Attention with full K, V
scores = (q.transpose(1,2) @ k.transpose(1,2).transpose(-2,-1))
scores = scores / (self.d_k ** 0.5)
attn = scores.softmax(-1)
out = (attn @ v.transpose(1,2)).transpose(1, 2)
return self.W_o(out.reshape(B, T, d)), new_cache
GPT Parameter Calculator
Calculate parameter count for different GPT configurations
Parameters
KV-Cache Memory Calculator
Estimate KV-cache memory for different model sizes and sequence lengths
Parameters
GPT Architecture Diagram
Quick Check
Why do modern GPTs use pre-norm (LayerNorm before attention) instead of post-norm?
It reduces parameter count
It stabilizes training for very deep networks
It improves inference speed
Pre-norm prevents gradient explosion/vanishing in deep transformers by normalizing inputs to each sub-layer.
Common Mistake: Underestimating KV-Cache Memory
Mistake:
Planning GPU memory budget based only on model weights.
Correction:
For long sequences, KV-cache can exceed model weight memory. A 7B model at 4096 tokens uses ~2 GB for KV-cache in fp16. At 32K tokens, this grows to ~16 GB. Always account for KV-cache when planning batch sizes.
Historical Note: The Evolution of GPT
2018-2023GPT-1 (2018, 117M params) showed that pre-training a transformer on unlabeled text, then fine-tuning on tasks, outperformed task-specific architectures. GPT-2 (2019, 1.5B) demonstrated emergent few-shot abilities. GPT-3 (2020, 175B) made in-context learning practical. GPT-4 (2023) remains largely undocumented but represents a massive leap in capability.
GPT (Generative Pre-trained Transformer)
A family of decoder-only transformer language models trained with next-token prediction, starting with GPT-1 (2018) at OpenAI.
Related: KV-Cache
KV-Cache
A cache storing previously computed key and value tensors during autoregressive generation, avoiding redundant computation at the cost of memory.