← back

Estimate KV Cache Size from Model Config

#418 · Inference · Medium

⊣ Solve on deep-ml.com

Problem

Estimate the KV cache memory usage for a transformer model given its configuration: number of layers, number of KV heads, head dimension, sequence length, batch size, and data type. The KV cache stores key and value tensors for all previous tokens across all layers to avoid recomputation during autoregressive decoding.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def estimate_kv_cache_size(
    n_layers: int,
    n_kv_heads: int,
    head_dim: int,
    seq_len: int,
    batch_size: int = 1,
    dtype_bytes: int = 2
) -> dict:
    # Per layer: 2 (K and V) * batch_size * n_kv_heads * seq_len * head_dim * dtype_bytes
    per_layer_bytes = 2 * batch_size * n_kv_heads * seq_len * head_dim * dtype_bytes
    total_bytes = per_layer_bytes * n_layers
    total_mb = total_bytes / (1024 ** 2)
    total_gb = total_bytes / (1024 ** 3)
    return {
        "per_layer_bytes": per_layer_bytes,
        "total_bytes": total_bytes,
        "total_mb": round(total_mb, 2),
        "total_gb": round(total_gb, 4)
    }

Explanation

  1. For each layer, we store both a Key and a Value tensor. Each has shape (batch_size, n_kv_heads, seq_len, head_dim).
  2. The factor of 2 accounts for both K and V.
  3. n_kv_heads may differ from the number of attention heads when using Grouped-Query Attention (GQA) or Multi-Query Attention (MQA), which significantly reduces KV cache size.
  4. For example, Llama 3 70B with 80 layers, 8 KV heads, head_dim=128, seq_len=8192, FP16: total = 80 2 8 8192 128 * 2 = ~2.6 GB per request.

Complexity

  • Time: O(1)
  • Space: O(1)