Implement a Sparse Mixture of Experts (MoE) layer. The layer consists of multiple expert networks (each a small feed-forward network) and a gating network that routes each input to the top-k experts. Combine the expert outputs weighted by the gating scores.
import numpy as np
def softmax(x, axis=-1):
x_shifted = x - np.max(x, axis=axis, keepdims=True)
exp_x = np.exp(x_shifted)
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
def relu(x):
return np.maximum(0, x)
class Expert:
def __init__(self, input_dim, hidden_dim, output_dim):
scale1 = np.sqrt(2.0 / input_dim)
scale2 = np.sqrt(2.0 / hidden_dim)
self.W1 = np.random.randn(input_dim, hidden_dim) * scale1
self.b1 = np.zeros(hidden_dim)
self.W2 = np.random.randn(hidden_dim, output_dim) * scale2
self.b2 = np.zeros(output_dim)
def forward(self, x):
h = relu(x @ self.W1 + self.b1)
return h @ self.W2 + self.b2
class SparseMoE:
def __init__(self, input_dim, hidden_dim, output_dim, num_experts, top_k=2):
self.experts = [Expert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
self.gate_W = np.random.randn(input_dim, num_experts) * np.sqrt(2.0 / input_dim)
self.gate_b = np.zeros(num_experts)
self.top_k = top_k
self.num_experts = num_experts
def forward(self, x):
# x shape: (batch_size, input_dim)
gate_logits = x @ self.gate_W + self.gate_b # (batch, num_experts)
# Top-k gating
batch_size = x.shape[0]
output_dim = self.experts[0].W2.shape[1]
output = np.zeros((batch_size, output_dim))
for i in range(batch_size):
top_k_idx = np.argsort(gate_logits[i])[-self.top_k:]
top_k_logits = gate_logits[i, top_k_idx]
top_k_weights = softmax(top_k_logits)
for j, idx in enumerate(top_k_idx):
expert_out = self.experts[idx].forward(x[i:i+1])
output[i] += top_k_weights[j] * expert_out.squeeze(0)
return output