#199 · Machine Learning · Easy
⊣ Solve on deep-ml.comImplement Early Stopping based on validation loss plateau detection. Monitor the validation loss across epochs and stop training if the loss has not improved by at least a minimum delta for a specified number of consecutive epochs (patience).
class EarlyStopping:
def __init__(self, patience: int = 5, min_delta: float = 0.0,
restore_best: bool = True):
self.patience = patience
self.min_delta = min_delta
self.restore_best = restore_best
self.best_loss = float('inf')
self.counter = 0
self.best_epoch = 0
self.best_weights = None
self.stopped = False
def __call__(self, val_loss: float, epoch: int = 0,
weights=None) -> bool:
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
self.best_epoch = epoch
if weights is not None:
self.best_weights = [w.copy() for w in weights]
return False # do not stop
else:
self.counter += 1
if self.counter >= self.patience:
self.stopped = True
return True # stop training
return False
def train_with_early_stopping(train_fn, val_fn, epochs: int,
patience: int = 5) -> dict:
stopper = EarlyStopping(patience=patience)
history = {"train_loss": [], "val_loss": []}
for epoch in range(epochs):
train_loss = train_fn(epoch)
val_loss = val_fn(epoch)
history["train_loss"].append(train_loss)
history["val_loss"].append(val_loss)
if stopper(val_loss, epoch):
break
history["stopped_epoch"] = epoch
history["best_epoch"] = stopper.best_epoch
return historymin_delta, reset the counter and update the best loss.patience, signal to stop training.