Model Architecture / Transformer Components

Multi-Head Attention

Advanced [4/5]
MHA Multi-headed attention

Definition

Multi-head attention runs multiple attention operations in parallel, each with different learned projections. The outputs are concatenated and projected, allowing the model to attend to information from different representation subspaces simultaneously.

Different heads can learn to focus on different aspects: one might capture syntax, another semantics, another positional patterns.

Key Concepts

  • Multiple heads: Typically 8-96 parallel attention operations
  • Subspace projections: Each head projects to lower dimensions
  • Concatenation: Head outputs combined into single representation
  • Specialization: Different heads learn different patterns

Examples

Architecture
Multi-Head Attention Structure
INPUT: X (sequence of token embeddings, dim=512) SINGLE-HEAD ATTENTION: X → Q, K, V (each dim=512) → Attention → Output (512) MULTI-HEAD ATTENTION (8 heads): X → Split into 8 parallel streams Head 1: X → Q₁,K₁,V₁ (dim=64) → Attention₁ → Out₁ (64) Head 2: X → Q₂,K₂,V₂ (dim=64) → Attention₂ → Out₂ (64) Head 3: X → Q₃,K₃,V₃ (dim=64) → Attention₃ → Out₃ (64) ... Head 8: X → Q₈,K₈,V₈ (dim=64) → Attention₈ → Out₈ (64) CONCATENATE: [Out₁ | Out₂ | ... | Out₈] = 512 dimensions FINAL PROJECTION: W_O × Concat → Output (512) Total parameters same as single-head! 512/8 = 64 dimensions per head
Specialization
What Different Heads Learn
Sentence: "The cat that sat on the mat was fluffy" HEAD 1 (Syntactic - Subject-Verb): "cat" ←→ "was" (subject-verb agreement) "cat" ←→ "sat" (relative clause verb) HEAD 2 (Positional - Adjacent): Each token → previous/next token Local context patterns HEAD 3 (Semantic - Noun-Adjective): "cat" ←→ "fluffy" (attribute relationship) HEAD 4 (Coreference): "that" ←→ "cat" (reference resolution) HEAD 5 (Prepositional): "sat" ←→ "on" ←→ "mat" (PP attachment) INSIGHT: Different heads specialize automatically during training! This emerges from the learning process, not explicit design. Research shows heads learn interpretable patterns: - Syntax heads - Semantic heads - Positional heads - Copy heads (for repetition)

Interactive Exercise

Calculate Head Dimensions

If a model has embedding dimension 768 and uses 12 attention heads, what is the dimension per head?

Pro Tips
  • More heads ≠ always better; diminishing returns after ~12-16
  • Head dimension must divide evenly into model dimension
  • Some heads can be pruned with minimal quality loss
  • Grouped-query attention (GQA) shares K,V across head groups

Related Terms