← back

Implement the SARSA Algorithm on policy

#175 · Reinforcement Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement the SARSA (State-Action-Reward-State-Action) on-policy temporal difference learning algorithm. SARSA updates Q-values using the action actually taken in the next state, making it on-policy.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np

def sarsa(n_states: int, n_actions: int, episodes: int,
          env_step, alpha: float = 0.1, gamma: float = 0.99,
          epsilon: float = 0.1):
    Q = np.zeros((n_states, n_actions))

    def choose_action(state):
        if np.random.rand() < epsilon:
            return np.random.randint(n_actions)
        return int(np.argmax(Q[state]))

    for _ in range(episodes):
        state = env_step(None, None, reset=True)
        action = choose_action(state)

        done = False
        while not done:
            next_state, reward, done = env_step(state, action)
            next_action = choose_action(next_state)

            Q[state, action] += alpha * (
                reward + gamma * Q[next_state, next_action] * (1 - done) - Q[state, action]
            )

            state = next_state
            action = next_action

    return Q

Explanation

  1. Initialize Q-table to zeros for all state-action pairs.
  2. For each episode, choose an initial action using epsilon-greedy.
  3. At each step, take the action, observe reward and next state, choose next action (epsilon-greedy).
  4. Update: Q(s,a) += alpha * [r + gamma * Q(s',a') - Q(s,a)] where a' is the actual next action chosen.
  5. The key difference from Q-learning: SARSA uses the next action a' from the policy (on-policy), while Q-learning uses max_a' Q(s',a') (off-policy).

Complexity

  • Time: O(episodes * episode_length) per update step is O(1)
  • Space: O(|S| * |A|) for the Q-table