← back

Gaussian Mixture Model with EM Algorithm

#341 · Machine Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement a Gaussian Mixture Model (GMM) fitted with the Expectation-Maximization (EM) algorithm. Given data points, fit a mixture of K Gaussians by iteratively computing responsibilities (E-step) and updating parameters (M-step).

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import numpy as np
from typing import Dict, Tuple

def gaussian_pdf(x: np.ndarray, mean: float, var: float) -> np.ndarray:
    return (1.0 / np.sqrt(2 * np.pi * var)) * np.exp(-0.5 * (x - mean) ** 2 / var)

def gmm_em(
    data: np.ndarray,
    k: int,
    max_iter: int = 100,
    tol: float = 1e-6,
    seed: int = 42
) -> Dict:
    np.random.seed(seed)
    n = len(data)

    # Initialize parameters
    indices = np.random.choice(n, k, replace=False)
    means = data[indices].astype(float)
    variances = np.full(k, np.var(data))
    weights = np.full(k, 1.0 / k)

    log_likelihood = -np.inf

    for iteration in range(max_iter):
        # E-step: compute responsibilities
        resp = np.zeros((n, k))
        for j in range(k):
            resp[:, j] = weights[j] * gaussian_pdf(data, means[j], variances[j])

        row_sums = resp.sum(axis=1, keepdims=True)
        row_sums = np.maximum(row_sums, 1e-300)
        resp /= row_sums

        # M-step: update parameters
        Nk = resp.sum(axis=0)
        for j in range(k):
            if Nk[j] < 1e-10:
                continue
            means[j] = np.dot(resp[:, j], data) / Nk[j]
            variances[j] = np.dot(resp[:, j], (data - means[j]) ** 2) / Nk[j]
            variances[j] = max(variances[j], 1e-6)
        weights = Nk / n

        # Check convergence
        new_ll = np.sum(np.log(row_sums))
        if abs(new_ll - log_likelihood) < tol:
            break
        log_likelihood = new_ll

    return {
        "means": means.tolist(),
        "variances": variances.tolist(),
        "weights": weights.tolist(),
        "log_likelihood": float(log_likelihood),
        "iterations": iteration + 1
    }

Explanation

  1. Initialize K components with means from random data points, equal variances, and uniform weights.
  2. E-step: For each data point, compute the responsibility of each component -- the probability that the component generated the point, proportional to weight * N(x; mean, var).
  3. M-step: Update each component's mean (weighted average of data), variance (weighted average of squared deviations), and weight (fraction of total responsibility).
  4. Repeat until the log-likelihood converges (change below tolerance).

Complexity

  • Time: O(I n K) where I is iterations, n is data points, K is components
  • Space: O(n * K) for the responsibility matrix