← back

GSPO: Group Sequence Policy Optimization

#209 · Reinforcement Learning · Hard

⊣ Solve on deep-ml.com

Problem

Implement GSPO (Group Sequence Policy Optimization), a reinforcement learning objective that optimizes a policy over groups of generated sequences. For each prompt, sample a group of responses, compute rewards, and use a group-normalized advantage for the policy gradient.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import numpy as np

def compute_group_advantages(rewards: np.ndarray) -> np.ndarray:
    """Normalize rewards within each group to get advantages."""
    mean = np.mean(rewards)
    std = np.std(rewards)
    if std < 1e-8:
        return np.zeros_like(rewards)
    return (rewards - mean) / std

def gspo_loss(log_probs_new: np.ndarray, log_probs_old: np.ndarray,
              rewards: np.ndarray, epsilon: float = 0.2,
              beta: float = 0.01, ref_log_probs: np.ndarray = None):
    """
    Compute GSPO loss for a group of responses.
    log_probs_new: log pi_theta(y|x) for each response in the group
    log_probs_old: log pi_old(y|x) from the sampling policy
    rewards: scalar reward for each response
    epsilon: clipping parameter
    beta: KL penalty coefficient
    ref_log_probs: log pi_ref(y|x) from the reference policy
    """
    advantages = compute_group_advantages(rewards)

    # Importance ratio
    ratio = np.exp(log_probs_new - log_probs_old)

    # Clipped surrogate objective
    surr1 = ratio * advantages
    surr2 = np.clip(ratio, 1 - epsilon, 1 + epsilon) * advantages
    policy_loss = -np.mean(np.minimum(surr1, surr2))

    # Optional KL penalty against reference policy
    kl_penalty = 0.0
    if ref_log_probs is not None:
        kl_penalty = np.mean(log_probs_new - ref_log_probs)

    total_loss = policy_loss + beta * kl_penalty
    return float(total_loss)

def gspo_step(prompts_rewards: list[tuple], log_probs_fn,
              epsilon=0.2, beta=0.01):
    """
    Full GSPO step over multiple prompt groups.
    prompts_rewards: list of (log_probs_new, log_probs_old, rewards) per group
    """
    total_loss = 0.0
    for lp_new, lp_old, rewards in prompts_rewards:
        total_loss += gspo_loss(
            np.array(lp_new), np.array(lp_old),
            np.array(rewards), epsilon, beta
        )
    return total_loss / len(prompts_rewards)

Explanation

  1. Group normalization: For each prompt, compute the mean and standard deviation of rewards across the sampled group. Normalize to get zero-mean, unit-variance advantages.
  2. Importance sampling ratio: ratio = exp(log pi_new - log pi_old) measures how much the policy has changed since sampling.
  3. Clipped objective: Use PPO-style clipping to prevent too-large policy updates: min(ratio * adv, clip(ratio, 1-eps, 1+eps) * adv).
  4. KL penalty: Optionally add a penalty for diverging from a reference policy, keeping the model close to a supervised baseline.
  5. Average the loss across all prompt groups for the final update.

Complexity

  • Time: O(G * K) where G is the number of prompt groups and K is the group size
  • Space: O(K) per group for advantages and ratios