← back

Nesterov Accelerated Gradient Optimizer

#150 · Deep Learning · Easy

⊣ Solve on deep-ml.com

Problem

Implement the Nesterov Accelerated Gradient (NAG) optimizer. NAG improves upon standard momentum by computing the gradient at the "lookahead" position rather than the current position, providing better convergence.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import numpy as np

class NesterovOptimizer:
    def __init__(self, learning_rate: float = 0.01, momentum: float = 0.9):
        self.lr = learning_rate
        self.momentum = momentum
        self.velocity = None

    def update(self, params: np.ndarray, grad_fn) -> np.ndarray:
        if self.velocity is None:
            self.velocity = np.zeros_like(params)

        # Look ahead: compute gradient at params + momentum * velocity
        lookahead = params + self.momentum * self.velocity
        grads = grad_fn(lookahead)

        # Update velocity and params
        self.velocity = self.momentum * self.velocity - self.lr * grads
        params = params + self.velocity
        return params

def nesterov_update(params: np.ndarray, grads: np.ndarray, velocity: np.ndarray,
                    lr: float = 0.01, momentum: float = 0.9) -> tuple[np.ndarray, np.ndarray]:
    # Reformulated NAG that uses gradient at current position
    # but applies Nesterov correction
    v_prev = velocity.copy()
    velocity = momentum * velocity - lr * grads
    params = params + (-momentum * v_prev + (1 + momentum) * velocity)
    return params, velocity

def nesterov_update_standard(params: np.ndarray, grads_at_lookahead: np.ndarray,
                             velocity: np.ndarray, lr: float = 0.01,
                             momentum: float = 0.9) -> tuple[np.ndarray, np.ndarray]:
    # Standard form: grads computed at (params + momentum * velocity)
    velocity = momentum * velocity - lr * grads_at_lookahead
    params = params + velocity
    return params, velocity

Explanation

  1. Standard momentum computes the gradient at the current position, then jumps.
  2. Nesterov first makes a "lookahead" step in the direction of accumulated momentum, computes the gradient there, then makes the actual update.
  3. This lookahead provides a corrective factor: if the momentum step overshoots, the gradient at the lookahead position will point back.
  4. The reformulated version avoids needing a separate gradient call at the lookahead position by algebraically rearranging the update.
  5. NAG typically converges faster than standard momentum, especially near the optimum.

Complexity

  • Time: O(P) per update where P = number of parameters
  • Space: O(P) for the velocity vector