← back

Compute Attention Memory Traffic and FLOPs

#416 · Inference · Medium

⊣ Solve on deep-ml.com

Problem

For a multi-head self-attention layer, compute the total memory traffic (bytes moved to/from memory) and total FLOPs. Given batch size B, sequence length S, number of heads H, and head dimension D (with model dimension d_model = H * D), and assuming a specific data type size in bytes.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def attention_memory_and_flops(
    B: int, S: int, H: int, D: int, dtype_bytes: int = 2
) -> dict:
    d_model = H * D
    # QKV projection: 3 weight matrices of shape (d_model, d_model)
    qkv_param_bytes = 3 * d_model * d_model * dtype_bytes
    qkv_activation_bytes = B * S * d_model * dtype_bytes  # input
    qkv_output_bytes = 3 * B * S * d_model * dtype_bytes  # Q, K, V
    qkv_memory = qkv_param_bytes + qkv_activation_bytes + qkv_output_bytes
    qkv_flops = 3 * 2 * B * S * d_model * d_model  # 3 matmuls

    # Attention scores: Q @ K^T -> (B, H, S, S)
    attn_scores_bytes = B * H * S * S * dtype_bytes
    q_k_read_bytes = 2 * B * S * d_model * dtype_bytes
    attn_score_memory = q_k_read_bytes + attn_scores_bytes
    attn_score_flops = 2 * B * H * S * S * D  # batched matmul

    # Softmax: read and write attention scores (in-place)
    softmax_memory = 2 * attn_scores_bytes
    softmax_flops = B * H * S * S * 5  # exp, sum, div per element approx

    # Attention output: scores @ V -> (B, H, S, D)
    attn_v_read = attn_scores_bytes + B * S * d_model * dtype_bytes
    attn_v_out = B * S * d_model * dtype_bytes
    attn_v_memory = attn_v_read + attn_v_out
    attn_v_flops = 2 * B * H * S * S * D

    # Output projection
    out_proj_param = d_model * d_model * dtype_bytes
    out_proj_memory = out_proj_param + B * S * d_model * dtype_bytes + B * S * d_model * dtype_bytes
    out_proj_flops = 2 * B * S * d_model * d_model

    total_memory = qkv_memory + attn_score_memory + softmax_memory + attn_v_memory + out_proj_memory
    total_flops = qkv_flops + attn_score_flops + softmax_flops + attn_v_flops + out_proj_flops

    return {
        "total_memory_bytes": total_memory,
        "total_flops": total_flops,
        "arithmetic_intensity": round(total_flops / total_memory, 4) if total_memory > 0 else 0
    }

Explanation

  1. QKV projections: Three linear layers project the input to Q, K, V. Each is a matrix multiply of shape (B*S, d_model) x (d_model, d_model). Memory includes reading weights + input and writing outputs.
  2. Attention scores: Q @ K^T produces a (B, H, S, S) attention matrix. FLOPs = 2 B H S^2 D.
  3. Softmax: Applied element-wise over the S dimension. Reads and writes the full attention matrix.
  4. Score @ V: Multiplies attention weights by V to get the context output.
  5. Output projection: A final linear layer maps concatenated heads back to d_model dimensions.

Complexity

  • Time: O(1) to compute the formulas
  • Space: O(1)