← back

FP4 Quantization with Microscaling (MXFP4)

#427 · Deep Learning · Hard

⊣ Solve on deep-ml.com

Problem

Implement FP4 quantization with the Microscaling (MX) format, specifically MXFP4. In MXFP4, a shared exponent (block scale) is computed per block of values, and each individual value is quantized to a 4-bit floating-point representation (1 sign bit, 2 exponent bits, 1 mantissa bit). Implement the quantization and dequantization.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import math

# FP4 E2M1 representable values (positive): 0, 0.5, 1, 1.5, 2, 3, 4, 6
FP4_VALUES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]

def nearest_fp4(x: float) -> float:
    """Find nearest FP4 E2M1 value to x (by absolute value)."""
    abs_x = abs(x)
    best = 0.0
    best_dist = float('inf')
    for v in FP4_VALUES:
        d = abs(abs_x - v)
        if d < best_dist:
            best_dist = d
            best = v
    return best if x >= 0 else -best

def mxfp4_quantize(
    values: list[float],
    block_size: int = 32
) -> dict:
    n = len(values)
    n_blocks = math.ceil(n / block_size)

    block_scales = []
    quantized = []
    dequantized = []

    for b in range(n_blocks):
        start = b * block_size
        end = min(start + block_size, n)
        block = values[start:end]

        # Compute shared exponent: max absolute value in block
        max_abs = max(abs(v) for v in block) if block else 1.0
        if max_abs == 0:
            scale = 1.0
        else:
            # Scale so that max value maps to FP4 max (6.0)
            scale = max_abs / 6.0

        # Round scale to power of 2 (shared exponent)
        if scale > 0:
            log_scale = math.floor(math.log2(scale))
            scale = 2.0 ** log_scale
        block_scales.append(scale)

        for v in block:
            scaled_v = v / scale if scale > 0 else 0.0
            q = nearest_fp4(scaled_v)
            quantized.append(q)
            dequantized.append(q * scale)

    mse = sum((values[i] - dequantized[i]) ** 2 for i in range(n)) / n if n > 0 else 0

    return {
        "block_scales": [round(s, 8) for s in block_scales],
        "quantized_fp4": quantized,
        "dequantized": [round(d, 6) for d in dequantized],
        "mse": round(mse, 8),
        "compression_ratio": round(32 / 4, 1)  # FP32 to FP4
    }

Explanation

  1. MXFP4 (Microscaling FP4) uses a shared block exponent plus per-element 4-bit FP values. The block scale is a power of 2.
  2. FP4 E2M1 has 1 sign bit, 2 exponent bits, 1 mantissa bit. The representable positive values are {0, 0.5, 1, 1.5, 2, 3, 4, 6}.
  3. The block scale is computed as a power-of-2 that maps the block's max absolute value into the FP4 representable range.
  4. Each value is divided by the block scale, rounded to the nearest FP4 value, and stored as 4 bits.
  5. Dequantization multiplies the FP4 value by the block scale.
  6. The shared exponent amortizes over the block (typically 32 elements), adding only 8 bits per block of overhead.

Complexity

  • Time: O(n) where n is the number of values
  • Space: O(n) for quantized and dequantized arrays