← back

Implement the SGTM Parameter Update Step

#235 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement the SGTM (Stochastic Gradient with Truncated Momentum) parameter update step. This optimizer combines SGD momentum with a truncation mechanism that clips the momentum buffer to prevent large, stale updates.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def sgtm_update(
    params: list[float],
    grads: list[float],
    momentum_buffer: list[float],
    lr: float,
    momentum: float,
    truncation_threshold: float,
) -> tuple:
    """
    Returns updated (params, momentum_buffer).
    """
    n = len(params)
    new_momentum = [0.0] * n
    new_params = [0.0] * n

    for i in range(n):
        # Update momentum with truncation
        m = momentum * momentum_buffer[i] + grads[i]

        # Truncate momentum if magnitude exceeds threshold
        if abs(m) > truncation_threshold:
            m = truncation_threshold if m > 0 else -truncation_threshold

        new_momentum[i] = round(m, 6)
        new_params[i] = round(params[i] - lr * m, 6)

    return new_params, new_momentum

Explanation

  1. Standard momentum update: m_t = beta * m_{t-1} + g_t.
  2. Apply truncation: clip the momentum magnitude to the threshold, preventing runaway updates from accumulated stale gradients.
  3. Parameter update: theta_t = theta_{t-1} - lr * m_t.
  4. Truncated momentum is particularly useful in distributed or asynchronous training where gradient staleness is a concern.

Complexity

  • Time: O(n) where n is the number of parameters
  • Space: O(n) for the momentum buffer