Implement an end-to-end simulation of speculative decoding, where a small draft model generates candidate tokens speculatively and a larger target model verifies them in parallel. This achieves the same output distribution as the target model but with fewer target model calls.
import numpy as np
def speculative_decode(
draft_model_fn,
target_model_fn,
prompt_tokens: list[int],
max_tokens: int,
num_speculative: int,
vocab_size: int
) -> list[int]:
generated = list(prompt_tokens)
tokens_generated = 0
while tokens_generated < max_tokens:
# Step 1: Draft model generates K candidate tokens autoregressively
draft_tokens = []
draft_probs = []
context = list(generated)
for _ in range(num_speculative):
logits = draft_model_fn(context)
probs = _softmax(logits)
token = int(np.random.choice(vocab_size, p=probs))
draft_tokens.append(token)
draft_probs.append(probs.copy())
context.append(token)
# Step 2: Target model scores all positions in one forward pass
target_probs_list = []
for i in range(len(draft_tokens) + 1):
ctx = generated + draft_tokens[:i]
target_probs_list.append(target_model_fn(ctx))
# Step 3: Verify each draft token
num_accepted = 0
for i in range(len(draft_tokens)):
token = draft_tokens[i]
p = target_probs_list[i][token]
q = draft_probs[i][token]
r = np.random.random()
if r < min(1.0, p / (q + 1e-10)):
generated.append(token)
num_accepted += 1
tokens_generated += 1
if tokens_generated >= max_tokens:
break
else:
# Reject: sample from adjusted distribution
adjusted = np.maximum(target_probs_list[i] - draft_probs[i], 0.0)
total = adjusted.sum()
if total < 1e-10:
new_token = int(np.random.choice(vocab_size, p=target_probs_list[i]))
else:
adjusted = adjusted / total
new_token = int(np.random.choice(vocab_size, p=adjusted))
generated.append(new_token)
tokens_generated += 1
break
# If all accepted and budget remains, sample bonus token
if num_accepted == len(draft_tokens) and tokens_generated < max_tokens:
bonus_probs = target_probs_list[len(draft_tokens)]
bonus_token = int(np.random.choice(vocab_size, p=bonus_probs))
generated.append(bonus_token)
tokens_generated += 1
return generated[len(prompt_tokens):]
def _softmax(logits: np.ndarray) -> np.ndarray:
logits = np.array(logits, dtype=np.float64)
exp_logits = np.exp(logits - np.max(logits))
return exp_logits / exp_logits.sum()min(1, p_target / p_draft). On rejection, sample a correction from max(0, p_target - p_draft) normalized.