Implement a forward pass with gradient checkpointing. Instead of storing all intermediate activations, only store activations at designated checkpoint layers and recompute the others during the backward pass. For this problem, implement the forward pass that selectively stores checkpoints.
import numpy as np
def forward_with_checkpointing(x: np.ndarray,
layers: list,
checkpoint_indices: set) -> tuple:
activations = {0: x}
current = x
for i, layer in enumerate(layers):
current = layer(current)
if (i + 1) in checkpoint_indices:
activations[i + 1] = current.copy()
output = current
return output, activations
def recompute_segment(activations: dict, layers: list,
start: int, end: int) -> list:
results = []
x = activations[start]
for i in range(start, end):
x = layers[i](x)
results.append(x)
return resultscheckpoint_indices.recompute_segment takes the nearest stored checkpoint and re-runs the forward pass from that point to reconstruct intermediate activations.