← back

Implement Neural Memory Update with Surprise and Momentum

#267 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement a Neural Memory Update mechanism that uses surprise (prediction error) and momentum to decide how much to update stored memory vectors. High surprise should trigger larger updates.

Solution

Compute surprise as the L2 distance between predicted and actual input. Use a running momentum of surprise to smooth the update gate. Update memory as a weighted combination of old memory and new input.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import math

def neural_memory_update(
    memory: list[float],
    input_vec: list[float],
    predicted: list[float],
    momentum: float,
    prev_surprise_ema: float,
    alpha: float = 0.1,
    base_lr: float = 0.5,
) -> dict:
    dim = len(memory)

    # Compute surprise as L2 distance between predicted and actual input
    surprise = math.sqrt(sum((input_vec[d] - predicted[d]) ** 2 for d in range(dim)))

    # Exponential moving average of surprise (momentum)
    surprise_ema = alpha * surprise + (1 - alpha) * prev_surprise_ema

    # Update gate: sigmoid of normalized surprise
    gate = 1.0 / (1.0 + math.exp(-surprise_ema))

    # Learning rate modulated by surprise
    lr = base_lr * gate

    # Update memory
    new_memory = [
        (1.0 - lr) * memory[d] + lr * input_vec[d]
        for d in range(dim)
    ]

    return {
        "memory": [round(v, 6) for v in new_memory],
        "surprise": round(surprise, 6),
        "surprise_ema": round(surprise_ema, 6),
        "gate": round(gate, 6),
    }

Explanation

  1. Surprise is quantified as the L2 norm of the prediction error (difference between expected and actual input).
  2. The surprise value is smoothed with an exponential moving average (EMA) to provide momentum — preventing noisy single-step spikes from dominating.
  3. A sigmoid gate converts the surprise EMA into a [0, 1] update strength.
  4. Memory is updated as a linear interpolation: new = (1 - lr) * old + lr * input where lr is modulated by the gate.
  5. High surprise leads to aggressive memory updates; low surprise preserves the current memory.

Complexity

  • Time: O(d) where d is the dimension of the memory vector
  • Space: O(d) for the new memory vector