Implement Gradient Clipping by Global Norm. Given a list of gradient tensors, compute the global norm (L2 norm across all gradients concatenated), and if it exceeds a threshold, scale all gradients down proportionally.
import numpy as np
def clip_grad_by_global_norm(gradients: list[np.ndarray],
max_norm: float) -> list[np.ndarray]:
# Compute global norm
total_norm_sq = 0.0
for g in gradients:
total_norm_sq += np.sum(g ** 2)
global_norm = np.sqrt(total_norm_sq)
# Clip if necessary
clip_coeff = max_norm / max(global_norm, max_norm)
clipped = [g * clip_coeff for g in gradients]
return clipped, float(global_norm)max_norm / max(global_norm, max_norm). This equals 1.0 when the global norm is within the threshold and max_norm / global_norm when it exceeds the threshold.max_norm.