Compute the memory savings from kernel fusion. When multiple sequential GPU operations are fused into a single kernel, intermediate results stay in registers/shared memory instead of being written to and read from global memory. Given a list of operations with their input/output sizes, calculate the memory traffic with and without fusion.
def kernel_fusion_savings(
operations: list[dict],
dtype_bytes: int = 2
) -> dict:
# Each op: {"name": str, "input_size": int, "output_size": int}
# Without fusion: each op reads input from global memory and writes output
unfused_reads = 0
unfused_writes = 0
for op in operations:
unfused_reads += op["input_size"] * dtype_bytes
unfused_writes += op["output_size"] * dtype_bytes
# With fusion: only read the first input and write the last output
# Intermediate tensors stay in registers/shared memory
if len(operations) == 0:
return {"unfused_bytes": 0, "fused_bytes": 0, "savings_bytes": 0, "savings_pct": 0.0}
fused_reads = operations[0]["input_size"] * dtype_bytes
fused_writes = operations[-1]["output_size"] * dtype_bytes
unfused_total = unfused_reads + unfused_writes
fused_total = fused_reads + fused_writes
savings = unfused_total - fused_total
savings_pct = (savings / unfused_total * 100) if unfused_total > 0 else 0.0
return {
"unfused_bytes": unfused_total,
"fused_bytes": fused_total,
"savings_bytes": savings,
"savings_pct": round(savings_pct, 2),
"speedup_estimate": round(unfused_total / fused_total, 2) if fused_total > 0 else 0
}