← back

MoE with Shared Expert Forward Pass

#409 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement a Mixture of Experts (MoE) layer with a shared expert that always participates. In addition to routing tokens to top-k selected experts, one expert processes all tokens unconditionally, improving model quality and training stability.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np

class MoEWithSharedExpert:
    def __init__(self, d_model: int, d_ff: int, num_routed_experts: int, top_k: int = 2):
        self.num_routed_experts = num_routed_experts
        self.top_k = top_k

        # Shared expert parameters
        self.shared_W1 = np.random.randn(d_model, d_ff) * 0.02
        self.shared_W2 = np.random.randn(d_ff, d_model) * 0.02

        # Routed expert parameters
        self.expert_W1 = [np.random.randn(d_model, d_ff) * 0.02 for _ in range(num_routed_experts)]
        self.expert_W2 = [np.random.randn(d_ff, d_model) * 0.02 for _ in range(num_routed_experts)]

        # Router
        self.gate = np.random.randn(d_model, num_routed_experts) * 0.02

    def forward(self, x: np.ndarray) -> np.ndarray:
        # x: (batch, seq_len, d_model)
        batch, seq_len, d_model = x.shape
        x_flat = x.reshape(-1, d_model)  # (N, d_model)
        N = x_flat.shape[0]

        # Shared expert output (always active)
        h_shared = np.maximum(0, x_flat @ self.shared_W1) @ self.shared_W2

        # Router logits and probabilities
        logits = x_flat @ self.gate
        probs = np.exp(logits - np.max(logits, axis=1, keepdims=True))
        probs = probs / probs.sum(axis=1, keepdims=True)

        # Select top-k experts
        top_k_idx = np.argsort(-probs, axis=1)[:, :self.top_k]
        top_k_probs = np.take_along_axis(probs, top_k_idx, axis=1)
        # Renormalize
        top_k_probs = top_k_probs / top_k_probs.sum(axis=1, keepdims=True)

        # Routed expert outputs
        h_routed = np.zeros_like(x_flat)
        for i in range(N):
            for j in range(self.top_k):
                eidx = top_k_idx[i, j]
                h = np.maximum(0, x_flat[i] @ self.expert_W1[eidx]) @ self.expert_W2[eidx]
                h_routed[i] += top_k_probs[i, j] * h

        output = h_shared + h_routed
        return output.reshape(batch, seq_len, d_model)

Explanation

  1. Shared expert: always processes every token, providing a stable baseline representation.
  2. Routed experts: the router (gating network) assigns each token to its top-k experts based on softmax routing probabilities.
  3. Combining: the output is the sum of the shared expert output and the weighted sum of routed expert outputs.
  4. The shared expert ensures all tokens get a minimum level of processing even if routing is imperfect, improving stability (used in DeepSeek-MoE).

Complexity

  • Time: O(N (d_model d_ff + k d_model d_ff)) where k is top_k
  • Space: O(E d_model d_ff) for expert parameters