Attention
Motivation for Attention Mechanism
The concept of attention tries to incorporate this question into network architecture: How relevant is the ith element in the sequence relative to other elements in the same sequence?
Key ideas:
- don't try to learn one global representation for the source sequence
- rather learn context-sensitive token representations for each token
- when generating a target token, dynamically combine the most relevant source representations (weighted sum with weights representing some notion of similarity between tokens)
Self-Attention
The attention mechanism applied inside Transformers is called scaled-dot-product-attention. Each sequence element provides a key, value, and query vector, where each of these vectors are learned linear projection from the token representation.
For each element, we perform attention computation where based on its query, we check the similarity of the all other sequence elements’ keys, and return a different, averaged value vector for each element, mixing other token information progressively.
Input is set of queries $Q \in \mathbb{R}^{T \times d_{k}},$ keys $K \in \mathbb{R}^{T \times d_{k}}$ and values $V \in \mathbb{R}^{T \times d_{v}}$ where $T$ is the sequence length, and $d_{k}$ and $d_{v}$ are the dimensionality for queries/keys and values respectively. The attention tensor is given as:
Large dot products cause softmax to saturate (outputting ~1 for the max value and ~0 for others), leading to vanishing gradients. Dividing by √d_k prevents the dot products from growing too large, keeping softmax in its sensitive range where gradients remain meaningful.
Why √d_k specifically>? This High-Dimensional Dot Product Normalization ensures stable training across different attention head dimensions by preserving the variance despite the dimensionality.
Also see: Autoregressive Generation and KV Caching in Transformers
Multi Head Attention (MHA)
To account for the fact that an element in sequence can have multiple interpretation or different relation to neighbors, we can combine several attention mechanisms with Multi-Head Attention. In practice, we simply slice the Q,K,V vectors into N heads and compute N attention maps.
If you have $h$ heads and model dimension $d_{model}$:
- Each head works with $d_k = d_{model} / h$
- You get $h$ separate attention maps, each $(n \times n)$
- Each head can learn to attend to different things
Then the outputs from all heads are concatenated and projected:
Where each $\text{head}_i \in \mathbb{R}^{(n \times d_k)}$ and $W_O \in \mathbb{R}^{(h \cdot d_k \times d_{model})}$ projects back to the original dimension. The whole thing is often implemented as one big projection followed by a reshape, rather than $h$ separate projections (but mathematically equivalent).
# Input
X # (n, d_model)
# Single large projections
Q = X @ W_Q # (n, d_model)
K = X @ W_K # (n, d_model)
V = X @ W_V # (n, d_model)
# Reshape to separate heads
Q = Q.reshape(n, h, d_k).transpose(0, 1) # (h, n, d_k)
K = K.reshape(n, h, d_k).transpose(0, 1) # (h, n, d_k)
V = V.reshape(n, h, d_v).transpose(0, 1) # (h, n, d_v)
# Batched attention over all heads
scores = Q @ K.transpose(-1, -2) # (h, n, n)
scores = scores / sqrt(d_k)
scores = scores + M # causal mask
attn = softmax(scores, dim=-1) # (h, n, n)
out = attn @ V # (h, n, d_v)
# Concat heads and project
out = out.transpose(0, 1).reshape(n, d_model) # (n, d_model)
out = out @ W_O # (n, d_model)
Multi-headed attention improves the attention layer by:
- Expanding ability to focus on different positions.
- Giving attention layer multiple "representation subspaces".