← back

Post-Training Quantization with Per-Channel Scale Factors

#426 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement post-training quantization (PTQ) with per-channel scale factors. Given a weight matrix in FP32, quantize it to INT8 using per-channel (per-output-channel) symmetric quantization. Compute the scale factor for each output channel, quantize, then dequantize and measure the quantization error.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def per_channel_quantize(
    weights: list[list[float]],
    n_bits: int = 8
) -> dict:
    qmin = -(1 << (n_bits - 1))
    qmax = (1 << (n_bits - 1)) - 1
    n_rows = len(weights)
    n_cols = len(weights[0])

    scales = []
    quantized = []
    dequantized = []

    for row in weights:
        max_abs = max(abs(v) for v in row)
        scale = max_abs / qmax if max_abs > 0 else 1.0
        scales.append(scale)

        q_row = []
        dq_row = []
        for v in row:
            q = round(v / scale)
            q = max(qmin, min(qmax, q))  # clamp
            q_row.append(q)
            dq_row.append(q * scale)
        quantized.append(q_row)
        dequantized.append(dq_row)

    # Compute quantization error
    total_error = 0.0
    total_elements = 0
    max_error = 0.0
    for i in range(n_rows):
        for j in range(n_cols):
            err = abs(weights[i][j] - dequantized[i][j])
            total_error += err * err
            max_error = max(max_error, err)
            total_elements += 1

    mse = total_error / total_elements if total_elements > 0 else 0.0

    return {
        "scales": [round(s, 8) for s in scales],
        "quantized": quantized,
        "dequantized": dequantized,
        "mse": round(mse, 8),
        "max_error": round(max_error, 8)
    }

Explanation

  1. Per-channel quantization computes a separate scale factor for each output channel (row of the weight matrix), allowing better precision than per-tensor quantization.
  2. For symmetric quantization, the scale = max(|w|) / qmax for each channel, mapping the float range to [-128, 127] for INT8.
  3. Quantize: q = clamp(round(w / scale), qmin, qmax).
  4. Dequantize: w_approx = q * scale.
  5. The MSE between original and dequantized weights measures quantization quality. Per-channel scales significantly reduce this error compared to per-tensor, since different channels can have very different value ranges.

Complexity

  • Time: O(m * n) where m is the number of output channels and n is the input dimension
  • Space: O(m * n) for the quantized and dequantized matrices