#232 · Deep Learning · Medium
⊣ Solve on deep-ml.comImplement the PTX (Pretraining Loss) term used in RLHF to prevent catastrophic forgetting. During RL fine-tuning, a fraction of pretraining data is mixed in and the standard language modeling loss is computed on it alongside the RL objective.
import math
def ptx_loss(
rl_loss: float,
pretrain_logits: list[list[float]],
pretrain_targets: list[int],
ptx_coeff: float,
) -> float:
"""
Combined loss = rl_loss + ptx_coeff * LM_loss(pretrain_data)
pretrain_logits: [seq_len, vocab_size]
pretrain_targets: [seq_len] token IDs
"""
seq_len = len(pretrain_logits)
if seq_len == 0:
return round(rl_loss, 6)
# Compute cross-entropy language modeling loss on pretraining data
lm_loss = 0.0
for t in range(seq_len):
logits = pretrain_logits[t]
target = pretrain_targets[t]
# Softmax + log for numerical stability
max_l = max(logits)
log_sum_exp = max_l + math.log(sum(math.exp(x - max_l) for x in logits))
log_prob = logits[target] - log_sum_exp
lm_loss -= log_prob
lm_loss /= seq_len
total_loss = rl_loss + ptx_coeff * lm_loss
return round(total_loss, 6)L_total = L_rl + ptx_coeff * L_pretrain.