← back

Mutual Information

#204 · Information Theory · Medium

⊣ Solve on deep-ml.com

Problem

Compute the Mutual Information I(X; Y) between two discrete random variables from their joint and marginal distributions. Mutual information measures how much knowing one variable reduces uncertainty about the other.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import numpy as np

def mutual_information(joint_probs: np.ndarray) -> float:
    joint = np.array(joint_probs, dtype=float)
    # Normalize
    joint = joint / joint.sum()
    # Marginals
    p_x = joint.sum(axis=1)  # sum over columns
    p_y = joint.sum(axis=0)  # sum over rows

    mi = 0.0
    for i in range(joint.shape[0]):
        for j in range(joint.shape[1]):
            if joint[i, j] > 0 and p_x[i] > 0 and p_y[j] > 0:
                mi += joint[i, j] * np.log(
                    joint[i, j] / (p_x[i] * p_y[j])
                )
    return float(mi)

def mutual_information_from_samples(x: list, y: list) -> float:
    x = np.array(x)
    y = np.array(y)
    x_vals = np.unique(x)
    y_vals = np.unique(y)
    n = len(x)

    joint = np.zeros((len(x_vals), len(y_vals)))
    x_map = {v: i for i, v in enumerate(x_vals)}
    y_map = {v: j for j, v in enumerate(y_vals)}
    for xi, yi in zip(x, y):
        joint[x_map[xi], y_map[yi]] += 1
    joint /= n

    return mutual_information(joint)

Explanation

  1. From joint distribution: Given P(X, Y), compute marginals P(X) = sum_y P(X, Y) and P(Y) = sum_x P(X, Y).
  2. MI formula: I(X; Y) = sum_{x,y} P(x, y) * log(P(x, y) / (P(x) * P(y))).
  3. Only include terms where the joint probability is positive to avoid log(0).
  4. From samples: Build the empirical joint distribution by counting co-occurrences, then apply the same formula.
  5. MI is always non-negative. It equals 0 when X and Y are independent.

Complexity

  • Time: O(|X| |Y|) for the joint distribution version; O(n + |X| |Y|) from samples
  • Space: O(|X| * |Y|) for the joint probability table