KV Cache Quantization with FP8
Running 70B models with long contexts is basically a VRAM budgeting exercise. The KV cache grows linearly with sequence length and number of layers — at FP16, a 70B model with 8k context across 80 layers uses ~40GB just for KV.
FP8 KV cache
Switching KV to FP8 halves that. Modern H100s have native FP8 support so there's no emulation overhead. The tricky part is the quantization scheme — per-tensor vs per-token vs per-head all have different accuracy/overhead tradeoffs.
Per-token quantization (separate scale per token per head) gives the best accuracy because attention scores are highly non-uniform across tokens. Per-tensor is faster to dequantize but loses more accuracy on long-range dependencies.
What I tested
On Llama 3 70B with 16k context:
- FP16 KV: 78GB total VRAM, doesn't fit on 2x A6000
- FP8 KV (per-tensor): 51GB, fits, ~0.3 perplexity regression
- FP8 KV (per-token): 52GB (small overhead for scales), ~0.1 perplexity regression
Worth it. The per-token scales cost almost nothing vs the accuracy gain.
Gotcha
You need to accumulate attention scores in FP32 even if KV is FP8, otherwise softmax precision tanks. Most frameworks do this automatically but worth verifying.