Implement MMLU (Massive Multitask Language Understanding) letter-matching evaluation. Given a model's text output and the correct answer letter (A/B/C/D), extract the model's predicted letter from the generated text and check if it matches.
import re
from typing import Dict, List
def extract_answer_letter(text: str) -> str:
# Try patterns from most to least specific
patterns = [
r'[Tt]he\s+answer\s+is\s+\(?([A-D])\)?',
r'[Aa]nswer:\s*\(?([A-D])\)?',
r'\b([A-D])\)\s',
r'\b([A-D])\.',
r'^\s*([A-D])\s*$',
]
for pattern in patterns:
match = re.search(pattern, text)
if match:
return match.group(1).upper()
# Fallback: find any standalone A-D letter
match = re.search(r'\b([A-D])\b', text)
if match:
return match.group(1).upper()
return ""
def mmlu_letter_eval(
model_outputs: List[str],
correct_answers: List[str]
) -> Dict:
if len(model_outputs) != len(correct_answers):
raise ValueError("Lists must have equal length")
correct = 0
predictions = []
for output, answer in zip(model_outputs, correct_answers):
pred = extract_answer_letter(output)
predictions.append(pred)
if pred == answer.upper():
correct += 1
total = len(model_outputs)
return {
"accuracy": round(correct / total, 4) if total > 0 else 0.0,
"correct": correct,
"total": total,
"predictions": predictions
}