Implement Mini-Batch K-Means, a scalable variant of K-Means that updates cluster centers using small random batches of data instead of the full dataset at each iteration.
import numpy as np
def mini_batch_kmeans(
X: np.ndarray, k: int, batch_size: int = 100, n_iters: int = 100
) -> tuple[np.ndarray, np.ndarray]:
n, d = X.shape
# Initialize centers using random samples
indices = np.random.choice(n, k, replace=False)
centers = X[indices].copy()
counts = np.zeros(k)
for _ in range(n_iters):
# Sample a mini-batch
batch_idx = np.random.choice(n, min(batch_size, n), replace=False)
batch = X[batch_idx]
# Assign batch points to nearest center
dists = np.sum((batch[:, None] - centers[None, :]) ** 2, axis=2)
assignments = np.argmin(dists, axis=1)
# Update centers with streaming average
for i, idx in enumerate(assignments):
counts[idx] += 1
eta = 1.0 / counts[idx]
centers[idx] = (1 - eta) * centers[idx] + eta * batch[i]
# Assign all points to final centers
dists = np.sum((X[:, None] - centers[None, :]) ** 2, axis=2)
labels = np.argmin(dists, axis=1)
return centers, labelscenter = (1 - eta) * center + eta * x where eta = 1/count.