← back

Implement Decision Tree for Regression

#286 · Machine Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement a decision tree for regression from scratch. Build a tree that splits on features to minimize the variance (or MSE) of the target values in each leaf, then predict by returning the mean of the leaf's values.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import numpy as np

class Node:
    def __init__(self, feature=None, threshold=None, left=None, right=None, value=None):
        self.feature = feature
        self.threshold = threshold
        self.left = left
        self.right = right
        self.value = value

def mse(y: np.ndarray) -> float:
    if len(y) == 0:
        return 0.0
    return np.var(y) * len(y)

def best_split(X: np.ndarray, y: np.ndarray):
    best_gain = 0
    best_feature = None
    best_threshold = None
    parent_mse = mse(y)

    for feature_idx in range(X.shape[1]):
        thresholds = np.unique(X[:, feature_idx])
        for threshold in thresholds:
            left_mask = X[:, feature_idx] <= threshold
            right_mask = ~left_mask
            if np.sum(left_mask) == 0 or np.sum(right_mask) == 0:
                continue
            gain = parent_mse - mse(y[left_mask]) - mse(y[right_mask])
            if gain > best_gain:
                best_gain = gain
                best_feature = feature_idx
                best_threshold = threshold

    return best_feature, best_threshold

def build_tree(X: np.ndarray, y: np.ndarray, max_depth: int = 10, min_samples: int = 2) -> Node:
    if len(y) < min_samples or max_depth == 0 or len(np.unique(y)) == 1:
        return Node(value=np.mean(y))

    feature, threshold = best_split(X, y)
    if feature is None:
        return Node(value=np.mean(y))

    left_mask = X[:, feature] <= threshold
    right_mask = ~left_mask

    left_node = build_tree(X[left_mask], y[left_mask], max_depth - 1, min_samples)
    right_node = build_tree(X[right_mask], y[right_mask], max_depth - 1, min_samples)

    return Node(feature=feature, threshold=threshold, left=left_node, right=right_node)

def predict_one(node: Node, x: np.ndarray) -> float:
    if node.value is not None:
        return node.value
    if x[node.feature] <= node.threshold:
        return predict_one(node.left, x)
    return predict_one(node.right, x)

def predict(node: Node, X: np.ndarray) -> np.ndarray:
    return np.array([predict_one(node, x) for x in X])

Explanation

  1. At each node, iterate over all features and unique thresholds to find the split that maximally reduces the total MSE (variance * count).
  2. Recursively build left and right subtrees until a stopping criterion is met (max depth, min samples, or pure node).
  3. Leaf nodes store the mean of their training targets.
  4. Prediction traverses the tree from root to leaf using the learned split conditions.

Complexity

  • Time: O(n m k * d) for building, where n=samples, m=features, k=unique values, d=depth
  • Space: O(n * d) for the tree structure