Attention Mechanism & Transformer Architecture
Intuition
Pre-2014 sequence-to-sequence models squeezed an entire source sentence through a single fixed vector — the *bottleneck*. Attention (Bahdanau 2015) let the decoder *look back* over the encoder's full history at each step, picking out the positions relevant to the word being generated. Three years later, Vaswani et al. 2017 asked the radical question: do we need the RNN at all? *Attention Is All You Need* — and every modern vision/language model from ViT to GPT to PaliGemma descends from that one paper.
Explanation
The three landmark papers — memorise the chain. *(1)* Bahdanau et al., ICLR 2015 — "Neural Machine Translation by Jointly Learning to Align and Translate" — invented attention as an *add-on* to RNNs. *(2)* Xu et al., ICML 2015 — "Show, Attend and Tell" — applied Bahdanau attention to image captioning. *(3)* Vaswani et al., NeurIPS 2017 — "Attention Is All You Need" — killed the RNN entirely; the whole model is attention. This is the Transformer.
Three Seq2Seq task types you must name. *Image captioning* — image → text (single input, sequence output). *Sentiment classification* — text → label (sequence input, single output). *Machine translation* — text → text (sequence to sequence).
The encoder-decoder paradigm and its bottleneck. Encoder RNN reads the source into hidden states ; takes the final as a single fixed vector summarising the entire source. Decoder RNN initialises from and generates tokens autoregressively: at each step, take the previously-generated token as input, produce a probability over the vocabulary, sample/argmax, repeat. The bottleneck: one vector cannot hold an entire sentence. *"The cat that sat on the mat"* and *"The cat sat on the mat"* must compress to the same 512-dim vector? Impossible. As source sentences get longer, performance collapses. This is the failure attention fixes.
Bahdanau attention — the four-step recipe. At decoder step : *(1)* Compute alignment scores for . Bahdanau's original score is the additive MLP . *(2)* Softmax: — these are the attention weights, summing to 1 over . *(3)* Context vector — weighted average of encoder states focused on positions relevant to step . *(4)* Combine and predict: ; .
Attention learns alignment as a byproduct. When you visualise the matrix for a translated sentence, you see a near-diagonal pattern — generating the French word for "cat" peaks attention on the encoder state for "cat". *No alignment supervision is given*, yet alignment emerges. This is the iconic Bahdanau heatmap and a guaranteed exam talking point.
Soft vs hard attention. *Soft* — is a continuous distribution, is a weighted average, differentiable. *Hard* — sample one encoder position discretely (or argmax); not differentiable, trained with REINFORCE. Image-captioning's Show-Attend-Tell compared both; soft won and became standard.
Training: teacher vs student forcing. *Teacher forcing* — feed the *ground-truth* previous word to the decoder. Fast, stable, but creates *exposure bias* — at inference the model only sees its own predictions, which may differ from the training distribution. *Student forcing* (scheduled sampling) — feed the decoder's own previous prediction . More realistic but harder to train (early mistakes propagate). At inference: always student forcing — there's no ground truth. Memorise this contrast.
Inference: greedy vs beam search. *Greedy* — pick the argmax token at each step. Fast, locally myopic. *Beam search* — keep the top- partial sequences at each step, expand and score each extension, keep the top- again. After steps, return the highest-scoring complete sequence. Trade compute for quality; typical or . Standard for translation and captioning.
Softmax temperature. . standard; flattens (exploratory); peaks (deterministic, near one-hot). The same idea appears in DINO's teacher sharpening (low ) and in InfoNCE.
The Transformer — kill the RNN. The recurrence in RNNs has a real cost: it *serialises* computation. You cannot process step 5 until step 4 finishes. If attention already lets the decoder look at any encoder position, why not let *every* encoder position look at every *other* encoder position, in parallel, with no recurrence at all? That's the Transformer. Original paper: encoder + decoder blocks; modern variants use 12–96+.
Self-attention — the engine. For an input : project to Queries, Keys, Values with three learnable matrices , , (each , with typically). Then scaled dot-product attention: . Read it: is — every token's query dot-producted with every token's key (pairwise similarities). Softmax along the last axis gives, for each query, a probability over keys. Multiply by — each query gets a weighted average of values. Output: .
**Why divide by ? — exam-gold.** With entries i.i.d. unit-variance, has variance . For , std is 8 — entries push softmax into the *saturated regime* where one logit dominates and gradients vanish. Dividing by rescales variance back to 1, keeping softmax in its useful regime. "Why scaled?" → "to prevent softmax saturation."
Multi-head attention. Instead of one big attention with , run parallel heads each with : ; output with . Each head can specialise on a different relationship — one for syntactic dependency, one for coreference, one for local patterns. Critical detail: heads do not change the total parameter count. Whether you have 1 head with or 12 heads with , the together always use parameters. Heads just partition the per-head dimension.
Three flavours of attention in the Transformer. *(1)* Encoder self-attention — all from the input sequence; no masking; bidirectional. *(2)* Decoder masked self-attention — all from the output-so-far; causal mask prevents peeking at future tokens. Mask is upper-triangular with above the diagonal; after softmax these become 0. Synonyms for the mask: *causal mask, autoregressive mask, look-ahead mask, left-to-right mask*. *(3)* Cross-attention — from the decoder's current state, from the encoder's output. Cross-attention is exactly Bahdanau attention generalised to Q-K-V form.
Positional encoding — restoring order. Self-attention is permutation-equivariant: shuffle the input tokens and the output tokens shuffle the same way. There is no notion of "position 3 vs 5" — without help, *"the cat sat on the mat"* and *"the mat sat on the cat"* produce equivalent outputs. Vaswani's fix: sinusoidal positional encoding added to the input embeddings: , . Each position gets a unique -dim vector built from sine/cosine pairs at exponentially decreasing frequencies. Why sinusoids? is a *linear function* of for any fixed offset , so the network can learn to use relative positions naturally. Variants you'll meet later: *learned absolute PEs* (BERT, ViT — simpler, but don't extrapolate beyond training-time lengths); RoPE (rotary, multiplicative on and ; relative-position-aware by construction; extrapolates better); 2D RoPE / M-RoPE for images and video.
The Transformer block in one snapshot. Encoder block: ; . Decoder block: masked-MSA → cross-attention → MLP, each Add & Norm. The MLP is — the 4× expansion ratio is the standard. Original 2017 paper used Post-Norm (); modern implementations use Pre-Norm () for stable deep stacks (this is what we'll meet again in Transformer Advances).
Practical batching: padding + masking. Sequences in a batch have different lengths → pad to the longest with [PAD]. Padding mask sets attention scores at padding positions to so they don't influence real tokens. Decoder needs both the *padding mask* AND the *causal mask* — combine element-wise (min, or sum in log-space).
Show-Attend-and-Tell — the image-captioning precursor. Image → CNN → grid of spatial features . At each LSTM step , compute attention weights over the locations conditioned on the LSTM state → context vector ; LSTM update . Visualising the attention maps shows the decoder "looking" at relevant image regions for each word — but attention reveals only *where* the model looks, not *whether it sees correctly*. The adversarial-colour experiment (caption says "red traffic light" but the model's heatmap is on an unrelated red object) exposes this.
Why the Transformer won — three reasons. *(1)* Parallelisation — self-attention computes all token interactions in parallel; training time on GPUs is dramatically faster than RNNs. *(2)* Constant path length — every token sees every other in one layer; no long-distance information decay. *(3)* Universality — the same architecture handles text, images (ViT), audio, video, even DNA — just change the tokenisation. Within 4 years (2017 → 2021), Transformers had taken over NLP, vision, speech, and multimodal everything.
Definitions
- Seq2Seq bottleneck — Pre-attention encoder-decoder RNNs compressed the entire source into a single fixed hidden vector; performance collapsed on long inputs.
- Bahdanau attention — Per-decoder-step weighted sum over encoder hidden states; weights computed by an additive MLP score; learns alignment as a byproduct.
- Q / K / V — Query / Key / Value — three learned projections of the input. Self-attention: all from same sequence. Cross-attention: Q from decoder, K, V from encoder.
- Scaled dot-product attention — . The keeps softmax in its useful (non-saturated) regime.
- Multi-head attention — parallel attention heads, each with ; concatenate outputs and project with . Same total params as single-head; heads can specialise.
- Causal mask (look-ahead mask) — Upper-triangular mask of ; added to attention scores pre-softmax; prevents the decoder from attending to future tokens.
- Cross-attention — Q from decoder's current state, K, V from encoder's output. Equivalent to Bahdanau attention in Q-K-V form.
- Sinusoidal positional encoding — Vaswani's at exponentially decreasing frequencies; allows linear expression of relative position; generalises to unseen lengths.
- Pre-Norm vs Post-Norm — Post-Norm (Vaswani 2017): ; needs warmup. Pre-Norm (modern): ; stable for deep stacks.
- Teacher forcing / student forcing — Training the decoder with ground-truth previous tokens (teacher) vs predicted previous tokens (student). Inference is always student-forcing.
- Beam search — Maintain top- partial sequences at each step; expand and score; keep top- again. Trades compute for quality; typical.
- Soft vs hard attention — Soft: continuous weighted average, differentiable. Hard: discrete sampling of one position, requires REINFORCE.
Formulas
Derivations
**Why divide by — the variance argument.** Assume have i.i.d. zero-mean, unit-variance components. Then is a sum of i.i.d. products; each product has variance 1, so the total has variance and std . For , std — large entries push softmax into the *saturated regime* where one logit dominates and gradients flatten. Dividing by rescales the variance back to 1, putting softmax in its useful regime.
Why sinusoidal PEs encode relative position linearly. uses sines and cosines at the same frequencies. For each frequency , the angle-addition identity gives — a linear combination of 's components, with coefficients depending only on . So a linear layer can recover relative offset from the absolute PEs.
Self-attention's path-length advantage. In an RNN, position depends on position through sequential cell applications — path length, vanishing-gradient risk. In self-attention, position attends directly to position in one layer — constant path length. Cost: pairwise scores per layer (the bill Flash Attention pays).
Parameter count of a Transformer block. Self-attention: 4 projections () each → . MLP: then → . Total per block: (ignoring LayerNorm's ). For : M per encoder block; for (ViT-B): M.
Multi-head parameter count is the same as single-head. heads, each with per-head . Per-head projections are . Sum over heads: . Identical to one head with . Heads partition the dimension; they don't multiply parameters.
Examples
- Bahdanau alignment heatmap. Translate "The agreement on the European Economic Area was signed in August 1992" to French. The matrix shows near-diagonal peaks; "signed" attends to "signé", "août" attends to "August". No alignment supervision was given — the model learns it.
- **Causal mask for a 4-token target .** . After softmax, → 0; token can only attend to tokens .
- Cross-attention in MT. Source English "the cat"; target French "le chat". Decoder query for "le" attends most to encoder K for "the"; query for "chat" attends most to "cat". The corresponding V vectors are the encoder's contextual representations, returned to the decoder.
- Scaled dot-product worked example. , with i.i.d. components. Unscaled dot product: variance 64, std 8 — softmax mostly concentrates on one entry. Scaled by : variance 1 — softmax produces a smooth distribution, gradients flow.
- Multi-head head count vs params. For : total M, regardless of whether you split into 1, 8, or 12 heads.
- Beam search example. . Vocab has 30k tokens; at each step expand each of 4 beams over 30k options, score 120k extensions, keep top 4. After 20 steps, return highest-scoring complete sequence. Trade 4× compute for typically +1-2 BLEU over greedy.
- Softmax temperature in DINO. Teacher → very peaky distribution → sharp pseudo-targets. Student → softer outputs. Sharpening pushes teacher away from collapse-to-uniform.
Diagrams
- Bahdanau attention — decoder at step : alignment scores over all encoder positions, softmax → , weighted sum → context , combined with decoder state to predict .
- Transformer architecture — encoder stack ( blocks, each: self-attention + FFN) + decoder stack ( blocks, each: masked self-attention + cross-attention + FFN). Inputs: source + target embeddings + positional encodings.
- Scaled dot-product attention — , , projections; ; softmax along last axis; multiply by ; output .
- Multi-head attention — parallel heads with their own , concatenate, project with .
- Causal mask matrix — upper-triangle filled with ; lower triangle and diagonal with 0; applied additively before softmax.
- Sinusoidal PE visualisation — heatmap of for , : high-frequency sinusoids in early dimensions, low-frequency in later dimensions.
- Show-Attend-and-Tell — CNN features → spatial attention map per generated word; visualise the heatmap shifting with the caption.
Edge cases
- Long sequences are attention-bound. memory for the attention matrix; for k tokens, FP16 attention can dominate the activation memory. Flash Attention reduces wall-clock without changing asymptotics.
- Padding interactions with the causal mask. A naive implementation that adds the padding mask after softmax (rather than before) leaks information from pad tokens. Always combine masks pre-softmax.
- Learned-PE extrapolation failure. Train ViT/BERT at sequence length 512; evaluate at 1024 → the model has no PE for positions 513–1024 and accuracy collapses. Mitigations: PE interpolation (ViT), RoPE, ALiBi.
- **Softmax saturation on large .** Without the scale, training stalls for any moderately large — gradients vanish through the saturated softmax.
- Decoder cross-attention quality is encoder-bound. A bad encoder bottlenecks the decoder; cross-attention can only retrieve information that's actually in the K, V outputs.
- Exposure bias under teacher forcing. A decoder trained only with ground-truth previous tokens may collapse at inference once it starts seeing its own (possibly wrong) outputs. Scheduled sampling and reinforcement-learning fine-tuning mitigate.
Common mistakes
- Stating multi-head attention has " more parameters than single-head" — no, identical. Heads partition ; total is regardless of .
- Forgetting the causal mask in the decoder's first sub-layer — leaks future tokens; the model trivially copies and fails at inference.
- Adding positional encoding after the encoder (or after the embedding lookup but after layer 1) instead of to the input embedding before block 1.
- Writing the score scaling as instead of — the variance argument requires the square root.
- Saying "cross-attention has from the decoder and from the encoder" — no. from decoder, AND from encoder. (Q-K asymmetry is the Bahdanau pattern.)
- Confusing soft attention (continuous, differentiable, weighted sum) with hard attention (discrete sample, REINFORCE).
- Claiming the original Transformer used Pre-Norm — it used Post-Norm; modern Transformers (and ViT) switched to Pre-Norm for deep-stack stability.
Shortcuts
- Sub-layer count: encoder block = 2 (MSA, FFN); decoder block = 3 (masked MSA, cross-attn, FFN).
- in the base Transformer.
- Multi-head params = single-head params — heads partition , don't multiply.
- **Scaling: , not .** The variance argument requires the square root.
- Causal mask synonyms: look-ahead, autoregressive, left-to-right. Same thing.
- Cross-attention = Bahdanau attention in Q-K-V form. Memorise the connection.
- Sinusoidal PE generalises to longer sequences; learned PE does not (in the trivial implementation).
- Teacher forcing = ground-truth previous tokens (training only); student forcing = predicted previous tokens; inference is always student forcing.
Proofs / Algorithms
Self-attention is permutation-equivariant. For permutation : . Hence shuffling inputs shuffles outputs identically — without positional encoding, the Transformer cannot tell "cat sat" from "sat cat".
Causal-mask correctness. After adding with for and softmaxing, the row- distribution has zero mass on . So — output at position depends only on positions . Autoregressive property preserved.
Sinusoidal PE: relative position is linear. For frequency : . So 's sine component is a linear combination of and , with -dependent coefficients. A linear layer can extract relative position from absolute PEs.