For a multi-head self-attention layer, compute the total memory traffic (bytes moved to/from memory) and total FLOPs. Given batch size B, sequence length S, number of heads H, and head dimension D (with model dimension d_model = H * D), and assuming a specific data type size in bytes.
def attention_memory_and_flops(
B: int, S: int, H: int, D: int, dtype_bytes: int = 2
) -> dict:
d_model = H * D
# QKV projection: 3 weight matrices of shape (d_model, d_model)
qkv_param_bytes = 3 * d_model * d_model * dtype_bytes
qkv_activation_bytes = B * S * d_model * dtype_bytes # input
qkv_output_bytes = 3 * B * S * d_model * dtype_bytes # Q, K, V
qkv_memory = qkv_param_bytes + qkv_activation_bytes + qkv_output_bytes
qkv_flops = 3 * 2 * B * S * d_model * d_model # 3 matmuls
# Attention scores: Q @ K^T -> (B, H, S, S)
attn_scores_bytes = B * H * S * S * dtype_bytes
q_k_read_bytes = 2 * B * S * d_model * dtype_bytes
attn_score_memory = q_k_read_bytes + attn_scores_bytes
attn_score_flops = 2 * B * H * S * S * D # batched matmul
# Softmax: read and write attention scores (in-place)
softmax_memory = 2 * attn_scores_bytes
softmax_flops = B * H * S * S * 5 # exp, sum, div per element approx
# Attention output: scores @ V -> (B, H, S, D)
attn_v_read = attn_scores_bytes + B * S * d_model * dtype_bytes
attn_v_out = B * S * d_model * dtype_bytes
attn_v_memory = attn_v_read + attn_v_out
attn_v_flops = 2 * B * H * S * S * D
# Output projection
out_proj_param = d_model * d_model * dtype_bytes
out_proj_memory = out_proj_param + B * S * d_model * dtype_bytes + B * S * d_model * dtype_bytes
out_proj_flops = 2 * B * S * d_model * d_model
total_memory = qkv_memory + attn_score_memory + softmax_memory + attn_v_memory + out_proj_memory
total_flops = qkv_flops + attn_score_flops + softmax_flops + attn_v_flops + out_proj_flops
return {
"total_memory_bytes": total_memory,
"total_flops": total_flops,
"arithmetic_intensity": round(total_flops / total_memory, 4) if total_memory > 0 else 0
}