← back

Early Stopping Based on Validation Loss Plateau

#199 · Machine Learning · Easy

⊣ Solve on deep-ml.com

Problem

Implement Early Stopping based on validation loss plateau detection. Monitor the validation loss across epochs and stop training if the loss has not improved by at least a minimum delta for a specified number of consecutive epochs (patience).

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class EarlyStopping:
    def __init__(self, patience: int = 5, min_delta: float = 0.0,
                 restore_best: bool = True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best = restore_best
        self.best_loss = float('inf')
        self.counter = 0
        self.best_epoch = 0
        self.best_weights = None
        self.stopped = False

    def __call__(self, val_loss: float, epoch: int = 0,
                 weights=None) -> bool:
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.best_epoch = epoch
            if weights is not None:
                self.best_weights = [w.copy() for w in weights]
            return False  # do not stop
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.stopped = True
                return True  # stop training
            return False

def train_with_early_stopping(train_fn, val_fn, epochs: int,
                               patience: int = 5) -> dict:
    stopper = EarlyStopping(patience=patience)
    history = {"train_loss": [], "val_loss": []}

    for epoch in range(epochs):
        train_loss = train_fn(epoch)
        val_loss = val_fn(epoch)
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)

        if stopper(val_loss, epoch):
            break

    history["stopped_epoch"] = epoch
    history["best_epoch"] = stopper.best_epoch
    return history

Explanation

  1. Track the best validation loss seen so far and a patience counter.
  2. Each epoch, if the validation loss improves by at least min_delta, reset the counter and update the best loss.
  3. If the loss does not improve, increment the counter.
  4. When the counter reaches patience, signal to stop training.
  5. Optionally store the model weights from the best epoch so they can be restored.

Complexity

  • Time: O(1) per epoch for the early stopping check
  • Space: O(N) if storing best weights, where N is the number of parameters; O(1) otherwise