← back

Implement the Noisy Top-K Gating Function

#124 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement the Noisy Top-K Gating function used in Mixture of Experts models. Given input logits and the number of top experts k, add tunable Gaussian noise to the logits before selecting the top-k experts, then apply softmax over the selected experts.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import numpy as np

def noisy_top_k_gating(logits: np.ndarray, k: int, noise_stddev: float = 1.0, training: bool = True) -> np.ndarray:
    n_experts = logits.shape[-1]

    if training and noise_stddev > 0:
        noise = np.random.normal(0, noise_stddev, logits.shape)
        noisy_logits = logits + noise
    else:
        noisy_logits = logits.copy()

    # Find top-k indices per sample
    if noisy_logits.ndim == 1:
        noisy_logits = noisy_logits.reshape(1, -1)
        squeeze = True
    else:
        squeeze = False

    batch_size = noisy_logits.shape[0]
    gates = np.full_like(noisy_logits, -np.inf)

    for i in range(batch_size):
        top_k_idx = np.argsort(noisy_logits[i])[-k:]
        gates[i, top_k_idx] = noisy_logits[i, top_k_idx]

    # Softmax over the non-masked values
    gates_shifted = gates - np.max(gates, axis=-1, keepdims=True)
    exp_gates = np.exp(gates_shifted)
    result = exp_gates / np.sum(exp_gates, axis=-1, keepdims=True)

    if squeeze:
        result = result.squeeze(0)
    return result

Explanation

  1. During training, add Gaussian noise to the gating logits to encourage exploration across experts and load balancing.
  2. Select the top-k experts by setting all non-top-k logits to negative infinity.
  3. Apply softmax over the remaining logits so only the top-k experts receive non-zero gating weights.
  4. The result is a sparse gating vector where exactly k entries are non-zero and sum to 1.

Complexity

  • Time: O(B E log E) where B = batch size, E = number of experts (dominated by sorting)
  • Space: O(B * E) for the gates matrix