← back

Fine-Tune Model Weights with RLHF Policy Gradient

#361 · Reinforcement Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement a simplified RLHF (Reinforcement Learning from Human Feedback) policy gradient update. Given a language model's log-probabilities for chosen actions and a reward signal from a reward model, compute the policy gradient loss and update model weights.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np

def rlhf_policy_gradient(
    log_probs: np.ndarray,
    rewards: np.ndarray,
    old_log_probs: np.ndarray,
    clip_epsilon: float = 0.2,
    lr: float = 1e-4
) -> dict:
    # PPO-style clipped objective
    ratios = np.exp(log_probs - old_log_probs)
    advantages = rewards - rewards.mean()
    advantages = advantages / (advantages.std() + 1e-8)

    # Clipped surrogate objective
    surr1 = ratios * advantages
    surr2 = np.clip(ratios, 1 - clip_epsilon, 1 + clip_epsilon) * advantages
    loss = -np.mean(np.minimum(surr1, surr2))

    # Approximate gradient w.r.t. log_probs
    clipped = (ratios < 1 - clip_epsilon) | (ratios > 1 + clip_epsilon)
    grad = np.where(
        (surr1 < surr2) | clipped,
        -advantages * ratios,
        -advantages * ratios
    ) / len(log_probs)

    # Update log_probs (simplified weight update)
    updated_log_probs = log_probs - lr * grad

    return {
        "loss": float(loss),
        "updated_log_probs": updated_log_probs,
        "mean_reward": float(rewards.mean()),
        "mean_advantage": float(advantages.mean()),
    }

Explanation

  1. Compute probability ratios between the current and old policy: r = exp(log_pi - log_pi_old).
  2. Normalize rewards to get advantages (zero-mean, unit-variance).
  3. Use the PPO clipped surrogate objective: take the minimum of the unclipped and clipped ratios times advantages to prevent too-large policy updates.
  4. Compute the gradient and apply a simple gradient descent step.

Complexity

  • Time: O(n) where n is the number of samples/tokens
  • Space: O(n) for ratios, advantages, and gradient arrays