#172 · Deep Learning · Medium
⊣ Solve on deep-ml.comImplement the Muon Optimizer Update with Newton-Schulz iteration for orthogonalizing the gradient. This involves computing the polar decomposition approximation of the gradient matrix using iterative refinement.
import numpy as np
def newton_schulz_orthogonalize(G: np.ndarray, steps: int = 5) -> np.ndarray:
assert G.ndim == 2
rows, cols = G.shape
transpose = False
if rows < cols:
G = G.T
transpose = True
GtG = G.T @ G
norm = np.sqrt(np.trace(GtG)) + 1e-7
X = G / norm
for _ in range(steps):
A = X.T @ X
X = X @ (3.0 * np.eye(A.shape[0]) - A) / 2.0
if transpose:
X = X.T
return X
def muon_update(params: np.ndarray, grads: np.ndarray,
buf: np.ndarray, lr: float = 0.02,
momentum: float = 0.95, ns_steps: int = 5):
buf_new = momentum * buf + grads
G_orth = newton_schulz_orthogonalize(buf_new.reshape(params.shape) if buf_new.ndim == 1 and params.ndim == 2 else buf_new, ns_steps)
scale = np.sqrt(float(max(G_orth.shape)))
update = G_orth.reshape(params.shape) * scale
params_new = params - lr * update
return params_new, buf_newbuf = momentum * buf + grads.X = G / ||G||_F, iterate: X <- X * (3I - X^T X) / 2.sqrt(max_dim) to preserve the update magnitude, then apply the learning rate.