← back

Implement AdamW Optimizer Step

#169 · Optimization · Medium

⊣ Solve on deep-ml.com

Problem

Implement a single step of the AdamW optimizer. AdamW decouples weight decay from the gradient update, applying weight decay directly to the weights rather than through the gradient as in L2 regularization.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import numpy as np

def adamw_step(params: np.ndarray, grads: np.ndarray,
               m: np.ndarray, v: np.ndarray, t: int,
               lr: float = 0.001, beta1: float = 0.9,
               beta2: float = 0.999, epsilon: float = 1e-8,
               weight_decay: float = 0.01):
    m_new = beta1 * m + (1 - beta1) * grads
    v_new = beta2 * v + (1 - beta2) * (grads ** 2)

    m_hat = m_new / (1 - beta1 ** t)
    v_hat = v_new / (1 - beta2 ** t)

    params_new = params - lr * (m_hat / (np.sqrt(v_hat) + epsilon) + weight_decay * params)

    return params_new, m_new, v_new

Explanation

  1. Update the first moment estimate m (exponential moving average of gradients).
  2. Update the second moment estimate v (exponential moving average of squared gradients).
  3. Apply bias correction to both moments to account for their initialization at zero.
  4. Update parameters: the Adam update m_hat / (sqrt(v_hat) + eps) plus a decoupled weight decay term weight_decay * params.
  5. The key difference from Adam + L2: weight decay is applied to the parameters directly, not added to the gradient before moment computation.

Complexity

  • Time: O(n) where n is the number of parameters
  • Space: O(n) for moment estimates