← back

First-Visit Monte Carlo Prediction

#272 · Reinforcement Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement First-Visit Monte Carlo Prediction to estimate the state-value function. Given episodes of (state, reward) pairs, compute the average return for the first visit to each state across all episodes.

Solution

For each episode, identify the first visit to each state, compute the return from that point forward, and maintain running averages.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def first_visit_mc_prediction(
    episodes: list[list[tuple]],
    gamma: float = 0.99,
) -> dict[str, float]:
    returns_sum: dict[str, float] = {}
    returns_count: dict[str, int] = {}

    for episode in episodes:
        # episode is a list of (state, reward) tuples
        visited: set[str] = set()
        T = len(episode)

        # Pre-compute returns from each timestep
        G = [0.0] * T
        G[T - 1] = episode[T - 1][1]
        for t in range(T - 2, -1, -1):
            G[t] = episode[t][1] + gamma * G[t + 1]

        for t in range(T):
            state = str(episode[t][0])
            if state not in visited:
                visited.add(state)
                returns_sum[state] = returns_sum.get(state, 0.0) + G[t]
                returns_count[state] = returns_count.get(state, 0) + 1

    value_function: dict[str, float] = {}
    for state in returns_sum:
        value_function[state] = round(returns_sum[state] / returns_count[state], 6)

    return value_function

Explanation

  1. For each episode, compute the discounted return G_t = r_t + gamma * G_{t+1} backwards from the terminal state.
  2. Track which states have been visited. Only the first visit to each state in each episode contributes to the estimate.
  3. Accumulate the returns for each state and count the number of first visits.
  4. The estimated value V(s) = average of all first-visit returns for state s.
  5. As the number of episodes increases, this converges to the true value function under the policy that generated the episodes.

Complexity

  • Time: O(sum of episode lengths) across all episodes
  • Space: O(|S|) where |S| is the number of unique states