← back

KL Divergence Estimator for GRPO

#225 · Reinforcement Learning · Easy

⊣ Solve on deep-ml.com

Problem

Compute a KL divergence estimator used in GRPO to regularize the policy from drifting too far from a reference policy. Given log-probabilities from the current policy and the reference policy for a sequence of tokens, estimate the per-token KL divergence.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
import math

def kl_divergence_estimator(
    log_probs_current: list[float],
    log_probs_reference: list[float],
) -> float:
    n = len(log_probs_current)
    if n == 0:
        return 0.0

    # Approximation: KL = E_pi[log(pi/pi_ref)] = E_pi[log_pi - log_pi_ref]
    kl = sum(lp - lr for lp, lr in zip(log_probs_current, log_probs_reference)) / n
    return round(kl, 6)

Explanation

  1. KL divergence measures how much the current policy diverges from the reference policy.
  2. The estimator uses KL = (1/n) * sum(log_pi(a|s) - log_pi_ref(a|s)) for tokens in the sequence.
  3. This is the sample-based approximation of E_{pi}[log(pi/pi_ref)].
  4. In GRPO, this term is added to the loss as a penalty (scaled by a coefficient beta) to prevent the policy from changing too drastically.

Complexity

  • Time: O(n) where n is the sequence length
  • Space: O(1)