← back

Temperature Sampling

#378 · Machine Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement temperature sampling for language model text generation. Given a logits vector from a model, apply temperature scaling and sample from the resulting probability distribution. Higher temperature produces more random outputs; lower temperature makes outputs more deterministic.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np

def temperature_sample(
    logits: np.ndarray,
    temperature: float = 1.0,
    top_k: int = 0
) -> int:
    if temperature <= 0:
        return int(np.argmax(logits))

    # Apply temperature
    scaled_logits = logits / temperature

    # Optional top-k filtering
    if top_k > 0:
        top_k = min(top_k, len(scaled_logits))
        indices = np.argpartition(scaled_logits, -top_k)[-top_k:]
        mask = np.full_like(scaled_logits, -np.inf)
        mask[indices] = scaled_logits[indices]
        scaled_logits = mask

    # Softmax
    exp_logits = np.exp(scaled_logits - np.max(scaled_logits))
    probs = exp_logits / exp_logits.sum()

    # Sample
    return int(np.random.choice(len(probs), p=probs))

def generate(logits_fn, prompt_tokens: list[int], max_len: int, temperature: float = 1.0) -> list[int]:
    tokens = prompt_tokens[:]
    for _ in range(max_len):
        logits = logits_fn(tokens)
        next_token = temperature_sample(logits, temperature)
        tokens.append(next_token)
    return tokens

Explanation

  1. Divide logits by temperature: high temperature (>1) flattens the distribution; low temperature (<1) sharpens it.
  2. Optionally apply top-k filtering to restrict sampling to the k most probable tokens.
  3. Apply softmax to get a valid probability distribution.
  4. Sample from the distribution. Temperature 0 degenerates to greedy (argmax) decoding.

Complexity

  • Time: O(V) where V is the vocabulary size
  • Space: O(V) for the probability distribution