← back

Feature Drift Detection using Population Stability Index

#253 · MLOps · Medium

⊣ Solve on deep-ml.com

Problem

Detect feature drift between a reference (training) distribution and a current (production) distribution using the Population Stability Index (PSI). Given two arrays of values, bucket them and compute PSI to quantify distribution shift.

Solution

Bin both distributions into equal-width buckets, compute the proportion in each bucket, and apply the PSI formula: sum of (p_current - p_ref) * ln(p_current / p_ref) across all bins.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import math

def compute_psi(
    reference: list[float],
    current: list[float],
    n_bins: int = 10,
    epsilon: float = 1e-4,
) -> dict:
    combined_min = min(min(reference), min(current))
    combined_max = max(max(reference), max(current))

    bin_edges = [
        combined_min + i * (combined_max - combined_min) / n_bins
        for i in range(n_bins + 1)
    ]
    bin_edges[-1] += 1e-9  # include the max value

    def bucket_proportions(values):
        counts = [0] * n_bins
        for v in values:
            for i in range(n_bins):
                if bin_edges[i] <= v < bin_edges[i + 1]:
                    counts[i] += 1
                    break
        n = len(values)
        return [(c / n) if c > 0 else epsilon for c in counts]

    ref_props = bucket_proportions(reference)
    cur_props = bucket_proportions(current)

    psi = 0.0
    bucket_psis = []
    for p_cur, p_ref in zip(cur_props, ref_props):
        p_cur = max(p_cur, epsilon)
        p_ref = max(p_ref, epsilon)
        bucket_psi = (p_cur - p_ref) * math.log(p_cur / p_ref)
        bucket_psis.append(round(bucket_psi, 6))
        psi += bucket_psi

    psi = round(psi, 6)
    if psi < 0.1:
        drift_level = "no_drift"
    elif psi < 0.2:
        drift_level = "moderate_drift"
    else:
        drift_level = "significant_drift"

    return {"psi": psi, "drift_level": drift_level, "bucket_psis": bucket_psis}

Explanation

  1. Compute common bin edges spanning both distributions.
  2. Calculate the proportion of values in each bin for both reference and current.
  3. Replace zero proportions with a small epsilon to avoid division by zero.
  4. PSI for each bin: (p_cur - p_ref) * ln(p_cur / p_ref).
  5. Sum across bins. PSI < 0.1 indicates no significant drift; 0.1-0.2 is moderate; > 0.2 is significant.

Complexity

  • Time: O(n * b) where n is the number of data points and b is the number of bins
  • Space: O(b) for bucket counts