← back

Gradient Checkpointing

#188 · Machine Learning · Easy

⊣ Solve on deep-ml.com

Problem

Implement a forward pass with gradient checkpointing. Instead of storing all intermediate activations, only store activations at designated checkpoint layers and recompute the others during the backward pass. For this problem, implement the forward pass that selectively stores checkpoints.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import numpy as np

def forward_with_checkpointing(x: np.ndarray,
                                layers: list,
                                checkpoint_indices: set) -> tuple:
    activations = {0: x}
    current = x
    for i, layer in enumerate(layers):
        current = layer(current)
        if (i + 1) in checkpoint_indices:
            activations[i + 1] = current.copy()
    output = current
    return output, activations

def recompute_segment(activations: dict, layers: list,
                      start: int, end: int) -> list:
    results = []
    x = activations[start]
    for i in range(start, end):
        x = layers[i](x)
        results.append(x)
    return results

Explanation

  1. Forward pass: Run the input through each layer sequentially. Only store the activation at layers whose index is in checkpoint_indices.
  2. The input (index 0) is always stored as the starting checkpoint.
  3. Recomputation: When gradients are needed for a segment, recompute_segment takes the nearest stored checkpoint and re-runs the forward pass from that point to reconstruct intermediate activations.
  4. This trades compute for memory: instead of O(L) memory for L layers, you only store O(sqrt(L)) checkpoints and recompute the rest.

Complexity

  • Time: O(L) for the forward pass; up to O(L) additional for recomputation during backward
  • Space: O(k) where k = |checkpoint_indices|, typically O(sqrt(L))