Implement a single step of the AdamW optimizer. AdamW decouples weight decay from the gradient update, applying weight decay directly to the weights rather than through the gradient as in L2 regularization.
import numpy as np
def adamw_step(params: np.ndarray, grads: np.ndarray,
m: np.ndarray, v: np.ndarray, t: int,
lr: float = 0.001, beta1: float = 0.9,
beta2: float = 0.999, epsilon: float = 1e-8,
weight_decay: float = 0.01):
m_new = beta1 * m + (1 - beta1) * grads
v_new = beta2 * v + (1 - beta2) * (grads ** 2)
m_hat = m_new / (1 - beta1 ** t)
v_hat = v_new / (1 - beta2 ** t)
params_new = params - lr * (m_hat / (np.sqrt(v_hat) + epsilon) + weight_decay * params)
return params_new, m_new, v_newm (exponential moving average of gradients).v (exponential moving average of squared gradients).m_hat / (sqrt(v_hat) + eps) plus a decoupled weight decay term weight_decay * params.