Implement a budget-constrained RL loss that incorporates a penalty term when the inference cost (e.g., number of tokens generated) exceeds a given budget. This encourages the model to produce efficient solutions.
def budget_constrained_rl_loss(
rewards: list[float],
costs: list[float],
log_probs: list[float],
budget: float,
lambda_cost: float,
) -> float:
n = len(rewards)
if n == 0:
return 0.0
# Compute advantages with cost penalty
mean_reward = sum(rewards) / n
adjusted_rewards = []
for r, c in zip(rewards, costs):
penalty = lambda_cost * max(0.0, c - budget)
adjusted_rewards.append(r - penalty)
mean_adj = sum(adjusted_rewards) / n
std_adj = (sum((a - mean_adj) ** 2 for a in adjusted_rewards) / n) ** 0.5
if std_adj < 1e-8:
std_adj = 1.0
advantages = [(a - mean_adj) / std_adj for a in adjusted_rewards]
# Policy gradient loss: -E[advantage * log_prob]
loss = -sum(a * lp for a, lp in zip(advantages, log_probs)) / n
return round(loss, 6)r_adjusted = reward - lambda * max(0, cost - budget).-E[A * log_prob].lambda_cost controls how strongly the budget constraint is enforced.