Autoregressive Generation and KV Caching in Transformers

Given a sequence, each token's representation is projected into three vectors:

$$ Q = XW_Q \quad K = XW_K \quad V = XW_V $$
The self-attention output is computed as:
$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V $$

Parallel Training with Full Sequence

During training, the entire sequence is processed in parallel.

Shapes:

$$ Q \in \mathbb{R}^{(n \times d_k)}, \quad K \in \mathbb{R}^{(n \times d_k)}, \quad V \in \mathbb{R}^{(n \times d_v)} $$
Computation:
$$ QK^T \in \mathbb{R}^{(n \times n)} $$
$$ \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) \cdot V \in \mathbb{R}^{(n \times d_v)} $$

This is a matrix-matrix multiplication. A causal mask is applied to $QK^T$ to prevent tokens from attending to future positions:
$$ M_{ij} = \begin{cases} 0 & \text{if } i \geq j \\ -\infty & \text{if } i < j \end{cases} $$

The mask is applied to the scores before softmax, not after.

$$ A = \text{softmax}\left(\frac{QK^T + M}{\sqrt{d_k}}\right) $$

The $-\infty$ values become 0 after softmax, and the remaining values in each row sum to 1. Note that you add the mask not multiply.

Autoregressive Generation

During generation however, tokens are produced one at a time. At step $t$:

Shapes:

$$ Q_t \in \mathbb{R}^{(1 \times d_k)}, \quad K_{1:t} \in \mathbb{R}^{(t \times d_k)}, \quad V_{1:t} \in \mathbb{R}^{(t \times d_v)} $$
Computation:
$$ Q_t K_{1:t}^T \in \mathbb{R}^{(1 \times t)} $$
$$ \text{softmax}\left(\frac{Q_t K_{1:t}^T}{\sqrt{d_k}}\right) \cdot V_{1:t} \in \mathbb{R}^{(1 \times d_v)} $$

This is a vector-matrix multiplication. No mask is needed because future tokens do not exist yet.

KV Caching

At each generation step $t$, the attention operation requires:

Component What's Needed Size
$Q$ Only $Q_t$ (current token) $(1, d_k)$
$K$ All previous: $K_1, K_2, \ldots, K_t$ $(t, d_k)$
$V$ All previous: $V_1, V_2, \ldots, V_t$ $(t, d_v)$

The current token's query $Q_t$ compares against all previous keys $K_{1:t}$ to compute attention weights, then retrieves a weighted combination of all previous values $V_{1:t}$.

Rather than recomputing $K$ and $V$ for all previous tokens at each step, we cache them:

  • At step $t$, compute $K_t$ and $V_t$ for the new token and append to the cache.
  • Reuse cached $K_{1:t-1}$ and $V_{1:t-1}$ from previous steps.
  • $Q_t$ is not cached because it is only used at step $t$ and never referenced again.

Without caching, all K and V projections would need to be recomputed at every step, resulting in $O(n^2)$ computation for generating each token at position $n$ and overall $O(n^3)$. With caching, generating each token becomes $O(n)$, and the overall complexity is therefore $O(n^2)$.