Implement a learning rate schedule that combines a linear warmup phase with cosine decay. During warmup, the learning rate linearly increases from 0 to the base LR. After warmup, it decays following a cosine curve to a minimum LR.
import math
def warmup_cosine_schedule(step: int, total_steps: int,
warmup_steps: int, base_lr: float = 0.001,
min_lr: float = 0.0) -> float:
if step < warmup_steps:
return base_lr * (step / max(1, warmup_steps))
else:
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
progress = min(progress, 1.0)
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
return min_lr + (base_lr - min_lr) * cosine_decay
def get_schedule(total_steps: int, warmup_steps: int,
base_lr: float = 0.001,
min_lr: float = 0.0) -> list[float]:
return [warmup_cosine_schedule(s, total_steps, warmup_steps,
base_lr, min_lr)
for s in range(total_steps)]base_lr by multiplying base_lr * (step / warmup_steps).0.5 * (1 + cos(pi * progress)) which smoothly decays from 1 to 0.base_lr and min_lr.get_schedule returns the full list of learning rates for all steps.