← back

Chunked Prefill Scheduling Alongside Decode

#441 · MLOps · Medium

⊣ Solve on deep-ml.com

Problem

Implement a chunked prefill scheduler that interleaves prefill chunks with ongoing decode steps. Given a set of pending prefill requests (each with a prompt length) and active decode requests, a per-step token budget, and a prefill chunk size, simulate scheduling across multiple steps. Return the step-by-step schedule showing which prefill chunks and decode tokens were processed each step.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def chunked_prefill_schedule(
    prefill_requests: list[dict],
    num_active_decodes: int,
    token_budget: int,
    chunk_size: int
) -> list[dict]:
    prefill_queue = []
    for req in prefill_requests:
        prefill_queue.append({
            "id": req["id"],
            "remaining": req["prompt_length"]
        })

    schedule = []
    step = 0

    while prefill_queue or num_active_decodes > 0:
        step_info = {"step": step, "prefill_chunks": [], "decode_tokens": 0}
        budget_left = token_budget

        decode_tokens = min(num_active_decodes, budget_left)
        step_info["decode_tokens"] = decode_tokens
        budget_left -= decode_tokens

        completed_prefills = []
        i = 0
        while i < len(prefill_queue) and budget_left >= chunk_size:
            req = prefill_queue[i]
            actual_chunk = min(chunk_size, req["remaining"])
            if actual_chunk > budget_left:
                break
            step_info["prefill_chunks"].append({
                "id": req["id"],
                "tokens_processed": actual_chunk
            })
            req["remaining"] -= actual_chunk
            budget_left -= actual_chunk
            if req["remaining"] <= 0:
                completed_prefills.append(i)
                num_active_decodes += 1
            i += 1

        for idx in reversed(completed_prefills):
            prefill_queue.pop(idx)

        schedule.append(step_info)
        step += 1

        if step > 1000:
            break

    return schedule

Explanation

  1. Each step has a fixed token budget shared between decode tokens (one per active decode request) and prefill chunks.
  2. Decode tokens are allocated first since they are latency-sensitive. Each active decode request needs exactly 1 token per step.
  3. Remaining budget is filled with prefill chunks from the queue, each of size chunk_size (or smaller for the tail of a prompt).
  4. When a prefill request is fully processed, it transitions to an active decode, incrementing the decode count.
  5. This interleaving prevents long prefills from stalling ongoing generation, reducing time-to-first-token for new requests while maintaining decode throughput.

Complexity

  • Time: O(S * (D + P)) where S is the number of steps, D is decodes, and P is prefill queue length per step
  • Space: O(n) where n is the number of prefill requests