← back

Implement Stratified Train-Test Split

#275 · Machine Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement a Stratified Train-Test Split — split a dataset into training and test sets while preserving the proportion of each class label in both sets.

Solution

Group samples by class, randomly shuffle each group, and split each group according to the desired ratio.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import random

def stratified_split(
    X: list[list[float]],
    y: list[int],
    test_size: float = 0.2,
    random_seed: int | None = None,
) -> dict:
    if random_seed is not None:
        random.seed(random_seed)

    n = len(X)

    # Group indices by class
    class_indices: dict[int, list[int]] = {}
    for i, label in enumerate(y):
        class_indices.setdefault(label, []).append(i)

    train_indices: list[int] = []
    test_indices: list[int] = []

    for label in sorted(class_indices.keys()):
        indices = class_indices[label]
        random.shuffle(indices)
        n_test = max(1, round(len(indices) * test_size))
        n_test = min(n_test, len(indices) - 1)  # ensure at least 1 in train
        test_indices.extend(indices[:n_test])
        train_indices.extend(indices[n_test:])

    # Sort for reproducibility
    train_indices.sort()
    test_indices.sort()

    X_train = [X[i] for i in train_indices]
    y_train = [y[i] for i in train_indices]
    X_test = [X[i] for i in test_indices]
    y_test = [y[i] for i in test_indices]

    return {
        "X_train": X_train,
        "y_train": y_train,
        "X_test": X_test,
        "y_test": y_test,
        "train_indices": train_indices,
        "test_indices": test_indices,
    }

Explanation

  1. Group all sample indices by their class label.
  2. Shuffle each group independently.
  3. For each class, allocate test_size fraction of samples to the test set and the rest to the training set.
  4. This ensures that class proportions in both train and test sets mirror the original distribution.
  5. Guarantees at least one sample per class in both splits (when possible).

Complexity

  • Time: O(n) for grouping and splitting
  • Space: O(n) for storing indices and split data