← back

Linear Learning Rate Decay

#377 · Machine Learning · Easy

⊣ Solve on deep-ml.com

Problem

Implement a linear learning rate decay schedule. Starting from an initial learning rate, linearly decrease the learning rate to zero (or a minimum value) over a specified number of training steps, optionally with a warmup phase.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def linear_lr_decay(
    step: int,
    total_steps: int,
    initial_lr: float,
    min_lr: float = 0.0,
    warmup_steps: int = 0
) -> float:
    if step < warmup_steps:
        return initial_lr * (step + 1) / warmup_steps
    if step >= total_steps:
        return min_lr
    progress = (step - warmup_steps) / (total_steps - warmup_steps)
    return max(min_lr, initial_lr * (1.0 - progress))

def get_lr_schedule(
    total_steps: int,
    initial_lr: float,
    min_lr: float = 0.0,
    warmup_steps: int = 0
) -> list[float]:
    return [
        linear_lr_decay(step, total_steps, initial_lr, min_lr, warmup_steps)
        for step in range(total_steps)
    ]

Explanation

  1. During warmup (if any), linearly increase the learning rate from 0 to initial_lr.
  2. After warmup, linearly decrease the learning rate from initial_lr to min_lr over the remaining steps.
  3. The progress fraction is (step - warmup_steps) / (total_steps - warmup_steps), and the LR is initial_lr * (1 - progress).
  4. Clamp to min_lr to prevent the learning rate from going below the minimum.

Complexity

  • Time: O(1) per step
  • Space: O(1)