Implement the core residualization step for a Mixture Density Network (MDN). Given input features and metadata to control for, train a residualization step that removes metadata-correlated variance from intermediate representations, then model the residual with a mixture of Gaussians.
import numpy as np
def mdn_residualize(X: np.ndarray, metadata: np.ndarray, n_components: int = 3) -> dict:
# Step 1: Residualize X by regressing out metadata
# Solve X = metadata @ beta + residual via least squares
M = metadata
if M.ndim == 1:
M = M.reshape(-1, 1)
M_aug = np.hstack([M, np.ones((M.shape[0], 1))]) # add intercept
beta, _, _, _ = np.linalg.lstsq(M_aug, X, rcond=None)
residual = X - M_aug @ beta
# Step 2: Fit simple GMM on residual for MDN output params
n, d = residual.shape
# Initialize GMM parameters
pi = np.ones(n_components) / n_components
indices = np.random.choice(n, n_components, replace=False)
mu = residual[indices].copy()
sigma = np.array([np.eye(d) for _ in range(n_components)])
for _ in range(50):
# E-step
resp = np.zeros((n, n_components))
for k in range(n_components):
diff = residual - mu[k]
inv_sigma = np.linalg.inv(sigma[k])
det_sigma = np.linalg.det(sigma[k])
exponent = -0.5 * np.sum(diff @ inv_sigma * diff, axis=1)
resp[:, k] = pi[k] * np.exp(exponent) / (np.sqrt((2 * np.pi) ** d * det_sigma) + 1e-300)
resp /= resp.sum(axis=1, keepdims=True) + 1e-300
# M-step
Nk = resp.sum(axis=0)
pi = Nk / n
for k in range(n_components):
mu[k] = (resp[:, k:k+1].T @ residual) / (Nk[k] + 1e-10)
diff = residual - mu[k]
sigma[k] = (diff.T @ (diff * resp[:, k:k+1])) / (Nk[k] + 1e-10)
sigma[k] += np.eye(d) * 1e-6
return {"beta": beta, "residual": residual, "pi": pi, "mu": mu, "sigma": sigma}