Implement the Noisy Top-K Gating function used in Mixture of Experts models. Given input logits and the number of top experts k, add tunable Gaussian noise to the logits before selecting the top-k experts, then apply softmax over the selected experts.
import numpy as np
def noisy_top_k_gating(logits: np.ndarray, k: int, noise_stddev: float = 1.0, training: bool = True) -> np.ndarray:
n_experts = logits.shape[-1]
if training and noise_stddev > 0:
noise = np.random.normal(0, noise_stddev, logits.shape)
noisy_logits = logits + noise
else:
noisy_logits = logits.copy()
# Find top-k indices per sample
if noisy_logits.ndim == 1:
noisy_logits = noisy_logits.reshape(1, -1)
squeeze = True
else:
squeeze = False
batch_size = noisy_logits.shape[0]
gates = np.full_like(noisy_logits, -np.inf)
for i in range(batch_size):
top_k_idx = np.argsort(noisy_logits[i])[-k:]
gates[i, top_k_idx] = noisy_logits[i, top_k_idx]
# Softmax over the non-masked values
gates_shifted = gates - np.max(gates, axis=-1, keepdims=True)
exp_gates = np.exp(gates_shifted)
result = exp_gates / np.sum(exp_gates, axis=-1, keepdims=True)
if squeeze:
result = result.squeeze(0)
return result