Model Architecture / Attention

Causal Masking

Advanced [4/5]
Causal attention mask Autoregressive mask Future masking

Definition

Causal masking is the technique of preventing tokens from attending to future positions in the sequence during self-attention. This ensures the model can only "see" past context when predicting each token, maintaining the autoregressive property.

Without causal masking, the model could "cheat" by looking at future tokens during training.

Key Concepts

  • Lower triangular: Mask is a lower triangular matrix of 1s
  • Applied to attention: Masked positions set to -infinity before softmax
  • Training efficiency: Enables parallel training on full sequences
  • Causality: Present depends only on past, not future

Examples

Visualization
Causal Attention Mask
CAUSAL MASK VISUALIZATION: Sequence: "The cat sat on" ATTENTION MASK (lower triangular): The cat sat on The [ 1 0 0 0 ] cat [ 1 1 0 0 ] sat [ 1 1 1 0 ] on [ 1 1 1 1 ] 1 = can attend (allowed) 0 = masked (blocked) HOW IT'S APPLIED: 1. Compute attention scores: QK^T 2. Add mask (0 stays, masked → -∞) 3. Softmax: e^(-∞) = 0 ATTENTION SCORES BEFORE MASK: The cat sat on The [ 5.2 3.1 2.4 1.8 ] cat [ 4.1 4.5 3.2 2.1 ] sat [ 3.8 4.2 5.1 2.9 ] on [ 2.5 3.1 3.8 4.2 ] AFTER MASK (add -inf to upper triangle): The cat sat on The [ 5.2 -∞ -∞ -∞ ] cat [ 4.1 4.5 -∞ -∞ ] sat [ 3.8 4.2 5.1 -∞ ] on [ 2.5 3.1 3.8 4.2 ] AFTER SOFTMAX (future becomes 0): The cat sat on The [ 1.0 0 0 0 ] cat [ 0.4 0.6 0 0 ] sat [ 0.2 0.3 0.5 0 ] on [ 0.15 0.2 0.3 0.35 ]
Implementation
Causal Mask in Code
PYTORCH IMPLEMENTATION: import torch import torch.nn.functional as F def causal_attention(Q, K, V, d_k): seq_len = Q.size(-2) # 1. Compute attention scores scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) # 2. Create causal mask mask = torch.triu( torch.ones(seq_len, seq_len), diagonal=1 ).bool() # 3. Apply mask (set future to -inf) scores = scores.masked_fill(mask, float('-inf')) # 4. Softmax (future positions become 0) attn_weights = F.softmax(scores, dim=-1) # 5. Weighted sum of values return torch.matmul(attn_weights, V) WHY THIS ENABLES PARALLEL TRAINING: Without causal mask: - Must generate tokens sequentially - Can't parallelize during training With causal mask: - Process entire sequence at once - Each position sees only valid context - Training is O(seq_len²) not O(seq_len³) - Same result as sequential generation!

Interactive Exercise

Understand Masking

In a sequence of 4 tokens, what percentage of attention connections are masked (blocked) by the causal mask?

Pro Tips
  • Flash Attention optimizes causal attention for GPU memory
  • Prefix LM allows bidirectional attention on prompt, causal on generation
  • Sliding window attention limits context but maintains causality
  • Causal mask is typically fused with attention kernel for efficiency

Related Terms