← back

Top-p (Nucleus) Sampling

#383 · LLM · Medium

⊣ Solve on deep-ml.com

Problem

Implement top-p (nucleus) sampling for language model decoding. Given a list of logits (unnormalized scores) from a language model and a threshold p, select the smallest set of tokens whose cumulative probability mass is at least p, then sample from that restricted distribution.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import numpy as np

def top_p_sampling(logits: list[float], p: float, temperature: float = 1.0) -> int:
    logits = np.array(logits) / temperature
    probs = np.exp(logits - np.max(logits))
    probs = probs / probs.sum()

    sorted_indices = np.argsort(-probs)
    sorted_probs = probs[sorted_indices]
    cumulative_probs = np.cumsum(sorted_probs)

    # Find the smallest set whose cumulative prob >= p
    cutoff_index = np.searchsorted(cumulative_probs, p)
    cutoff_index = min(cutoff_index + 1, len(sorted_probs))

    top_indices = sorted_indices[:cutoff_index]
    top_probs = probs[top_indices]
    top_probs = top_probs / top_probs.sum()

    return int(np.random.choice(top_indices, p=top_probs))

Explanation

  1. Scale logits by temperature, then convert to probabilities via softmax (subtract max for numerical stability).
  2. Sort tokens by descending probability.
  3. Compute cumulative probabilities and find the cutoff index where cumulative mass first exceeds p.
  4. Restrict the distribution to only those top tokens, renormalize, and sample.

Complexity

  • Time: O(V log V) where V is the vocabulary size (due to sorting)
  • Space: O(V) for the probability and index arrays