Implement PCA Color Augmentation (AlexNet-style color jittering). Given an image, compute PCA on the set of RGB pixel values, then add multiples of the principal components (scaled by their eigenvalues and random factors) to every pixel.
import numpy as np
def pca_color_augmentation(image: np.ndarray,
alpha_std: float = 0.1) -> np.ndarray:
# image: (H, W, 3) with float values in [0, 1] or [0, 255]
img = image.astype(np.float64)
pixels = img.reshape(-1, 3) # (N, 3)
# Center the pixels
mean = pixels.mean(axis=0)
centered = pixels - mean
# Compute covariance matrix (3x3)
cov = np.cov(centered, rowvar=False)
# Eigendecomposition
eigenvalues, eigenvectors = np.linalg.eigh(cov)
# Sort by descending eigenvalue
idx = np.argsort(eigenvalues)[::-1]
eigenvalues = eigenvalues[idx]
eigenvectors = eigenvectors[:, idx]
# Sample random alphas
alphas = np.random.normal(0, alpha_std, size=3)
# Perturbation: sum of alpha_i * eigenvalue_i * eigenvector_i
delta = np.zeros(3)
for i in range(3):
delta += alphas[i] * eigenvalues[i] * eigenvectors[:, i]
# Add perturbation to every pixel
result = img + delta.reshape(1, 1, 3)
# Clip to valid range
if image.max() > 1.0:
result = np.clip(result, 0, 255)
else:
result = np.clip(result, 0.0, 1.0)
return resultalpha_i ~ N(0, alpha_std) and compute the perturbation as sum(alpha_i * lambda_i * p_i) where lambda_i are eigenvalues and p_i are eigenvectors.