Implement Byte Pair Encoding (BPE) tokenization. Starting from individual characters, iteratively merge the most frequent pair of adjacent tokens to build a vocabulary of subword units. Then use the learned merges to tokenize new text.
from collections import Counter
def train_bpe(corpus: list[str], num_merges: int) -> list[tuple[str, str]]:
# Initialize: split each word into characters with end-of-word marker
vocab = {}
for word in corpus:
chars = tuple(list(word) + ["</w>"])
vocab[chars] = vocab.get(chars, 0) + 1
merges = []
for _ in range(num_merges):
# Count all adjacent pairs
pairs = Counter()
for word, freq in vocab.items():
for i in range(len(word) - 1):
pairs[(word[i], word[i + 1])] += freq
if not pairs:
break
# Find most frequent pair
best = max(pairs, key=pairs.get)
merges.append(best)
# Merge the best pair in all words
new_vocab = {}
for word, freq in vocab.items():
new_word = []
i = 0
while i < len(word):
if i < len(word) - 1 and word[i] == best[0] and word[i + 1] == best[1]:
new_word.append(best[0] + best[1])
i += 2
else:
new_word.append(word[i])
i += 1
new_vocab[tuple(new_word)] = freq
vocab = new_vocab
return merges
def tokenize_bpe(text: str, merges: list[tuple[str, str]]) -> list[str]:
tokens = list(text) + ["</w>"]
for a, b in merges:
new_tokens = []
i = 0
while i < len(tokens):
if i < len(tokens) - 1 and tokens[i] == a and tokens[i + 1] == b:
new_tokens.append(a + b)
i += 2
else:
new_tokens.append(tokens[i])
i += 1
tokens = new_tokens
return tokens