← back

Block-wise FP8 Quantization

#234 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement block-wise FP8 quantization for model weights. Divide a weight tensor into fixed-size blocks, compute a per-block scaling factor, and quantize each block to FP8 representation.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def fp8_block_quantize(
    weights: list[list[float]],
    block_size: int = 32,
    fp8_max: float = 448.0,
) -> dict:
    """
    Returns quantized weights and per-block scales.
    """
    rows = len(weights)
    cols = len(weights[0])

    # Flatten
    flat = []
    for r in range(rows):
        for c in range(cols):
            flat.append(weights[r][c])

    total = len(flat)
    num_blocks = (total + block_size - 1) // block_size

    quantized_flat = [0.0] * total
    scales = []

    for b in range(num_blocks):
        start = b * block_size
        end = min(start + block_size, total)
        block = flat[start:end]

        # Compute per-block scale
        abs_max = max(abs(v) for v in block)
        if abs_max < 1e-12:
            scale = 1.0
        else:
            scale = fp8_max / abs_max
        scales.append(round(1.0 / scale, 8))

        # Quantize: multiply by scale, clamp to FP8 range, then dequantize
        for i in range(len(block)):
            q = block[i] * scale
            q = max(-fp8_max, min(fp8_max, q))
            # Dequantize back
            quantized_flat[start + i] = round(q / scale, 6)

    # Reshape
    quantized = []
    idx = 0
    for r in range(rows):
        row = []
        for c in range(cols):
            row.append(quantized_flat[idx])
            idx += 1
        quantized.append(row)

    return {"quantized_weights": quantized, "scales": scales}

Explanation

  1. Flatten the weight matrix and divide into blocks of fixed size.
  2. For each block, compute the absolute maximum value.
  3. Compute a scale factor: scale = fp8_max / abs_max, so the block's range maps to [-fp8_max, fp8_max].
  4. Quantize by multiplying by scale and clamping, then dequantize by dividing by scale.
  5. Store per-block scales for later dequantization during inference.

Complexity

  • Time: O(n) where n is the total number of weight elements
  • Space: O(n + n/block_size) for quantized weights and scales