#417 · Inference · Medium
⊣ Solve on deep-ml.comClassify whether the prefill phase and decode phase of LLM inference are compute-bound or memory-bound. The prefill phase processes all prompt tokens in parallel (large batch matrix multiplies), while the decode phase generates one token at a time (matrix-vector operations). Given model config and hardware specs, compute arithmetic intensity for each phase and classify.
def classify_prefill_vs_decode(
d_model: int,
n_layers: int,
seq_len: int,
peak_flops: float,
peak_bandwidth: float,
dtype_bytes: int = 2
) -> dict:
ridge_point = peak_flops / peak_bandwidth
# Prefill: batched matmul (seq_len, d_model) x (d_model, d_model)
# FLOPs per layer ~ 2 * seq_len * d_model * d_model (for one linear)
# Each transformer layer has ~4 such linear ops (Q, K, V, O) + 2 FFN
# Roughly 12 * d_model^2 parameters per layer
prefill_flops_per_layer = 2 * seq_len * 12 * d_model * d_model
prefill_mem_per_layer = 12 * d_model * d_model * dtype_bytes + 2 * seq_len * d_model * dtype_bytes
prefill_total_flops = prefill_flops_per_layer * n_layers
prefill_total_mem = prefill_mem_per_layer * n_layers
prefill_ai = prefill_total_flops / prefill_total_mem if prefill_total_mem > 0 else 0
# Decode: matrix-vector (1, d_model) x (d_model, d_model)
decode_flops_per_layer = 2 * 1 * 12 * d_model * d_model
decode_mem_per_layer = 12 * d_model * d_model * dtype_bytes + 2 * 1 * d_model * dtype_bytes
decode_total_flops = decode_flops_per_layer * n_layers
decode_total_mem = decode_mem_per_layer * n_layers
decode_ai = decode_total_flops / decode_total_mem if decode_total_mem > 0 else 0
return {
"prefill": {
"arithmetic_intensity": round(prefill_ai, 4),
"classification": "compute-bound" if prefill_ai >= ridge_point else "memory-bound"
},
"decode": {
"arithmetic_intensity": round(decode_ai, 4),
"classification": "compute-bound" if decode_ai >= ridge_point else "memory-bound"
},
"ridge_point": round(ridge_point, 4)
}seq_len tokens at once, making it a batched GEMM. The arithmetic intensity scales with seq_len, pushing it toward compute-bound territory.