← back

MDN with Label Collinearity Control

#360 · Deep Learning · Hard

⊣ Solve on deep-ml.com

Problem

Implement a Mixture Density Network loss with label collinearity control. When multiple labels are correlated, add a regularization term that penalizes the MDN for capturing collinear label variance, encouraging the network to model independent residual structure.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np

def mdn_collinearity_loss(
    Y: np.ndarray,
    pi: np.ndarray,
    mu: np.ndarray,
    sigma: np.ndarray,
    labels: np.ndarray,
    lam: float = 0.1
) -> float:
    n = Y.shape[0]
    n_components = pi.shape[-1]

    # Negative log-likelihood of MDN
    nll = 0.0
    for i in range(n):
        prob = 0.0
        for k in range(n_components):
            diff = Y[i] - mu[i, k]
            d = len(diff)
            det_s = np.linalg.det(sigma[i, k])
            inv_s = np.linalg.inv(sigma[i, k])
            exponent = -0.5 * diff @ inv_s @ diff
            gauss = np.exp(exponent) / (np.sqrt((2 * np.pi) ** d * det_s) + 1e-300)
            prob += pi[i, k] * gauss
        nll -= np.log(prob + 1e-300)
    nll /= n

    # Collinearity penalty: penalize correlation between predicted means and labels
    # Compute correlation matrix of labels
    label_corr = np.corrcoef(labels.T)
    # Penalty = sum of squared off-diagonal correlations scaled by lambda
    mask = ~np.eye(label_corr.shape[0], dtype=bool)
    collinearity_penalty = lam * np.sum(label_corr[mask] ** 2)

    return nll + collinearity_penalty

Explanation

  1. Compute the standard MDN negative log-likelihood: for each sample, sum the weighted Gaussian probabilities from all mixture components, then take the negative log.
  2. Compute a collinearity penalty from the label correlation matrix: sum of squared off-diagonal correlations captures how much redundant information labels share.
  3. The total loss combines NLL and the weighted penalty, encouraging the model to focus on independent label structure.

Complexity

  • Time: O(n K d^2) for NLL + O(L^2 * n) for label correlation where L is the number of labels
  • Space: O(n K d^2) for sigma storage + O(L^2) for correlation matrix