← back

PCA Color Augmentation

#191 · Computer Vision · Hard

⊣ Solve on deep-ml.com

Problem

Implement PCA Color Augmentation (AlexNet-style color jittering). Given an image, compute PCA on the set of RGB pixel values, then add multiples of the principal components (scaled by their eigenvalues and random factors) to every pixel.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np

def pca_color_augmentation(image: np.ndarray,
                            alpha_std: float = 0.1) -> np.ndarray:
    # image: (H, W, 3) with float values in [0, 1] or [0, 255]
    img = image.astype(np.float64)
    pixels = img.reshape(-1, 3)  # (N, 3)

    # Center the pixels
    mean = pixels.mean(axis=0)
    centered = pixels - mean

    # Compute covariance matrix (3x3)
    cov = np.cov(centered, rowvar=False)

    # Eigendecomposition
    eigenvalues, eigenvectors = np.linalg.eigh(cov)

    # Sort by descending eigenvalue
    idx = np.argsort(eigenvalues)[::-1]
    eigenvalues = eigenvalues[idx]
    eigenvectors = eigenvectors[:, idx]

    # Sample random alphas
    alphas = np.random.normal(0, alpha_std, size=3)

    # Perturbation: sum of alpha_i * eigenvalue_i * eigenvector_i
    delta = np.zeros(3)
    for i in range(3):
        delta += alphas[i] * eigenvalues[i] * eigenvectors[:, i]

    # Add perturbation to every pixel
    result = img + delta.reshape(1, 1, 3)

    # Clip to valid range
    if image.max() > 1.0:
        result = np.clip(result, 0, 255)
    else:
        result = np.clip(result, 0.0, 1.0)

    return result

Explanation

  1. Flatten the image to an (N, 3) matrix of RGB pixel values.
  2. Compute the 3x3 covariance matrix of the RGB channels.
  3. Perform eigendecomposition to get the principal components and their eigenvalues.
  4. Sample random scaling factors alpha_i ~ N(0, alpha_std) and compute the perturbation as sum(alpha_i * lambda_i * p_i) where lambda_i are eigenvalues and p_i are eigenvectors.
  5. Add this constant RGB perturbation to every pixel and clip to valid range.

Complexity

  • Time: O(H * W) for reshaping and adding the perturbation; the 3x3 eigendecomposition is O(1)
  • Space: O(H * W) for the flattened pixel array