← back

Implement Prediction Distribution Monitoring

#295 · MLOps · Medium

⊣ Solve on deep-ml.com

Problem

Implement prediction distribution monitoring for a deployed ML model. Track the distribution of model predictions over time and detect drift by comparing against a reference distribution using statistical tests.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import numpy as np

def compute_psi(reference: np.ndarray, current: np.ndarray, bins: int = 10) -> float:
    """Population Stability Index for distribution monitoring."""
    breakpoints = np.linspace(np.min(reference), np.max(reference), bins + 1)
    breakpoints[0] = -np.inf
    breakpoints[-1] = np.inf

    ref_counts = np.histogram(reference, bins=breakpoints)[0]
    cur_counts = np.histogram(current, bins=breakpoints)[0]

    ref_pct = ref_counts / len(reference)
    cur_pct = cur_counts / len(current)

    # Avoid division by zero
    ref_pct = np.clip(ref_pct, 1e-4, None)
    cur_pct = np.clip(cur_pct, 1e-4, None)

    psi = np.sum((cur_pct - ref_pct) * np.log(cur_pct / ref_pct))
    return float(psi)

def monitor_predictions(reference: np.ndarray, current: np.ndarray,
                        psi_threshold: float = 0.2) -> dict:
    psi = compute_psi(reference, current)
    drift_detected = psi > psi_threshold

    ref_mean = float(np.mean(reference))
    cur_mean = float(np.mean(current))
    ref_std = float(np.std(reference))
    cur_std = float(np.std(current))

    return {
        "psi": round(psi, 4),
        "drift_detected": drift_detected,
        "reference_stats": {"mean": ref_mean, "std": ref_std},
        "current_stats": {"mean": cur_mean, "std": cur_std},
    }

Explanation

  1. PSI (Population Stability Index) compares two distributions by binning values and computing the symmetric KL-like divergence.
  2. Bin both the reference and current prediction distributions using shared breakpoints.
  3. Compute PSI = sum((cur% - ref%) * ln(cur% / ref%)). A PSI < 0.1 indicates no significant drift, 0.1-0.2 is moderate, and > 0.2 suggests significant drift.
  4. Also track basic statistics (mean, std) for quick monitoring dashboards.

Complexity

  • Time: O(n + m) where n and m are the sizes of reference and current distributions
  • Space: O(bins) for the histograms