← back

Implement Cosine Annealing with Warm Restarts

#392 · Machine Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement cosine annealing with warm restarts (SGDR). The learning rate follows a cosine curve from an initial value down to a minimum, then restarts periodically. Given the current epoch, compute the learning rate.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import numpy as np

def cosine_annealing_warm_restarts(epoch: int, T_0: int, T_mult: int = 1, lr_max: float = 0.1, lr_min: float = 0.0) -> float:
    # Find which restart cycle we are in
    if T_mult == 1:
        cycle = epoch // T_0
        epoch_in_cycle = epoch % T_0
        T_cur = T_0
    else:
        # Geometric schedule: T_0, T_0*T_mult, T_0*T_mult^2, ...
        cycle = 0
        cumulative = T_0
        T_cur = T_0
        while cumulative <= epoch:
            cycle += 1
            T_cur = T_0 * (T_mult ** cycle)
            cumulative += T_cur
        epoch_in_cycle = epoch - (cumulative - T_cur)

    lr = lr_min + 0.5 * (lr_max - lr_min) * (1 + np.cos(np.pi * epoch_in_cycle / T_cur))
    return float(lr)

Explanation

  1. Determine which restart cycle the current epoch belongs to. With T_mult=1, cycles are fixed length. Otherwise, each cycle is T_mult times longer than the previous.
  2. Compute the position within the current cycle.
  3. Apply the cosine annealing formula: lr = lr_min + 0.5 * (lr_max - lr_min) * (1 + cos(pi * t / T)).
  4. At the start of each cycle the LR resets to lr_max; at the end it reaches lr_min.

Complexity

  • Time: O(number of cycles) to find the current cycle
  • Space: O(1)