Implement beam search decoding for a sequence model. Given a function that returns next-token log-probabilities and a beam width k, maintain the top-k most likely partial sequences at each step and return the best complete sequence.
import numpy as np
def beam_search(log_prob_fn, beam_width: int, max_len: int, start_token: int, end_token: int, vocab_size: int) -> list[int]:
# Each beam: (log_probability, sequence)
beams = [(0.0, [start_token])]
completed = []
for _ in range(max_len):
candidates = []
for score, seq in beams:
if seq[-1] == end_token:
completed.append((score, seq))
continue
log_probs = log_prob_fn(seq) # shape: (vocab_size,)
for token in range(vocab_size):
candidates.append((score + log_probs[token], seq + [token]))
if not candidates:
break
# Keep top-k candidates
candidates.sort(key=lambda x: x[0], reverse=True)
beams = candidates[:beam_width]
# Add any remaining beams to completed
completed.extend(beams)
completed.sort(key=lambda x: x[0], reverse=True)
return completed[0][1]