← back

n-Step TD Prediction

#273 · Reinforcement Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement n-Step TD Prediction for estimating the state-value function. Instead of using a single-step bootstrap (TD(0)) or a full episode (Monte Carlo), n-step TD uses the next n rewards plus a bootstrap from the value estimate n steps ahead.

Solution

For each time step, compute the n-step return: sum of discounted rewards over n steps plus the discounted value of the state n steps ahead. Update the value function towards this target.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def n_step_td_prediction(
    episodes: list[list[tuple]],
    n: int = 3,
    alpha: float = 0.1,
    gamma: float = 0.99,
    initial_value: float = 0.0,
) -> dict[str, float]:
    V: dict[str, float] = {}

    def get_v(state: str) -> float:
        return V.get(state, initial_value)

    for episode in episodes:
        T = len(episode)
        states = [str(episode[t][0]) for t in range(T)]
        rewards = [episode[t][1] for t in range(T)]

        # Initialize values for unseen states
        for s in states:
            if s not in V:
                V[s] = initial_value

        for t in range(T):
            # Compute n-step return G
            G = 0.0
            end = min(t + n, T)
            for k in range(t, end):
                G += (gamma ** (k - t)) * rewards[k]

            # Bootstrap from V(S_{t+n}) if not at terminal
            if end < T:
                G += (gamma ** n) * get_v(states[end])

            # Update V(S_t)
            state = states[t]
            V[state] = V[state] + alpha * (G - V[state])

    return {s: round(v, 6) for s, v in V.items()}

Explanation

  1. For each time step t, the n-step return is: G_t = r_t + gamma*r_{t+1} + ... + gamma^{n-1}*r_{t+n-1} + gamma^n * V(S_{t+n}).
  2. If the episode ends before n steps, use the actual returns without bootstrapping.
  3. Update: V(S_t) <- V(S_t) + alpha * (G_t - V(S_t)).
  4. n=1 recovers TD(0); n=infinity recovers Monte Carlo.
  5. Intermediate n values often give the best bias-variance tradeoff.

Complexity

  • Time: O(sum of episode lengths * n) across all episodes
  • Space: O(|S|) for the value function