#425 · Inference · Medium
⊣ Solve on deep-ml.comSimulate a continuous batching (in-flight batching) system for LLM inference. Unlike static batching where all sequences must finish before new ones start, continuous batching allows new requests to join the batch as soon as a slot opens. Given a stream of requests with arrival times and output lengths, simulate the system and compute throughput.
def continuous_batching_sim(
requests: list[dict],
max_batch_size: int,
time_per_token: float
) -> dict:
# Each request: {"id": int, "arrival_time": float, "num_tokens": int}
requests = sorted(requests, key=lambda r: r["arrival_time"])
queue = list(requests)
active = [] # (request_id, tokens_remaining, start_time)
current_time = 0.0
completed = []
total_steps = 0
while queue or active:
# Add requests from queue that have arrived
while queue and len(active) < max_batch_size:
if queue[0]["arrival_time"] <= current_time:
r = queue.pop(0)
active.append({
"id": r["id"],
"tokens_remaining": r["num_tokens"],
"start_time": current_time,
"arrival_time": r["arrival_time"]
})
else:
break
if not active:
# Fast-forward to next arrival
if queue:
current_time = queue[0]["arrival_time"]
continue
break
# Process one decode step for all active requests
current_time += time_per_token
total_steps += 1
next_active = []
for req in active:
req["tokens_remaining"] -= 1
if req["tokens_remaining"] <= 0:
completed.append({
"id": req["id"],
"latency": current_time - req["arrival_time"],
"time_in_queue": req["start_time"] - req["arrival_time"]
})
else:
next_active.append(req)
active = next_active
# Fill freed slots with queued requests
while queue and len(active) < max_batch_size:
if queue[0]["arrival_time"] <= current_time:
r = queue.pop(0)
active.append({
"id": r["id"],
"tokens_remaining": r["num_tokens"],
"start_time": current_time,
"arrival_time": r["arrival_time"]
})
else:
break
total_tokens = sum(r["num_tokens"] for r in requests)
total_time = current_time if current_time > 0 else 1
avg_latency = sum(c["latency"] for c in completed) / len(completed) if completed else 0
return {
"completed": len(completed),
"total_tokens_generated": total_tokens,
"total_time": round(total_time, 4),
"throughput_tps": round(total_tokens / total_time, 4),
"avg_latency": round(avg_latency, 4),
"avg_queue_time": round(sum(c["time_in_queue"] for c in completed) / len(completed), 4) if completed else 0
}