Scaling Attention

Attention mechanism is expensive.

$$ \operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V $$

  • The vanilla self-attention component with $O\left(n^2\right)$ time and memory complexity where $n$ is the input sequence length.
  • The term $QK^T$ is the main culprit.
    • Q, K, V has shape (batch_size, max_length, model_dim)
    • $QK^T$ has shape (batch_size, max_length, max_length)
    • So for sequence length of 16K, even batch size of 1 would be 64k $\times$ 64K which is 16GB memory in 32-bit float.
    • For deep transformers with large number of layers it quickly explodes just for forward pass.
    • For training i.e. backward pass, we require ~5x this memory i.e. parameters, activations, gradients, first order and second order moments (generally AdamW is used).

There are several approaches and variants to help scale attention in modern architectures:

DeepSeek Sparse Attention (DSA)

Introduced in DeepSeek-V3.2-Exp, Dec 2025 that was comparable to Gemini 3.0 and GPT5 in performance. Uses the latent vectors of Multi-Head Latent Attention (MLA) to compute similarity to select the "window" of tokens it can attend to, thus sparse.

Multi-Head Latent Attention (MLA)

Project key and value tensors to lower-dimensional space, breaking independence from the sequence length. Introduced in DeepSeek-V2, 2024, not only saves memory but can outperform MHA!

Grouped Query Attention (GQA)

Grouped-query Attention (GQA), 2023 showed that each head doesn't necessarily need its own keys and values, and sharing the same key and value heads across multiple heads saves a lot of memory without noticeably affecting modeling performance.

Used heavily in modern architectures (2025 time of writing).

Local Attention

  • Key idea: Divide input space into group of neighbors, apply self-attention separately and then combine the outputs — called as block-sparse attention.
    • To make sure farther elements can interact with each other, alternate block attentions layers with full attention layers.
    • Complexity reduces down to $O(n \sqrt{n})$
    • Similar interleaved sparse and full self-attention is used in GPT-3
  • Also introduced in the work Liu et al. (ICLR 2018)
  • Longformer by Beltagy et al. (2020) introduces a couple of additional ideas
    • Use a sliding windows $w$ acoss different layers similar to CNNs - each token attends to $\frac{w}{2}$ tokens, with complexity $O(n \times w)$
    • Additionally, sliding windows can be "dilated" by adding gaps of size $d$ in windows, increasing coverage without increasing parameters.
    • Global attention: Allow certain tokens to attend across all tokens, not just $\frac{w}{2}$ to emulate property like CLS token in BERT
  • Global attention can be thought of as "memory" tokens and is generalized in Extended Transformer Construction (ETC) Ainslie et al.
  • LongT5 by Guo et al. 2022 adopts similar idea by introducing Transient Global Attention (TGlobal)
    • Instead of arbitrary global tokens like ETC, create "transient" global tokens for fixed blocks of the tokens by summing tokens in the blocks.
    • Allow these global tokens to attend to full input and rest to attend locally.

Compressed Attention

Memory-compressed attention

  • Introduced by Liu et al. (ICLR 2018) for long sequence generation

  • memory compressed attention
  • Key idea: Use Convolution operation on top of key and value matrices to reduce the size of the attention matrix

  • Introduces kernels $\theta_k$ and $\theta_v$ to compute self-attention as

    $$ \operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\frac{Q (\theta_k \circledast K)^T}{\sqrt{d_k}}\right) (\theta_v \circledast V) $$

  • Reduces attention matrix size from $n \times d$ to $n \times (n/s)$ where $s$ is the stride and kernel size (3 in the original paper).

Low-rank approximation of attention matrix

  • Introduced by Wang et al. (2020) in the Linformer paper
  • Key idea: Self-attention matrix can be approximated by a low-rank matrix.
    • Enables linear time $O(n)$ and space complexity!
    • Provides proof that low-rank matrix exists!
  • Introduce two linear projection matrices $E_i, F_i \in \mathbb{R}^{n \times k}$ to project original $n \times d$ key and value matrices to $k \times d$. When $k=O\left(d / \epsilon^2\right)$ approximates self attention with $\epsilon$ error.
    $$ \operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\frac{Q (E_i K)^T}{\sqrt{d_k}}\right) F_i V $$
  • The projection matrices can be shared across layers and heads!

Kernelized Attention

  • We can consider the exponential of dot product $QK^T$ to be a kernel function i.e. computing similarity.
  • Instead of doing full dot product, can we can find a mapping of Q and K that approximates the similarity function $k(x,y) = \exp(QK^T)$ of full self-attention?
  • Linear transformers Katharopoulos et al. 2020
    • Self-attention can be rewritten as
      $$ V=\frac{\phi(Q)\left(\phi(K)^T V\right)}{\phi(Q) \phi(K)^T} $$
    • They choose $\phi(x)=\operatorname{elu}(x)+1$
    • Reduces complexity to linear time for causual attention.
  • Performer Choromanski et al. 2021 use an unbiased approximation kernel to approximate $QK^T$.

Conditional Attention

  • What if not all tokens needs the same amount of computation with full self-attention? Can we learn to "route" inputs between light and heavy computation path? Similar in idea to Mixture of Experts.
  • CoLT5 Ainslie et al. April 2023 introduces token level conditional computation.
    • Light branch has lower hidden dimension, has fewer heads and applies only local attention. Heavy branch performs full attention.
    • How to "route" or find important tokens?
      • Multiply tokens with learned embedding to get scores $s$, select top-k highest tokens.
      • Apply conditional feedforward $X_i=X_i+\operatorname{FFd}_{\text {Light }}\left(X_i\right)+\tilde{s}_i \cdot \operatorname{FFd}_{\text {Heavy }}\left(X_i\right)$
      • Apply conditional attention $X_i=X_i+\mathrm{A}_{\text {Light }}\left(X_i, X\right)+\tilde{s}_i^q \cdot \mathrm{A}_{\text {Heavy }}\left(X_i, \tilde{s}^{k v} X\right)$
    • Results in upto 75% training speedup, 100% inference speedup than LongT5 with performace improvements.
    • Can handle upto 64k tokens.

Recurrence based Transformers

  • Can we use a recurrence and memory states with Transformers?
  • RMT (Bulatov et al., 2022) uses global memory tokens, similar to ETC as part of the input and outputs memory tokens.
    • Shown to scale to 1M+ tokens!
    • Compatible with existing small input sized Transformers.
  • Incorporates recurrence:
    • Input is segmented to N segments. The first segment is appended with memory tokens and fed to the Transformer.
    • The second segment is now appended with the output memory tokens of the first segment and fed again to the Transformer.
    • This process is repeated until the full sequence is processed.
  • Basically Recurrent Neural Networks (RNN)!
  • Quadratic complexity can be reduced to linear and can handle arbitrary input length.

Hardware Optimization

Flash Attention Dao et al. 2022

  • Use a single fused kernel operation for self-attention that takes into account GPU computational architecture.
    • Computes exact attention, not approximation.
  • Two techniques:
    • Incrementally perform softmax by splitting input into blocks (tiling).
    • Store softmax denominator on-chip to quickly recompute attention for backward pass.
  • Upto 7.6x faster on GPT-2 and uses less memory
    • Allows longer sequence length (16K)
    • Block-sparse version allows even longer (64K)