#285 · Machine Learning · Medium
⊣ Solve on deep-ml.comImplement cost-complexity pruning for decision trees. Given a tree, compute the effective alpha for each subtree and prune nodes where the cost-complexity criterion indicates pruning improves generalization.
import numpy as np
class TreeNode:
def __init__(self, value=None, left=None, right=None, n_samples=0, impurity=0.0):
self.value = value
self.left = left
self.right = right
self.n_samples = n_samples
self.impurity = impurity
def count_leaves(node: TreeNode) -> int:
if node.left is None and node.right is None:
return 1
count = 0
if node.left:
count += count_leaves(node.left)
if node.right:
count += count_leaves(node.right)
return count
def subtree_cost(node: TreeNode) -> float:
if node.left is None and node.right is None:
return node.impurity * node.n_samples
cost = 0.0
if node.left:
cost += subtree_cost(node.left)
if node.right:
cost += subtree_cost(node.right)
return cost
def effective_alpha(node: TreeNode) -> float:
if node.left is None and node.right is None:
return float('inf')
r_t = node.impurity * node.n_samples
r_subtree = subtree_cost(node)
leaves = count_leaves(node)
alpha = (r_t - r_subtree) / (leaves - 1)
return alpha
def prune_tree(node: TreeNode, alpha: float) -> TreeNode:
if node.left is None and node.right is None:
return node
if node.left:
node.left = prune_tree(node.left, alpha)
if node.right:
node.right = prune_tree(node.right, alpha)
if effective_alpha(node) <= alpha:
node.left = None
node.right = None
return node