Implement the auxiliary load-balancing loss for Mixture of Experts (MoE) models. This loss encourages each expert to receive a roughly equal share of tokens, preventing expert collapse where only a few experts are used.
import numpy as np
def moe_load_balancing_loss(gate_logits: np.ndarray, num_experts: int, top_k: int = 1) -> float:
# gate_logits: shape (batch_size, num_experts)
batch_size = gate_logits.shape[0]
# Compute routing probabilities
probs = np.exp(gate_logits - np.max(gate_logits, axis=1, keepdims=True))
probs = probs / probs.sum(axis=1, keepdims=True)
# Determine which experts are selected (top-k)
top_k_indices = np.argsort(-probs, axis=1)[:, :top_k]
# Fraction of tokens routed to each expert
expert_mask = np.zeros_like(probs)
for i in range(batch_size):
expert_mask[i, top_k_indices[i]] = 1.0
f = expert_mask.mean(axis=0) # (num_experts,)
# Mean routing probability for each expert
P = probs.mean(axis=0) # (num_experts,)
# Load balancing loss: num_experts * sum(f_i * P_i)
loss = num_experts * np.sum(f * P)
return float(loss)f: the fraction of tokens dispatched to each expert (from hard routing decisions).P: the mean routing probability for each expert (soft, differentiable).num_experts * dot(f, P). When load is perfectly balanced, each component equals 1/num_experts, giving a loss of 1. Imbalance pushes the loss higher.