← back

Expert Parallelism Token Routing and Communication Cost

#439 · Machine Learning · Medium

⊣ Solve on deep-ml.com

Problem

Simulate expert parallelism token routing in a Mixture-of-Experts (MoE) model. Given the number of tokens, number of experts, top-k routing assignments per token, the expert-to-GPU placement, and inter-GPU bandwidth, compute how many tokens must be communicated between GPUs and the estimated communication time.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def expert_parallel_routing_cost(
    token_assignments: list[list[int]],
    expert_to_gpu: dict[int, int],
    token_to_gpu: list[int],
    token_hidden_bytes: int,
    bandwidth_gbps: float
) -> dict:
    local_tokens = 0
    remote_tokens = 0

    for t_idx, experts in enumerate(token_assignments):
        src_gpu = token_to_gpu[t_idx]
        for expert_id in experts:
            dst_gpu = expert_to_gpu[expert_id]
            if dst_gpu == src_gpu:
                local_tokens += 1
            else:
                remote_tokens += 1

    total_remote_bytes = remote_tokens * token_hidden_bytes
    bandwidth_bytes_per_sec = bandwidth_gbps * 1e9 / 8
    comm_time_ms = (total_remote_bytes / bandwidth_bytes_per_sec) * 1000

    return {
        "local_tokens": local_tokens,
        "remote_tokens": remote_tokens,
        "total_remote_bytes": total_remote_bytes,
        "comm_time_ms": round(comm_time_ms, 4)
    }

Explanation

  1. Each token is routed to its top-k selected experts. For each (token, expert) pair, check whether the expert's GPU matches the token's source GPU.
  2. If they differ, the token's hidden state must be sent across GPUs, counting as a remote transfer.
  3. Multiply the number of remote token transfers by the hidden-state size in bytes to get total communication volume.
  4. Divide by the link bandwidth (converted to bytes/sec) to estimate communication time.
  5. This captures the all-to-all communication cost inherent in expert parallelism.

Complexity

  • Time: O(n * k) where n is the number of tokens and k is the top-k
  • Space: O(1) beyond the input