← back

PTX Loss for Catastrophic Forgetting Prevention (RLHF)

#232 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement the PTX (Pretraining Loss) term used in RLHF to prevent catastrophic forgetting. During RL fine-tuning, a fraction of pretraining data is mixed in and the standard language modeling loss is computed on it alongside the RL objective.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import math

def ptx_loss(
    rl_loss: float,
    pretrain_logits: list[list[float]],
    pretrain_targets: list[int],
    ptx_coeff: float,
) -> float:
    """
    Combined loss = rl_loss + ptx_coeff * LM_loss(pretrain_data)
    pretrain_logits: [seq_len, vocab_size]
    pretrain_targets: [seq_len] token IDs
    """
    seq_len = len(pretrain_logits)
    if seq_len == 0:
        return round(rl_loss, 6)

    # Compute cross-entropy language modeling loss on pretraining data
    lm_loss = 0.0
    for t in range(seq_len):
        logits = pretrain_logits[t]
        target = pretrain_targets[t]
        # Softmax + log for numerical stability
        max_l = max(logits)
        log_sum_exp = max_l + math.log(sum(math.exp(x - max_l) for x in logits))
        log_prob = logits[target] - log_sum_exp
        lm_loss -= log_prob

    lm_loss /= seq_len

    total_loss = rl_loss + ptx_coeff * lm_loss
    return round(total_loss, 6)

Explanation

  1. During RLHF, the model is fine-tuned with RL rewards, which can cause it to forget pretraining knowledge.
  2. The PTX loss mixes in a standard language modeling cross-entropy loss on a small batch of pretraining data.
  3. For each token position, compute the log-probability of the correct token using the log-softmax trick.
  4. The total loss is: L_total = L_rl + ptx_coeff * L_pretrain.
  5. A typical ptx_coeff is small (e.g., 0.01-0.1) to gently regularize without dominating the RL signal.

Complexity

  • Time: O(seq_len * vocab_size)
  • Space: O(vocab_size) for softmax computation