Implement a Mixture of Experts (MoE) layer with a shared expert that always participates. In addition to routing tokens to top-k selected experts, one expert processes all tokens unconditionally, improving model quality and training stability.
import numpy as np
class MoEWithSharedExpert:
def __init__(self, d_model: int, d_ff: int, num_routed_experts: int, top_k: int = 2):
self.num_routed_experts = num_routed_experts
self.top_k = top_k
# Shared expert parameters
self.shared_W1 = np.random.randn(d_model, d_ff) * 0.02
self.shared_W2 = np.random.randn(d_ff, d_model) * 0.02
# Routed expert parameters
self.expert_W1 = [np.random.randn(d_model, d_ff) * 0.02 for _ in range(num_routed_experts)]
self.expert_W2 = [np.random.randn(d_ff, d_model) * 0.02 for _ in range(num_routed_experts)]
# Router
self.gate = np.random.randn(d_model, num_routed_experts) * 0.02
def forward(self, x: np.ndarray) -> np.ndarray:
# x: (batch, seq_len, d_model)
batch, seq_len, d_model = x.shape
x_flat = x.reshape(-1, d_model) # (N, d_model)
N = x_flat.shape[0]
# Shared expert output (always active)
h_shared = np.maximum(0, x_flat @ self.shared_W1) @ self.shared_W2
# Router logits and probabilities
logits = x_flat @ self.gate
probs = np.exp(logits - np.max(logits, axis=1, keepdims=True))
probs = probs / probs.sum(axis=1, keepdims=True)
# Select top-k experts
top_k_idx = np.argsort(-probs, axis=1)[:, :self.top_k]
top_k_probs = np.take_along_axis(probs, top_k_idx, axis=1)
# Renormalize
top_k_probs = top_k_probs / top_k_probs.sum(axis=1, keepdims=True)
# Routed expert outputs
h_routed = np.zeros_like(x_flat)
for i in range(N):
for j in range(self.top_k):
eidx = top_k_idx[i, j]
h = np.maximum(0, x_flat[i] @ self.expert_W1[eidx]) @ self.expert_W2[eidx]
h_routed[i] += top_k_probs[i, j] * h
output = h_shared + h_routed
return output.reshape(batch, seq_len, d_model)