#209 · Reinforcement Learning · Hard
⊣ Solve on deep-ml.comImplement GSPO (Group Sequence Policy Optimization), a reinforcement learning objective that optimizes a policy over groups of generated sequences. For each prompt, sample a group of responses, compute rewards, and use a group-normalized advantage for the policy gradient.
import numpy as np
def compute_group_advantages(rewards: np.ndarray) -> np.ndarray:
"""Normalize rewards within each group to get advantages."""
mean = np.mean(rewards)
std = np.std(rewards)
if std < 1e-8:
return np.zeros_like(rewards)
return (rewards - mean) / std
def gspo_loss(log_probs_new: np.ndarray, log_probs_old: np.ndarray,
rewards: np.ndarray, epsilon: float = 0.2,
beta: float = 0.01, ref_log_probs: np.ndarray = None):
"""
Compute GSPO loss for a group of responses.
log_probs_new: log pi_theta(y|x) for each response in the group
log_probs_old: log pi_old(y|x) from the sampling policy
rewards: scalar reward for each response
epsilon: clipping parameter
beta: KL penalty coefficient
ref_log_probs: log pi_ref(y|x) from the reference policy
"""
advantages = compute_group_advantages(rewards)
# Importance ratio
ratio = np.exp(log_probs_new - log_probs_old)
# Clipped surrogate objective
surr1 = ratio * advantages
surr2 = np.clip(ratio, 1 - epsilon, 1 + epsilon) * advantages
policy_loss = -np.mean(np.minimum(surr1, surr2))
# Optional KL penalty against reference policy
kl_penalty = 0.0
if ref_log_probs is not None:
kl_penalty = np.mean(log_probs_new - ref_log_probs)
total_loss = policy_loss + beta * kl_penalty
return float(total_loss)
def gspo_step(prompts_rewards: list[tuple], log_probs_fn,
epsilon=0.2, beta=0.01):
"""
Full GSPO step over multiple prompt groups.
prompts_rewards: list of (log_probs_new, log_probs_old, rewards) per group
"""
total_loss = 0.0
for lp_new, lp_old, rewards in prompts_rewards:
total_loss += gspo_loss(
np.array(lp_new), np.array(lp_old),
np.array(rewards), epsilon, beta
)
return total_loss / len(prompts_rewards)ratio = exp(log pi_new - log pi_old) measures how much the policy has changed since sampling.min(ratio * adv, clip(ratio, 1-eps, 1+eps) * adv).