Saral Shiksha Yojna
Courses/Computer Vision

Computer Vision

CSE471
Prof. Makarand Tapaswi + Prof. Charu SharmaSpring 2025-264 credits

Modern Transformer Upgrades

NotesStory

Intuition

The 2017 Transformer is still the architecture every state-of-the-art LLM and ViT runs on today. *Broadly speaking* — but "broadly speaking" hides a lot of work. Almost every component has been quietly replaced between 2017 and 2025. Not torn down, refined. Same skeleton, modernised organs. The slide pack is named *"ViT-5: Vision Transformers for The Mid-2020s"* — and the framing is exactly that. Nine specific upgrades, grouped into four themes: stability (Pre-Norm, LayerScale), normalisation (RMSNorm, QK-Norm), efficiency (Flash Attention, KV-cache, GQA), positional encoding (RoPE), and attention hygiene (Registers). Knowing what each one does and why is the entire exam.

Explanation

Why does any of this matter? When you train a 100-billion-parameter Transformer on a trillion tokens, every one of these is the difference between *"it diverges after 50k steps"* and *"it converges to GPT-4."* These are not nice-to-haves — they are the engineering reality of modern Transformers.

Change #1 — Post-Norm → Pre-Norm. *2017 original:* — LayerNorm *after* the residual addition. *Modern:* — LayerNorm *before* the sublayer, inside the residual branch. Why this matters: in Pre-Norm, the residual stream flows directly from input to output, *untouched by any normalisation*. This direct path is the residual stream — the highway gradients use to flow back to early layers. Post-Norm squashes every residual through a LayerNorm before passing on; gradients have to push through a normaliser at every layer, and very deep Transformers (24+ layers) become unstable to train without careful warmup. Exam line: *Pre-Norm improves training stability by giving gradients a direct, unnormalised path through the residual stream.* Every modern LLM (GPT, LLaMA, PaLM, Gemma) is Pre-Norm.

Change #2 — LayerNorm → RMSNorm. *LayerNorm* centres then scales: ; ; ; output . Two learnable params per dim. *RMSNorm* drops the centring step: ; ; output . **No mean subtraction, no bias .** Slightly faster (fewer ops) and empirically performs as well as or better than LayerNorm. LLaMA, Gemma, and many recent LLMs use RMSNorm.

The geometric quiz — the slide-7 exam target. *Picture a 2D scatter of points .* LayerNorm: centring () shifts every point so the projection onto the all-ones vector is zero — every point lands on the line (the anti-diagonal). Scaling () puts each point at unit distance from the origin on that line. Result: all points lie on two specific symmetric points and . RMSNorm: no centring. Scaling sends every point to the full unit circle (radius in dims).

The geometric quiz answer in one box. *LayerNorm:* points → small circle on the -dim zero-mean hyperplane (1 magnitude DOF + 1 mean DOF removed = 2 constraints). *RMSNorm:* points → full -dim sphere of radius (only 1 magnitude constraint). Memorise: *"LayerNorm collapses onto the anti-diagonal line; RMSNorm spreads around the full unit circle."* Dropping the mean doesn't hurt because the LLM has plenty of other ways to centre activations (bias terms in linear layers, learnable ).

Change #3 — LayerScale (Touvron et al., CaiT, 2021). In very deep Transformers, the residual contribution from each sublayer can either explode or vanish. LayerScale inserts a per-channel learnable diagonal scaling into the residual branch: where is initialised to a tiny value (). At the start of training, every sublayer contributes almost nothing — the network behaves like a stack of identity functions, and gradients flow trivially. As training proceeds, useful directions in grow. Exam line: *LayerScale gives each sublayer a learnable "volume knob" per channel, initialised near zero so training begins from a near-identity network.*

Change #4 — QK-Norm (Henry et al., 2020). Attention has a hidden numerical hazard: can grow large in magnitude (especially at long sequences or with large head dims), pushing softmax into saturation. The original paper's fix — divide by — is *approximately correct but not enough at scale*. QK-Norm applies LayerNorm (or RMSNorm) to and independently before the dot product: , , then . Result: are unit-norm along their head dim, dot products bounded by per pair → softmax sees a well-conditioned input — no extreme values, no saturation, no exploding gradients.

Change #5 — Registers (Darcet et al., 2023: "Vision Transformers Need Registers"). Researchers training large ViTs noticed strange high-norm activations at uninformative patches — chunks of sky, blurred backgrounds, low-content regions. The model was using these patches as scratch space, dumping unrelated global information there. With no dedicated place to put global state, the network hijacks the least informative tokens and corrupts them. Attention maps then look chaotic.

