← back

Calculate BIC/AIC for Model Selection

#368 · Machine Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement BIC (Bayesian Information Criterion) and AIC (Akaike Information Criterion) for model selection. Given a fitted model's log-likelihood, number of parameters, and number of data points, compute both criteria.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import numpy as np

def compute_bic_aic(
    log_likelihood: float,
    n_params: int,
    n_samples: int
) -> dict:
    bic = -2 * log_likelihood + n_params * np.log(n_samples)
    aic = -2 * log_likelihood + 2 * n_params
    return {"bic": float(bic), "aic": float(aic)}

def gmm_log_likelihood(X: np.ndarray, pi: np.ndarray, mu: np.ndarray, sigma: np.ndarray) -> float:
    n, d = X.shape
    k = len(pi)
    ll = 0.0
    for i in range(n):
        prob = 0.0
        for j in range(k):
            diff = X[i] - mu[j]
            inv_cov = np.linalg.inv(sigma[j])
            det_cov = np.linalg.det(sigma[j])
            exponent = -0.5 * diff @ inv_cov @ diff
            gauss = np.exp(exponent) / (np.sqrt((2 * np.pi) ** d * det_cov) + 1e-300)
            prob += pi[j] * gauss
        ll += np.log(prob + 1e-300)
    return ll

def gmm_bic_aic(X: np.ndarray, pi: np.ndarray, mu: np.ndarray, sigma: np.ndarray) -> dict:
    n, d = X.shape
    k = len(pi)
    n_params = k * d + k * d * (d + 1) // 2 + (k - 1)
    ll = gmm_log_likelihood(X, pi, mu, sigma)
    return compute_bic_aic(ll, n_params, n)

Explanation

  1. BIC = -2 log_likelihood + k ln(n), penalizing model complexity more heavily as sample size grows.
  2. AIC = -2 * log_likelihood + 2k, providing a fixed penalty per parameter.
  3. For GMM, the number of free parameters includes: kd for means, kd*(d+1)/2 for covariance matrices (symmetric), and k-1 for mixing coefficients (they sum to 1).
  4. Lower BIC/AIC values indicate a better model (better trade-off between fit and complexity).

Complexity

  • Time: O(n k d^2) for log-likelihood computation
  • Space: O(1) beyond input storage