CUDA Kernel Fusion Basics
Kernel fusion is one of those optimizations that sounds obvious in hindsight but took me embarrassingly long to internalize.
The problem
Every CUDA kernel launch involves a roundtrip to global memory. For element-wise ops chained together — say, a ReLU after a GELU after a layer norm — you're doing 3 separate read-write cycles to HBM. At 2TB/s bandwidth (A100), that sounds fine until you realize your compute is sitting idle waiting for memory.
The fix
Fuse them. Write a single kernel that does all three ops in registers, touching HBM exactly once per element.
# Bad — three kernel launches, three HBM roundtrips
x = layer_norm(x)
x = gelu(x)
x = relu(x)
# Good (conceptually) — one fused kernel
x = fused_layer_norm_gelu_relu(x)When it actually matters
Fusion pays off when:
- Ops are memory-bound (not compute-bound)
- The fused result fits in L2/shared memory
- You're doing this in a hot path (forward/backward of large models)
Flash Attention is the canonical example — 5-10x speedup over naive attention purely from fusion + tiling.
What I'm still figuring out
Triton makes this much more approachable than raw CUDA. But getting the tiling right for non-trivial ops is still painful. Next up: reading the Triton matmul tutorial carefully.