Implement a chunked prefill scheduler that interleaves prefill chunks with ongoing decode steps. Given a set of pending prefill requests (each with a prompt length) and active decode requests, a per-step token budget, and a prefill chunk size, simulate scheduling across multiple steps. Return the step-by-step schedule showing which prefill chunks and decode tokens were processed each step.
def chunked_prefill_schedule(
prefill_requests: list[dict],
num_active_decodes: int,
token_budget: int,
chunk_size: int
) -> list[dict]:
prefill_queue = []
for req in prefill_requests:
prefill_queue.append({
"id": req["id"],
"remaining": req["prompt_length"]
})
schedule = []
step = 0
while prefill_queue or num_active_decodes > 0:
step_info = {"step": step, "prefill_chunks": [], "decode_tokens": 0}
budget_left = token_budget
decode_tokens = min(num_active_decodes, budget_left)
step_info["decode_tokens"] = decode_tokens
budget_left -= decode_tokens
completed_prefills = []
i = 0
while i < len(prefill_queue) and budget_left >= chunk_size:
req = prefill_queue[i]
actual_chunk = min(chunk_size, req["remaining"])
if actual_chunk > budget_left:
break
step_info["prefill_chunks"].append({
"id": req["id"],
"tokens_processed": actual_chunk
})
req["remaining"] -= actual_chunk
budget_left -= actual_chunk
if req["remaining"] <= 0:
completed_prefills.append(i)
num_active_decodes += 1
i += 1
for idx in reversed(completed_prefills):
prefill_queue.pop(idx)
schedule.append(step_info)
step += 1
if step > 1000:
break
return schedulechunk_size (or smaller for the tail of a prompt).