Inference / Optimization

KV Cache

Advanced [4/5]
Key-Value cache Attention cache KV store

Definition

KV Cache stores the Key and Value projections from previous tokens during autoregressive generation, avoiding redundant computation. Instead of reprocessing all tokens for each new generation step, the model only computes K and V for the new token.

KV caching is essential for fast LLM inference, reducing generation from O(n²) to O(n) per token.

Key Concepts

  • Memory trade-off: Uses memory to save computation
  • Incremental: Each step appends new K, V to cache
  • Per-layer: Separate cache for each transformer layer
  • Memory bound: Cache size limits context length

Examples

Mechanism
How KV Cache Works
GENERATION WITHOUT KV CACHE (naive): Step 1: "The" → compute K,V for [The] Step 2: "The cat" → compute K,V for [The, cat] (redundant!) Step 3: "The cat sat" → compute K,V for [The, cat, sat] (redundant!) Work per step: O(seq_len) → total O(seq_len²) GENERATION WITH KV CACHE: Step 1: "The" - Compute K₁,V₁ for "The" - Cache: {K: [K₁], V: [V₁]} Step 2: "cat" - Compute K₂,V₂ for "cat" ONLY - Cache: {K: [K₁,K₂], V: [V₁,V₂]} - Attend to cached K,V Step 3: "sat" - Compute K₃,V₃ for "sat" ONLY - Cache: {K: [K₁,K₂,K₃], V: [V₁,V₂,V₃]} - Attend to cached K,V Work per step: O(1) for K,V compute → total O(seq_len) SPEEDUP: ~seq_len × faster!
Memory
KV Cache Memory Usage
KV CACHE MEMORY CALCULATION: For each layer, per token: K: (num_heads × head_dim) × bytes_per_param V: (num_heads × head_dim) × bytes_per_param EXAMPLE: LLaMA 7B (FP16) - Layers: 32 - Heads: 32 - Head dim: 128 - Bytes: 2 (FP16) Per token per layer: K: 32 × 128 × 2 = 8,192 bytes V: 32 × 128 × 2 = 8,192 bytes Total: 16,384 bytes Per token (all layers): 32 layers × 16,384 = 524,288 bytes ≈ 0.5 MB For 4096 token context: 4096 × 0.5 MB = 2 GB just for KV cache! MEMORY BREAKDOWN: ┌────────────────┬───────────────┐ │ Component │ 7B Model │ ├────────────────┼───────────────┤ │ Model weights │ ~14 GB (FP16) │ │ KV cache (4K) │ ~2 GB │ │ Activations │ ~1 GB │ │ Total │ ~17 GB │ └────────────────┴───────────────┘ Long context = huge KV cache = memory issues

Interactive Exercise

Calculate Savings

If generating 100 tokens without KV cache requires computing 5050 K,V pairs total, how many with KV cache?

Hint: 1+2+3+...+100 = 5050

Pro Tips
  • PagedAttention (vLLM) manages KV cache like virtual memory
  • GQA/MQA reduces KV cache by sharing keys/values across heads
  • Quantizing KV cache (FP8) can halve memory usage
  • Sliding window attention limits KV cache growth

Related Terms