#419 · Inference · Hard
⊣ Solve on deep-ml.comImplement a combined token sampling pipeline that applies temperature scaling, top-k filtering, and top-p (nucleus) filtering in sequence. Given raw logits from a language model, apply: (1) temperature scaling, (2) top-k filtering, (3) top-p filtering, (4) softmax to get probabilities, and (5) sample a token.
import math
def combined_sampling(
logits: list[float],
temperature: float = 1.0,
top_k: int = 0,
top_p: float = 1.0
) -> tuple[int, list[float]]:
n = len(logits)
# Step 1: Temperature scaling
if temperature <= 0:
# Greedy: return argmax
max_idx = max(range(n), key=lambda i: logits[i])
probs = [0.0] * n
probs[max_idx] = 1.0
return max_idx, probs
scaled = [l / temperature for l in logits]
# Step 2: Top-k filtering
if top_k > 0 and top_k < n:
indexed = sorted(enumerate(scaled), key=lambda x: -x[1])
threshold = indexed[top_k - 1][1]
scaled = [s if s >= threshold else -float('inf') for s in scaled]
# Step 3: Top-p (nucleus) filtering
if top_p < 1.0:
indexed = sorted(enumerate(scaled), key=lambda x: -x[1])
# Compute softmax for sorting order
max_val = max(s for s in scaled if s != -float('inf'))
exp_vals = []
for idx, val in indexed:
if val == -float('inf'):
exp_vals.append((idx, 0.0))
else:
exp_vals.append((idx, math.exp(val - max_val)))
total = sum(e for _, e in exp_vals)
cumulative = 0.0
keep_set = set()
for idx, e in exp_vals:
cumulative += e / total
keep_set.add(idx)
if cumulative >= top_p:
break
scaled = [s if i in keep_set else -float('inf') for i, s in enumerate(scaled)]
# Step 4: Softmax
valid = [s for s in scaled if s != -float('inf')]
if not valid:
probs = [1.0 / n] * n
else:
max_val = max(valid)
exps = [math.exp(s - max_val) if s != -float('inf') else 0.0 for s in scaled]
total = sum(exps)
probs = [e / total for e in exps]
# Step 5: Sample from distribution
import random
r = random.random()
cumulative = 0.0
for i, p in enumerate(probs):
cumulative += p
if r <= cumulative:
return i, probs
return n - 1, probs