Training / Normalization

Batch Normalization

Intermediate [3/5]
BatchNorm BN

Definition

Batch normalization normalizes layer inputs by subtracting the batch mean and dividing by the batch standard deviation. This stabilizes training by reducing internal covariate shift—the change in input distributions as parameters update.

Note: Transformers typically use Layer Normalization instead, which normalizes across features rather than across the batch.

Key Concepts

  • Normalize across batch: Uses statistics from current mini-batch
  • Learnable parameters: Scale (γ) and shift (β) restore expressiveness
  • Running statistics: Tracks mean/var for inference
  • Batch dependency: Requires sufficiently large batch sizes

Examples

Mathematics
Batch Normalization Formula
BATCH NORMALIZATION FORMULA: Given: mini-batch of m examples {x₁, x₂, ..., xₘ} STEP 1: Compute batch statistics μ_B = (1/m) × Σ xᵢ # batch mean σ²_B = (1/m) × Σ (xᵢ - μ_B)² # batch variance STEP 2: Normalize x̂ᵢ = (xᵢ - μ_B) / √(σ²_B + ε) STEP 3: Scale and shift (learnable) yᵢ = γ × x̂ᵢ + β Where: - ε = small constant for numerical stability (1e-5) - γ = learned scale parameter (initialized to 1) - β = learned shift parameter (initialized to 0) WHY SCALE AND SHIFT? Without γ, β: output forced to mean=0, var=1 With γ, β: network can learn any mean/variance If optimal is original distribution, γ=σ, β=μ recovers it EXAMPLE: Batch activations: [2.0, 4.0, 6.0, 8.0] μ_B = 5.0 σ_B = √5 ≈ 2.24 Normalized: [-1.34, -0.45, 0.45, 1.34] (Mean = 0, Std ≈ 1)
Comparison
BatchNorm vs LayerNorm
NORMALIZATION COMPARISON: Input shape: [batch=4, seq=3, dim=2] BATCH NORMALIZATION: Normalize across batch dimension for each feature Batch 1: [[1, 2], [3, 4], [5, 6]] ↓ Batch 2: [[2, 3], [4, 5], [6, 7]] ↓ Average Batch 3: [[3, 4], [5, 6], [7, 8]] ↓ across Batch 4: [[4, 5], [6, 7], [8, 9]] ↓ batches Problems for sequences: - Different sequence lengths - Small batches → noisy statistics - Can't use with batch size 1 LAYER NORMALIZATION (used in transformers): Normalize across feature dimension for each sample Batch 1: [[1, 2], [3, 4], [5, 6]] → normalize each mean/std across dim=2 Benefits: - Independent of batch size - Works with variable sequences - Consistent train/inference # PyTorch # BatchNorm (CNNs) nn.BatchNorm2d(num_features) # LayerNorm (Transformers) nn.LayerNorm(normalized_shape) RULE OF THUMB: - CNNs, vision: BatchNorm - Transformers, NLP: LayerNorm - RNNs: LayerNorm

Interactive Exercise

Choose Normalization

Why do transformers use LayerNorm instead of BatchNorm? What problems would BatchNorm cause?

Pro Tips
  • BatchNorm's running statistics can cause train/eval mismatches
  • RMSNorm (no mean subtraction) is gaining popularity for efficiency
  • Pre-norm (before attention) vs post-norm (after) affects stability
  • Group Normalization works with small batches (compromise)

Related Terms