Implement FP4 quantization with the Microscaling (MX) format, specifically MXFP4. In MXFP4, a shared exponent (block scale) is computed per block of values, and each individual value is quantized to a 4-bit floating-point representation (1 sign bit, 2 exponent bits, 1 mantissa bit). Implement the quantization and dequantization.
import math
# FP4 E2M1 representable values (positive): 0, 0.5, 1, 1.5, 2, 3, 4, 6
FP4_VALUES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
def nearest_fp4(x: float) -> float:
"""Find nearest FP4 E2M1 value to x (by absolute value)."""
abs_x = abs(x)
best = 0.0
best_dist = float('inf')
for v in FP4_VALUES:
d = abs(abs_x - v)
if d < best_dist:
best_dist = d
best = v
return best if x >= 0 else -best
def mxfp4_quantize(
values: list[float],
block_size: int = 32
) -> dict:
n = len(values)
n_blocks = math.ceil(n / block_size)
block_scales = []
quantized = []
dequantized = []
for b in range(n_blocks):
start = b * block_size
end = min(start + block_size, n)
block = values[start:end]
# Compute shared exponent: max absolute value in block
max_abs = max(abs(v) for v in block) if block else 1.0
if max_abs == 0:
scale = 1.0
else:
# Scale so that max value maps to FP4 max (6.0)
scale = max_abs / 6.0
# Round scale to power of 2 (shared exponent)
if scale > 0:
log_scale = math.floor(math.log2(scale))
scale = 2.0 ** log_scale
block_scales.append(scale)
for v in block:
scaled_v = v / scale if scale > 0 else 0.0
q = nearest_fp4(scaled_v)
quantized.append(q)
dequantized.append(q * scale)
mse = sum((values[i] - dequantized[i]) ** 2 for i in range(n)) / n if n > 0 else 0
return {
"block_scales": [round(s, 8) for s in block_scales],
"quantized_fp4": quantized,
"dequantized": [round(d, 6) for d in dequantized],
"mse": round(mse, 8),
"compression_ratio": round(32 / 4, 1) # FP32 to FP4
}