#430 · Machine Learning · Medium
⊣ Solve on deep-ml.comSimulate draft-target speculative decoding. A small draft model generates K candidate tokens quickly, then the larger target model verifies them all in a single forward pass. Tokens are accepted from left to right as long as they match the target model's distribution. If a token is rejected, sampling resumes from the target model at that position.
import random
import math
def speculative_decode_sim(
draft_probs: list[list[float]],
target_probs: list[list[float]],
draft_tokens: list[int],
vocab_size: int
) -> dict:
"""
draft_probs[i]: draft model probability distribution at position i
target_probs[i]: target model probability distribution at position i
draft_tokens[i]: token sampled by draft model at position i
"""
K = len(draft_tokens)
accepted = []
for i in range(K):
token = draft_tokens[i]
p = target_probs[i][token] # target prob
q = draft_probs[i][token] # draft prob
if q == 0:
# Draft assigned zero prob; reject
break
acceptance_ratio = min(1.0, p / q)
r = random.random()
if r < acceptance_ratio:
accepted.append(token)
else:
# Reject: sample from adjusted distribution
# p'(x) = max(0, p(x) - q(x)) / Z
adjusted = [max(0.0, target_probs[i][v] - draft_probs[i][v]) for v in range(vocab_size)]
z = sum(adjusted)
if z > 0:
adjusted = [a / z for a in adjusted]
else:
adjusted = [1.0 / vocab_size] * vocab_size
# Sample from adjusted
r2 = random.random()
cum = 0.0
correction_token = 0
for v in range(vocab_size):
cum += adjusted[v]
if r2 <= cum:
correction_token = v
break
accepted.append(correction_token)
break
# If all K accepted, sample one more from target at position K
if len(accepted) == K and len(target_probs) > K:
r = random.random()
cum = 0.0
bonus_token = 0
for v in range(vocab_size):
cum += target_probs[K][v]
if r <= cum:
bonus_token = v
break
accepted.append(bonus_token)
return {
"num_draft_tokens": K,
"num_accepted": len(accepted),
"acceptance_rate": round(len(accepted) / (K + 1), 4) if K > 0 else 0,
"tokens": accepted,
"speedup_over_sequential": round(len(accepted) / 2, 2) # 1 draft + 1 verify step
}min(1, p_target / p_draft). This ensures the output distribution exactly matches the target model.max(0, p_target - p_draft) normalized, guaranteeing correctness.(accepted_tokens) / (1 draft pass + 1 verify pass), which can be 2-3x for well-matched draft models.