The register fix. Append extra trainable tokens (typically 4–8) to the input sequence: . These have no associated patch — just learnable vectors *appended at the input and discarded at the output*. The network now has dedicated scratch space; high-norm activations cluster on the register tokens rather than corrupting real patches. Attention maps become clean; dense-prediction downstream tasks improve. Slide line verbatim: *Registers = "garbage collector" tokens; they prevent attention peaks at blank areas by giving the model dedicated scratch space.*

Change #6 — Flash Attention (Dao et al., NeurIPS 2022). The single most consequential efficiency improvement to Transformers in the last 5 years. An EXACT attention algorithm — same outputs as standard attention, no approximation — that is dramatically faster and uses far less memory.

The attention bottleneck Flash Attention solves. Standard attention computes . The intermediate has shape . For , this is a 64M-entry matrix. The bottleneck on modern GPUs isn't compute — it's moving the matrix between HBM (large, slow) and SRAM (small, ~100× faster). Standard attention writes the full matrix to HBM, reads it back for softmax, writes again, reads for the value multiplication. Lots of slow memory traffic.

The Flash Attention trick — tile and never materialise. Process in small blocks that fit in SRAM; compute softmax block-by-block using online streaming; never write the full matrix to HBM. The mathematical trick is the *online softmax algorithm*: keep a running max and running denominator , rescale the partial output when the max changes. Loop: load tile of → compute partial scores → update → discard tile. **Memory cost drops from to . Wall-clock 2–4× faster, sometimes more. Two exam emphases:** *(1) exact* — bitwise close to standard attention; *(2) IO-aware* — the speedup comes from minimising HBM ↔ SRAM data movement, not from doing less computation.

Change #7 — RoPE (Rotary Position Embedding, Su et al., 2021). For token at position , embedding dim , frequencies : rotation matrix on each pair of dimensions: . Rotated query/key: , . Critical property: — the attention score depends only on *relative* position. Three advantages over learned absolute PEs: extrapolates to longer sequences than training (no fixed-size lookup table); encodes relative position directly (depends only on ); multiplicative, not additive (rotation inside attention, not added to input embedding). Variants: 2D-RoPE for images, M-RoPE for video.

Change #8 — KV Caching. A *pure-inference* optimisation. When generating autoregressively, naïvely you re-compute attention over the entire history at every step — work per token, for length- generation. **Observation: the and projections for past tokens never change** — once you compute , they're fixed. So *store them in a cache*; at each new step: compute only for the new token, append the new to the cache, compute attention of new against all cached s and s. Per-token work: . Total generation: . This is the reason chatbot streaming is fast. *Trade-off:* memory — the cache grows linearly with sequence length, per layer, per head, per batch element. For long contexts and large models, KV-cache memory dominates.

Change #9 — Grouped-Query Attention (GQA). In MHA you have query heads, key heads, value heads → KV cache is proportional to . For a 70B-param LLM with , the KV cache becomes enormous at long contexts. Two extremes: *MHA* — separate heads (max capacity, max cache). *MQA (Multi-Query Attention)* — all query heads share one and one head (massive savings, factor of , but quality drops). GQA — the middle ground: divide the query heads into groups; each group shares one head and one head. query heads, key heads, value heads with . **KV cache is times smaller** than full MHA. Quality nearly as good as MHA. LLaMA-2, Gemma, Mistral, and most recent open LLMs use GQA.

The GQA trade-off in one line. *MHA:* max quality, max KV cache. *MQA:* min KV cache, some quality loss. *GQA:* tunable middle ground via . Typical: , smaller cache.

The modern Transformer block — putting it all together. Self-attention sublayer: (Pre-Norm + RMSNorm) → projections (GQA: K, V have fewer heads) → (QK-Norm) → apply RoPE to → append to KV-cache → Flash Attention → (LayerScale). MLP sublayer: same RMSNorm + LayerScale wrap around SwiGLU. For ViTs: add register tokens to the input, discard at output. Use 2D-RoPE for image patches. Every single line is different from the 2017 original.

