Training / Model Compression

Distillation

Intermediate [3/5]
Knowledge distillation Model distillation Teacher-student training

Definition

Distillation trains a smaller "student" model to mimic a larger "teacher" model's behavior. Instead of learning from hard labels, the student learns from the teacher's soft probability distributions, capturing the teacher's "dark knowledge" about relationships between classes.

Distillation enables deploying smaller, faster models that retain much of the larger model's capability.

Key Concepts

  • Soft targets: Teacher's probability distribution over outputs
  • Temperature: Softens probabilities to reveal dark knowledge
  • Dark knowledge: Relationships between outputs (e.g., "cat" similar to "dog")
  • Combined loss: Mix of hard label loss and soft target loss

Examples

Concept
Hard vs Soft Labels
WHY SOFT LABELS HELP: HARD LABEL (one-hot): Image of cat → [0, 0, 1, 0, 0] dog cat bird fish car Tells student: "It's a cat. Period." TEACHER'S SOFT PREDICTION: Image of cat → [0.15, 0.02, 0.75, 0.05, 0.03] dog cat bird fish car ↑ "Kind of similar to a cat!" Tells student: - It's mostly a cat (0.75) - Has some dog features (0.15) - Not car-like at all (0.03) DARK KNOWLEDGE: The teacher learned: - Cats and dogs share features - Cats and cars are very different - Some birds look like cats (maybe ears?) This structural knowledge transfers to student! WITH TEMPERATURE (τ): softmax(logits / τ) τ = 1: Normal probabilities τ = 2: [0.20, 0.10, 0.50, 0.12, 0.08] (softer) τ = 5: [0.22, 0.18, 0.28, 0.18, 0.14] (very soft) Higher temp → reveals more relationships
Implementation
Distillation Loss Function
DISTILLATION LOSS: L = α × L_hard + (1-α) × L_soft Where: L_hard = CrossEntropy(student_logits, true_labels) L_soft = KL_Div( softmax(student_logits / T), softmax(teacher_logits / T) ) × T² PYTORCH IMPLEMENTATION: def distillation_loss( student_logits, teacher_logits, labels, temperature=4.0, alpha=0.5 ): # Soft target loss soft_targets = F.softmax(teacher_logits / temperature, dim=-1) soft_student = F.log_softmax(student_logits / temperature, dim=-1) soft_loss = F.kl_div(soft_student, soft_targets, reduction='batchmean') soft_loss = soft_loss * (temperature ** 2) # Scale gradient # Hard target loss hard_loss = F.cross_entropy(student_logits, labels) # Combined loss return alpha * hard_loss + (1 - alpha) * soft_loss LLM DISTILLATION APPROACHES: 1. Logit matching (above) 2. Response-based (generate text, compare) 3. Feature-based (match hidden states) 4. Synthetic data (teacher generates training data) EXAMPLE: Distilling GPT-4 to smaller model - Generate 1M examples from GPT-4 - Fine-tune 7B model on these examples - Student mimics teacher's outputs

Interactive Exercise

Temperature Effect

Teacher outputs logits [5.0, 1.0, 0.5]. Calculate softmax at temperature T=1 and T=2. Which reveals more information?

Pro Tips
  • Temperature 2-4 typically works well for distillation
  • Larger teacher-student gap may need more training data
  • For LLMs, synthetic data distillation is very effective
  • DistilBERT achieved 97% of BERT's performance at 60% size

Related Terms