Implement the Nesterov Accelerated Gradient (NAG) optimizer. NAG improves upon standard momentum by computing the gradient at the "lookahead" position rather than the current position, providing better convergence.
import numpy as np
class NesterovOptimizer:
def __init__(self, learning_rate: float = 0.01, momentum: float = 0.9):
self.lr = learning_rate
self.momentum = momentum
self.velocity = None
def update(self, params: np.ndarray, grad_fn) -> np.ndarray:
if self.velocity is None:
self.velocity = np.zeros_like(params)
# Look ahead: compute gradient at params + momentum * velocity
lookahead = params + self.momentum * self.velocity
grads = grad_fn(lookahead)
# Update velocity and params
self.velocity = self.momentum * self.velocity - self.lr * grads
params = params + self.velocity
return params
def nesterov_update(params: np.ndarray, grads: np.ndarray, velocity: np.ndarray,
lr: float = 0.01, momentum: float = 0.9) -> tuple[np.ndarray, np.ndarray]:
# Reformulated NAG that uses gradient at current position
# but applies Nesterov correction
v_prev = velocity.copy()
velocity = momentum * velocity - lr * grads
params = params + (-momentum * v_prev + (1 + momentum) * velocity)
return params, velocity
def nesterov_update_standard(params: np.ndarray, grads_at_lookahead: np.ndarray,
velocity: np.ndarray, lr: float = 0.01,
momentum: float = 0.9) -> tuple[np.ndarray, np.ndarray]:
# Standard form: grads computed at (params + momentum * velocity)
velocity = momentum * velocity - lr * grads_at_lookahead
params = params + velocity
return params, velocity