โฑ 12 min read ๐Ÿ“Š Intermediate ๐Ÿ—“ Updated Jan 2025

๐Ÿ’ก The Attention Intuition

Before diving into matrices and formulas, it helps to build the right mental model for what attention actually does. The mechanism can be understood as a soft, differentiable lookup in a dictionary.

The Soft Dictionary Metaphor

Imagine a dictionary where every key and value is a vector. Given a query vector, instead of finding an exact key match, you compute how similar the query is to every key, then retrieve a weighted blend of all values โ€” where the weights come from the similarities. This is attention. The "softness" means no hard decisions: every position contributes to every other, just with different weights.

For a token at position i trying to understand itself, the query asks "what do I need?" Each other token's key answers "what do I contain?" The highest-similarity keys get the most weight, and their values โ€” the actual information content โ€” are blended to form the output at position i.

A Concrete Language Example

Consider the sentence: "The animal didn't cross the street because it was too tired."

When processing the token "it", the model needs to resolve: does "it" refer to "animal" or "street"? Attention allows "it" to look at every other token in the sentence. The model learns to give high attention weight to "animal" (and low weight to "street") because "tired" applies to animals, not streets. The resulting representation of "it" is enriched with information from "animal", making coreference resolution possible.

CoreferenceLong-range deps

Query, Key, Value โ€” Plain English

Query (Q): "What am I looking for?" โ€” a representation of the current token's information needs, projected to a query space.

Key (K): "What do I advertise?" โ€” a representation of each token's content, projected to a key space designed to match with queries.

Value (V): "What do I actually contain?" โ€” the information that gets passed on when a token is attended to.

Q, K, V are all derived from the same input tokens via separate learned linear projections โ€” the model learns what aspects of each token to "search by" vs "return."

๐Ÿ“ Scaled Dot-Product Attention

The formal definition of attention as given in "Attention Is All You Need" is remarkably compact. Every step has a clear purpose.

The Attention Formula

Attention(Q, K, V) = softmax( QKแต€ / โˆšdk ) ยท V

Where Q โˆˆ โ„nร—dk, K โˆˆ โ„mร—dk, V โˆˆ โ„mร—dv, and the output is โˆˆ โ„nร—dv. For self-attention, n = m (query and key/value sequence length are the same).

# Scaled Dot-Product Attention โ€” step by step
import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    # Q: [batch, heads, seq_len, d_k]
    # K: [batch, heads, seq_len, d_k]
    # V: [batch, heads, seq_len, d_v]

    d_k = Q.size(-1)

    # Step 1: Compute similarity scores โ€” how much does each query
    # match each key? Result shape: [batch, heads, seq_len, seq_len]
    scores = torch.matmul(Q, K.transpose(-2, -1))  # QKแต€

    # Step 2: Scale by 1/sqrt(d_k) to prevent dot products from
    # growing too large (which would cause softmax to saturate,
    # producing near-zero gradients)
    scores = scores / math.sqrt(d_k)

    # Step 3: Apply causal mask (decoder / autoregressive models)
    # Set future positions to -inf so softmax gives them weight 0
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Step 4: Softmax over the key dimension โ€” converts scores to
    # a probability distribution (weights that sum to 1)
    attn_weights = F.softmax(scores, dim=-1)

    # Step 5: Weighted sum of values โ€” the actual information retrieval
    output = torch.matmul(attn_weights, V)

    return output, attn_weights

Why Divide by โˆšdโ‚–?

The dot product QKแต€ produces values whose magnitude grows with the dimension dโ‚–. For random unit vectors in dโ‚– dimensions, the expected dot product variance is dโ‚–. With large dโ‚– (e.g., 128), the raw dot products can be in the range of ยฑ50, pushing the softmax into its saturated region where one logit dominates and gradients become vanishingly small. Dividing by โˆšdโ‚– normalises the variance back to O(1), keeping gradients healthy during training.

The Causal Mask

