Implement the verification step of speculative decoding. A draft model generates candidate tokens quickly, and a target model verifies them. For each draft token, compare the target and draft probabilities to decide whether to accept or reject, maintaining the exact target distribution.
import numpy as np
def speculative_decoding_verify(
draft_tokens: list[int],
draft_probs: list[np.ndarray],
target_probs: list[np.ndarray],
vocab_size: int
) -> list[int]:
accepted_tokens = []
for i, token in enumerate(draft_tokens):
p = target_probs[i][token]
q = draft_probs[i][token]
# Accept with probability min(1, p/q)
r = np.random.random()
if r < min(1.0, p / (q + 1e-10)):
accepted_tokens.append(token)
else:
# Reject: sample from adjusted distribution max(0, p - q) normalized
adjusted = np.maximum(target_probs[i] - draft_probs[i], 0.0)
total = adjusted.sum()
if total < 1e-10:
# Fallback: sample from target
new_token = int(np.random.choice(vocab_size, p=target_probs[i]))
else:
adjusted = adjusted / total
new_token = int(np.random.choice(vocab_size, p=adjusted))
accepted_tokens.append(new_token)
break # Stop verifying after first rejection
# If all accepted, sample one more token from target at position len(draft_tokens)
if len(accepted_tokens) == len(draft_tokens) and len(target_probs) > len(draft_tokens):
new_token = int(np.random.choice(vocab_size, p=target_probs[len(draft_tokens)]))
accepted_tokens.append(new_token)
return accepted_tokensmin(1, p_target / p_draft).max(0, p_target - p_draft) normalized. This ensures the overall distribution matches the target model exactly.