GPU Memory for LLMs

Where Every Byte Goes — Weights, KV Cache, Activations, Gradients, and the Tricks to Fit Bigger

GPU memory is the binding constraint of LLMs. Training a 70B model takes ~1.2 TB of HBM (way past any single device); inference needs ~140 GB just for fp16 weights, plus a KV cache that grows linearly with context. The memory equation has four terms — model weights, KV cache, activations, gradients — and each gets attacked by a different family of optimizations: quantization for weights, GQA/MQA for KV cache, FlashAttention/recompute for activations, and ZeRO/tensor/pipeline parallelism to shard everything across devices.

Memory Anatomy: Where the Bytes Live

A 7B model in fp16, training vs inference. Same weights, very different total.

Inference (7B fp16, 8k ctx, batch=8) Weights · 14 GB KV cache · 8 GB Acts · 2GB Total ~24 GB Training (7B fp16, optimizer Adam, no recompute) W·14GB ∇W·14GB Adam state · 28 GB Activations · 50–80 GB ~120 GB Inference cost per request (8k context): Weights: ONE-TIME — shared across all requests KV cache: PER-REQUEST — grows linearly with sequence length Activations: TRANSIENT — only the current step's intermediates

Key Numbers

2 bytes
fp16/bf16 weight, the inference baseline
12-16 bytes
Per parameter during training (weight + grad + Adam state)
2 GB
KV cache per Llama-3 70B request at 8k context (GQA)
5-8×
KV cache reduction with GQA (8 KV heads vs 64 Q heads)
O(N)
FlashAttention activation memory · vs O(N²) standard
80 GB
A100 / H100 single-device HBM
192 GB
H200 HBM3e (the new max)

1. The Memory Equation

For inference of a transformer with N parameters in fp16, batch B, context L, hidden dim D, layers L_n, and num_kv_heads h_kv:

# Inference memory (per replica)
mem_weights      = 2 * N
mem_kv_cache     = 2 * 2 * B * L * L_n * h_kv * head_dim
                   ↑     ↑      K and V, fp16, batch, ctx, layers, heads
