Implement a linear learning rate decay schedule. Starting from an initial learning rate, linearly decrease the learning rate to zero (or a minimum value) over a specified number of training steps, optionally with a warmup phase.
def linear_lr_decay(
step: int,
total_steps: int,
initial_lr: float,
min_lr: float = 0.0,
warmup_steps: int = 0
) -> float:
if step < warmup_steps:
return initial_lr * (step + 1) / warmup_steps
if step >= total_steps:
return min_lr
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return max(min_lr, initial_lr * (1.0 - progress))
def get_lr_schedule(
total_steps: int,
initial_lr: float,
min_lr: float = 0.0,
warmup_steps: int = 0
) -> list[float]:
return [
linear_lr_decay(step, total_steps, initial_lr, min_lr, warmup_steps)
for step in range(total_steps)
]initial_lr.initial_lr to min_lr over the remaining steps.(step - warmup_steps) / (total_steps - warmup_steps), and the LR is initial_lr * (1 - progress).min_lr to prevent the learning rate from going below the minimum.