← back

Implement Core MDN Residualization

#358 · Deep Learning · Hard

⊣ Solve on deep-ml.com

Problem

Implement the core residualization step for a Mixture Density Network (MDN). Given input features and metadata to control for, train a residualization step that removes metadata-correlated variance from intermediate representations, then model the residual with a mixture of Gaussians.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np

def mdn_residualize(X: np.ndarray, metadata: np.ndarray, n_components: int = 3) -> dict:
    # Step 1: Residualize X by regressing out metadata
    # Solve X = metadata @ beta + residual via least squares
    M = metadata
    if M.ndim == 1:
        M = M.reshape(-1, 1)
    M_aug = np.hstack([M, np.ones((M.shape[0], 1))])  # add intercept
    beta, _, _, _ = np.linalg.lstsq(M_aug, X, rcond=None)
    residual = X - M_aug @ beta

    # Step 2: Fit simple GMM on residual for MDN output params
    n, d = residual.shape
    # Initialize GMM parameters
    pi = np.ones(n_components) / n_components
    indices = np.random.choice(n, n_components, replace=False)
    mu = residual[indices].copy()
    sigma = np.array([np.eye(d) for _ in range(n_components)])

    for _ in range(50):
        # E-step
        resp = np.zeros((n, n_components))
        for k in range(n_components):
            diff = residual - mu[k]
            inv_sigma = np.linalg.inv(sigma[k])
            det_sigma = np.linalg.det(sigma[k])
            exponent = -0.5 * np.sum(diff @ inv_sigma * diff, axis=1)
            resp[:, k] = pi[k] * np.exp(exponent) / (np.sqrt((2 * np.pi) ** d * det_sigma) + 1e-300)
        resp /= resp.sum(axis=1, keepdims=True) + 1e-300

        # M-step
        Nk = resp.sum(axis=0)
        pi = Nk / n
        for k in range(n_components):
            mu[k] = (resp[:, k:k+1].T @ residual) / (Nk[k] + 1e-10)
            diff = residual - mu[k]
            sigma[k] = (diff.T @ (diff * resp[:, k:k+1])) / (Nk[k] + 1e-10)
            sigma[k] += np.eye(d) * 1e-6

    return {"beta": beta, "residual": residual, "pi": pi, "mu": mu, "sigma": sigma}

Explanation

  1. Regress out metadata from X using ordinary least squares to get residuals free of metadata-correlated variance.
  2. Fit a Gaussian Mixture Model on the residual to model it as a mixture of Gaussians (the MDN output).
  3. The E-step computes responsibilities (soft assignments) and the M-step updates mixing coefficients, means, and covariances.
  4. Returns regression coefficients (for residualization) and GMM parameters (pi, mu, sigma).

Complexity

  • Time: O(n d^2 K * T) where K is number of components and T is EM iterations
  • Space: O(n K + K d^2) for responsibilities and covariance matrices