Calculate model inference statistics for monitoring. Given a list of inference latencies, compute key metrics: mean, median, p95, p99, throughput, and error rate.
def inference_statistics(
latencies_ms: list[float],
errors: int = 0,
total_requests: int = 0,
) -> dict:
"""
latencies_ms: list of successful inference latencies in milliseconds
errors: number of failed requests
total_requests: total requests (if 0, inferred from latencies + errors)
"""
if total_requests == 0:
total_requests = len(latencies_ms) + errors
n = len(latencies_ms)
if n == 0:
return {
"mean_ms": 0.0,
"median_ms": 0.0,
"p95_ms": 0.0,
"p99_ms": 0.0,
"throughput_rps": 0.0,
"error_rate": round(errors / max(total_requests, 1), 6),
}
sorted_lat = sorted(latencies_ms)
# Mean
mean_ms = sum(sorted_lat) / n
# Median
if n % 2 == 1:
median_ms = sorted_lat[n // 2]
else:
median_ms = (sorted_lat[n // 2 - 1] + sorted_lat[n // 2]) / 2
# Percentiles (nearest rank method)
def percentile(data, pct):
idx = int(pct / 100.0 * len(data)) - 1
idx = max(0, min(idx, len(data) - 1))
return data[idx]
p95 = percentile(sorted_lat, 95)
p99 = percentile(sorted_lat, 99)
# Throughput: requests per second
total_time_s = sum(sorted_lat) / 1000.0
throughput = n / total_time_s if total_time_s > 0 else 0.0
error_rate = errors / total_requests if total_requests > 0 else 0.0
return {
"mean_ms": round(mean_ms, 4),
"median_ms": round(median_ms, 4),
"p95_ms": round(p95, 4),
"p99_ms": round(p99, 4),
"throughput_rps": round(throughput, 4),
"error_rate": round(error_rate, 6),
}