← back

Adam Optimizer

#87 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement the Adam (Adaptive Moment Estimation) optimizer from scratch. Given parameters, gradients, and hyperparameters (learning rate, beta1, beta2, epsilon), perform one step of the Adam update rule.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import numpy as np

def adam_optimizer(params, grads, m, v, t, lr=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
    updated_params = []
    new_m = []
    new_v = []

    for i in range(len(params)):
        p = np.array(params[i], dtype=float)
        g = np.array(grads[i], dtype=float)
        mi = np.array(m[i], dtype=float)
        vi = np.array(v[i], dtype=float)

        # Update biased first moment estimate
        mi = beta1 * mi + (1 - beta1) * g
        # Update biased second moment estimate
        vi = beta2 * vi + (1 - beta2) * (g ** 2)

        # Bias-corrected estimates
        m_hat = mi / (1 - beta1 ** t)
        v_hat = vi / (1 - beta2 ** t)

        # Update parameters
        p = p - lr * m_hat / (np.sqrt(v_hat) + epsilon)

        updated_params.append(p.tolist() if hasattr(p, 'tolist') else p)
        new_m.append(mi.tolist() if hasattr(mi, 'tolist') else mi)
        new_v.append(vi.tolist() if hasattr(vi, 'tolist') else vi)

    return updated_params, new_m, new_v

Explanation

  1. First moment (m): exponential moving average of gradients (like momentum). Updated as m = beta1 m + (1 - beta1) g.
  2. Second moment (v): exponential moving average of squared gradients (like RMSProp). Updated as v = beta2 v + (1 - beta2) g^2.
  3. Bias correction: divide by (1 - beta^t) to correct for initialization bias toward zero.
  4. Parameter update: p = p - lr * m_hat / (sqrt(v_hat) + epsilon).
  5. Adam combines the benefits of momentum (stable direction) and RMSProp (adaptive learning rates).

Complexity

  • Time: O(n) where n is the total number of parameters
  • Space: O(n) for storing m and v moment estimates