← back

Implement a Sparse Mixture of Experts Layer

#125 · Deep Learning · Hard

⊣ Solve on deep-ml.com

Problem

Implement a Sparse Mixture of Experts (MoE) layer. The layer consists of multiple expert networks (each a small feed-forward network) and a gating network that routes each input to the top-k experts. Combine the expert outputs weighted by the gating scores.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import numpy as np

def softmax(x, axis=-1):
    x_shifted = x - np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x_shifted)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

def relu(x):
    return np.maximum(0, x)

class Expert:
    def __init__(self, input_dim, hidden_dim, output_dim):
        scale1 = np.sqrt(2.0 / input_dim)
        scale2 = np.sqrt(2.0 / hidden_dim)
        self.W1 = np.random.randn(input_dim, hidden_dim) * scale1
        self.b1 = np.zeros(hidden_dim)
        self.W2 = np.random.randn(hidden_dim, output_dim) * scale2
        self.b2 = np.zeros(output_dim)

    def forward(self, x):
        h = relu(x @ self.W1 + self.b1)
        return h @ self.W2 + self.b2

class SparseMoE:
    def __init__(self, input_dim, hidden_dim, output_dim, num_experts, top_k=2):
        self.experts = [Expert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
        self.gate_W = np.random.randn(input_dim, num_experts) * np.sqrt(2.0 / input_dim)
        self.gate_b = np.zeros(num_experts)
        self.top_k = top_k
        self.num_experts = num_experts

    def forward(self, x):
        # x shape: (batch_size, input_dim)
        gate_logits = x @ self.gate_W + self.gate_b  # (batch, num_experts)

        # Top-k gating
        batch_size = x.shape[0]
        output_dim = self.experts[0].W2.shape[1]
        output = np.zeros((batch_size, output_dim))

        for i in range(batch_size):
            top_k_idx = np.argsort(gate_logits[i])[-self.top_k:]
            top_k_logits = gate_logits[i, top_k_idx]
            top_k_weights = softmax(top_k_logits)

            for j, idx in enumerate(top_k_idx):
                expert_out = self.experts[idx].forward(x[i:i+1])
                output[i] += top_k_weights[j] * expert_out.squeeze(0)

        return output

Explanation

  1. Each expert is a two-layer feed-forward network with ReLU activation.
  2. The gating network projects the input to a score per expert using a linear layer.
  3. For each input, select the top-k experts by score, apply softmax to their scores, and compute a weighted sum of the expert outputs.
  4. Only k out of N experts are computed per input, giving computational sparsity.

Complexity

  • Time: O(B k (D H + H O)) where B = batch, k = top-k, D = input dim, H = hidden dim, O = output dim
  • Space: O(N (D H + H O)) for all expert parameters plus O(B O) for output