#255 · Machine Learning · Medium
⊣ Solve on deep-ml.comImplement Focal Loss for imbalanced classification. Focal loss down-weights easy examples and focuses training on hard, misclassified ones. Given predicted probabilities and true labels, compute the focal loss.
Focal loss extends binary cross-entropy by adding a modulating factor (1 - p_t)^gamma where p_t is the predicted probability for the true class. An optional alpha parameter handles class imbalance weighting.
import math
def focal_loss(
y_true: list[int],
y_pred: list[float],
gamma: float = 2.0,
alpha: float | None = None,
epsilon: float = 1e-7,
) -> float:
n = len(y_true)
total_loss = 0.0
for i in range(n):
p = max(min(y_pred[i], 1.0 - epsilon), epsilon)
if y_true[i] == 1:
p_t = p
else:
p_t = 1.0 - p
modulating = (1.0 - p_t) ** gamma
ce = -math.log(p_t)
loss = modulating * ce
if alpha is not None:
alpha_t = alpha if y_true[i] == 1 else (1.0 - alpha)
loss *= alpha_t
total_loss += loss
return round(total_loss / n, 6)p_t — the model's predicted probability for the true class.(1 - p_t)^gamma is small when the model is confident and correct (easy example) and large when it is wrong (hard example).-log(p_t).