Implement top-p (nucleus) sampling for language model decoding. Given a list of logits (unnormalized scores) from a language model and a threshold p, select the smallest set of tokens whose cumulative probability mass is at least p, then sample from that restricted distribution.
import numpy as np
def top_p_sampling(logits: list[float], p: float, temperature: float = 1.0) -> int:
logits = np.array(logits) / temperature
probs = np.exp(logits - np.max(logits))
probs = probs / probs.sum()
sorted_indices = np.argsort(-probs)
sorted_probs = probs[sorted_indices]
cumulative_probs = np.cumsum(sorted_probs)
# Find the smallest set whose cumulative prob >= p
cutoff_index = np.searchsorted(cumulative_probs, p)
cutoff_index = min(cutoff_index + 1, len(sorted_probs))
top_indices = sorted_indices[:cutoff_index]
top_probs = probs[top_indices]
top_probs = top_probs / top_probs.sum()
return int(np.random.choice(top_indices, p=top_probs))p.