← back

Monte Carlo Tree Search

#207 · Reinforcement Learning · Hard

⊣ Solve on deep-ml.com

Problem

Implement Monte Carlo Tree Search (MCTS) with the four key phases: selection, expansion, simulation (rollout), and backpropagation. Use UCB1 for balancing exploration and exploitation.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import math
import random

class MCTSNode:
    def __init__(self, state, parent=None, action=None):
        self.state = state
        self.parent = parent
        self.action = action
        self.children = []
        self.visits = 0
        self.value = 0.0
        self.untried_actions = None

    def ucb1(self, c=1.414):
        if self.visits == 0:
            return float('inf')
        exploit = self.value / self.visits
        explore = c * math.sqrt(math.log(self.parent.visits) / self.visits)
        return exploit + explore

    def best_child(self, c=1.414):
        return max(self.children, key=lambda ch: ch.ucb1(c))

    def is_fully_expanded(self):
        return len(self.untried_actions) == 0

def mcts(root_state, get_actions, apply_action, is_terminal,
         get_reward, n_iterations=1000, c=1.414):
    root = MCTSNode(root_state)
    root.untried_actions = list(get_actions(root_state))

    for _ in range(n_iterations):
        node = root
        state = root_state

        # Selection
        while node.untried_actions is not None and \
              node.is_fully_expanded() and node.children:
            node = node.best_child(c)
            state = apply_action(state, node.action)

        # Expansion
        if node.untried_actions and not is_terminal(state):
            action = random.choice(node.untried_actions)
            node.untried_actions.remove(action)
            state = apply_action(state, action)
            child = MCTSNode(state, parent=node, action=action)
            child.untried_actions = list(get_actions(state))
            node.children.append(child)
            node = child

        # Simulation (rollout)
        sim_state = state
        while not is_terminal(sim_state):
            actions = get_actions(sim_state)
            if not actions:
                break
            sim_state = apply_action(sim_state, random.choice(actions))
        reward = get_reward(sim_state)

        # Backpropagation
        while node is not None:
            node.visits += 1
            node.value += reward
            node = node.parent

    # Return the best action (most visited child)
    if root.children:
        best = max(root.children, key=lambda ch: ch.visits)
        return best.action
    return None

Explanation

  1. Selection: Starting from the root, repeatedly pick the child with the highest UCB1 score until reaching a node that is not fully expanded or is terminal.
  2. Expansion: Add a new child node by taking one of the untried actions from the current node.
  3. Simulation: From the new node, play out randomly until a terminal state is reached and get the reward.
  4. Backpropagation: Walk back up to the root, incrementing visit counts and accumulating reward at each ancestor.
  5. UCB1: value/visits + c * sqrt(ln(parent_visits) / visits) balances exploitation (high average reward) with exploration (rarely visited nodes).
  6. After all iterations, select the action from the root with the most visits.

Complexity

  • Time: O(n_iterations * D) where D is the average depth of a rollout
  • Space: O(n_iterations) for storing tree nodes (at most one new node per iteration)