#135 · Machine Learning · Easy
⊣ Solve on deep-ml.comImplement early stopping for model training. Monitor the validation loss over epochs and stop training when the validation loss has not improved for a specified number of consecutive epochs (patience). Optionally restore the best model weights.
class EarlyStopping:
def __init__(self, patience: int = 5, min_delta: float = 0.0):
self.patience = patience
self.min_delta = min_delta
self.best_loss = float('inf')
self.counter = 0
self.best_weights = None
self.should_stop = False
def step(self, val_loss: float, model_weights=None) -> bool:
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
if model_weights is not None:
import copy
self.best_weights = copy.deepcopy(model_weights)
else:
self.counter += 1
if self.counter >= self.patience:
self.should_stop = True
return True
return False
def train_with_early_stopping(train_losses: list[float], val_losses: list[float], patience: int = 5) -> int:
early_stop = EarlyStopping(patience=patience)
for epoch, val_loss in enumerate(val_losses):
if early_stop.step(val_loss):
return epoch # stopped at this epoch
return len(val_losses) # trained for all epochspatience, signal to stop training.