← back

Decision Tree Learning

#20 · Machine Learning · Hard

⊣ Solve on deep-ml.com

Problem

Implement a Decision Tree classifier from scratch using information gain (entropy-based) to select the best feature splits.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import math
from collections import Counter

def decision_tree(examples: list[dict], features: list[str], target: str) -> dict:
    # Get all target values
    labels = [ex[target] for ex in examples]
    label_counts = Counter(labels)

    # If all examples have the same label, return that label
    if len(label_counts) == 1:
        return labels[0]

    # If no features left, return majority label
    if not features:
        return label_counts.most_common(1)[0][0]

    # Calculate entropy of current set
    def entropy(data):
        total = len(data)
        if total == 0:
            return 0
        counts = Counter(d[target] for d in data)
        return -sum((c / total) * math.log2(c / total) for c in counts.values() if c > 0)

    # Find best feature by information gain
    current_entropy = entropy(examples)
    best_gain = -1
    best_feature = None

    for feature in features:
        values = set(ex[feature] for ex in examples)
        weighted_entropy = 0
        for val in values:
            subset = [ex for ex in examples if ex[feature] == val]
            weighted_entropy += (len(subset) / len(examples)) * entropy(subset)
        gain = current_entropy - weighted_entropy
        if gain > best_gain:
            best_gain = gain
            best_feature = feature

    # Build tree
    tree = {best_feature: {}}
    remaining_features = [f for f in features if f != best_feature]
    values = set(ex[best_feature] for ex in examples)

    for val in values:
        subset = [ex for ex in examples if ex[best_feature] == val]
        if not subset:
            tree[best_feature][val] = label_counts.most_common(1)[0][0]
        else:
            tree[best_feature][val] = decision_tree(subset, remaining_features, target)

    return tree

Explanation

  1. Base cases: Return the label if all examples agree, or the majority label if no features remain.
  2. Entropy: Measures the impurity of a set using -sum(p * log2(p)).
  3. Information gain: For each feature, compute how much entropy decreases after splitting. Select the feature with the highest gain.
  4. Recursion: Split the data on the best feature and recursively build subtrees for each feature value.

Complexity

  • Time: O(f n log(n)) per level, O(f^2 n log(n)) total for a balanced tree
  • Space: O(f * n) for the recursive calls and data subsets