← back

Knowledge Distillation Loss

#227 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement the knowledge distillation loss that combines the standard cross-entropy loss with a soft target loss (KL divergence between teacher and student softmax outputs at a given temperature).

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import math

def knowledge_distillation_loss(
    student_logits: list[float],
    teacher_logits: list[float],
    hard_labels: list[float],
    temperature: float,
    alpha: float,
) -> float:
    def softmax(logits, T=1.0):
        max_l = max(logits)
        exps = [math.exp((x - max_l) / T) for x in logits]
        total = sum(exps)
        return [e / total for e in exps]

    # Hard label loss: cross-entropy with true labels
    student_probs = softmax(student_logits, T=1.0)
    hard_loss = -sum(
        y * math.log(p + 1e-12) for y, p in zip(hard_labels, student_probs)
    )

    # Soft target loss: KL(teacher || student) at temperature T
    teacher_soft = softmax(teacher_logits, T=temperature)
    student_soft = softmax(student_logits, T=temperature)
    soft_loss = sum(
        t * math.log((t + 1e-12) / (s + 1e-12))
        for t, s in zip(teacher_soft, student_soft)
    )

    # Combined loss
    loss = alpha * (temperature ** 2) * soft_loss + (1 - alpha) * hard_loss
    return round(loss, 6)

Explanation

  1. Hard loss: standard cross-entropy between the student's predictions and ground truth labels.
  2. Soft loss: KL divergence between the teacher's and student's softened probability distributions at temperature T.
  3. Higher temperature produces softer distributions, revealing more inter-class information (dark knowledge).
  4. The T^2 factor compensates for the gradient magnitude reduction caused by temperature scaling.
  5. alpha balances the two loss components.

Complexity

  • Time: O(n) where n is the number of classes
  • Space: O(n) for the probability distributions