In a decoder (autoregressive) model, token at position i must not see tokens at positions i+1, i+2, โ€ฆ (they haven't been generated yet). The causal mask is an upper-triangular matrix of โˆ’โˆž values added to the attention scores before softmax. Since exp(โˆ’โˆž) = 0, future positions receive exactly zero attention weight. This is implemented once as a precomputed boolean mask and is extremely cheap โ€” there's no separate masking layer.

# Causal mask for sequence length 4:
# Position 0 can attend to: [0]
# Position 1 can attend to: [0, 1]
# Position 2 can attend to: [0, 1, 2]
# Position 3 can attend to: [0, 1, 2, 3]

mask = torch.tril(torch.ones(4, 4))
# [[1, 0, 0, 0],
#  [1, 1, 0, 0],
#  [1, 1, 1, 0],
#  [1, 1, 1, 1]]

๐Ÿ”€ Multi-Head Attention

Single-head attention computes one attention pattern per layer. But a sentence has many simultaneous relationships: syntactic (subject-verb agreement), semantic (coreference), positional (nearby tokens), and factual (entity associations). Multi-head attention runs several attention computations in parallel, each potentially learning a different relationship type.

# Multi-Head Attention
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # each head operates on a slice

        # Four projection matrices: Q, K, V, and output
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
        # Reshape [batch, seq, d_model] -> [batch, heads, seq, d_k]
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        return x.transpose(1, 2)

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)

        # Project inputs to Q, K, V spaces and split into heads
        Q = self.split_heads(self.W_q(Q), batch_size)
        K = self.split_heads(self.W_k(K), batch_size)
        V = self.split_heads(self.W_v(V), batch_size)

        # Each head computes its own attention independently
        attn_out, _ = scaled_dot_product_attention(Q, K, V, mask)

        # Concatenate heads: [batch, heads, seq, d_k] -> [batch, seq, d_model]
        attn_out = attn_out.transpose(1, 2).contiguous()
        attn_out = attn_out.view(batch_size, -1, self.num_heads * self.d_k)

        # Final linear projection to mix head outputs
        return self.W_o(attn_out)
Property Single-Head Attention Multi-Head Attention
Relationship types captured One: all information must be packed into a single attention pattern per layer h: each head independently learns a different relationship type simultaneously
Expressivity Limited โ€” one weighted aggregation High โ€” h independent subspace aggregations, then mixed
Computation O(nยฒ ร— d_model) Same O(nยฒ ร— d_model) โ€” heads share the total dimension, so total FLOPs are equal
Typical head counts N/A 8 (7B models), 32 (13B models), 64+ (70B+ models)

๐Ÿ”ฌ Attention Patterns & What Heads Learn

Mechanistic interpretability research has analysed attention heads in trained models, finding that different heads specialise for distinct linguistic and computational functions. This is not programmed โ€” it emerges purely from gradient descent on next-token prediction.

Local Attention Heads

Some heads attend primarily to nearby tokens โ€” the previous one or two tokens. These heads capture local syntactic structure: adjacent word relationships, noun-adjective agreement, verb-subject proximity. They function like n-gram detectors but are learned rather than hand-coded. In practice they look like narrow diagonal bands in the attention weight matrix.

SyntacticLocal

Global / Positional Heads

Other heads attend broadly across the sequence, sometimes distributing weight roughly uniformly. These may be involved in passing global context (e.g., the topic of a document) to all positions, or performing positional bookkeeping. Some heads consistently attend to the first token (often a BOS token or delimiter) โ€” researchers call these "sink" heads (related to the attention sink phenomenon below).

Global context

Coreference Heads

Several heads in larger models specialise in coreference โ€” identifying which pronoun refers to which entity. When processing "it" in a sentence, these heads show strong activation toward the referred noun regardless of distance. This is verifiable by ablating (zeroing) specific heads and measuring degradation on coreference benchmarks โ€” the targeted heads cause dramatic drops while others do not.

SemanticCoreference

The Attention Sink Phenomenon

LLMs trained with sliding window attention (like Mistral) and other long-context models exhibit attention sinks: the first token (often a special BOS or delimiter token) accumulates disproportionately high attention weight from many heads across many layers, even when it is semantically irrelevant. This appears to be a learned mechanism for "dumping" attention mass that must go somewhere โ€” softmax always outputs a probability distribution summing to 1, so some position must absorb excess weight. Recognising attention sinks is important for KV-cache compression and attention head pruning. StreamingLLM exploits this to enable infinite context by always keeping the sink token in the cache.

Not All Heads Are Equal โ€” Head Pruning Research

Voita et al. (2019) and Michel et al. (2019) independently found that a large fraction of attention heads in trained models can be pruned (set to zero) with minimal impact on task performance. In some BERT models, over 60% of heads could be removed with less than 1% accuracy drop. This suggests massive redundancy. The important heads tend to fall into specific functional categories (positional, syntactic, rare-word detection). This finding motivates more efficient multi-head designs and has driven research into learned head importance scoring.

