Multi-Head Latent Attention (MLA)

Paper: [2405.04434] DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model


Multi-Head Latent Attention (MLA) modifies Self-Attention Mechanism for efficient inference by compressing the input representation into a latent and then up-projecting to K and V, reducing the memory footprint of KV cache (as it only needs to store the latent), which is a big bottleneck that limits the maximum batch size and sequence length.

For example, DeepSeek-V2 uses $d_c = 512$ with $d_{model} = 5120$, giving a 20x reduction in memory ($10240 \rightarrow 512$).

Key idea

Spend more training compute (learning to compress) for inference-time memory efficiency by caching compressed token representation instead of caching its K and V.

MLA compresses the KV cache by projecting the input into a low-dimensional latent vector before expanding back to K and V. Given input $X \in \mathbb{R}^{(n \times d_{model})}$:

Q projection follows the standard approach:

$$ Q = XW_Q \quad W_Q \in \mathbb{R}^{(d_{model} \times d_{model})} $$

Latent compression projects the input into a low-dimensional latent (512 vs 5120):

$$ c = XW_{DKV} \quad W_{DKV} \in \mathbb{R}^{(d_{model} \times d_c)} $$

This latent $c \in \mathbb{R}^{(n \times d_c)}$ is used to project into K and V directly rather than X, and is what is used for caching during inference.

KV expansion projects the latent back to full-dimensional K and V during both training and inference:

$$ K = cW_{UK} \quad W_{UK} \in \mathbb{R}^{(d_c \times d_{model})} $$
$$ V = cW_{UV} \quad W_{UV} \in \mathbb{R}^{(d_c \times d_{model})} $$

From here, attention proceeds as usual:
$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$

MLA actually adds compute during training since you're doing two matmuls (compress then expand) instead of one to get K and V. But the model learns to compress information into c, which pays off at inference time.

Note: Queries are also compressed during training using a separate compressed version of token input. But interestingly not during inference.

Decoupled RoPE Optimization

MLA compresses KV cache by caching a latent $c$ instead of K and V directly. Without Rotary Position Embeddings (RoPE), this enables an optimization where K is never explicitly computed:

$$ QK^T = (XW_Q)(cW_{UK})^T = X (W_Q W_{UK}^T) c^T $$

Precompute $W_{combined} = W_Q W_{UK}^T$, then:

$$ QK^T = X W_{combined} c^T $$

Attention scores are computed directly from $X$ and cached $c$. K is never materialized.

Why RoPE Breaks This:

With RoPE, position-dependent rotation matrices $R_i$ and $R_j$ are applied:

$$ QK^T = (XW_Q R_i)(cW_{UK} R_j)^T = XW_Q R_i R_j^T W_{UK}^T c^T $$

The $R_i R_j^T$ term is stuck between $W_Q$ and $W_{UK}^T$. Since matrix multiplication isn't commutative, you can't precompute a combined weight matrix. This forces explicit computation of $K = cW_{UK}$ before applying RoPE—losing the benefit of working with smaller $c$ directly.

Multiplies with Size
Without RoPE $c$ directly $(t, d_c)$
With RoPE $K$ explicitly $(t, d_{model})$

Since $d_c \ll d_{model}$, this is a significant compute cost.

Decoupled RoPE Solution:

Split Q and K into content (no RoPE) and position (with RoPE):

$$ Q = [Q_C; Q_R], \quad K = [K_C; K_R] $$

The attention score becomes:

$$ QK^T = Q_C K_C^T + Q_R K_R^T $$

Content term ($Q_C K_C^T$):

  • No RoPE involved
  • Uses the absorption optimization: $Q_C K_C^T = X W_{combined} c^T$
  • Only needs cached $c$

Position term ($Q_R K_R^T$):

  • RoPE applied to both $Q_R$ and $K_R$
  • Small dimensionality (e.g., $d_R = 64$ vs $d_{model} = 5120$)
  • $K_R$ cached separately with RoPE already applied

What Gets Cached:

Component Size Notes
$c$ $d_c$ Compressed latent for content
$K_R$ $d_R$ Small, RoPE already applied

Total cache per token: $d_c + d_R$, still much smaller than standard MHA's $2 \times d_{model}$.

Decoupled RoPE preserves the compression benefits for the bulk of the computation while handling positional information through a small separate pathway.

KV Cache Comparison

Method Cached Size per token per layer Total cache
Standard MHA $K, V$ $2 \times d_{model}$ $L \times n \times 2d_{model}$
MLA $c$ $d_c$ $L \times n \times d_c$
MLA + Decoupled RoPE $c, K_R$ $d_c + d_R$ $L \times n \times (d_c + d_R)$

Using DeepSeek-V2 numbers ($d_{model} = 5120$, $d_c = 512$, $d_R = 64$):

Method Size per token per layer
Standard MHA $10240$
MLA + Decoupled RoPE $576$
~18x reduction.