4 min read

How transformer self-attention actually works

Table of Contents

Self-attention is the mechanism that lets every token in a sequence look at every other token and decide what to pull in. This note walks through the math, the causal mask that makes decoding work, and the KV-cache that makes it cheap.

Queries, keys, and values

Each token’s embedding is projected into three vectors using learned weight matrices. For an input matrix X of shape (seq_len, d_model):

Q = X @ W_q   # what this token is looking for
K = X @ W_k   # what each token offers as a match
V = X @ W_v   # what each token actually contributes

Attention scores are the dot product of queries against keys, scaled and softmaxed into weights, then used to mix the values.

Attention(Q, K, V) = softmax( Q·Kᵀ / √d_k ) · V

The √d_k scaling matters: without it, dot products grow with dimension, pushing the softmax into saturated regions where gradients vanish.

Multi-head attention

A single attention operation can only express one “kind” of relationship. Multi-head attention runs h independent attention computations in parallel on slices of the embedding, then concatenates:

ComponentShape per headPurpose
Q, K, V(seq, d_model / h)independent subspaces
headsh of themlearn distinct relations
outputconcat → (seq, d_model)projected by W_o

One head might track subject-verb agreement, another local adjacency. The split is cheap because total compute stays roughly constant.

Causal masking

Decoder-only models (the GPT family) must not let a token attend to future tokens during training, or they would trivially “cheat” by reading the answer. The fix is a mask added to the scores before softmax:

       t0    t1    t2    t3
t0  [  0  | -inf | -inf | -inf ]
t1  [  0  |  0   | -inf | -inf ]
t2  [  0  |  0   |  0   | -inf ]
t3  [  0  |  0   |  0   |  0   ]

-inf becomes zero after softmax, so position t1 only sees t0 and itself. This lower-triangular mask is what makes a single forward pass equivalent to predicting every next token simultaneously.

The KV-cache

At generation time you produce one token at a time. Naively, generating token n recomputes K and V for all n previous tokens — O(n²) wasted work across a sequence. But past keys and values never change. So you cache them.

# step n: only the new token needs fresh Q, K, V
k_new, v_new = project(token_n)
K = cat([K_cache, k_new])   # append, don't recompute
V = cat([V_cache, v_new])
scores = q_new @ K.T         # one query against all keys

The cache turns per-step cost from O(n²) to O(n). The price is memory: the cache grows linearly with sequence length and dominates VRAM for long contexts, which is exactly what techniques like grouped-query attention and paged KV-caches exist to tame.

Wrap up

  • Attention is a learned, content-based weighted average: queries match keys, weights mix values.
  • The causal mask is the single trick that turns attention into an autoregressive language model.
  • The KV-cache trades memory for speed and is the main lever (and bottleneck) in long-context inference.

References

  • Vaswani et al., Attention Is All You Need (2017)
  • Ainslie et al., GQA: Training Generalized Multi-Query Transformer Models (2023)