← back

Gradient Clipping by Global Norm

#197 · Optimization · Medium

⊣ Solve on deep-ml.com

Problem

Implement Gradient Clipping by Global Norm. Given a list of gradient tensors, compute the global norm (L2 norm across all gradients concatenated), and if it exceeds a threshold, scale all gradients down proportionally.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import numpy as np

def clip_grad_by_global_norm(gradients: list[np.ndarray],
                              max_norm: float) -> list[np.ndarray]:
    # Compute global norm
    total_norm_sq = 0.0
    for g in gradients:
        total_norm_sq += np.sum(g ** 2)
    global_norm = np.sqrt(total_norm_sq)

    # Clip if necessary
    clip_coeff = max_norm / max(global_norm, max_norm)
    clipped = [g * clip_coeff for g in gradients]

    return clipped, float(global_norm)

Explanation

  1. Compute the global norm: square each gradient element, sum across all tensors, then take the square root.
  2. Compute the clipping coefficient: max_norm / max(global_norm, max_norm). This equals 1.0 when the global norm is within the threshold and max_norm / global_norm when it exceeds the threshold.
  3. Multiply every gradient by the clipping coefficient, uniformly scaling all gradients to bring the global norm down to max_norm.
  4. Return the clipped gradients and the original global norm for logging.

Complexity

  • Time: O(N) where N is the total number of gradient elements across all tensors
  • Space: O(N) for the clipped gradient copies