← back

RAFT: Iterative Reward-Ranked Fine-Tuning Loop

#379 · Reinforcement Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement RAFT (Reward-rAnked Fine-Tuning), an iterative loop that generates candidate responses, ranks them with a reward model, selects the top-ranked responses, and fine-tunes the policy model on those selected responses.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import numpy as np

def raft_iteration(
    policy_logits: np.ndarray,
    prompts: list,
    reward_fn,
    n_candidates: int = 8,
    top_fraction: float = 0.25,
    lr: float = 1e-3
) -> dict:
    n_prompts = len(prompts)
    top_k = max(1, int(n_candidates * top_fraction))

    all_selected_logits = []
    all_rewards = []

    for i in range(n_prompts):
        # Generate candidates by sampling from policy
        candidates = []
        candidate_logprobs = []
        for _ in range(n_candidates):
            probs = np.exp(policy_logits[i]) / np.exp(policy_logits[i]).sum()
            sample = np.random.choice(len(probs), p=probs)
            candidates.append(sample)
            candidate_logprobs.append(policy_logits[i][sample])

        # Score with reward model
        rewards = np.array([reward_fn(prompts[i], c) for c in candidates])

        # Select top-k by reward
        top_indices = np.argsort(rewards)[-top_k:]
        for idx in top_indices:
            all_selected_logits.append((i, candidates[idx]))
            all_rewards.append(rewards[idx])

    # Fine-tune: increase log-prob of selected high-reward responses
    updated_logits = policy_logits.copy()
    for (prompt_idx, token), reward in zip(all_selected_logits, all_rewards):
        updated_logits[prompt_idx][token] += lr * reward

    return {
        "updated_logits": updated_logits,
        "mean_reward": float(np.mean(all_rewards)),
        "n_selected": len(all_selected_logits),
    }

Explanation

  1. For each prompt, sample multiple candidate responses from the current policy.
  2. Score each candidate using a reward model.
  3. Select the top fraction of candidates ranked by reward.
  4. Fine-tune the policy by increasing the log-probability of the selected high-reward responses, weighted by their reward scores.
  5. Repeat this loop iteratively to progressively improve the policy.

Complexity

  • Time: O(P C V) where P is prompts, C is candidates, V is vocab size
  • Space: O(P * C) for candidate storage