← back

Kernel Fusion Memory Savings Calculator

#424 · Inference · Medium

⊣ Solve on deep-ml.com

Problem

Compute the memory savings from kernel fusion. When multiple sequential GPU operations are fused into a single kernel, intermediate results stay in registers/shared memory instead of being written to and read from global memory. Given a list of operations with their input/output sizes, calculate the memory traffic with and without fusion.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def kernel_fusion_savings(
    operations: list[dict],
    dtype_bytes: int = 2
) -> dict:
    # Each op: {"name": str, "input_size": int, "output_size": int}
    # Without fusion: each op reads input from global memory and writes output
    unfused_reads = 0
    unfused_writes = 0
    for op in operations:
        unfused_reads += op["input_size"] * dtype_bytes
        unfused_writes += op["output_size"] * dtype_bytes

    # With fusion: only read the first input and write the last output
    # Intermediate tensors stay in registers/shared memory
    if len(operations) == 0:
        return {"unfused_bytes": 0, "fused_bytes": 0, "savings_bytes": 0, "savings_pct": 0.0}

    fused_reads = operations[0]["input_size"] * dtype_bytes
    fused_writes = operations[-1]["output_size"] * dtype_bytes

    unfused_total = unfused_reads + unfused_writes
    fused_total = fused_reads + fused_writes
    savings = unfused_total - fused_total
    savings_pct = (savings / unfused_total * 100) if unfused_total > 0 else 0.0

    return {
        "unfused_bytes": unfused_total,
        "fused_bytes": fused_total,
        "savings_bytes": savings,
        "savings_pct": round(savings_pct, 2),
        "speedup_estimate": round(unfused_total / fused_total, 2) if fused_total > 0 else 0
    }

Explanation

  1. Without fusion: each operator reads its inputs from GPU global (HBM) memory and writes outputs back. The next operator then reads those outputs as its inputs. This creates redundant memory traffic for intermediate tensors.
  2. With fusion: a fused kernel chains operations so intermediate results stay in fast on-chip memory (registers or shared memory). Only the initial input and final output touch global memory.
  3. For example, fusing LayerNorm + Linear + GELU saves 2 intermediate read/write round trips to HBM.
  4. The savings are especially significant for memory-bound operations like element-wise ops and reductions.
  5. The estimated speedup assumes the fused kernel is memory-bandwidth limited.

Complexity

  • Time: O(n) where n is the number of operations
  • Space: O(1)