← back

Combined Token Sampling Pipeline (Temperature + Top-k + Top-p)

#419 · Inference · Hard

⊣ Solve on deep-ml.com

Problem

Implement a combined token sampling pipeline that applies temperature scaling, top-k filtering, and top-p (nucleus) filtering in sequence. Given raw logits from a language model, apply: (1) temperature scaling, (2) top-k filtering, (3) top-p filtering, (4) softmax to get probabilities, and (5) sample a token.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import math

def combined_sampling(
    logits: list[float],
    temperature: float = 1.0,
    top_k: int = 0,
    top_p: float = 1.0
) -> tuple[int, list[float]]:
    n = len(logits)

    # Step 1: Temperature scaling
    if temperature <= 0:
        # Greedy: return argmax
        max_idx = max(range(n), key=lambda i: logits[i])
        probs = [0.0] * n
        probs[max_idx] = 1.0
        return max_idx, probs
    scaled = [l / temperature for l in logits]

    # Step 2: Top-k filtering
    if top_k > 0 and top_k < n:
        indexed = sorted(enumerate(scaled), key=lambda x: -x[1])
        threshold = indexed[top_k - 1][1]
        scaled = [s if s >= threshold else -float('inf') for s in scaled]

    # Step 3: Top-p (nucleus) filtering
    if top_p < 1.0:
        indexed = sorted(enumerate(scaled), key=lambda x: -x[1])
        # Compute softmax for sorting order
        max_val = max(s for s in scaled if s != -float('inf'))
        exp_vals = []
        for idx, val in indexed:
            if val == -float('inf'):
                exp_vals.append((idx, 0.0))
            else:
                exp_vals.append((idx, math.exp(val - max_val)))
        total = sum(e for _, e in exp_vals)
        cumulative = 0.0
        keep_set = set()
        for idx, e in exp_vals:
            cumulative += e / total
            keep_set.add(idx)
            if cumulative >= top_p:
                break
        scaled = [s if i in keep_set else -float('inf') for i, s in enumerate(scaled)]

    # Step 4: Softmax
    valid = [s for s in scaled if s != -float('inf')]
    if not valid:
        probs = [1.0 / n] * n
    else:
        max_val = max(valid)
        exps = [math.exp(s - max_val) if s != -float('inf') else 0.0 for s in scaled]
        total = sum(exps)
        probs = [e / total for e in exps]

    # Step 5: Sample from distribution
    import random
    r = random.random()
    cumulative = 0.0
    for i, p in enumerate(probs):
        cumulative += p
        if r <= cumulative:
            return i, probs
    return n - 1, probs

Explanation

  1. Temperature controls randomness. Dividing logits by T > 1 flattens the distribution (more random); T < 1 sharpens it (more deterministic). T -> 0 approaches greedy decoding.
  2. Top-k keeps only the k highest-scored tokens, setting the rest to -inf so they get zero probability after softmax.
  3. Top-p (nucleus) sorts tokens by probability and keeps the smallest set whose cumulative probability exceeds p. This adapts the candidate set size dynamically.
  4. After filtering, softmax converts the remaining logits to a valid probability distribution.
  5. Sampling is done via inverse CDF: draw a uniform random number and walk through cumulative probabilities.

Complexity

  • Time: O(n log n) due to sorting for top-k and top-p
  • Space: O(n)