← back

Warmup + Cosine Decay Schedule

#196 · Optimization · Medium

⊣ Solve on deep-ml.com

Problem

Implement a learning rate schedule that combines a linear warmup phase with cosine decay. During warmup, the learning rate linearly increases from 0 to the base LR. After warmup, it decays following a cosine curve to a minimum LR.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import math

def warmup_cosine_schedule(step: int, total_steps: int,
                            warmup_steps: int, base_lr: float = 0.001,
                            min_lr: float = 0.0) -> float:
    if step < warmup_steps:
        return base_lr * (step / max(1, warmup_steps))
    else:
        progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        progress = min(progress, 1.0)
        cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
        return min_lr + (base_lr - min_lr) * cosine_decay

def get_schedule(total_steps: int, warmup_steps: int,
                 base_lr: float = 0.001,
                 min_lr: float = 0.0) -> list[float]:
    return [warmup_cosine_schedule(s, total_steps, warmup_steps,
                                    base_lr, min_lr)
            for s in range(total_steps)]

Explanation

  1. Warmup phase (step < warmup_steps): Linearly increase LR from 0 to base_lr by multiplying base_lr * (step / warmup_steps).
  2. Cosine decay phase (step >= warmup_steps): Compute progress as fraction of remaining steps, then apply 0.5 * (1 + cos(pi * progress)) which smoothly decays from 1 to 0.
  3. Scale the cosine factor to interpolate between base_lr and min_lr.
  4. get_schedule returns the full list of learning rates for all steps.

Complexity

  • Time: O(1) per step for computing the learning rate; O(T) for the full schedule
  • Space: O(1) per step; O(T) for storing the full schedule