Estimate the KV cache memory usage for a transformer model given its configuration: number of layers, number of KV heads, head dimension, sequence length, batch size, and data type. The KV cache stores key and value tensors for all previous tokens across all layers to avoid recomputation during autoregressive decoding.
def estimate_kv_cache_size(
n_layers: int,
n_kv_heads: int,
head_dim: int,
seq_len: int,
batch_size: int = 1,
dtype_bytes: int = 2
) -> dict:
# Per layer: 2 (K and V) * batch_size * n_kv_heads * seq_len * head_dim * dtype_bytes
per_layer_bytes = 2 * batch_size * n_kv_heads * seq_len * head_dim * dtype_bytes
total_bytes = per_layer_bytes * n_layers
total_mb = total_bytes / (1024 ** 2)
total_gb = total_bytes / (1024 ** 3)
return {
"per_layer_bytes": per_layer_bytes,
"total_bytes": total_bytes,
"total_mb": round(total_mb, 2),
"total_gb": round(total_gb, 4)
}(batch_size, n_kv_heads, seq_len, head_dim).n_kv_heads may differ from the number of attention heads when using Grouped-Query Attention (GQA) or Multi-Query Attention (MQA), which significantly reduces KV cache size.