← back

Implement Focal Loss for Imbalanced Classification

#255 · Machine Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement Focal Loss for imbalanced classification. Focal loss down-weights easy examples and focuses training on hard, misclassified ones. Given predicted probabilities and true labels, compute the focal loss.

Solution

Focal loss extends binary cross-entropy by adding a modulating factor (1 - p_t)^gamma where p_t is the predicted probability for the true class. An optional alpha parameter handles class imbalance weighting.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import math

def focal_loss(
    y_true: list[int],
    y_pred: list[float],
    gamma: float = 2.0,
    alpha: float | None = None,
    epsilon: float = 1e-7,
) -> float:
    n = len(y_true)
    total_loss = 0.0

    for i in range(n):
        p = max(min(y_pred[i], 1.0 - epsilon), epsilon)

        if y_true[i] == 1:
            p_t = p
        else:
            p_t = 1.0 - p

        modulating = (1.0 - p_t) ** gamma
        ce = -math.log(p_t)

        loss = modulating * ce

        if alpha is not None:
            alpha_t = alpha if y_true[i] == 1 else (1.0 - alpha)
            loss *= alpha_t

        total_loss += loss

    return round(total_loss / n, 6)

Explanation

  1. For each sample, compute p_t — the model's predicted probability for the true class.
  2. The modulating factor (1 - p_t)^gamma is small when the model is confident and correct (easy example) and large when it is wrong (hard example).
  3. Multiply this factor by the standard cross-entropy loss -log(p_t).
  4. Optionally apply an alpha weighting: alpha for positives, (1 - alpha) for negatives.
  5. Average over all samples.

Complexity

  • Time: O(n) where n is the number of samples
  • Space: O(1) — constant extra storage