← back

Classify LLM Prefill vs Decode as Compute-Bound or Memory-Bound

#417 · Inference · Medium

⊣ Solve on deep-ml.com

Problem

Classify whether the prefill phase and decode phase of LLM inference are compute-bound or memory-bound. The prefill phase processes all prompt tokens in parallel (large batch matrix multiplies), while the decode phase generates one token at a time (matrix-vector operations). Given model config and hardware specs, compute arithmetic intensity for each phase and classify.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def classify_prefill_vs_decode(
    d_model: int,
    n_layers: int,
    seq_len: int,
    peak_flops: float,
    peak_bandwidth: float,
    dtype_bytes: int = 2
) -> dict:
    ridge_point = peak_flops / peak_bandwidth

    # Prefill: batched matmul (seq_len, d_model) x (d_model, d_model)
    # FLOPs per layer ~ 2 * seq_len * d_model * d_model (for one linear)
    # Each transformer layer has ~4 such linear ops (Q, K, V, O) + 2 FFN
    # Roughly 12 * d_model^2 parameters per layer
    prefill_flops_per_layer = 2 * seq_len * 12 * d_model * d_model
    prefill_mem_per_layer = 12 * d_model * d_model * dtype_bytes + 2 * seq_len * d_model * dtype_bytes
    prefill_total_flops = prefill_flops_per_layer * n_layers
    prefill_total_mem = prefill_mem_per_layer * n_layers
    prefill_ai = prefill_total_flops / prefill_total_mem if prefill_total_mem > 0 else 0

    # Decode: matrix-vector (1, d_model) x (d_model, d_model)
    decode_flops_per_layer = 2 * 1 * 12 * d_model * d_model
    decode_mem_per_layer = 12 * d_model * d_model * dtype_bytes + 2 * 1 * d_model * dtype_bytes
    decode_total_flops = decode_flops_per_layer * n_layers
    decode_total_mem = decode_mem_per_layer * n_layers
    decode_ai = decode_total_flops / decode_total_mem if decode_total_mem > 0 else 0

    return {
        "prefill": {
            "arithmetic_intensity": round(prefill_ai, 4),
            "classification": "compute-bound" if prefill_ai >= ridge_point else "memory-bound"
        },
        "decode": {
            "arithmetic_intensity": round(decode_ai, 4),
            "classification": "compute-bound" if decode_ai >= ridge_point else "memory-bound"
        },
        "ridge_point": round(ridge_point, 4)
    }

Explanation

  1. Prefill processes seq_len tokens at once, making it a batched GEMM. The arithmetic intensity scales with seq_len, pushing it toward compute-bound territory.
  2. Decode processes 1 token at a time, making each layer a matrix-vector product. The weights must be loaded from memory but only do a small amount of compute, so the arithmetic intensity is low and typically memory-bound.
  3. The ridge point (peak FLOP/s / peak bandwidth) determines the crossover. For modern GPUs like A100 (312 TFLOP/s FP16, 2 TB/s bandwidth), the ridge point is ~156 FLOP/byte.
  4. Prefill with a long sequence easily exceeds this; decode with batch size 1 is well below it.

Complexity

  • Time: O(1)
  • Space: O(1)