← back

MMLU Letter-Matching Evaluation

#326 · NLP · Medium

⊣ Solve on deep-ml.com

Problem

Implement MMLU (Massive Multitask Language Understanding) letter-matching evaluation. Given a model's text output and the correct answer letter (A/B/C/D), extract the model's predicted letter from the generated text and check if it matches.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import re
from typing import Dict, List

def extract_answer_letter(text: str) -> str:
    # Try patterns from most to least specific
    patterns = [
        r'[Tt]he\s+answer\s+is\s+\(?([A-D])\)?',
        r'[Aa]nswer:\s*\(?([A-D])\)?',
        r'\b([A-D])\)\s',
        r'\b([A-D])\.',
        r'^\s*([A-D])\s*$',
    ]
    for pattern in patterns:
        match = re.search(pattern, text)
        if match:
            return match.group(1).upper()

    # Fallback: find any standalone A-D letter
    match = re.search(r'\b([A-D])\b', text)
    if match:
        return match.group(1).upper()
    return ""

def mmlu_letter_eval(
    model_outputs: List[str],
    correct_answers: List[str]
) -> Dict:
    if len(model_outputs) != len(correct_answers):
        raise ValueError("Lists must have equal length")

    correct = 0
    predictions = []
    for output, answer in zip(model_outputs, correct_answers):
        pred = extract_answer_letter(output)
        predictions.append(pred)
        if pred == answer.upper():
            correct += 1

    total = len(model_outputs)
    return {
        "accuracy": round(correct / total, 4) if total > 0 else 0.0,
        "correct": correct,
        "total": total,
        "predictions": predictions
    }

Explanation

  1. For each model output, attempt to extract the answer letter using multiple regex patterns ordered from most specific ("The answer is A") to least specific (any standalone A-D).
  2. Compare each extracted prediction to the correct answer letter.
  3. Return accuracy (fraction correct), the count of correct predictions, and the list of extracted predictions.

Complexity

  • Time: O(B * n) where B is the number of outputs and n is the average output length
  • Space: O(B) for storing predictions