#175 · Reinforcement Learning · Medium
⊣ Solve on deep-ml.comImplement the SARSA (State-Action-Reward-State-Action) on-policy temporal difference learning algorithm. SARSA updates Q-values using the action actually taken in the next state, making it on-policy.
import numpy as np
def sarsa(n_states: int, n_actions: int, episodes: int,
env_step, alpha: float = 0.1, gamma: float = 0.99,
epsilon: float = 0.1):
Q = np.zeros((n_states, n_actions))
def choose_action(state):
if np.random.rand() < epsilon:
return np.random.randint(n_actions)
return int(np.argmax(Q[state]))
for _ in range(episodes):
state = env_step(None, None, reset=True)
action = choose_action(state)
done = False
while not done:
next_state, reward, done = env_step(state, action)
next_action = choose_action(next_state)
Q[state, action] += alpha * (
reward + gamma * Q[next_state, next_action] * (1 - done) - Q[state, action]
)
state = next_state
action = next_action
return QQ(s,a) += alpha * [r + gamma * Q(s',a') - Q(s,a)] where a' is the actual next action chosen.a' from the policy (on-policy), while Q-learning uses max_a' Q(s',a') (off-policy).