#452 · Machine Learning · Medium
⊣ Solve on deep-ml.comCompare the throughput of continuous batching versus static batching for an LLM inference server. Given a list of request sequence lengths, a maximum batch size, and per-token latency, compute the total tokens processed per second under each batching strategy.
def batching_throughput(
seq_lengths: list[int],
max_batch_size: int,
per_token_latency_ms: float
) -> dict:
n = len(seq_lengths)
per_token_s = per_token_latency_ms / 1000.0
# Static batching: pad every request in a batch to the longest in that batch
static_total_time = 0.0
total_tokens_static = sum(seq_lengths)
for i in range(0, n, max_batch_size):
batch = seq_lengths[i:i + max_batch_size]
max_len = max(batch)
static_total_time += max_len * per_token_s
static_throughput = total_tokens_static / static_total_time if static_total_time > 0 else 0.0
# Continuous batching: each request finishes at its own length, no padding waste
if n <= max_batch_size:
continuous_total_time = max(seq_lengths) * per_token_s
else:
sorted_lens = sorted(seq_lengths, reverse=True)
time_slots = [0.0] * max_batch_size
for length in sorted_lens:
min_idx = time_slots.index(min(time_slots))
time_slots[min_idx] += length * per_token_s
continuous_total_time = max(time_slots)
total_tokens_continuous = sum(seq_lengths)
continuous_throughput = total_tokens_continuous / continuous_total_time if continuous_total_time > 0 else 0.0
return {
"static_throughput": round(static_throughput, 2),
"continuous_throughput": round(continuous_throughput, 2)
}