← back

Implement Mini-Batch K-Means

#363 · Machine Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement Mini-Batch K-Means, a scalable variant of K-Means that updates cluster centers using small random batches of data instead of the full dataset at each iteration.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import numpy as np

def mini_batch_kmeans(
    X: np.ndarray, k: int, batch_size: int = 100, n_iters: int = 100
) -> tuple[np.ndarray, np.ndarray]:
    n, d = X.shape
    # Initialize centers using random samples
    indices = np.random.choice(n, k, replace=False)
    centers = X[indices].copy()
    counts = np.zeros(k)

    for _ in range(n_iters):
        # Sample a mini-batch
        batch_idx = np.random.choice(n, min(batch_size, n), replace=False)
        batch = X[batch_idx]

        # Assign batch points to nearest center
        dists = np.sum((batch[:, None] - centers[None, :]) ** 2, axis=2)
        assignments = np.argmin(dists, axis=1)

        # Update centers with streaming average
        for i, idx in enumerate(assignments):
            counts[idx] += 1
            eta = 1.0 / counts[idx]
            centers[idx] = (1 - eta) * centers[idx] + eta * batch[i]

    # Assign all points to final centers
    dists = np.sum((X[:, None] - centers[None, :]) ** 2, axis=2)
    labels = np.argmin(dists, axis=1)
    return centers, labels

Explanation

  1. Initialize centers by randomly selecting k data points.
  2. At each iteration, sample a mini-batch and assign each batch point to its nearest center.
  3. Update each center using a streaming average: center = (1 - eta) * center + eta * x where eta = 1/count.
  4. After all iterations, assign every point to its nearest final center.

Complexity

  • Time: O(T B k * d) where T is iterations, B is batch size
  • Space: O(k * d) for centers