← back

Implement Request Batching for Inference

#297 · MLOps · Medium

⊣ Solve on deep-ml.com

Problem

Implement request batching for ML model inference. Collect incoming requests over a short time window or until a batch is full, then process them together for efficient GPU utilization.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import numpy as np
import time
from collections import deque

class InferenceBatcher:
    def __init__(self, model_fn, max_batch_size: int = 32,
                 max_wait_ms: float = 50.0):
        self.model_fn = model_fn
        self.max_batch_size = max_batch_size
        self.max_wait_ms = max_wait_ms
        self.queue = deque()

    def add_request(self, input_data: np.ndarray, request_id: int):
        self.queue.append({"id": request_id, "input": input_data,
                           "timestamp": time.time()})

    def should_process(self) -> bool:
        if len(self.queue) == 0:
            return False
        if len(self.queue) >= self.max_batch_size:
            return True
        oldest = self.queue[0]["timestamp"]
        elapsed_ms = (time.time() - oldest) * 1000
        return elapsed_ms >= self.max_wait_ms

    def process_batch(self) -> list[dict]:
        batch_size = min(len(self.queue), self.max_batch_size)
        batch_items = [self.queue.popleft() for _ in range(batch_size)]

        inputs = np.stack([item["input"] for item in batch_items])
        outputs = self.model_fn(inputs)

        results = []
        for i, item in enumerate(batch_items):
            results.append({"id": item["id"], "output": outputs[i]})
        return results

def batch_inference(model_fn, inputs: list[np.ndarray],
                    batch_size: int = 32) -> list[np.ndarray]:
    results = []
    for i in range(0, len(inputs), batch_size):
        batch = np.stack(inputs[i:i + batch_size])
        batch_output = model_fn(batch)
        for j in range(len(batch_output)):
            results.append(batch_output[j])
    return results

Explanation

  1. InferenceBatcher collects requests into a queue and processes them when the batch is full or a time limit is reached.
  2. should_process returns True if the batch is full or the oldest request has waited longer than max_wait_ms.
  3. process_batch stacks inputs into a single array, runs the model once on the full batch, and distributes outputs back to individual requests.
  4. batch_inference is a simpler synchronous utility that processes a list of inputs in fixed-size batches.
  5. Batching amortizes per-inference overhead and improves GPU utilization by processing multiple inputs in parallel.

Complexity

  • Time: O(n * T / B) where n is total requests, T is per-batch inference time, B is batch size
  • Space: O(B * d) where d is the input dimensionality