← back

Speculative Decoding End-to-End Simulation

#410 · LLM · Hard

⊣ Solve on deep-ml.com

Problem

Implement an end-to-end simulation of speculative decoding, where a small draft model generates candidate tokens speculatively and a larger target model verifies them in parallel. This achieves the same output distribution as the target model but with fewer target model calls.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import numpy as np

def speculative_decode(
    draft_model_fn,
    target_model_fn,
    prompt_tokens: list[int],
    max_tokens: int,
    num_speculative: int,
    vocab_size: int
) -> list[int]:
    generated = list(prompt_tokens)
    tokens_generated = 0

    while tokens_generated < max_tokens:
        # Step 1: Draft model generates K candidate tokens autoregressively
        draft_tokens = []
        draft_probs = []
        context = list(generated)
        for _ in range(num_speculative):
            logits = draft_model_fn(context)
            probs = _softmax(logits)
            token = int(np.random.choice(vocab_size, p=probs))
            draft_tokens.append(token)
            draft_probs.append(probs.copy())
            context.append(token)

        # Step 2: Target model scores all positions in one forward pass
        target_probs_list = []
        for i in range(len(draft_tokens) + 1):
            ctx = generated + draft_tokens[:i]
            target_probs_list.append(target_model_fn(ctx))

        # Step 3: Verify each draft token
        num_accepted = 0
        for i in range(len(draft_tokens)):
            token = draft_tokens[i]
            p = target_probs_list[i][token]
            q = draft_probs[i][token]

            r = np.random.random()
            if r < min(1.0, p / (q + 1e-10)):
                generated.append(token)
                num_accepted += 1
                tokens_generated += 1
                if tokens_generated >= max_tokens:
                    break
            else:
                # Reject: sample from adjusted distribution
                adjusted = np.maximum(target_probs_list[i] - draft_probs[i], 0.0)
                total = adjusted.sum()
                if total < 1e-10:
                    new_token = int(np.random.choice(vocab_size, p=target_probs_list[i]))
                else:
                    adjusted = adjusted / total
                    new_token = int(np.random.choice(vocab_size, p=adjusted))
                generated.append(new_token)
                tokens_generated += 1
                break

        # If all accepted and budget remains, sample bonus token
        if num_accepted == len(draft_tokens) and tokens_generated < max_tokens:
            bonus_probs = target_probs_list[len(draft_tokens)]
            bonus_token = int(np.random.choice(vocab_size, p=bonus_probs))
            generated.append(bonus_token)
            tokens_generated += 1

    return generated[len(prompt_tokens):]


def _softmax(logits: np.ndarray) -> np.ndarray:
    logits = np.array(logits, dtype=np.float64)
    exp_logits = np.exp(logits - np.max(logits))
    return exp_logits / exp_logits.sum()

Explanation

  1. Draft phase: the small, fast draft model generates K candidate tokens autoregressively, storing each token and its probability.
  2. Verification phase: the large target model scores all K+1 positions in parallel (one forward pass). In practice, this is done via a single batched forward pass, but here we simulate it.
  3. Accept/reject: for each draft token, accept with probability min(1, p_target / p_draft). On rejection, sample a correction from max(0, p_target - p_draft) normalized.
  4. Bonus token: if all K tokens are accepted, sample one additional token from the target model at position K+1.
  5. Guarantee: the output distribution is mathematically identical to sampling directly from the target model. The speedup comes from accepting multiple tokens per target model call.

Complexity

  • Time: O(K * C_draft + C_target) per iteration, where K is speculation length, C_draft and C_target are model costs
  • Space: O(K * V) for storing draft probabilities