Training / Optimization

Cross-Entropy Loss

Intermediate [3/5]
Log loss Negative log likelihood CE loss

Definition

Cross-entropy loss measures the difference between two probability distributions—the model's predicted distribution and the true distribution. It's the standard loss function for classification tasks and language modeling, heavily penalizing confident wrong predictions.

Every LLM is trained using cross-entropy loss on next-token prediction.

Key Concepts

  • Information theory: Based on bits needed to encode true distribution using predicted
  • Logarithmic: Heavily penalizes confident wrong predictions
  • Bounded below: Minimum is 0 when prediction equals truth
  • Softmax pairing: Usually computed with softmax for numerical stability

Examples

Mathematics
Cross-Entropy Formula
CROSS-ENTROPY LOSS: H(p, q) = -Σ p(x) × log(q(x)) Where: - p = true distribution (one-hot for classification) - q = predicted distribution (from softmax) SIMPLIFIED FOR CLASSIFICATION: Loss = -log(q_correct) Only the probability of the correct class matters! EXAMPLE - 3-class classification: True label: Class B (one-hot: [0, 1, 0]) Model prediction (after softmax): Class A: 0.1 Class B: 0.7 ← correct class Class C: 0.2 Loss = -log(0.7) = 0.357 If model was more confident: Class B: 0.95 → Loss = -log(0.95) = 0.051 If model was wrong: Class B: 0.05 → Loss = -log(0.05) = 2.996 The penalty grows EXPONENTIALLY for wrong predictions!
LLM Training
Language Model Loss Computation
TRAINING A LANGUAGE MODEL: Sentence: "The quick brown fox" Tokenized: ["The", "quick", "brown", "fox"] Model predicts next token at each position: Position 0: Given "" Predict: "The" (actual) P("The") = 0.02 → Loss = -log(0.02) = 3.91 Position 1: Given "The" Predict: "quick" (actual) P("quick") = 0.15 → Loss = -log(0.15) = 1.90 Position 2: Given "The quick" Predict: "brown" (actual) P("brown") = 0.40 → Loss = -log(0.40) = 0.92 Position 3: Given "The quick brown" Predict: "fox" (actual) P("fox") = 0.60 → Loss = -log(0.60) = 0.51 TOTAL LOSS = Average = (3.91 + 1.90 + 0.92 + 0.51) / 4 = 1.81 # PyTorch import torch.nn.functional as F loss = F.cross_entropy( logits, # [batch, seq_len, vocab_size] target_ids, # [batch, seq_len] ignore_index=pad_token_id # Don't count padding )

Interactive Exercise

Calculate Cross-Entropy

A model predicts P(correct_class) = 0.5. What is the cross-entropy loss? What about P = 0.9?

Hint: Loss = -log(P), use natural log

Pro Tips
  • Cross-entropy = log(perplexity), so they're directly related
  • Label smoothing adds small probability to wrong classes, reducing overconfidence
  • Use log_softmax + NLLLoss for numerical stability in PyTorch
  • Focal loss down-weights easy examples to focus on hard ones

Related Terms