← back

Beam Search Decoding

#385 · Machine Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement beam search decoding for a sequence model. Given a function that returns next-token log-probabilities and a beam width k, maintain the top-k most likely partial sequences at each step and return the best complete sequence.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np

def beam_search(log_prob_fn, beam_width: int, max_len: int, start_token: int, end_token: int, vocab_size: int) -> list[int]:
    # Each beam: (log_probability, sequence)
    beams = [(0.0, [start_token])]
    completed = []

    for _ in range(max_len):
        candidates = []
        for score, seq in beams:
            if seq[-1] == end_token:
                completed.append((score, seq))
                continue
            log_probs = log_prob_fn(seq)  # shape: (vocab_size,)
            for token in range(vocab_size):
                candidates.append((score + log_probs[token], seq + [token]))

        if not candidates:
            break

        # Keep top-k candidates
        candidates.sort(key=lambda x: x[0], reverse=True)
        beams = candidates[:beam_width]

    # Add any remaining beams to completed
    completed.extend(beams)
    completed.sort(key=lambda x: x[0], reverse=True)
    return completed[0][1]

Explanation

  1. Start with a single beam containing only the start token and zero log-probability.
  2. At each step, expand every active beam by all possible next tokens, scoring each extension.
  3. Keep only the top-k (beam width) candidates by cumulative log-probability.
  4. When a beam produces the end token, move it to completed sequences.
  5. After reaching max length, return the highest-scoring sequence.

Complexity

  • Time: O(L k V) where L is max sequence length, k is beam width, V is vocabulary size
  • Space: O(k * L) for storing beams