← back

Sparse MoE Top-K Routing

#229 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement the Sparse Mixture of Experts (MoE) top-k routing mechanism. Given input tokens and a set of expert networks, route each token to the top-k experts based on a gating network, then combine their outputs.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def sparse_moe_top_k(
    inputs: list[list[float]],
    gate_weights: list[list[float]],
    expert_weights: list[list[list[float]]],
    k: int,
) -> list[list[float]]:
    import math

    num_tokens = len(inputs)
    num_experts = len(gate_weights[0]) if gate_weights else 0
    dim = len(inputs[0]) if inputs else 0

    def softmax(vals):
        m = max(vals)
        exps = [math.exp(v - m) for v in vals]
        s = sum(exps)
        return [e / s for e in exps]

    def matvec(mat, vec):
        return [sum(mat[i][j] * vec[j] for j in range(len(vec))) for i in range(len(mat))]

    outputs = []
    for t in range(num_tokens):
        # Compute gate scores for this token
        gate_scores = gate_weights[t]

        # Select top-k experts
        indexed = sorted(enumerate(gate_scores), key=lambda x: -x[1])
        top_k = indexed[:k]

        # Softmax over selected expert scores for routing weights
        top_vals = [v for _, v in top_k]
        weights = softmax(top_vals)

        # Compute weighted combination of expert outputs
        out = [0.0] * dim
        for idx, (expert_id, _) in enumerate(top_k):
            expert_out = matvec(expert_weights[expert_id], inputs[t])
            for d in range(dim):
                out[d] += weights[idx] * expert_out[d]

        outputs.append([round(v, 4) for v in out])

    return outputs

Explanation

  1. For each token, compute gate scores across all experts.
  2. Select the top-k experts with the highest gate scores.
  3. Apply softmax to the selected scores to get routing weights that sum to 1.
  4. Each selected expert processes the token independently.
  5. The final output is the weighted sum of the top-k expert outputs.

Complexity

  • Time: O(T (E log E + k * d^2)) where T is tokens, E is experts, d is dimension
  • Space: O(T * d) for outputs