← back

Implement Entropy-based Split Selection

#284 · Machine Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement entropy-based split selection for a decision tree. Given a dataset with features and labels, find the best feature and threshold to split on by maximizing the information gain (reduction in entropy).

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import numpy as np

def entropy(labels: np.ndarray) -> float:
    if len(labels) == 0:
        return 0.0
    counts = np.bincount(labels.astype(int))
    probs = counts[counts > 0] / len(labels)
    return -np.sum(probs * np.log2(probs))

def information_gain(parent: np.ndarray, left: np.ndarray, right: np.ndarray) -> float:
    n = len(parent)
    if n == 0:
        return 0.0
    w_left = len(left) / n
    w_right = len(right) / n
    return entropy(parent) - w_left * entropy(left) - w_right * entropy(right)

def best_split(X: np.ndarray, y: np.ndarray) -> dict:
    best_gain = -1
    best_feature = None
    best_threshold = None

    n_features = X.shape[1]
    for feature_idx in range(n_features):
        thresholds = np.unique(X[:, feature_idx])
        for threshold in thresholds:
            left_mask = X[:, feature_idx] <= threshold
            right_mask = ~left_mask
            if np.sum(left_mask) == 0 or np.sum(right_mask) == 0:
                continue
            gain = information_gain(y, y[left_mask], y[right_mask])
            if gain > best_gain:
                best_gain = gain
                best_feature = feature_idx
                best_threshold = threshold

    return {"feature": best_feature, "threshold": best_threshold, "gain": best_gain}

Explanation

  1. Entropy measures the impurity of a set of labels using the formula -sum(p * log2(p)).
  2. Information gain is the reduction in entropy achieved by splitting: IG = H(parent) - weighted_avg(H(children)).
  3. We iterate over every feature and every unique threshold value, computing the information gain for each potential split.
  4. The split with the highest information gain is selected.

Complexity

  • Time: O(n m k) where n is number of samples, m is number of features, k is average number of unique values per feature
  • Space: O(n) for mask arrays