← back

Implement Weight Decay as L2 Regularization

#198 · Machine Learning · Easy

⊣ Solve on deep-ml.com

Problem

Implement Weight Decay as L2 regularization. Given model weights, a loss value, and a weight decay factor, compute the regularized loss and the gradient contribution from the L2 penalty.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import numpy as np

def l2_regularization_loss(weights: list[np.ndarray],
                            weight_decay: float = 0.01) -> float:
    l2_sum = 0.0
    for w in weights:
        l2_sum += np.sum(w ** 2)
    return 0.5 * weight_decay * l2_sum

def apply_weight_decay(weights: list[np.ndarray],
                       gradients: list[np.ndarray],
                       weight_decay: float = 0.01) -> list[np.ndarray]:
    updated_grads = []
    for w, g in zip(weights, gradients):
        updated_grads.append(g + weight_decay * w)
    return updated_grads

def step_with_weight_decay(weights: list[np.ndarray],
                           gradients: list[np.ndarray],
                           lr: float = 0.01,
                           weight_decay: float = 0.01) -> list[np.ndarray]:
    new_weights = []
    for w, g in zip(weights, gradients):
        new_weights.append(w - lr * (g + weight_decay * w))
    return new_weights

Explanation

  1. L2 regularization loss: Add 0.5 * lambda * sum(w^2) to the original loss. The factor 0.5 simplifies the gradient.
  2. Gradient update: The gradient of the L2 term with respect to weights is lambda * w, which gets added to the existing gradient.
  3. Weight update: Combine the gradient step and weight decay: w_new = w - lr * (grad + lambda * w).
  4. This is equivalent to decoupled weight decay when applied directly to the weights: w_new = (1 - lr * lambda) * w - lr * grad.

Complexity

  • Time: O(N) where N is the total number of weight parameters
  • Space: O(N) for the updated gradients/weights