Implement the knowledge distillation loss that combines the standard cross-entropy loss with a soft target loss (KL divergence between teacher and student softmax outputs at a given temperature).
import math
def knowledge_distillation_loss(
student_logits: list[float],
teacher_logits: list[float],
hard_labels: list[float],
temperature: float,
alpha: float,
) -> float:
def softmax(logits, T=1.0):
max_l = max(logits)
exps = [math.exp((x - max_l) / T) for x in logits]
total = sum(exps)
return [e / total for e in exps]
# Hard label loss: cross-entropy with true labels
student_probs = softmax(student_logits, T=1.0)
hard_loss = -sum(
y * math.log(p + 1e-12) for y, p in zip(hard_labels, student_probs)
)
# Soft target loss: KL(teacher || student) at temperature T
teacher_soft = softmax(teacher_logits, T=temperature)
student_soft = softmax(student_logits, T=temperature)
soft_loss = sum(
t * math.log((t + 1e-12) / (s + 1e-12))
for t, s in zip(teacher_soft, student_soft)
)
# Combined loss
loss = alpha * (temperature ** 2) * soft_loss + (1 - alpha) * hard_loss
return round(loss, 6)T^2 factor compensates for the gradient magnitude reduction caused by temperature scaling.alpha balances the two loss components.