← back

Implement Precision-Recall Curve

#278 · Machine Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement the Precision-Recall Curve calculation. Given true binary labels and predicted probabilities, compute precision and recall at various thresholds.

Solution

Sort predictions descending, sweep thresholds, and compute precision and recall at each step.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def precision_recall_curve(
    y_true: list[int],
    y_scores: list[float],
) -> dict:
    n = len(y_true)
    total_pos = sum(y_true)

    if total_pos == 0:
        return {"precision": [0.0], "recall": [0.0], "thresholds": []}

    # Sort by score descending
    paired = sorted(zip(y_scores, y_true), key=lambda x: -x[0])

    precisions = []
    recalls = []
    thresholds = []

    tp = 0
    fp = 0
    prev_score = None

    for i, (score, label) in enumerate(paired):
        if label == 1:
            tp += 1
        else:
            fp += 1

        if i + 1 < n and paired[i + 1][0] == score:
            continue  # process ties together

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / total_pos

        precisions.append(round(precision, 6))
        recalls.append(round(recall, 6))
        thresholds.append(round(score, 6))

    # Add the start point (recall=0, precision=1)
    precisions.insert(0, 1.0)
    recalls.insert(0, 0.0)

    return {
        "precision": precisions,
        "recall": recalls,
        "thresholds": thresholds,
    }

Explanation

  1. Sort all predictions by score in descending order.
  2. Sweep the threshold from high to low. At each step, one more sample is classified as positive.
  3. Precision = TP / (TP + FP) — fraction of positive predictions that are correct.
  4. Recall = TP / total positives — fraction of actual positives that are detected.
  5. The curve starts at (recall=0, precision=1) and generally shows a precision-recall tradeoff.
  6. The area under the PR curve (Average Precision) is especially useful for imbalanced datasets.

Complexity

  • Time: O(n log n) for sorting
  • Space: O(n) for the curve points