Compute a KL divergence estimator used in GRPO to regularize the policy from drifting too far from a reference policy. Given log-probabilities from the current policy and the reference policy for a sequence of tokens, estimate the per-token KL divergence.
import math
def kl_divergence_estimator(
log_probs_current: list[float],
log_probs_reference: list[float],
) -> float:
n = len(log_probs_current)
if n == 0:
return 0.0
# Approximation: KL = E_pi[log(pi/pi_ref)] = E_pi[log_pi - log_pi_ref]
kl = sum(lp - lr for lp, lr in zip(log_probs_current, log_probs_reference)) / n
return round(kl, 6)KL = (1/n) * sum(log_pi(a|s) - log_pi_ref(a|s)) for tokens in the sequence.E_{pi}[log(pi/pi_ref)].