Implement temperature sampling for language model text generation. Given a logits vector from a model, apply temperature scaling and sample from the resulting probability distribution. Higher temperature produces more random outputs; lower temperature makes outputs more deterministic.
import numpy as np
def temperature_sample(
logits: np.ndarray,
temperature: float = 1.0,
top_k: int = 0
) -> int:
if temperature <= 0:
return int(np.argmax(logits))
# Apply temperature
scaled_logits = logits / temperature
# Optional top-k filtering
if top_k > 0:
top_k = min(top_k, len(scaled_logits))
indices = np.argpartition(scaled_logits, -top_k)[-top_k:]
mask = np.full_like(scaled_logits, -np.inf)
mask[indices] = scaled_logits[indices]
scaled_logits = mask
# Softmax
exp_logits = np.exp(scaled_logits - np.max(scaled_logits))
probs = exp_logits / exp_logits.sum()
# Sample
return int(np.random.choice(len(probs), p=probs))
def generate(logits_fn, prompt_tokens: list[int], max_len: int, temperature: float = 1.0) -> list[int]:
tokens = prompt_tokens[:]
for _ in range(max_len):
logits = logits_fn(tokens)
next_token = temperature_sample(logits, temperature)
tokens.append(next_token)
return tokens