← back

Implement Early Stopping Based on Validation Loss

#135 · Machine Learning · Easy

⊣ Solve on deep-ml.com

Problem

Implement early stopping for model training. Monitor the validation loss over epochs and stop training when the validation loss has not improved for a specified number of consecutive epochs (patience). Optionally restore the best model weights.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class EarlyStopping:
    def __init__(self, patience: int = 5, min_delta: float = 0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = float('inf')
        self.counter = 0
        self.best_weights = None
        self.should_stop = False

    def step(self, val_loss: float, model_weights=None) -> bool:
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            if model_weights is not None:
                import copy
                self.best_weights = copy.deepcopy(model_weights)
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
                return True
        return False

def train_with_early_stopping(train_losses: list[float], val_losses: list[float], patience: int = 5) -> int:
    early_stop = EarlyStopping(patience=patience)
    for epoch, val_loss in enumerate(val_losses):
        if early_stop.step(val_loss):
            return epoch  # stopped at this epoch
    return len(val_losses)  # trained for all epochs

Explanation

  1. Track the best validation loss seen so far and a counter for epochs without improvement.
  2. Each epoch, compare the current validation loss against the best minus a min_delta threshold.
  3. If improved, reset the counter and optionally save model weights.
  4. If not improved, increment the counter. When the counter reaches patience, signal to stop training.
  5. The functional version returns the epoch at which training would stop.

Complexity

  • Time: O(1) per epoch check
  • Space: O(W) where W is the size of model weights (if saving best weights)