#138 · Machine Learning · Medium
⊣ Solve on deep-ml.comFind the best binary split for a decision tree node using the Gini impurity criterion. Given feature values and labels, evaluate all possible splits and return the one with the lowest weighted Gini impurity.
import numpy as np
def gini_impurity(labels: np.ndarray) -> float:
if len(labels) == 0:
return 0.0
classes, counts = np.unique(labels, return_counts=True)
probs = counts / len(labels)
return 1.0 - np.sum(probs ** 2)
def best_gini_split(X: np.ndarray, y: np.ndarray) -> dict:
n_samples, n_features = X.shape
best_gini = float('inf')
best_feature = None
best_threshold = None
for feature_idx in range(n_features):
values = X[:, feature_idx]
sorted_unique = np.unique(values)
# Candidate thresholds: midpoints between consecutive unique values
thresholds = (sorted_unique[:-1] + sorted_unique[1:]) / 2.0
for threshold in thresholds:
left_mask = values <= threshold
right_mask = ~left_mask
n_left = np.sum(left_mask)
n_right = np.sum(right_mask)
if n_left == 0 or n_right == 0:
continue
gini_left = gini_impurity(y[left_mask])
gini_right = gini_impurity(y[right_mask])
weighted_gini = (n_left * gini_left + n_right * gini_right) / n_samples
if weighted_gini < best_gini:
best_gini = weighted_gini
best_feature = feature_idx
best_threshold = threshold
return {
"feature": best_feature,
"threshold": best_threshold,
"gini": round(best_gini, 4)
}1 - sum(p_i^2).(n_left * gini_left + n_right * gini_right) / n_total.