← back

Muon Optimizer Step with Matrix Preconditioning

#170 · Optimization · Medium

⊣ Solve on deep-ml.com

Problem

Implement a single step of the Muon Optimizer with matrix preconditioning. Muon applies Newton-Schulz iterations to approximate the matrix square root inverse of the gradient's covariance for preconditioning.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np

def muon_step(params: np.ndarray, grads: np.ndarray,
              momentum_buffer: np.ndarray, lr: float = 0.02,
              momentum: float = 0.95, ns_steps: int = 5):
    if grads.ndim < 2:
        grads_2d = grads.reshape(-1, 1)
    else:
        grads_2d = grads

    buf = momentum * momentum_buffer + grads
    G = buf.copy()

    if G.shape[0] >= G.shape[1]:
        GtG = G.T @ G
        norm = np.sqrt(np.trace(GtG)) + 1e-7
        X = G / norm
        for _ in range(ns_steps):
            A = X.T @ X
            X = X @ (3 * np.eye(A.shape[0]) - A) / 2
    else:
        GGt = G @ G.T
        norm = np.sqrt(np.trace(GGt)) + 1e-7
        X = G / norm
        for _ in range(ns_steps):
            A = X @ X.T
            X = (3 * np.eye(A.shape[0]) - A) @ X / 2

    preconditioned = X * np.sqrt(max(G.shape[0], G.shape[1]))

    if grads.ndim < 2:
        preconditioned = preconditioned.flatten()[:params.shape[0]]

    params_new = params - lr * preconditioned
    return params_new, buf

Explanation

  1. Accumulate momentum: buf = momentum * old_buf + grads.
  2. Reshape the gradient to 2D if needed for matrix operations.
  3. Normalize the gradient matrix by the square root of its Frobenius norm squared (trace of GtG or GGt).
  4. Apply Newton-Schulz iterations to approximate the orthogonalized gradient (polar decomposition): X <- X * (3I - X^T X) / 2.
  5. Scale the result and update parameters. This preconditioning adapts the gradient direction for better convergence.

Complexity

  • Time: O(ns_steps d^2 k) where d and k are the gradient matrix dimensions
  • Space: O(d * k) for the gradient matrix and intermediate computations