← back

Implement Speculative Decoding Verification

#394 · Deep Learning · Hard

⊣ Solve on deep-ml.com

Problem

Implement the verification step of speculative decoding. A draft model generates candidate tokens quickly, and a target model verifies them. For each draft token, compare the target and draft probabilities to decide whether to accept or reject, maintaining the exact target distribution.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import numpy as np

def speculative_decoding_verify(
    draft_tokens: list[int],
    draft_probs: list[np.ndarray],
    target_probs: list[np.ndarray],
    vocab_size: int
) -> list[int]:
    accepted_tokens = []

    for i, token in enumerate(draft_tokens):
        p = target_probs[i][token]
        q = draft_probs[i][token]

        # Accept with probability min(1, p/q)
        r = np.random.random()
        if r < min(1.0, p / (q + 1e-10)):
            accepted_tokens.append(token)
        else:
            # Reject: sample from adjusted distribution max(0, p - q) normalized
            adjusted = np.maximum(target_probs[i] - draft_probs[i], 0.0)
            total = adjusted.sum()
            if total < 1e-10:
                # Fallback: sample from target
                new_token = int(np.random.choice(vocab_size, p=target_probs[i]))
            else:
                adjusted = adjusted / total
                new_token = int(np.random.choice(vocab_size, p=adjusted))
            accepted_tokens.append(new_token)
            break  # Stop verifying after first rejection

    # If all accepted, sample one more token from target at position len(draft_tokens)
    if len(accepted_tokens) == len(draft_tokens) and len(target_probs) > len(draft_tokens):
        new_token = int(np.random.choice(vocab_size, p=target_probs[len(draft_tokens)]))
        accepted_tokens.append(new_token)

    return accepted_tokens

Explanation

  1. For each draft token, compute the acceptance probability as min(1, p_target / p_draft).
  2. If accepted, keep the token and move to the next.
  3. If rejected, sample a correction token from the adjusted distribution max(0, p_target - p_draft) normalized. This ensures the overall distribution matches the target model exactly.
  4. Stop verifying after the first rejection since subsequent draft tokens are conditioned on the rejected one.
  5. If all draft tokens are accepted, sample one bonus token from the target model's distribution.

Complexity

  • Time: O(K * V) where K is the number of draft tokens and V is vocabulary size
  • Space: O(V) for the adjusted probability distribution