Autoregressive Generation and KV Caching in Transformers
Given a sequence, each token's representation is projected into three vectors:
Parallel Training with Full Sequence
During training, the entire sequence is processed in parallel.
Shapes:
This is a matrix-matrix multiplication. A causal mask is applied to $QK^T$ to prevent tokens from attending to future positions:
The mask is applied to the scores before softmax, not after.
The $-\infty$ values become 0 after softmax, and the remaining values in each row sum to 1. Note that you add the mask not multiply.
Autoregressive Generation
During generation however, tokens are produced one at a time. At step $t$:
Shapes:
This is a vector-matrix multiplication. No mask is needed because future tokens do not exist yet.
KV Caching
At each generation step $t$, the attention operation requires:
| Component | What's Needed | Size |
|---|---|---|
| $Q$ | Only $Q_t$ (current token) | $(1, d_k)$ |
| $K$ | All previous: $K_1, K_2, \ldots, K_t$ | $(t, d_k)$ |
| $V$ | All previous: $V_1, V_2, \ldots, V_t$ | $(t, d_v)$ |
The current token's query $Q_t$ compares against all previous keys $K_{1:t}$ to compute attention weights, then retrieves a weighted combination of all previous values $V_{1:t}$.
Rather than recomputing $K$ and $V$ for all previous tokens at each step, we cache them:
- At step $t$, compute $K_t$ and $V_t$ for the new token and append to the cache.
- Reuse cached $K_{1:t-1}$ and $V_{1:t-1}$ from previous steps.
- $Q_t$ is not cached because it is only used at step $t$ and never referenced again.
Without caching, all K and V projections would need to be recomputed at every step, resulting in $O(n^2)$ computation for generating each token at position $n$ and overall $O(n^3)$. With caching, generating each token becomes $O(n)$, and the overall complexity is therefore $O(n^2)$.