Definitions

  • Residual streamThe unbroken identity path in a Pre-Norm Transformer; gradients flow directly through it from output back to input.
  • Pre-Norm. LayerNorm placed before the sublayer; residual stream is never normalised. Modern default.
  • Post-Norm. 2017 original; needs careful warmup at depth.
  • RMSNormDrop-the-mean variant of LayerNorm: . No mean subtraction, no bias . Cheaper and as effective.
  • LayerScalePer-channel learnable diagonal on each sublayer's residual contribution, initialised . Makes deep nets train as near-identity at init.
  • QK-NormApply LayerNorm (or RMSNorm) to and separately before the attention dot product. Prevents softmax saturation at long sequences / large head dims.
  • RegistersExtra learnable tokens prepended to the input with no positional encoding; act as global scratchpad. Without them, the model corrupts uninformative patches. Discarded at output.
  • Flash AttentionExact, IO-aware attention algorithm. Tiles into SRAM-sized blocks; computes streaming online softmax; never materialises the matrix in HBM. Memory . 2–4× faster.
  • Online softmaxStream the softmax: keep running max , denominator , output ; rescale when grows. Mathematically equivalent to standard softmax, computable tile-by-tile.
  • RoPE (Rotary Position Embedding)Rotate pairs of dimensions of and by angle proportional to position. Attention dot product depends only on relative position .
  • KV-cacheCache of for past tokens during autoregressive generation. Per-step compute ; total generation . Memory grows linearly.
  • MHA / GQA / MQAMulti-Head Attention: separate heads. Grouped-Query: shared groups. Multi-Query: (one for all heads). KV cache ratio: .

Formulas

Derivations

Why Pre-Norm is more stable. Under Pre-Norm, . The residual is *never* normalised — it passes through untouched. Variance of the residual stream grows additively across layers (bounded growth). Under Post-Norm, . The combined signal is normalised every layer; for very deep stacks, the LayerNorm interacts with the residual in ways that produce ill-conditioned Jacobians — gradients can explode or vanish at depth. Pre-Norm trains 100+-layer Transformers without warmup; Post-Norm needs careful LR warmup even at 24 layers.

The geometric LN-vs-RMSNorm derivation. LayerNorm in : projects onto the hyperplane (dimension ); normalises onto the unit sphere within . So output — a -sphere embedded in a -hyperplane. RMSNorm: skip the projection; scales so . Output (the full sphere in ). *Two DOFs removed by LayerNorm vs one by RMSNorm.*

RoPE preserves relative position. After rotating by angle and by angle : . So the dot product depends only on the difference , not absolute positions. This is why RoPE generalises to longer sequences than seen at training — absolute position never appears.