โšก Efficient Attention Variants

Standard multi-head attention has O(nยฒ) time and memory complexity. For a 128k context window, the attention matrix alone would be 128k ร— 128k = 16 billion entries. Numerous techniques have been developed to reduce this cost without losing quality.

Variant Key Idea Complexity Used In
Flash Attention 2/3 (Dao et al., 2022โ€“2024) IO-aware implementation: tiles the attention computation to stay in SRAM (fast GPU cache), avoiding materialising the full nร—n attention matrix in slow HBM. Flash Attention 3 targets H100 tensor cores with FP8 precision, reaching ~75% of theoretical hardware peak (~1.5โ€“2ร— over FA2). Exact same results, just faster. O(nยฒ) compute, O(n) memory (HBM) Virtually all modern LLM training and inference โ€” Llama 3/3.1/3.3, Mistral, Gemma, DeepSeek, GPT-4
Sliding Window Attention (SWA) Each token only attends to the W nearest tokens (window of size W). Global information propagates through depth โ€” layer L+1 can "see" 2W positions away because layer L already aggregated neighbours. Dramatically reduces quadratic cost. O(n ร— W) per layer Mistral 7B / Small 3, Longformer, BigBird (combined with global tokens)
Multi-Query Attention (MQA) All attention heads share a single set of K and V projections (only Q has per-head projections). Dramatically reduces KV-cache size at inference โ€” fewer K/V tensors to store and load per decode step. Quality slightly lower than full MHA. Same FLOPs, much smaller KV-cache PaLM, Falcon, early Gemini variants
Grouped-Query Attention (GQA) Compromise between MHA and MQA: heads are divided into G groups, each group shares one set of K/V projections. G=1 is MQA; G=num_heads is MHA. Achieves most of MQA's inference speedup with nearly MHA quality. Same FLOPs, moderate KV-cache reduction Llama 3 / 3.1 / 3.2 / 3.3, Mistral Small 3, Gemma 2/3, Qwen2.5, GPT-4 (likely)
Multi-head Latent Attention (MLA) DeepSeek (V2/V3) innovation: compresses the K/V projections through a low-rank latent bottleneck before expanding for each head. At inference, only the small latent vector needs to be cached rather than the full K/V tensors โ€” reducing KV-cache memory by over 90% vs standard MHA, while maintaining multi-head expressiveness. Same FLOPs, ~10% of standard MHA KV-cache footprint DeepSeek-V2, DeepSeek-V3, DeepSeek-R1

GQA Is the Industry Standard; MLA Is the Next Frontier

As of 2024โ€“2025, virtually every new high-quality open-source LLM uses Grouped-Query Attention (Llama 3/3.3, Gemma 2/3, Mistral, Qwen2.5). The reason is straightforward: inference throughput is often bottlenecked by KV-cache memory bandwidth, not raw compute. GQA reduces the K/V data loaded from GPU memory on every decode step โ€” for Llama 3 8B with 8 KV heads (vs 32 query heads), this is a 4ร— KV-cache reduction, enabling significantly higher batch sizes. DeepSeek's Multi-head Latent Attention goes further, achieving ~10ร— KV-cache compression vs standard MHA while preserving expressiveness, and is likely to be widely adopted as other labs study the DeepSeek-V3 architecture.

Flash Attention 2 & 3

Flash Attention 2 (2023) improved on the original by reordering the outer loop in the tiling algorithm and reducing non-matrix operations, achieving ~2ร— speedup over FA1. Flash Attention 3 (2024) targets H100 GPUs specifically, exploiting tensor core asynchrony and FP8 precision, reaching 75% of theoretical hardware peak โ€” about 1.5โ€“2ร— faster than FA2. These are purely engineering wins; the mathematical operation is unchanged. FA2 is now the default in PyTorch and most training frameworks; FA3 is used in high-throughput production serving.

No quality lossInfrastructure

Linear Attention Approximations

Various methods approximate softmax attention with linear kernels, reducing complexity to O(n). Examples: Performer (random feature attention), Linear Transformers (kernelised attention), RWKV (recurrent reformulation), Mamba (selective state space model). These achieve true O(n) complexity but generally lag behind exact attention in quality on complex tasks. The quality gap has narrowed significantly โ€” Mamba-2 and hybrid architectures (e.g., alternating Mamba/Transformer layers) are an active research frontier for long-context efficiency.

O(n) computeQuality trade-off