Implement an exact match scoring function with text normalization. Given a predicted answer and a reference answer, normalize both (lowercasing, removing articles, punctuation, and extra whitespace) and check for an exact match.
import re
import string
def normalize_text(text: str) -> str:
# Lowercase
text = text.lower()
# Remove punctuation
text = text.translate(str.maketrans('', '', string.punctuation))
# Remove articles
text = re.sub(r'\b(a|an|the)\b', ' ', text)
# Collapse whitespace
text = ' '.join(text.split())
return text.strip()
def exact_match_score(prediction: str, reference: str) -> float:
return 1.0 if normalize_text(prediction) == normalize_text(reference) else 0.0
def batch_exact_match(
predictions: list,
references: list
) -> dict:
if len(predictions) != len(references):
raise ValueError("Lists must have equal length")
scores = [exact_match_score(p, r) for p, r in zip(predictions, references)]
return {
"scores": scores,
"average": sum(scores) / len(scores) if scores else 0.0,
"total_matches": int(sum(scores)),
"total": len(scores)
}