← back

Mixture of Experts Load Balancing Loss

#389 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement the auxiliary load-balancing loss for Mixture of Experts (MoE) models. This loss encourages each expert to receive a roughly equal share of tokens, preventing expert collapse where only a few experts are used.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import numpy as np

def moe_load_balancing_loss(gate_logits: np.ndarray, num_experts: int, top_k: int = 1) -> float:
    # gate_logits: shape (batch_size, num_experts)
    batch_size = gate_logits.shape[0]

    # Compute routing probabilities
    probs = np.exp(gate_logits - np.max(gate_logits, axis=1, keepdims=True))
    probs = probs / probs.sum(axis=1, keepdims=True)

    # Determine which experts are selected (top-k)
    top_k_indices = np.argsort(-probs, axis=1)[:, :top_k]

    # Fraction of tokens routed to each expert
    expert_mask = np.zeros_like(probs)
    for i in range(batch_size):
        expert_mask[i, top_k_indices[i]] = 1.0
    f = expert_mask.mean(axis=0)  # (num_experts,)

    # Mean routing probability for each expert
    P = probs.mean(axis=0)  # (num_experts,)

    # Load balancing loss: num_experts * sum(f_i * P_i)
    loss = num_experts * np.sum(f * P)
    return float(loss)

Explanation

  1. Convert gate logits to routing probabilities via softmax.
  2. For each token, determine the top-k experts selected.
  3. Compute f: the fraction of tokens dispatched to each expert (from hard routing decisions).
  4. Compute P: the mean routing probability for each expert (soft, differentiable).
  5. The loss is num_experts * dot(f, P). When load is perfectly balanced, each component equals 1/num_experts, giving a loss of 1. Imbalance pushes the loss higher.

Complexity

  • Time: O(B * E) where B is batch size and E is number of experts
  • Space: O(B * E) for the probability and mask matrices