Implement the BLEU (Bilingual Evaluation Understudy) score for evaluating text generation quality. BLEU compares n-gram overlaps between a candidate translation and one or more reference translations, applying a brevity penalty.
import math
from collections import Counter
from typing import List
def get_ngrams(tokens: List[str], n: int) -> Counter:
return Counter(tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1))
def bleu_score(
candidate: List[str],
references: List[List[str]],
max_n: int = 4,
weights: List[float] = None
) -> float:
if weights is None:
weights = [1.0 / max_n] * max_n
# Brevity penalty
c = len(candidate)
ref_lens = [len(r) for r in references]
r = min(ref_lens, key=lambda rl: (abs(rl - c), rl))
if c == 0:
return 0.0
bp = math.exp(1 - r / c) if c < r else 1.0
# Modified precision for each n-gram order
log_avg = 0.0
for n in range(1, max_n + 1):
cand_ngrams = get_ngrams(candidate, n)
if not cand_ngrams:
return 0.0
max_ref_counts = Counter()
for ref in references:
ref_ngrams = get_ngrams(ref, n)
for ng in cand_ngrams:
max_ref_counts[ng] = max(max_ref_counts[ng], ref_ngrams[ng])
clipped = sum(min(cand_ngrams[ng], max_ref_counts[ng]) for ng in cand_ngrams)
total = sum(cand_ngrams.values())
precision = clipped / total if total > 0 else 0
if precision == 0:
return 0.0
log_avg += weights[n - 1] * math.log(precision)
return bp * math.exp(log_avg)