mem_activations  = roughly 2 * B * L * D    (one step's intermediates)
mem_total_inf    = mem_weights + mem_kv_cache + mem_activations

# Training memory (per replica, fp16 + Adam fp32 master)
mem_weights      = 2 * N
mem_grads        = 2 * N
mem_adam_state   = 2 * 4 * N           (m and v, fp32 each)
mem_activations  = 2 * B * L * D * L_n  (saved for backward) — DOMINANT
mem_total_train  = ~16 * N + activations

The "16N" rule of thumb: training a model in standard mixed-precision Adam costs ~16 bytes per parameter for state alone. For a 70B model, that's 1.12 TB before activations. This is why training pretty much always requires sharding.

2. FlashAttention

Don't materialize the n×n attention matrix. Tile it through SRAM.

The standard attention implementation computes QKᵀ as an explicit n×n matrix in HBM, applies softmax, multiplies by V. For n=8192, that's 256 MB per head per layer just for the score matrix — and HBM I/O dominates the latency.

FlashAttention (Dao 2022) recasts attention as a tiled, fused kernel:

# Process Q, K, V in tiles of (block_M, block_N)
# Use online softmax (Milakov & Gimelshein 2018) to avoid storing
# the full n×n matrix anywhere.

for q_block in tiles_of(Q, block_M):
    o_block = zeros(...)
    m_block = -inf  # running max
    l_block = 0     # running denominator
    for k_block, v_block in tiles_of(K, V, block_N):
        s = q_block @ k_block.T / sqrt(d)
        m_new = max(m_block, s.max())
        # rescale running output and denominator
        ...
        o_block = ...
    write o_block to HBM

Memory: O(n) instead of O(n²). HBM traffic: ~10× lower. Throughput: 2-4× faster on long contexts. FlashAttention-2 (2023) added better parallelism across Q heads. FlashAttention-3 (2024) added fp8 and async H100 features.

3. Multi-Query and Grouped-Query Attention

Most of the KV cache is wasted on duplicate-ish information. Share K and V across heads.

VariantQ headsKV headsKV cache sizeUsed by
MHA (full multi-head)32321.0×GPT-3, Llama-1, Llama-2 7B
GQA (Grouped-Query)3280.25×Llama-2 70B, Llama-3, Mistral
MQA (Multi-Query)3210.03×PaLM, Falcon, original CodeGen

MQA gives the biggest savings but loses ~1% on benchmarks. GQA (Ainslie 2023) is the compromise that won — it preserves 99%+ of MHA quality with 4-8× KV cache reduction. Llama-2 70B uses 8 KV groups for 64 Q heads.

4. Activation Recomputation (Gradient Checkpointing)

Backward pass needs activations from forward pass. Storing them all costs O(B·L·D·L_n) — the dominant term in training memory.

Activation checkpointing trades compute for memory: only save activations at boundaries (every 4-8 layers), and recompute the in-between activations during backward.

# Without checkpointing: store every layer's output (96 layers → 96 saves)
# With checkpointing every 4 layers: store 24 saves
# Memory: 4× lower
# Compute: ~1.33× more (one extra forward pass per checkpoint segment)

# PyTorch:
from torch.utils.checkpoint import checkpoint
def block_fn(x): return self.layer(x)
out = checkpoint(block_fn, x, use_reentrant=False)

This is the difference between fitting a 70B model on 8 H100s vs 16. Most modern training stacks (DeepSpeed, FSDP) enable it by default.

5. Tensor, Pipeline, and ZeRO Parallelism

StrategyWhat's shardedCommWhen
Data parallelNothing — replicate, average gradsAllReduce on gradsModel fits per device
Tensor parallelEach weight matrix split across devicesAllReduce per layerWithin a node (NVLink)
Pipeline parallelLayers split across devices in a pipelineActivations between stagesAcross nodes (slower link)
ZeRO-1Optimizer stateAllReduce + AllGatherSave 4× memory, ~no slowdown
ZeRO-2+ gradients+ ReduceScatterSave 8× memory
ZeRO-3 (FSDP)+ parameters+ AllGather paramsSave 16× memory; large comm

Modern training stacks (Megatron-LM, DeepSpeed, PyTorch FSDP) combine all three. A typical 70B run: tensor-parallel within an 8-GPU node, pipeline-parallel across 4 nodes, ZeRO-3 sharding optimizer state across the global group, gradient accumulation to hit a 4M-token batch.

6. Inference Memory Math (Practical)

Concrete worked example — Llama-3 70B serving on H100 (80 GB):

Weights (fp16):                  140 GB  ← won't fit on 1 GPU
Weights (fp8):                    70 GB  ← fits with no headroom
Weights (int4 GPTQ):              35 GB  ← comfortable on 1× H100

KV cache per token (GQA, 80 layers, 8 heads, 128 dim, fp16):
  2 * 2 * 80 * 8 * 128 = 320 KB

KV cache per request (8k ctx):    320 KB * 8192 = 2.5 GB

With int4 weights + 2.5 GB/req KV:
  35 GB weights + (80 - 35) / 2.5 = 18 concurrent 8k-context requests

With KV cache quantized to int8:
  KV halves → 36 concurrent requests

Tradeoffs

OptimizationMemory savedCost
int4 weights4× on weights~0.5% perplexity
GQA4-8× on KV cacheTiny accuracy hit
FlashAttentionMost activation memoryNone — pure win
Activation checkpoint4-8× on activations~33% extra compute
ZeRO-3 / FSDP~16× across N devicesHeavy AllGather traffic
fp8 KV cache2× on KVSmall quality loss; per-tensor scales

FAQ

Why does training need so much more memory than inference?

Three reasons: (1) you need gradients (same size as weights), (2) optimizer state for Adam is 2× weights in fp32, (3) activations from forward must persist for backward — far more than a single decode step's intermediates. Total: ~16× weights for naive training vs ~1× for inference (excluding KV cache).

What is "the activation cliff" people talk about?

Activation memory grows quadratically with sequence length without FlashAttention (because of the n×n attention matrix). At ~8k context on a 70B model, this becomes the dominant term and you fall off a cliff. FlashAttention removes the quadratic part.

Can I quantize the KV cache?

Yes. int8 KV cache works well with per-channel scaling. fp8 KV is becoming standard on H100. Lower than int8 typically hurts quality because attention is sensitive to KV precision (it's used in QKᵀ at every step).

What's the difference between ZeRO-3 and FSDP?

Conceptually identical: shard parameters across devices, AllGather them just before they're needed, then immediately re-shard. ZeRO-3 is the DeepSpeed implementation; FSDP is the PyTorch-native one. FSDP is now the default for most teams because it composes better with PyTorch features.

Why is HBM bandwidth (not capacity) the inference bottleneck?

For decode steps with batch=1, the matmul has dimensions (1 × D) × (D × D) — almost no compute reuse. The GPU spends most time waiting for HBM reads. Larger batches amortize the weight-loading cost across requests, which is why batching matters so much.