nick.dev
← learn-log
cudagpuperformance

CUDA Kernel Fusion Basics

Kernel fusion is one of those optimizations that sounds obvious in hindsight but took me embarrassingly long to internalize.

The problem

Every CUDA kernel launch involves a roundtrip to global memory. For element-wise ops chained together — say, a ReLU after a GELU after a layer norm — you're doing 3 separate read-write cycles to HBM. At 2TB/s bandwidth (A100), that sounds fine until you realize your compute is sitting idle waiting for memory.

The fix

Fuse them. Write a single kernel that does all three ops in registers, touching HBM exactly once per element.

# Bad — three kernel launches, three HBM roundtrips
x = layer_norm(x)
x = gelu(x)
x = relu(x)
 
# Good (conceptually) — one fused kernel
x = fused_layer_norm_gelu_relu(x)

When it actually matters

Fusion pays off when:

  • Ops are memory-bound (not compute-bound)
  • The fused result fits in L2/shared memory
  • You're doing this in a hot path (forward/backward of large models)

Flash Attention is the canonical example — 5-10x speedup over naive attention purely from fusion + tiling.

What I'm still figuring out

Triton makes this much more approachable than raw CUDA. But getting the tiling right for non-trivial ops is still painful. Next up: reading the Triton matmul tutorial carefully.