← back

Byte Pair Encoding (BPE) Tokenizer

#380 · NLP · Medium

⊣ Solve on deep-ml.com

Problem

Implement Byte Pair Encoding (BPE) tokenization. Starting from individual characters, iteratively merge the most frequent pair of adjacent tokens to build a vocabulary of subword units. Then use the learned merges to tokenize new text.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from collections import Counter

def train_bpe(corpus: list[str], num_merges: int) -> list[tuple[str, str]]:
    # Initialize: split each word into characters with end-of-word marker
    vocab = {}
    for word in corpus:
        chars = tuple(list(word) + ["</w>"])
        vocab[chars] = vocab.get(chars, 0) + 1

    merges = []
    for _ in range(num_merges):
        # Count all adjacent pairs
        pairs = Counter()
        for word, freq in vocab.items():
            for i in range(len(word) - 1):
                pairs[(word[i], word[i + 1])] += freq

        if not pairs:
            break

        # Find most frequent pair
        best = max(pairs, key=pairs.get)
        merges.append(best)

        # Merge the best pair in all words
        new_vocab = {}
        for word, freq in vocab.items():
            new_word = []
            i = 0
            while i < len(word):
                if i < len(word) - 1 and word[i] == best[0] and word[i + 1] == best[1]:
                    new_word.append(best[0] + best[1])
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_vocab[tuple(new_word)] = freq
        vocab = new_vocab

    return merges

def tokenize_bpe(text: str, merges: list[tuple[str, str]]) -> list[str]:
    tokens = list(text) + ["</w>"]
    for a, b in merges:
        new_tokens = []
        i = 0
        while i < len(tokens):
            if i < len(tokens) - 1 and tokens[i] == a and tokens[i + 1] == b:
                new_tokens.append(a + b)
                i += 2
            else:
                new_tokens.append(tokens[i])
                i += 1
        tokens = new_tokens
    return tokens

Explanation

  1. Initialize the vocabulary by splitting each word into individual characters plus an end-of-word marker.
  2. Count all adjacent token pairs across the corpus, weighted by word frequency.
  3. Merge the most frequent pair into a single new token everywhere it occurs.
  4. Repeat for the desired number of merges, building a list of merge rules.
  5. To tokenize new text, apply the learned merge rules in order.

Complexity

  • Time: O(num_merges V L) where V is vocab size and L is average token sequence length
  • Space: O(V * L) for the vocabulary