GPU Memory for LLMs
Where Every Byte Goes — Weights, KV Cache, Activations, Gradients, and the Tricks to Fit Bigger
GPU memory is the binding constraint of LLMs. Training a 70B model takes ~1.2 TB of HBM (way past any single device); inference needs ~140 GB just for fp16 weights, plus a KV cache that grows linearly with context. The memory equation has four terms — model weights, KV cache, activations, gradients — and each gets attacked by a different family of optimizations: quantization for weights, GQA/MQA for KV cache, FlashAttention/recompute for activations, and ZeRO/tensor/pipeline parallelism to shard everything across devices.
Memory Anatomy: Where the Bytes Live
A 7B model in fp16, training vs inference. Same weights, very different total.
Key Numbers
1. The Memory Equation
For inference of a transformer with N parameters in fp16, batch B, context L, hidden dim D, layers L_n, and num_kv_heads h_kv:
# Inference memory (per replica)
mem_weights = 2 * N
mem_kv_cache = 2 * 2 * B * L * L_n * h_kv * head_dim
↑ ↑ K and V, fp16, batch, ctx, layers, heads
mem_activations = roughly 2 * B * L * D (one step's intermediates)
mem_total_inf = mem_weights + mem_kv_cache + mem_activations
# Training memory (per replica, fp16 + Adam fp32 master)
mem_weights = 2 * N
mem_grads = 2 * N
mem_adam_state = 2 * 4 * N (m and v, fp32 each)
mem_activations = 2 * B * L * D * L_n (saved for backward) — DOMINANT
mem_total_train = ~16 * N + activations The "16N" rule of thumb: training a model in standard mixed-precision Adam costs ~16 bytes per parameter for state alone. For a 70B model, that's 1.12 TB before activations. This is why training pretty much always requires sharding.
2. FlashAttention
Don't materialize the n×n attention matrix. Tile it through SRAM.
The standard attention implementation computes QKᵀ as an explicit n×n matrix in HBM, applies softmax, multiplies by V. For n=8192, that's 256 MB per head per layer just for the score matrix — and HBM I/O dominates the latency.
FlashAttention (Dao 2022) recasts attention as a tiled, fused kernel:
# Process Q, K, V in tiles of (block_M, block_N)
# Use online softmax (Milakov & Gimelshein 2018) to avoid storing
# the full n×n matrix anywhere.
for q_block in tiles_of(Q, block_M):
o_block = zeros(...)
m_block = -inf # running max
l_block = 0 # running denominator
for k_block, v_block in tiles_of(K, V, block_N):
s = q_block @ k_block.T / sqrt(d)
m_new = max(m_block, s.max())
# rescale running output and denominator
...
o_block = ...
write o_block to HBM Memory: O(n) instead of O(n²). HBM traffic: ~10× lower. Throughput: 2-4× faster on long contexts. FlashAttention-2 (2023) added better parallelism across Q heads. FlashAttention-3 (2024) added fp8 and async H100 features.
3. Multi-Query and Grouped-Query Attention
Most of the KV cache is wasted on duplicate-ish information. Share K and V across heads.
| Variant | Q heads | KV heads | KV cache size | Used by |
|---|---|---|---|---|
| MHA (full multi-head) | 32 | 32 | 1.0× | GPT-3, Llama-1, Llama-2 7B |
| GQA (Grouped-Query) | 32 | 8 | 0.25× | Llama-2 70B, Llama-3, Mistral |
| MQA (Multi-Query) | 32 | 1 | 0.03× | PaLM, Falcon, original CodeGen |
MQA gives the biggest savings but loses ~1% on benchmarks. GQA (Ainslie 2023) is the compromise that won — it preserves 99%+ of MHA quality with 4-8× KV cache reduction. Llama-2 70B uses 8 KV groups for 64 Q heads.
4. Activation Recomputation (Gradient Checkpointing)
Backward pass needs activations from forward pass. Storing them all costs O(B·L·D·L_n) — the dominant term in training memory.
Activation checkpointing trades compute for memory: only save activations at boundaries (every 4-8 layers), and recompute the in-between activations during backward.
# Without checkpointing: store every layer's output (96 layers → 96 saves)
# With checkpointing every 4 layers: store 24 saves
# Memory: 4× lower
# Compute: ~1.33× more (one extra forward pass per checkpoint segment)
# PyTorch:
from torch.utils.checkpoint import checkpoint
def block_fn(x): return self.layer(x)
out = checkpoint(block_fn, x, use_reentrant=False) This is the difference between fitting a 70B model on 8 H100s vs 16. Most modern training stacks (DeepSpeed, FSDP) enable it by default.
5. Tensor, Pipeline, and ZeRO Parallelism
| Strategy | What's sharded | Comm | When |
|---|---|---|---|
| Data parallel | Nothing — replicate, average grads | AllReduce on grads | Model fits per device |
| Tensor parallel | Each weight matrix split across devices | AllReduce per layer | Within a node (NVLink) |
| Pipeline parallel | Layers split across devices in a pipeline | Activations between stages | Across nodes (slower link) |
| ZeRO-1 | Optimizer state | AllReduce + AllGather | Save 4× memory, ~no slowdown |
| ZeRO-2 | + gradients | + ReduceScatter | Save 8× memory |
| ZeRO-3 (FSDP) | + parameters | + AllGather params | Save 16× memory; large comm |
Modern training stacks (Megatron-LM, DeepSpeed, PyTorch FSDP) combine all three. A typical 70B run: tensor-parallel within an 8-GPU node, pipeline-parallel across 4 nodes, ZeRO-3 sharding optimizer state across the global group, gradient accumulation to hit a 4M-token batch.
6. Inference Memory Math (Practical)
Concrete worked example — Llama-3 70B serving on H100 (80 GB):
Weights (fp16): 140 GB ← won't fit on 1 GPU
Weights (fp8): 70 GB ← fits with no headroom
Weights (int4 GPTQ): 35 GB ← comfortable on 1× H100
KV cache per token (GQA, 80 layers, 8 heads, 128 dim, fp16):
2 * 2 * 80 * 8 * 128 = 320 KB
KV cache per request (8k ctx): 320 KB * 8192 = 2.5 GB
With int4 weights + 2.5 GB/req KV:
35 GB weights + (80 - 35) / 2.5 = 18 concurrent 8k-context requests
With KV cache quantized to int8:
KV halves → 36 concurrent requests Tradeoffs
| Optimization | Memory saved | Cost |
|---|---|---|
| int4 weights | 4× on weights | ~0.5% perplexity |
| GQA | 4-8× on KV cache | Tiny accuracy hit |
| FlashAttention | Most activation memory | None — pure win |
| Activation checkpoint | 4-8× on activations | ~33% extra compute |
| ZeRO-3 / FSDP | ~16× across N devices | Heavy AllGather traffic |
| fp8 KV cache | 2× on KV | Small quality loss; per-tensor scales |
FAQ
Why does training need so much more memory than inference?
Three reasons: (1) you need gradients (same size as weights), (2) optimizer state for Adam is 2× weights in fp32, (3) activations from forward must persist for backward — far more than a single decode step's intermediates. Total: ~16× weights for naive training vs ~1× for inference (excluding KV cache).
What is "the activation cliff" people talk about?
Activation memory grows quadratically with sequence length without FlashAttention (because of the n×n attention matrix). At ~8k context on a 70B model, this becomes the dominant term and you fall off a cliff. FlashAttention removes the quadratic part.
Can I quantize the KV cache?
Yes. int8 KV cache works well with per-channel scaling. fp8 KV is becoming standard on H100. Lower than int8 typically hurts quality because attention is sensitive to KV precision (it's used in QKᵀ at every step).
What's the difference between ZeRO-3 and FSDP?
Conceptually identical: shard parameters across devices, AllGather them just before they're needed, then immediately re-shard. ZeRO-3 is the DeepSpeed implementation; FSDP is the PyTorch-native one. FSDP is now the default for most teams because it composes better with PyTorch features.
Why is HBM bandwidth (not capacity) the inference bottleneck?
For decode steps with batch=1, the matmul has dimensions (1 × D) × (D × D) — almost no compute reuse. The GPU spends most time waiting for HBM reads. Larger batches amortize the weight-loading cost across requests, which is why batching matters so much.