← back

Find the Best Gini-Based Split for a Binary Decision Tree

#138 · Machine Learning · Medium

⊣ Solve on deep-ml.com

Problem

Find the best binary split for a decision tree node using the Gini impurity criterion. Given feature values and labels, evaluate all possible splits and return the one with the lowest weighted Gini impurity.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np

def gini_impurity(labels: np.ndarray) -> float:
    if len(labels) == 0:
        return 0.0
    classes, counts = np.unique(labels, return_counts=True)
    probs = counts / len(labels)
    return 1.0 - np.sum(probs ** 2)

def best_gini_split(X: np.ndarray, y: np.ndarray) -> dict:
    n_samples, n_features = X.shape
    best_gini = float('inf')
    best_feature = None
    best_threshold = None

    for feature_idx in range(n_features):
        values = X[:, feature_idx]
        sorted_unique = np.unique(values)

        # Candidate thresholds: midpoints between consecutive unique values
        thresholds = (sorted_unique[:-1] + sorted_unique[1:]) / 2.0

        for threshold in thresholds:
            left_mask = values <= threshold
            right_mask = ~left_mask

            n_left = np.sum(left_mask)
            n_right = np.sum(right_mask)

            if n_left == 0 or n_right == 0:
                continue

            gini_left = gini_impurity(y[left_mask])
            gini_right = gini_impurity(y[right_mask])

            weighted_gini = (n_left * gini_left + n_right * gini_right) / n_samples

            if weighted_gini < best_gini:
                best_gini = weighted_gini
                best_feature = feature_idx
                best_threshold = threshold

    return {
        "feature": best_feature,
        "threshold": best_threshold,
        "gini": round(best_gini, 4)
    }

Explanation

  1. Gini impurity measures the probability that a random sample would be misclassified: 1 - sum(p_i^2).
  2. For each feature, compute candidate thresholds as midpoints between consecutive sorted unique values.
  3. For each threshold, split data into left (<=) and right (>) subsets.
  4. Compute the weighted Gini impurity: (n_left * gini_left + n_right * gini_right) / n_total.
  5. Return the feature and threshold that minimize the weighted Gini.

Complexity

  • Time: O(F N log N) where F = features, N = samples (sorting dominates)
  • Space: O(N) for the masks and sorted values