#379 · Reinforcement Learning · Medium
⊣ Solve on deep-ml.comImplement RAFT (Reward-rAnked Fine-Tuning), an iterative loop that generates candidate responses, ranks them with a reward model, selects the top-ranked responses, and fine-tunes the policy model on those selected responses.
import numpy as np
def raft_iteration(
policy_logits: np.ndarray,
prompts: list,
reward_fn,
n_candidates: int = 8,
top_fraction: float = 0.25,
lr: float = 1e-3
) -> dict:
n_prompts = len(prompts)
top_k = max(1, int(n_candidates * top_fraction))
all_selected_logits = []
all_rewards = []
for i in range(n_prompts):
# Generate candidates by sampling from policy
candidates = []
candidate_logprobs = []
for _ in range(n_candidates):
probs = np.exp(policy_logits[i]) / np.exp(policy_logits[i]).sum()
sample = np.random.choice(len(probs), p=probs)
candidates.append(sample)
candidate_logprobs.append(policy_logits[i][sample])
# Score with reward model
rewards = np.array([reward_fn(prompts[i], c) for c in candidates])
# Select top-k by reward
top_indices = np.argsort(rewards)[-top_k:]
for idx in top_indices:
all_selected_logits.append((i, candidates[idx]))
all_rewards.append(rewards[idx])
# Fine-tune: increase log-prob of selected high-reward responses
updated_logits = policy_logits.copy()
for (prompt_idx, token), reward in zip(all_selected_logits, all_rewards):
updated_logits[prompt_idx][token] += lr * reward
return {
"updated_logits": updated_logits,
"mean_reward": float(np.mean(all_rewards)),
"n_selected": len(all_selected_logits),
}