Implement the SGTM (Stochastic Gradient with Truncated Momentum) parameter update step. This optimizer combines SGD momentum with a truncation mechanism that clips the momentum buffer to prevent large, stale updates.
def sgtm_update(
params: list[float],
grads: list[float],
momentum_buffer: list[float],
lr: float,
momentum: float,
truncation_threshold: float,
) -> tuple:
"""
Returns updated (params, momentum_buffer).
"""
n = len(params)
new_momentum = [0.0] * n
new_params = [0.0] * n
for i in range(n):
# Update momentum with truncation
m = momentum * momentum_buffer[i] + grads[i]
# Truncate momentum if magnitude exceeds threshold
if abs(m) > truncation_threshold:
m = truncation_threshold if m > 0 else -truncation_threshold
new_momentum[i] = round(m, 6)
new_params[i] = round(params[i] - lr * m, 6)
return new_params, new_momentumm_t = beta * m_{t-1} + g_t.theta_t = theta_{t-1} - lr * m_t.