← back

Muon Optimizer Update with Newton-Schulz Iteration

#172 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement the Muon Optimizer Update with Newton-Schulz iteration for orthogonalizing the gradient. This involves computing the polar decomposition approximation of the gradient matrix using iterative refinement.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy as np

def newton_schulz_orthogonalize(G: np.ndarray, steps: int = 5) -> np.ndarray:
    assert G.ndim == 2
    rows, cols = G.shape
    transpose = False
    if rows < cols:
        G = G.T
        transpose = True

    GtG = G.T @ G
    norm = np.sqrt(np.trace(GtG)) + 1e-7
    X = G / norm

    for _ in range(steps):
        A = X.T @ X
        X = X @ (3.0 * np.eye(A.shape[0]) - A) / 2.0

    if transpose:
        X = X.T
    return X

def muon_update(params: np.ndarray, grads: np.ndarray,
                buf: np.ndarray, lr: float = 0.02,
                momentum: float = 0.95, ns_steps: int = 5):
    buf_new = momentum * buf + grads
    G_orth = newton_schulz_orthogonalize(buf_new.reshape(params.shape) if buf_new.ndim == 1 and params.ndim == 2 else buf_new, ns_steps)
    scale = np.sqrt(float(max(G_orth.shape)))
    update = G_orth.reshape(params.shape) * scale
    params_new = params - lr * update
    return params_new, buf_new

Explanation

  1. Accumulate the gradient into a momentum buffer: buf = momentum * buf + grads.
  2. Newton-Schulz iteration approximates the orthogonal polar factor of a matrix. Starting from X = G / ||G||_F, iterate: X <- X * (3I - X^T X) / 2.
  3. This converges to an orthogonal matrix that best approximates the gradient's direction structure.
  4. Scale by sqrt(max_dim) to preserve the update magnitude, then apply the learning rate.
  5. The orthogonalization removes scale information and focuses the update on the directional structure of the gradient.

Complexity

  • Time: O(ns_steps d k * min(d, k)) for the Newton-Schulz iterations
  • Space: O(d * k) for the gradient and intermediate matrices