← back

Mixed Precision Training

#160 · Machine Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement Mixed Precision Training utilities: convert a model's forward pass to use half precision (float16) for speed, while keeping the master weights in full precision (float32) for accuracy. Include loss scaling to prevent gradient underflow.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import numpy as np

def mixed_precision_step(weights_fp32: np.ndarray, inputs: np.ndarray,
                         targets: np.ndarray, lr: float,
                         loss_scale: float = 1024.0):
    weights_fp16 = weights_fp32.astype(np.float16)
    inputs_fp16 = inputs.astype(np.float16)

    predictions = inputs_fp16 @ weights_fp16
    loss = np.mean((predictions.astype(np.float32) - targets) ** 2)

    error = 2.0 * (predictions.astype(np.float32) - targets) / len(targets)
    scaled_error = (error * loss_scale).astype(np.float16)
    grads_fp16 = inputs_fp16.T @ scaled_error

    grads_fp32 = grads_fp16.astype(np.float32) / loss_scale

    weights_fp32 = weights_fp32 - lr * grads_fp32
    return weights_fp32, float(loss)

Explanation

  1. Cast master weights and inputs to float16 for the forward pass (cheaper computation).
  2. Compute the loss in float32 to avoid overflow.
  3. Scale the loss (or equivalently the error) by a large factor before computing float16 gradients. This prevents small gradient values from underflowing to zero in float16.
  4. Cast gradients back to float32 and unscale by dividing by the loss scale factor.
  5. Update the float32 master weights with the unscaled float32 gradients.

Complexity

  • Time: O(n * d) for the matrix multiplications
  • Space: O(n * d) for storing float16 copies and gradients