Implement request batching for ML model inference. Collect incoming requests over a short time window or until a batch is full, then process them together for efficient GPU utilization.
import numpy as np
import time
from collections import deque
class InferenceBatcher:
def __init__(self, model_fn, max_batch_size: int = 32,
max_wait_ms: float = 50.0):
self.model_fn = model_fn
self.max_batch_size = max_batch_size
self.max_wait_ms = max_wait_ms
self.queue = deque()
def add_request(self, input_data: np.ndarray, request_id: int):
self.queue.append({"id": request_id, "input": input_data,
"timestamp": time.time()})
def should_process(self) -> bool:
if len(self.queue) == 0:
return False
if len(self.queue) >= self.max_batch_size:
return True
oldest = self.queue[0]["timestamp"]
elapsed_ms = (time.time() - oldest) * 1000
return elapsed_ms >= self.max_wait_ms
def process_batch(self) -> list[dict]:
batch_size = min(len(self.queue), self.max_batch_size)
batch_items = [self.queue.popleft() for _ in range(batch_size)]
inputs = np.stack([item["input"] for item in batch_items])
outputs = self.model_fn(inputs)
results = []
for i, item in enumerate(batch_items):
results.append({"id": item["id"], "output": outputs[i]})
return results
def batch_inference(model_fn, inputs: list[np.ndarray],
batch_size: int = 32) -> list[np.ndarray]:
results = []
for i in range(0, len(inputs), batch_size):
batch = np.stack(inputs[i:i + batch_size])
batch_output = model_fn(batch)
for j in range(len(batch_output)):
results.append(batch_output[j])
return results