← back

Decision Tree Pruning with Cost-Complexity

#285 · Machine Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement cost-complexity pruning for decision trees. Given a tree, compute the effective alpha for each subtree and prune nodes where the cost-complexity criterion indicates pruning improves generalization.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import numpy as np

class TreeNode:
    def __init__(self, value=None, left=None, right=None, n_samples=0, impurity=0.0):
        self.value = value
        self.left = left
        self.right = right
        self.n_samples = n_samples
        self.impurity = impurity

def count_leaves(node: TreeNode) -> int:
    if node.left is None and node.right is None:
        return 1
    count = 0
    if node.left:
        count += count_leaves(node.left)
    if node.right:
        count += count_leaves(node.right)
    return count

def subtree_cost(node: TreeNode) -> float:
    if node.left is None and node.right is None:
        return node.impurity * node.n_samples
    cost = 0.0
    if node.left:
        cost += subtree_cost(node.left)
    if node.right:
        cost += subtree_cost(node.right)
    return cost

def effective_alpha(node: TreeNode) -> float:
    if node.left is None and node.right is None:
        return float('inf')
    r_t = node.impurity * node.n_samples
    r_subtree = subtree_cost(node)
    leaves = count_leaves(node)
    alpha = (r_t - r_subtree) / (leaves - 1)
    return alpha

def prune_tree(node: TreeNode, alpha: float) -> TreeNode:
    if node.left is None and node.right is None:
        return node
    if node.left:
        node.left = prune_tree(node.left, alpha)
    if node.right:
        node.right = prune_tree(node.right, alpha)
    if effective_alpha(node) <= alpha:
        node.left = None
        node.right = None
    return node

Explanation

  1. Cost-complexity criterion: for each internal node, compute its effective alpha = (R(t) - R(T_t)) / (|T_t| - 1), where R(t) is the cost of replacing the subtree with a leaf, R(T_t) is the total cost of the subtree, and |T_t| is the number of leaves.
  2. A lower alpha means the subtree provides little benefit; it should be pruned first.
  3. Given a target alpha, prune bottom-up: if a node's effective alpha is at or below the threshold, replace it with a leaf.

Complexity

  • Time: O(n^2) in the worst case for repeated pruning passes over n nodes
  • Space: O(h) where h is the height of the tree (recursion stack)