Training / Optimization

Gradient Descent

Intermediate [3/5]
GD Optimization Weight update

Definition

Gradient descent is the optimization algorithm that updates model weights to minimize the loss function. It moves weights in the direction opposite to the gradient (steepest decrease), iteratively finding better parameter values.

All modern neural network training relies on variants of gradient descent to learn from data.

Key Concepts

  • Gradient: Direction of steepest increase in loss
  • Learning rate: Step size for each update
  • Stochastic (SGD): Uses mini-batches, not full dataset
  • Optimizers: Adam, AdamW improve on basic gradient descent

Examples

Intuition
Finding the Valley
GRADIENT DESCENT INTUITION: Imagine standing on a hilly landscape in fog. Goal: Find the lowest point (minimum loss). Strategy: 1. Feel the slope under your feet (compute gradient) 2. Step downhill (opposite to gradient) 3. Repeat until flat (convergence) LOSS LANDSCAPE VISUALIZATION: Loss │ ╱╲ │ ╱ ╲ Start here: high loss │ ╱ ╲ ↓ │ ╱ ╲ Gradient points uphill │╱ ● ╲ Move opposite direction │ ↓ step ↓ │ ● ╲ Lower loss │ ↓ ↓ │ ● ╲ Even lower │ ↘ ↓ │ ● Minimum! Stop. └─────────────────────── Weights Basic update rule: w_new = w_old - learning_rate × gradient
Variants
Types of Gradient Descent
GRADIENT DESCENT VARIANTS: 1. BATCH GRADIENT DESCENT - Use entire dataset for each update - Accurate but slow and memory-heavy gradient = (1/N) × Σ ∇L(xᵢ) 2. STOCHASTIC GRADIENT DESCENT (SGD) - Use single sample per update - Noisy but fast gradient = ∇L(xᵢ) 3. MINI-BATCH SGD (Most Common) - Use batch of samples (e.g., 32, 64, 256) - Balance of speed and accuracy gradient = (1/B) × Σ ∇L(xᵢ) 4. SGD WITH MOMENTUM - Accumulate past gradients for smoother updates velocity = β×velocity + gradient w = w - lr×velocity 5. ADAM (Adaptive Moment Estimation) - Adapts learning rate per parameter - Tracks first and second moments - Most popular for transformers m = β₁×m + (1-β₁)×gradient # momentum v = β₂×v + (1-β₂)×gradient² # squared grad w = w - lr × m / (√v + ε) # PyTorch optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) loss.backward() # Compute gradients optimizer.step() # Update weights optimizer.zero_grad() # Clear gradients

Interactive Exercise

Learning Rate Impact

What happens if the learning rate is too high? Too low? What's a good strategy for choosing it?

Pro Tips
  • AdamW (with weight decay) is the default for transformer training
  • Learning rate schedulers (cosine, linear decay) improve convergence
  • Gradient accumulation simulates larger batch sizes
  • Loss plateaus often indicate learning rate should decrease

Related Terms