Scaling Attention
Attention mechanism is expensive.
- The vanilla self-attention component with $O\left(n^2\right)$ time and memory complexity where $n$ is the input sequence length.
- The term $QK^T$ is the main culprit.
- Q, K, V has shape (batch_size, max_length, model_dim)
- $QK^T$ has shape (batch_size, max_length, max_length)
- So for sequence length of 16K, even batch size of 1 would be 64k $\times$ 64K which is 16GB memory in 32-bit float.
- For deep transformers with large number of layers it quickly explodes just for forward pass.
- For training i.e. backward pass, we require ~5x this memory i.e. parameters, activations, gradients, first order and second order moments (generally AdamW is used).
There are several approaches and variants to help scale attention in modern architectures:
DeepSeek Sparse Attention (DSA)
Introduced in DeepSeek-V3.2-Exp, Dec 2025 that was comparable to Gemini 3.0 and GPT5 in performance. Uses the latent vectors of Multi-Head Latent Attention (MLA) to compute similarity to select the "window" of tokens it can attend to, thus sparse.
Multi-Head Latent Attention (MLA)
Project key and value tensors to lower-dimensional space, breaking independence from the sequence length. Introduced in DeepSeek-V2, 2024, not only saves memory but can outperform MHA!
Grouped Query Attention (GQA)
Grouped-query Attention (GQA), 2023 showed that each head doesn't necessarily need its own keys and values, and sharing the same key and value heads across multiple heads saves a lot of memory without noticeably affecting modeling performance.
Used heavily in modern architectures (2025 time of writing).
Local Attention
- Key idea: Divide input space into group of neighbors, apply self-attention separately and then combine the outputs — called as block-sparse attention.
- To make sure farther elements can interact with each other, alternate block attentions layers with full attention layers.
- Complexity reduces down to $O(n \sqrt{n})$
- Similar interleaved sparse and full self-attention is used in GPT-3
- Also introduced in the work Liu et al. (ICLR 2018)
- Longformer by Beltagy et al. (2020) introduces a couple of additional ideas
- Use a sliding windows $w$ acoss different layers similar to CNNs - each token attends to $\frac{w}{2}$ tokens, with complexity $O(n \times w)$
- Additionally, sliding windows can be "dilated" by adding gaps of size $d$ in windows, increasing coverage without increasing parameters.
- Global attention: Allow certain tokens to attend across all tokens, not just $\frac{w}{2}$ to emulate property like CLS token in BERT
- Global attention can be thought of as "memory" tokens and is generalized in Extended Transformer Construction (ETC) Ainslie et al.
- LongT5 by Guo et al. 2022 adopts similar idea by introducing Transient Global Attention (TGlobal)
- Instead of arbitrary global tokens like ETC, create "transient" global tokens for fixed blocks of the tokens by summing tokens in the blocks.
- Allow these global tokens to attend to full input and rest to attend locally.
Compressed Attention
Memory-compressed attention
-
Introduced by Liu et al. (ICLR 2018) for long sequence generation
-
-
Key idea: Use Convolution operation on top of key and value matrices to reduce the size of the attention matrix
-
Introduces kernels $\theta_k$ and $\theta_v$ to compute self-attention as
$$ \operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\frac{Q (\theta_k \circledast K)^T}{\sqrt{d_k}}\right) (\theta_v \circledast V) $$ -
Reduces attention matrix size from $n \times d$ to $n \times (n/s)$ where $s$ is the stride and kernel size (3 in the original paper).
Low-rank approximation of attention matrix
- Introduced by Wang et al. (2020) in the Linformer paper
- Key idea: Self-attention matrix can be approximated by a low-rank matrix.
- Enables linear time $O(n)$ and space complexity!
- Provides proof that low-rank matrix exists!
- Introduce two linear projection matrices $E_i, F_i \in \mathbb{R}^{n \times k}$ to project original $n \times d$ key and value matrices to $k \times d$. When $k=O\left(d / \epsilon^2\right)$ approximates self attention with $\epsilon$ error.
$$ \operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\frac{Q (E_i K)^T}{\sqrt{d_k}}\right) F_i V $$ - The projection matrices can be shared across layers and heads!
Kernelized Attention
- We can consider the exponential of dot product $QK^T$ to be a kernel function i.e. computing similarity.
- Instead of doing full dot product, can we can find a mapping of Q and K that approximates the similarity function $k(x,y) = \exp(QK^T)$ of full self-attention?
- Linear transformers Katharopoulos et al. 2020
- Self-attention can be rewritten as $$ V=\frac{\phi(Q)\left(\phi(K)^T V\right)}{\phi(Q) \phi(K)^T} $$
- They choose $\phi(x)=\operatorname{elu}(x)+1$
- Reduces complexity to linear time for causual attention.
- Self-attention can be rewritten as
- Performer Choromanski et al. 2021 use an unbiased approximation kernel to approximate $QK^T$.
Conditional Attention
- What if not all tokens needs the same amount of computation with full self-attention? Can we learn to "route" inputs between light and heavy computation path? Similar in idea to Mixture of Experts.
- CoLT5 Ainslie et al. April 2023 introduces token level conditional computation.
- Light branch has lower hidden dimension, has fewer heads and applies only local attention. Heavy branch performs full attention.
- How to "route" or find important tokens?
- Multiply tokens with learned embedding to get scores $s$, select top-k highest tokens.
- Apply conditional feedforward $X_i=X_i+\operatorname{FFd}_{\text {Light }}\left(X_i\right)+\tilde{s}_i \cdot \operatorname{FFd}_{\text {Heavy }}\left(X_i\right)$
- Apply conditional attention $X_i=X_i+\mathrm{A}_{\text {Light }}\left(X_i, X\right)+\tilde{s}_i^q \cdot \mathrm{A}_{\text {Heavy }}\left(X_i, \tilde{s}^{k v} X\right)$
- Results in upto 75% training speedup, 100% inference speedup than LongT5 with performace improvements.
- Can handle upto 64k tokens.
Recurrence based Transformers
- Can we use a recurrence and memory states with Transformers?
- RMT (Bulatov et al., 2022) uses global memory tokens, similar to ETC as part of the input and outputs memory tokens.
- Shown to scale to 1M+ tokens!
- Compatible with existing small input sized Transformers.
- Incorporates recurrence:
- Input is segmented to N segments. The first segment is appended with memory tokens and fed to the Transformer.
- The second segment is now appended with the output memory tokens of the first segment and fed again to the Transformer.
- This process is repeated until the full sequence is processed.
- Basically Recurrent Neural Networks (RNN)!
- Quadratic complexity can be reduced to linear and can handle arbitrary input length.
Hardware Optimization
Flash Attention Dao et al. 2022
- Use a single fused kernel operation for self-attention that takes into account GPU computational architecture.
- Computes exact attention, not approximation.
- Two techniques:
- Incrementally perform softmax by splitting input into blocks (tiling).
- Store softmax denominator on-chip to quickly recompute attention for backward pass.
- Upto 7.6x faster on GPT-2 and uses less memory
- Allows longer sequence length (16K)
- Block-sparse version allows even longer (64K)