๐ก The Attention Intuition
Before diving into matrices and formulas, it helps to build the right mental model for what attention actually does. The mechanism can be understood as a soft, differentiable lookup in a dictionary.
The Soft Dictionary Metaphor
Imagine a dictionary where every key and value is a vector. Given a query vector, instead of finding an exact key match, you compute how similar the query is to every key, then retrieve a weighted blend of all values โ where the weights come from the similarities. This is attention. The "softness" means no hard decisions: every position contributes to every other, just with different weights.
For a token at position i trying to understand itself, the query asks "what do I need?" Each other token's key answers "what do I contain?" The highest-similarity keys get the most weight, and their values โ the actual information content โ are blended to form the output at position i.
A Concrete Language Example
Consider the sentence: "The animal didn't cross the street because it was too tired."
When processing the token "it", the model needs to resolve: does "it" refer to "animal" or "street"? Attention allows "it" to look at every other token in the sentence. The model learns to give high attention weight to "animal" (and low weight to "street") because "tired" applies to animals, not streets. The resulting representation of "it" is enriched with information from "animal", making coreference resolution possible.
Query, Key, Value โ Plain English
Query (Q): "What am I looking for?" โ a representation of the current token's information needs, projected to a query space.
Key (K): "What do I advertise?" โ a representation of each token's content, projected to a key space designed to match with queries.
Value (V): "What do I actually contain?" โ the information that gets passed on when a token is attended to.
Q, K, V are all derived from the same input tokens via separate learned linear projections โ the model learns what aspects of each token to "search by" vs "return."
๐ Scaled Dot-Product Attention
The formal definition of attention as given in "Attention Is All You Need" is remarkably compact. Every step has a clear purpose.
The Attention Formula
Attention(Q, K, V) = softmax( QKแต / โdk ) ยท V
Where Q โ โnรdk, K โ โmรdk, V โ โmรdv, and the output is โ โnรdv. For self-attention, n = m (query and key/value sequence length are the same).
# Scaled Dot-Product Attention โ step by step
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(Q, K, V, mask=None):
# Q: [batch, heads, seq_len, d_k]
# K: [batch, heads, seq_len, d_k]
# V: [batch, heads, seq_len, d_v]
d_k = Q.size(-1)
# Step 1: Compute similarity scores โ how much does each query
# match each key? Result shape: [batch, heads, seq_len, seq_len]
scores = torch.matmul(Q, K.transpose(-2, -1)) # QKแต
# Step 2: Scale by 1/sqrt(d_k) to prevent dot products from
# growing too large (which would cause softmax to saturate,
# producing near-zero gradients)
scores = scores / math.sqrt(d_k)
# Step 3: Apply causal mask (decoder / autoregressive models)
# Set future positions to -inf so softmax gives them weight 0
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Step 4: Softmax over the key dimension โ converts scores to
# a probability distribution (weights that sum to 1)
attn_weights = F.softmax(scores, dim=-1)
# Step 5: Weighted sum of values โ the actual information retrieval
output = torch.matmul(attn_weights, V)
return output, attn_weights
Why Divide by โdโ?
The dot product QKแต produces values whose magnitude grows with the dimension dโ. For random unit vectors in dโ dimensions, the expected dot product variance is dโ. With large dโ (e.g., 128), the raw dot products can be in the range of ยฑ50, pushing the softmax into its saturated region where one logit dominates and gradients become vanishingly small. Dividing by โdโ normalises the variance back to O(1), keeping gradients healthy during training.
The Causal Mask
In a decoder (autoregressive) model, token at position i must not see tokens at positions i+1, i+2, โฆ (they haven't been generated yet). The causal mask is an upper-triangular matrix of โโ values added to the attention scores before softmax. Since exp(โโ) = 0, future positions receive exactly zero attention weight. This is implemented once as a precomputed boolean mask and is extremely cheap โ there's no separate masking layer.
# Causal mask for sequence length 4: # Position 0 can attend to: [0] # Position 1 can attend to: [0, 1] # Position 2 can attend to: [0, 1, 2] # Position 3 can attend to: [0, 1, 2, 3] mask = torch.tril(torch.ones(4, 4)) # [[1, 0, 0, 0], # [1, 1, 0, 0], # [1, 1, 1, 0], # [1, 1, 1, 1]]
๐ Multi-Head Attention
Single-head attention computes one attention pattern per layer. But a sentence has many simultaneous relationships: syntactic (subject-verb agreement), semantic (coreference), positional (nearby tokens), and factual (entity associations). Multi-head attention runs several attention computations in parallel, each potentially learning a different relationship type.
# Multi-Head Attention
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.num_heads = num_heads
self.d_k = d_model // num_heads # each head operates on a slice
# Four projection matrices: Q, K, V, and output
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
# Reshape [batch, seq, d_model] -> [batch, heads, seq, d_k]
x = x.view(batch_size, -1, self.num_heads, self.d_k)
return x.transpose(1, 2)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# Project inputs to Q, K, V spaces and split into heads
Q = self.split_heads(self.W_q(Q), batch_size)
K = self.split_heads(self.W_k(K), batch_size)
V = self.split_heads(self.W_v(V), batch_size)
# Each head computes its own attention independently
attn_out, _ = scaled_dot_product_attention(Q, K, V, mask)
# Concatenate heads: [batch, heads, seq, d_k] -> [batch, seq, d_model]
attn_out = attn_out.transpose(1, 2).contiguous()
attn_out = attn_out.view(batch_size, -1, self.num_heads * self.d_k)
# Final linear projection to mix head outputs
return self.W_o(attn_out)
| Property | Single-Head Attention | Multi-Head Attention |
|---|---|---|
| Relationship types captured | One: all information must be packed into a single attention pattern per layer | h: each head independently learns a different relationship type simultaneously |
| Expressivity | Limited โ one weighted aggregation | High โ h independent subspace aggregations, then mixed |
| Computation | O(nยฒ ร d_model) | Same O(nยฒ ร d_model) โ heads share the total dimension, so total FLOPs are equal |
| Typical head counts | N/A | 8 (7B models), 32 (13B models), 64+ (70B+ models) |
๐ฌ Attention Patterns & What Heads Learn
Mechanistic interpretability research has analysed attention heads in trained models, finding that different heads specialise for distinct linguistic and computational functions. This is not programmed โ it emerges purely from gradient descent on next-token prediction.
Local Attention Heads
Some heads attend primarily to nearby tokens โ the previous one or two tokens. These heads capture local syntactic structure: adjacent word relationships, noun-adjective agreement, verb-subject proximity. They function like n-gram detectors but are learned rather than hand-coded. In practice they look like narrow diagonal bands in the attention weight matrix.
Global / Positional Heads
Other heads attend broadly across the sequence, sometimes distributing weight roughly uniformly. These may be involved in passing global context (e.g., the topic of a document) to all positions, or performing positional bookkeeping. Some heads consistently attend to the first token (often a BOS token or delimiter) โ researchers call these "sink" heads (related to the attention sink phenomenon below).
Coreference Heads
Several heads in larger models specialise in coreference โ identifying which pronoun refers to which entity. When processing "it" in a sentence, these heads show strong activation toward the referred noun regardless of distance. This is verifiable by ablating (zeroing) specific heads and measuring degradation on coreference benchmarks โ the targeted heads cause dramatic drops while others do not.
The Attention Sink Phenomenon
LLMs trained with sliding window attention (like Mistral) and other long-context models exhibit attention sinks: the first token (often a special BOS or delimiter token) accumulates disproportionately high attention weight from many heads across many layers, even when it is semantically irrelevant. This appears to be a learned mechanism for "dumping" attention mass that must go somewhere โ softmax always outputs a probability distribution summing to 1, so some position must absorb excess weight. Recognising attention sinks is important for KV-cache compression and attention head pruning. StreamingLLM exploits this to enable infinite context by always keeping the sink token in the cache.
Not All Heads Are Equal โ Head Pruning Research
Voita et al. (2019) and Michel et al. (2019) independently found that a large fraction of attention heads in trained models can be pruned (set to zero) with minimal impact on task performance. In some BERT models, over 60% of heads could be removed with less than 1% accuracy drop. This suggests massive redundancy. The important heads tend to fall into specific functional categories (positional, syntactic, rare-word detection). This finding motivates more efficient multi-head designs and has driven research into learned head importance scoring.
โก Efficient Attention Variants
Standard multi-head attention has O(nยฒ) time and memory complexity. For a 128k context window, the attention matrix alone would be 128k ร 128k = 16 billion entries. Numerous techniques have been developed to reduce this cost without losing quality.
| Variant | Key Idea | Complexity | Used In |
|---|---|---|---|
| Flash Attention 2/3 (Dao et al., 2022โ2024) | IO-aware implementation: tiles the attention computation to stay in SRAM (fast GPU cache), avoiding materialising the full nรn attention matrix in slow HBM. Flash Attention 3 targets H100 tensor cores with FP8 precision, reaching ~75% of theoretical hardware peak (~1.5โ2ร over FA2). Exact same results, just faster. | O(nยฒ) compute, O(n) memory (HBM) | Virtually all modern LLM training and inference โ Llama 3/3.1/3.3, Mistral, Gemma, DeepSeek, GPT-4 |
| Sliding Window Attention (SWA) | Each token only attends to the W nearest tokens (window of size W). Global information propagates through depth โ layer L+1 can "see" 2W positions away because layer L already aggregated neighbours. Dramatically reduces quadratic cost. | O(n ร W) per layer | Mistral 7B / Small 3, Longformer, BigBird (combined with global tokens) |
| Multi-Query Attention (MQA) | All attention heads share a single set of K and V projections (only Q has per-head projections). Dramatically reduces KV-cache size at inference โ fewer K/V tensors to store and load per decode step. Quality slightly lower than full MHA. | Same FLOPs, much smaller KV-cache | PaLM, Falcon, early Gemini variants |
| Grouped-Query Attention (GQA) | Compromise between MHA and MQA: heads are divided into G groups, each group shares one set of K/V projections. G=1 is MQA; G=num_heads is MHA. Achieves most of MQA's inference speedup with nearly MHA quality. | Same FLOPs, moderate KV-cache reduction | Llama 3 / 3.1 / 3.2 / 3.3, Mistral Small 3, Gemma 2/3, Qwen2.5, GPT-4 (likely) |
| Multi-head Latent Attention (MLA) | DeepSeek (V2/V3) innovation: compresses the K/V projections through a low-rank latent bottleneck before expanding for each head. At inference, only the small latent vector needs to be cached rather than the full K/V tensors โ reducing KV-cache memory by over 90% vs standard MHA, while maintaining multi-head expressiveness. | Same FLOPs, ~10% of standard MHA KV-cache footprint | DeepSeek-V2, DeepSeek-V3, DeepSeek-R1 |
GQA Is the Industry Standard; MLA Is the Next Frontier
As of 2024โ2025, virtually every new high-quality open-source LLM uses Grouped-Query Attention (Llama 3/3.3, Gemma 2/3, Mistral, Qwen2.5). The reason is straightforward: inference throughput is often bottlenecked by KV-cache memory bandwidth, not raw compute. GQA reduces the K/V data loaded from GPU memory on every decode step โ for Llama 3 8B with 8 KV heads (vs 32 query heads), this is a 4ร KV-cache reduction, enabling significantly higher batch sizes. DeepSeek's Multi-head Latent Attention goes further, achieving ~10ร KV-cache compression vs standard MHA while preserving expressiveness, and is likely to be widely adopted as other labs study the DeepSeek-V3 architecture.
Flash Attention 2 & 3
Flash Attention 2 (2023) improved on the original by reordering the outer loop in the tiling algorithm and reducing non-matrix operations, achieving ~2ร speedup over FA1. Flash Attention 3 (2024) targets H100 GPUs specifically, exploiting tensor core asynchrony and FP8 precision, reaching 75% of theoretical hardware peak โ about 1.5โ2ร faster than FA2. These are purely engineering wins; the mathematical operation is unchanged. FA2 is now the default in PyTorch and most training frameworks; FA3 is used in high-throughput production serving.
Linear Attention Approximations
Various methods approximate softmax attention with linear kernels, reducing complexity to O(n). Examples: Performer (random feature attention), Linear Transformers (kernelised attention), RWKV (recurrent reformulation), Mamba (selective state space model). These achieve true O(n) complexity but generally lag behind exact attention in quality on complex tasks. The quality gap has narrowed significantly โ Mamba-2 and hybrid architectures (e.g., alternating Mamba/Transformer layers) are an active research frontier for long-context efficiency.