← back

Draft-Target Speculative Decoding Simulation

#430 · Machine Learning · Medium

⊣ Solve on deep-ml.com

Problem

Simulate draft-target speculative decoding. A small draft model generates K candidate tokens quickly, then the larger target model verifies them all in a single forward pass. Tokens are accepted from left to right as long as they match the target model's distribution. If a token is rejected, sampling resumes from the target model at that position.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import random
import math

def speculative_decode_sim(
    draft_probs: list[list[float]],
    target_probs: list[list[float]],
    draft_tokens: list[int],
    vocab_size: int
) -> dict:
    """
    draft_probs[i]: draft model probability distribution at position i
    target_probs[i]: target model probability distribution at position i
    draft_tokens[i]: token sampled by draft model at position i
    """
    K = len(draft_tokens)
    accepted = []

    for i in range(K):
        token = draft_tokens[i]
        p = target_probs[i][token]  # target prob
        q = draft_probs[i][token]   # draft prob

        if q == 0:
            # Draft assigned zero prob; reject
            break

        acceptance_ratio = min(1.0, p / q)
        r = random.random()
        if r < acceptance_ratio:
            accepted.append(token)
        else:
            # Reject: sample from adjusted distribution
            # p'(x) = max(0, p(x) - q(x)) / Z
            adjusted = [max(0.0, target_probs[i][v] - draft_probs[i][v]) for v in range(vocab_size)]
            z = sum(adjusted)
            if z > 0:
                adjusted = [a / z for a in adjusted]
            else:
                adjusted = [1.0 / vocab_size] * vocab_size
            # Sample from adjusted
            r2 = random.random()
            cum = 0.0
            correction_token = 0
            for v in range(vocab_size):
                cum += adjusted[v]
                if r2 <= cum:
                    correction_token = v
                    break
            accepted.append(correction_token)
            break

    # If all K accepted, sample one more from target at position K
    if len(accepted) == K and len(target_probs) > K:
        r = random.random()
        cum = 0.0
        bonus_token = 0
        for v in range(vocab_size):
            cum += target_probs[K][v]
            if r <= cum:
                bonus_token = v
                break
        accepted.append(bonus_token)

    return {
        "num_draft_tokens": K,
        "num_accepted": len(accepted),
        "acceptance_rate": round(len(accepted) / (K + 1), 4) if K > 0 else 0,
        "tokens": accepted,
        "speedup_over_sequential": round(len(accepted) / 2, 2)  # 1 draft + 1 verify step
    }

Explanation

  1. Speculative decoding uses a small, fast draft model to propose K tokens, then verifies them with the large target model in one parallel forward pass.
  2. Each draft token is accepted with probability min(1, p_target / p_draft). This ensures the output distribution exactly matches the target model.
  3. On rejection, a correction token is sampled from the residual distribution max(0, p_target - p_draft) normalized, guaranteeing correctness.
  4. If all K tokens are accepted, a bonus (K+1)th token is sampled from the target model's next position.
  5. The expected speedup is (accepted_tokens) / (1 draft pass + 1 verify pass), which can be 2-3x for well-matched draft models.

Complexity

  • Time: O(K * V) where K is draft length and V is vocabulary size
  • Space: O(V) for the adjusted distribution