The online softmax trick (Flash Attention's mathematical core). Want without storing the full . Process in tiles. Keep running max and running sum and running output . When a new tile produces a larger max : rescale and by . Add the new tile's contributions. The final output equals the standard-softmax output, byte-for-byte (up to floating-point reduction order).

KV-cache asymptotics. Without cache: at step , compute attention over all tokens → per step (full-attention recompute) → total generation. With cache: at step , compute for new token + attention of 1 query against cached keys → per step → total. *Per-token: ; total: cubic to quadratic.*

GQA cache scaling. KV-cache size #KV heads sequence length. MHA: KV heads. GQA: KV heads. Ratio: . For LLaMA-2-70B with : smaller cache. At long contexts (32k+), this is the difference between *fits on one GPU* and *needs model parallelism*.

Examples

  • LLaMA 3 stack. Pre-Norm + RMSNorm + RoPE + GQA + Flash Attention + SwiGLU MLP. No LayerScale (decoder-only LLMs are not as deep as some ViT variants).
  • Geometric quiz in 2D. Points . LayerNorm: for first → . Second: . Third: . **All three points map to the same ** on the anti-diagonal! (Modulo the learnable , which can recover individuality.) RMSNorm: scale to RMS 1 — each point goes to a different location on the unit circle.
  • GQA in LLaMA-2-70B. query heads grouped into KV groups → KV cache shrinks with negligible quality loss.
  • Flash Attention v2 memory. At sequence length 32k, FP16 attention: standard requires GB just for the attention matrix per layer. Flash Attention requires MB per layer. >100× memory savings.
  • LayerScale init. means the sublayer initially contributes of its output to the residual — effectively zero. After training, grows to on useful channels.
  • Registers visual. Pre-register: heatmap of patch attention norms shows 5–10 spikes on patches of blank sky in a ViT-L. Post-register: spikes are absorbed by the 4 register tokens; patch attention is uniform on the object.
  • RoPE wavelengths. For , (period ), (period ). Different pairs encode position at different frequencies — fine-to-coarse positional decomposition.

Diagrams

  • Pre-Norm vs Post-Norm block. Two side-by-side block diagrams. Pre-Norm: residual line shown as a straight unbroken arrow from input to output; sublayer + LN forms a side branch. Post-Norm: residual goes into a final LN block; the LN is *on* the residual path.
  • LayerNorm vs RMSNorm geometric. 2D scatter: LayerNorm collapses points onto the anti-diagonal at unit distance from origin (2 points: and ). RMSNorm: same points spread around the full unit circle.
  • LayerScale architecture. A standard residual block with a diag(λ) element between the sublayer and the residual add; annotation: .
  • QK-Norm placement. Q and K projection → LN block → dot product → softmax → ×V.
  • Registers in a ViT sequence. Input: ; output: take CLS for classification; registers discarded.
  • Flash Attention tiling. Show the matrices partitioned into blocks; each iteration loads one tile into SRAM, computes a partial output, updates running stats, moves on. Annotate: *full matrix never lives in HBM.*
  • RoPE rotation. A pair of embedding dimensions shown as a 2D vector; rotation by angle shown as an arc.
  • GQA visualisation. query heads → groups of 2 query heads; each group shares one head and one head. Tree diagram showing the sharing.
  • KV-cache buildup. Three autoregressive steps: at cache has ; at cache has ; at adds . Only is computed fresh each step.

Edge cases

  • Pre-Norm output magnitude can drift with depth (the unbroken residual accumulates); a final LayerNorm before the classification head is still required.
  • Flash Attention requires GPU support — Ampere+ for v1, Hopper+ for v3. Older GPUs cannot run it.
  • **GQA with (MQA limit)** loses quality on tasks requiring fine-grained discrimination. GQA with is the empirical sweet spot.
  • KV-cache memory dominates at long contexts. LLaMA-2-70B at 32k context: KV cache > 20 GB; the model is held by KV cache memory, not by weights or activations.
  • LayerScale on shallow models is unnecessary — only useful for very deep nets where sublayer outputs can blow up the residual.
  • **RoPE base frequency ** can be insufficient for very long contexts; LLaMA-2 increased it; some long-context models use even larger bases.
  • Register tokens have no PE — adding positional embedding to them defeats the point (they're position-free scratchpad).

Common mistakes

  • Calling RMSNorm *"LayerNorm without bias"* — also drops the mean subtraction. Both ops are removed.
  • Treating RoPE as *additive* — it's multiplicative (rotation matrix applied to ).
  • Conflating Flash Attention with KV-cache — different problems (memory layout vs incremental compute).
  • Stating registers are positional — they have no positional encoding; that's the point.
  • Saying Flash Attention is *approximate* — it's exact. Same outputs as standard attention.
  • Writing the Pre-Norm equation as — that's Post-Norm. Pre-Norm puts LN on the sublayer's input.
  • Treating GQA as a quality-degraded MQA — it's the sweet spot between MHA (max quality) and MQA (min cache).
  • Confusing the *online softmax* trick with *online learning* — they're unrelated; "online" here means *streaming* (process one tile at a time).

Shortcuts

  • Nine modern upgrades: Pre-Norm, RMSNorm, LayerScale, QK-Norm, Registers, Flash Attention, RoPE, KV-cache, GQA.
  • Pre-Norm = unbroken residual; Post-Norm normalises the combined signal.
  • RMSNorm = LayerNorm minus mean minus bias. Same final shape, cheaper.
  • Geometric quiz answer: LayerNorm → small circle on anti-diagonal; RMSNorm → full sphere.
  • RoPE is multiplicative (rotation). Old PE is additive. RoPE attention depends only on relative position.
  • Flash Attention is exact + IO-aware. Speed comes from avoiding HBM↔SRAM traffic, not from doing less compute.
  • **KV-cache: ** generation; memory grows linearly.
  • **GQA reduces KV cache by ** (typical 8×). LLaMA-2/3, Gemma, Mistral all use GQA.

Proofs / Algorithms

RoPE encodes relative position. , . Then because and rotation matrices compose by adding angles. So the dot product depends only on the *difference* , hence on *relative* position. Absolute positions never appear in the score — RoPE generalises beyond training-time lengths.

Online softmax correctness. Given , define , , . Processing in tiles: maintain after seeing the first elements. When element arrives: ; . The rescaling factor corrects the previously accumulated terms for the new max. After processing all elements, and the output equals the standard softmax bytewise.

GQA KV-cache reduction. Cache size ( layers, KV heads, tokens, per-head dim, factor 2 for and ). MHA: . GQA: . Ratio: . For : 8× smaller cache, all other factors equal.