Implement a Mixture Density Network loss with label collinearity control. When multiple labels are correlated, add a regularization term that penalizes the MDN for capturing collinear label variance, encouraging the network to model independent residual structure.
import numpy as np
def mdn_collinearity_loss(
Y: np.ndarray,
pi: np.ndarray,
mu: np.ndarray,
sigma: np.ndarray,
labels: np.ndarray,
lam: float = 0.1
) -> float:
n = Y.shape[0]
n_components = pi.shape[-1]
# Negative log-likelihood of MDN
nll = 0.0
for i in range(n):
prob = 0.0
for k in range(n_components):
diff = Y[i] - mu[i, k]
d = len(diff)
det_s = np.linalg.det(sigma[i, k])
inv_s = np.linalg.inv(sigma[i, k])
exponent = -0.5 * diff @ inv_s @ diff
gauss = np.exp(exponent) / (np.sqrt((2 * np.pi) ** d * det_s) + 1e-300)
prob += pi[i, k] * gauss
nll -= np.log(prob + 1e-300)
nll /= n
# Collinearity penalty: penalize correlation between predicted means and labels
# Compute correlation matrix of labels
label_corr = np.corrcoef(labels.T)
# Penalty = sum of squared off-diagonal correlations scaled by lambda
mask = ~np.eye(label_corr.shape[0], dtype=bool)
collinearity_penalty = lam * np.sum(label_corr[mask] ** 2)
return nll + collinearity_penalty