Model Architecture / Transformer Components

Attention Mechanism

Advanced [4/5]
Attention Scaled dot-product attention

Definition

Attention is a mechanism that allows a model to focus on relevant parts of the input when producing each part of the output. It computes weighted combinations of input representations, where weights indicate the "relevance" or "attention" given to each input element.

Attention is the core innovation behind transformers and modern LLMs, enabling them to capture long-range dependencies in text.

Key Concepts

  • Query, Key, Value: Three projections of input used to compute attention
  • Attention weights: How much to "attend" to each position
  • Weighted sum: Output is weighted combination of values
  • Softmax normalization: Weights sum to 1 for each query

Examples

Intuition
Attention as Relevance
Sentence: "The cat sat on the mat because it was tired" When processing "it", attention asks: "What does 'it' refer to?" Attention weights: "The" → 0.05 (low relevance) "cat" → 0.60 ← HIGH (likely referent!) "sat" → 0.08 "on" → 0.02 "the" → 0.03 "mat" → 0.15 (possible but less likely) "because" → 0.02 "it" → 0.03 "was" → 0.01 "tired" → 0.01 Attention mechanism learns that "it" should attend strongly to "cat" to understand the reference! This enables: - Coreference resolution - Long-range dependencies - Contextual understanding
Mathematics
Scaled Dot-Product Attention
ATTENTION FORMULA: Attention(Q, K, V) = softmax(QK^T / √d_k) × V WHERE: Q = Query matrix (what we're looking for) K = Key matrix (what we're looking at) V = Value matrix (information to retrieve) d_k = dimension of keys (for scaling) STEP BY STEP: 1. Compute similarity: QK^T → How similar is each query to each key 2. Scale: / √d_k → Prevents dot products from getting too large 3. Normalize: softmax(...) → Convert to probability distribution 4. Aggregate: × V → Weighted sum of values # PyTorch implementation import torch import torch.nn.functional as F def attention(Q, K, V, mask=None): d_k = Q.size(-1) scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) weights = F.softmax(scores, dim=-1) return torch.matmul(weights, V), weights

Interactive Exercise

Predict Attention Patterns

In the sentence "The lawyer asked the witness what she saw", when processing "she", which word should receive the highest attention?

Pro Tips
  • Attention patterns can be visualized to interpret model behavior
  • Causal attention (for generation) masks future tokens
  • O(n²) complexity limits context length—hence research on efficient attention
  • Different heads learn different attention patterns

Related Terms