2026-01-22 Sequential Multi-Stage Knowledge Distillation
Sequential Knowledge Distillation is the answer.
arXiv: 2601.15657 • Published: January 22, 2026
The Knowledge Distillation Integration Problem
Knowledge distillation compresses large teacher models into compact student models by transferring “knowledge”—but defining “knowledge” turns out to be surprisingly complex:
Response-based KD: Student mimics teacher’s output logits (soft labels)
Feature-based KD: Student matches teacher’s intermediate layer activations
Relation-based KD: Student replicates teacher’s pairwise sample relationships
Each method captures different aspects of the teacher’s behavior. The natural question: Can we combine multiple KD methods to transfer more complete knowledge?
Previous attempts:
Joint training: Mix multiple distillation losses (response + feature + relation)
Problem: Hyperparameter tuning nightmare (balancing 3+ loss terms)
Result: Often worse than single-method KD
Multi-teacher distillation: Ensemble of teachers, each using different KD method
Problem: Requires training multiple teachers (expensive)
Result: Limited by teacher diversity, not method diversity
The core failure mode: Catastrophic forgetting—when learning knowledge from Method B, the student forgets knowledge from Method A.
The Sequential Multi-Stage Framework (SMSKD)
This paper’s breakthrough: Learn KD methods sequentially, not jointly—but anchor each stage to prevent forgetting.
The architecture:
Stage 1: Train student with KD Method A (e.g., response-based)
Teacher: Large model (e.g., ResNet-50)
Student: Compact model (e.g., MobileNet-v2)
Loss: Standard KD loss (teacher-student divergence)
Output: Student model S₁
Stage 2: Train S₁ with KD Method B (e.g., feature-based)
Teacher: Same large model
Reference model: Frozen copy of S₁ from Stage 1
Loss: L_total = L_KD(teacher) + λ_ref * L_anchor(reference)
Output: Student model S₂
Stage 3: Train S₂ with KD Method C (e.g., relation-based)
Teacher: Same large model
Reference model: Frozen copy of S₂ from Stage 2
Loss: L_total = L_KD(teacher) + λ_ref * L_anchor(reference)
Output: Student model S₃
Key insight: The reference model from the previous stage acts as an anchor—preventing the student from forgetting knowledge learned in earlier stages while integrating new knowledge from the current stage.
The Adaptive Weighting Mechanism
Not all samples benefit equally from distillation. SMSKD introduces Teacher Class Probability (TCP) weighting to dynamically adjust the reference loss per sample:
Intuition: When the teacher is confident about a sample (high max probability), the student should prioritize learning from the teacher. When the teacher is uncertain, the student should prioritize preserving previous knowledge (high reference weight).
The weighting formula:
λ_ref(sample) = (1 - TCP) * λ_base
where TCP = max(teacher_logits) (softmax probability of top class)Example scenarios:
Easy sample: Teacher predicts class A with 95% confidence
TCP = 0.95
λ_ref = (1 − 0.95) * λ_base = 0.05 * λ_base
Effect: Low reference weight → prioritize learning from teacher
Hard sample: Teacher predicts class A with 55% confidence (ambiguous)
TCP = 0.55
λ_ref = (1 − 0.55) * λ_base = 0.45 * λ_base
Effect: High reference weight → preserve previous knowledge
Rationale: For ambiguous samples, the teacher’s guidance is noisy—better to trust the student’s previous learned behavior than to overfit to uncertain teacher predictions.
Ablation Study: What Actually Works?
The paper systematically tests each component:
Experiment 1: Sequential vs. Joint Training
Joint (all losses together): 72.4% accuracy
Sequential (3 stages): 74.8% accuracy
Gain: +2.4% from sequential integration
Experiment 2: With vs. Without Reference Model
Sequential, no reference: 71.2% accuracy (catastrophic forgetting!)
Sequential, with reference: 74.8% accuracy
Gain: +3.6% from reference anchoring
Experiment 3: Fixed vs. Adaptive (TCP) Weighting
Fixed λ_ref = 0.5: 73.9% accuracy
Adaptive (TCP-based): 74.8% accuracy
Gain: +0.9% from adaptive weighting
Experiment 4: Number of Stages
1 stage (baseline): 70.1%
2 stages: 73.2%
3 stages: 74.8%
4 stages: 74.6% (diminishing returns)
Optimal: 3 stages for this task
Key finding: Reference model supervision is the primary contributor (+3.6%), sequential staging is secondary (+2.4%), adaptive weighting is complementary (+0.9%).
Experimental Results Across Architectures
Test Setup:
Dataset: CIFAR-100 (100-class image classification)
Teachers: ResNet-50, WideResNet-40-2, VGG-19
Students: MobileNetV2, ShuffleNetV2, ResNet-18
KD Methods: Response-based (KD), Feature-based (FitNets), Relation-based (RKD)
Accuracy Comparison (ResNet-50 Teacher → MobileNetV2 Student):
Method Test Accuracy Student (no distillation) 68.2% Single-method KD 70.1% Joint multi-method 72.4% Multi-teacher ensemble 72.8% SMSKD (3 stages) 74.8% Teacher (upper bound) 78.4%
Cross-Architecture Results (CIFAR-100):
Teacher → Student Baseline Joint SMSKD Gain ResNet-50 → MobileNetV2 68.2% 72.4% 74.8% +2.4% WRN-40-2 → ShuffleNetV2 65.7% 71.2% 73.9% +2.7% VGG-19 → ResNet-18 71.8% 75.1% 77.2% +2.1%
Training Cost:
Joint training: 150 epochs (all methods simultaneously)
SMSKD: 3 × 50 epochs = 150 epochs (sequential stages)
Computational cost: Identical (same total epochs)
Memory overhead: Minimal (freeze reference model, only store student gradients)
Why This Matters
For Model Compression Pipelines: SMSKD provides a systematic way to integrate multiple KD methods without hyperparameter tuning nightmares. Each stage is independent—add or remove methods without redesigning the entire pipeline.
For On-Device Deployment: Better student accuracy at the same model size means either (a) smaller models at target accuracy, or (b) higher accuracy at deployment size. 74.8% vs. 72.4% = equivalent to 15-20% more parameters in a single-stage model.
For Transfer Learning: The sequential + reference architecture generalizes beyond distillation. Same pattern works for domain adaptation, task-incremental learning, multi-task learning—any scenario where you want to integrate heterogeneous training objectives without forgetting.
The conceptual breakthrough: Treating KD method integration as a continual learning problem—where each method is a “task” and reference models prevent catastrophic forgetting. This reframing unlocks techniques from continual learning (EWC, rehearsal, knowledge anchoring) for distillation.
Design Principles for Integration
The paper establishes general principles for sequential KD:
1. Stage Ordering Matters (Sometimes)
Theory: Order by knowledge specificity (coarse → fine)
Practice: Response (logits) → Feature (layers) → Relation (samples)
Empirical: +0.3-0.8% gain from optimal ordering
2. Reference Model Is Essential
Without anchoring: Catastrophic forgetting dominates (+3.6% loss)
With frozen reference: Knowledge accumulates across stages
Alternative: EWC penalty (similar results, more compute)
3. Adaptive Weighting Is Optional
Fixed weighting: 73.9% (simple, works well)
TCP-based adaptive: 74.8% (+0.9% gain)
Trade-off: Complexity vs. marginal improvement
4. Diminishing Returns After 3 Stages
1 stage: 70.1%
2 stages: 73.2% (+3.1% gain)
3 stages: 74.8% (+1.6% gain)
4 stages: 74.6% (−0.2%, overfitting to reference)
Recommended recipe:
Pick 2-3 complementary KD methods (response + feature, or response + relation)
Train sequentially with frozen reference models
Use fixed reference weight (λ_ref = 0.5) unless you have compute budget for TCP tuning
Stop at 3 stages (more stages = diminishing returns)
Open Questions and Future Work
Scalability: Results are on CIFAR-100 (small images, 100 classes). Does SMSKD work on ImageNet-1K or larger datasets? Early experiments suggest yes, but not published.
LLM Distillation: Can sequential integration apply to language models? Response-based (next-token logits) + Relation-based (attention patterns) + Feature-based (hidden states) could combine orthogonally.
Online Learning: SMSKD assumes fixed teacher and staged training. What about continual distillation where teacher updates over time?
Hardware Awareness: Can stage ordering optimize for inference hardware? Example: Stage 1 learns quantization-friendly features, Stage 2 learns hardware-specific optimizations.
But the core contribution is validated: Sequential integration with reference anchoring achieves 2-3% accuracy gains over joint training at zero additional cost—and provides a modular framework for combining heterogeneous KD methods.
Authors: Yinxi Tian, Changwu Huang, Ke Tang, Xin Yao
Primary Category: cs.LG (Machine Learning)

