38  Attention and Transformers

Imagine reading a long sentence and trying to remember the very first word by the time you reach the last. If you could only hold one running summary in your head, updated word by word, the early details would fade. That is essentially how a recurrent neural network (Chapter 15) reads a sequence: one step at a time, carrying a single hidden state forward and updating it as each new token1 arrives. This design made recurrent networks the default tool for language and time series for many years, but it has structural limits.

Attention takes a different stance. Instead of squeezing the past into one running summary, it lets every position in a sequence look directly at every other position and decide, on the fly, which ones matter. The Transformer architecture builds an entire model out of this idea, and it is the backbone of nearly every modern large language model (Chapter 40).

This chapter builds attention up from intuition. We start by spelling out why recurrence runs into trouble, then derive scaled dot-product attention (the core operation) and explain each piece, including the scaling factor that many treatments leave unexplained. From there we assemble multi-head attention, positional encoding, and the full Transformer block, distinguishing the encoder from the decoder. To make the equations concrete, we implement attention from scratch in base R on a five-word toy sentence and inspect the attention weights directly. We close with the efficient variants that tame attention’s cost on long inputs and the link to BERT, covered in the BERT chapter (Chapter 39).

Key idea

Attention replaces “remember everything in one state” with “look up whatever is relevant, whenever you need it.” Every position can reach every other position in a single step.

38.1 Why Move Beyond Recurrence

A recurrent network reads a sequence \(x_1, x_2, \dots, x_n\) and produces hidden states \(h_t = f(h_{t-1}, x_t)\). Three problems follow from this recurrence.

First, the computation is inherently sequential. To compute \(h_t\) you need \(h_{t-1}\), so the work cannot be spread across the sequence in parallel. Long sequences mean long chains of dependent operations, which makes training slow and underuses modern hardware that is built for parallel matrix math.

Second, long-range dependencies are hard to capture. Information from \(x_1\) that matters for a prediction at position \(n\) has to survive being repeatedly transformed through every intermediate state. The path length between two positions grows linearly with their distance, so distant tokens influence each other only weakly.

Third, training suffers from vanishing and exploding gradients. Backpropagation through time multiplies many Jacobian factors together. If their norms are below one the gradient shrinks toward zero over long spans; if above one it blows up. Gated units such as LSTMs and GRUs soften this problem but do not eliminate it.

Attention attacks all three at once. It connects any two positions with a path of length one, so distant tokens interact in a single step. The core computation is a set of matrix multiplications that run in parallel across positions. And because there is no long multiplicative chain along the sequence axis, the gradient flow between distant positions is far healthier.

Note

Attention trades one problem for another. By comparing every position with every other position, it does work that grows with the square of the sequence length. That quadratic cost is the price of the short paths, and we return to it at the end of the chapter.

With the motivation in place, we can now define the operation itself.

38.2 Scaled Dot-Product Attention

Attention answers a retrieval question, and a useful mental model is a soft dictionary lookup. In an ordinary dictionary you match a search term exactly against one key and get back its value. Attention softens that: it matches a search term against all keys by degree, then returns a blend of the values weighted by how well each key matched.

The three roles carry standard names. For each position we form a query that describes what that position is looking for. Every position also exposes a key that describes what it contains, and a value that holds the content to be returned. A position compares its query against all keys, turns the comparisons into weights, and reads out a weighted blend of the values.

Intuition

Query = “what am I looking for?”, key = “what do I offer?”, value = “what do I hand over if chosen?”. A pronoun’s query might match the key of the noun it refers to, and pull in that noun’s value.

Stack the queries, keys, and values for all \(n\) positions into matrices \(Q \in \mathbb{R}^{n \times d_k}\), \(K \in \mathbb{R}^{n \times d_k}\), and \(V \in \mathbb{R}^{n \times d_v}\). The similarity between query \(i\) and key \(j\) is the dot product \(q_i^\top k_j\). Collecting all pairs gives the score matrix \(QK^\top \in \mathbb{R}^{n \times n}\), whose entry \((i, j)\) is how strongly position \(i\) attends to position \(j\) before normalization.

We turn each row of scores into a probability distribution with a row-wise softmax, then use those weights to average the value vectors. The full operation is

\[ \text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right) V . \]

The softmax acts along each row, so row \(i\) of the resulting weight matrix sums to one and tells us how position \(i\) distributes its attention over all positions.2 Multiplying by \(V\) replaces each position with a convex combination of value vectors, that is, a weighted average where the weights are non-negative and sum to one.

38.2.1 Why Divide by the Square Root of the Key Dimension

The scaling factor \(1/\sqrt{d_k}\) is not cosmetic. Suppose the entries of a query \(q\) and a key \(k\) are independent with mean zero and variance one. The dot product is \(q^\top k = \sum_{m=1}^{d_k} q_m k_m\). Each term has mean zero and variance one, and the terms are independent, so the sum has mean zero and variance \(d_k\). Its standard deviation grows like \(\sqrt{d_k}\).

When \(d_k\) is large the raw scores spread over a wide range. Feeding large values into a softmax pushes it toward a one-hot distribution where almost all weight lands on a single position. In that regime the softmax gradient is tiny, which stalls learning. Dividing by \(\sqrt{d_k}\) rescales the scores back to roughly unit variance, keeping the softmax in a responsive range.

Why it matters

Without the scaling, simply making the model wider (larger \(d_k\)) would silently saturate the softmax and freeze learning. The \(1/\sqrt{d_k}\) factor decouples the temperature of the attention from the width of the representation.

38.2.1.1 The saturation argument made precise

The qualitative claim that large scores “freeze learning” can be made exact by examining the softmax Jacobian. Write \(a = \text{softmax}(s)\) for a single row of scores \(s \in \mathbb{R}^n\). The derivative of the output with respect to the scores is

\[ \frac{\partial a_i}{\partial s_j} = a_i(\delta_{ij} - a_j), \qquad J = \text{diag}(a) - a a^\top , \tag{38.1}\]

which follows from differentiating \(a_i = e^{s_i} / \sum_m e^{s_m}\) and using the quotient rule. As a row vector saturates toward a one-hot vector \(e_{i^\star}\), the diagonal term \(a_{i^\star}(1 - a_{i^\star}) \to 0\) and every off-diagonal term \(a_i a_j \to 0\), so \(J \to 0\) entrywise. The spectral norm of \(J\) is bounded by \(\|J\|_2 \le \tfrac12\), attained in the limit at the uniform-over-two-atoms distribution (where the nonzero eigenvalue is \(2 \cdot \tfrac12 \cdot \tfrac12 = \tfrac12\)); each diagonal entry \(a_i(1 - a_i) \le \tfrac14\) separately. The Jacobian collapses to zero in the saturated regime. Concretely, a single dominant logit gap \(\Delta = s_{i^\star} - \max_{j \ne i^\star} s_j\) gives \(1 - a_{i^\star} \le (n-1) e^{-\Delta}\), so the gradient that flows back through attention decays exponentially in the logit gap. Because the gap scales like \(\sqrt{d_k}\) before normalization (the variance computation above), an unscaled softmax over a wide model sits in exactly the regime where Equation 38.1 vanishes. Dividing by \(\sqrt{d_k}\) holds \(\Delta = O(1)\) in expectation, keeping \(\|J\|\) bounded away from zero.

38.2.1.2 Attention as kernel smoothing

Scaled dot-product attention is a learned instance of the Nadaraya-Watson kernel regression estimator from Chapter 4. With the exponential similarity \(\kappa(q_i, k_j) = \exp(q_i^\top k_j / \sqrt{d_k})\), row \(i\) of the output is

\[ o_i = \sum_{j=1}^{n} \frac{\kappa(q_i, k_j)}{\sum_{m=1}^{n} \kappa(q_i, k_m)}\, v_j , \tag{38.2}\]

which is precisely a Nadaraya-Watson estimate of a value “regressed” on a query, using an asymmetric exponential kernel whose bandwidth is set by \(\sqrt{d_k}\) and whose geometry is learned through the query and key projections. Causal masking is the same estimator restricted to a one-sided neighborhood. This connection explains both the smoothing behavior of attention (outputs are convex combinations of values, so they cannot extrapolate beyond the convex hull of \(V\)) and its failure mode: like any kernel smoother, attention degrades when the effective neighborhood is either too diffuse (uniform weights, no selectivity) or too peaked (a single value copied, no aggregation).

38.2.1.3 Backpropagation through attention

Training requires gradients of the loss \(L\) with respect to \(Q\), \(K\), and \(V\). Write \(S = QK^\top / \sqrt{d_k}\), \(A = \text{softmax}_{\text{row}}(S)\), and \(O = AV\). Given the upstream gradient \(G = \partial L / \partial O \in \mathbb{R}^{n \times d_v}\), the value and weight gradients are immediate from \(O = AV\),

\[ \frac{\partial L}{\partial V} = A^\top G, \qquad \frac{\partial L}{\partial A} = G V^\top . \]

Propagating through the row-wise softmax uses Equation 38.1 applied per row, which for the full matrix collapses to the compact form

\[ \frac{\partial L}{\partial S} = A \odot \left( \frac{\partial L}{\partial A} - \big( (\tfrac{\partial L}{\partial A} \odot A)\, \mathbf{1} \big) \mathbf{1}^\top \right), \tag{38.3}\]

where \(\odot\) is the Hadamard product and \(\mathbf{1}\) is the all-ones vector that subtracts each row’s weighted mean. Finally the chain rule through \(S = QK^\top / \sqrt{d_k}\) gives

\[ \frac{\partial L}{\partial Q} = \frac{1}{\sqrt{d_k}} \frac{\partial L}{\partial S} K, \qquad \frac{\partial L}{\partial K} = \frac{1}{\sqrt{d_k}} \Big(\frac{\partial L}{\partial S}\Big)^\top Q . \]

The factor \(1/\sqrt{d_k}\) multiplies the gradient as well as the forward score, so the same scaling that tames the forward softmax also rescales the backward signal. Equation Equation 38.3 is the operation that memory-efficient kernels such as FlashAttention recompute on the fly rather than store, since the \(n \times n\) matrix \(A\) is the dominant memory term.

38.2.2 Computational Cost

The score matrix \(QK^\top\) has \(n^2\) entries, and each entry is a dot product of length \(d_k\), giving \(O(n^2 d_k)\) work to form the scores. Applying the weights to \(V\) costs \(O(n^2 d_v)\). Writing \(d\) for the representation width, the cost is \(O(n^2 d)\) in time and \(O(n^2)\) in memory for the attention matrix. The quadratic dependence on sequence length \(n\) is the main scaling concern, and it motivates the efficient variants discussed later.

38.3 Multi-Head Attention

A single attention computation forms one set of weights and one blend of values. That forces every relationship in the sequence through a single similarity pattern, which is limiting: a sentence has many kinds of relationships at once (grammatical agreement, reference, topic) and one pattern cannot capture all of them. Multi-head attention runs several attention computations in parallel, each with its own learned projections, so different heads can specialize. One head might track local word order while another links a pronoun to its referent.

Given input representations \(X \in \mathbb{R}^{n \times d}\), head \(i\) uses learned projection matrices \(W_i^Q, W_i^K \in \mathbb{R}^{d \times d_k}\) and \(W_i^V \in \mathbb{R}^{d \times d_v}\) to build its own queries, keys, and values,

\[ \text{head}_i = \text{Attention}(X W_i^Q,\; X W_i^K,\; X W_i^V) . \]

The \(h\) head outputs are concatenated along the feature axis and mixed by an output projection \(W^O \in \mathbb{R}^{h d_v \times d}\),

\[ \text{MultiHead}(X) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)\, W^O . \]

A common choice is \(d_k = d_v = d / h\), so the total work of multi-head attention is about the same as a single full-width head, but spread across specialized subspaces.

Tip

Multi-head attention is not “more compute for more power.” With \(d_k = d / h\) the heads split the existing width among themselves, so you get several specialized views at roughly the cost of one wide view.

38.3.1 Self-Attention and Cross-Attention

In self-attention the queries, keys, and values all come from the same sequence. Every position attends to every position in its own sequence, which is how a model builds context-aware representations of its input.

In cross-attention the queries come from one sequence while the keys and values come from another. A translation decoder, for example, forms queries from the target sentence it is generating and reads keys and values from the encoded source sentence. Cross-attention is how one sequence conditions on another.

38.4 Positional Encoding

There is a catch hiding in everything so far: attention has no built-in notion of order. If you permute the rows of \(Q\), \(K\), and \(V\) together, the set of outputs is permuted the same way and otherwise unchanged. The operation is permutation equivariant,3 so on its own it treats a sentence as a bag of tokens. To attention, “dog bites man” and “man bites dog” look identical, which is clearly wrong for language. Word order has to be injected explicitly.

To state this precisely, let \(P \in \{0,1\}^{n \times n}\) be a permutation matrix. Self-attention with projections applied to \(PX\) satisfies

\[ \text{Attention}(PXW^Q, PXW^K, PXW^V) = \text{softmax}\!\left(\frac{P (XW^Q)(XW^K)^\top P^\top}{\sqrt{d_k}}\right) P X W^V = P\, \text{Attention}(XW^Q, XW^K, XW^V), \tag{38.4}\]

where the last step uses that a row-wise softmax commutes with the symmetric permutation \(P (\cdot) P^\top\) (it permutes both the rows being normalized and the columns within each row identically) and that \(P^\top P = I\). Equation Equation 38.4 is the formal content of permutation equivariance, and it shows that any order information must enter through \(X\) itself, which is exactly what positional encoding does.

The original Transformer adds a fixed positional encoding to the input embeddings. For position \(\text{pos}\) and embedding dimension index \(i\), the encoding uses sines and cosines at geometrically spaced frequencies,

\[ PE_{(\text{pos},\, 2i)} = \sin\!\left(\frac{\text{pos}}{10000^{\,2i/d}}\right), \qquad PE_{(\text{pos},\, 2i+1)} = \cos\!\left(\frac{\text{pos}}{10000^{\,2i/d}}\right). \]

Each dimension is a sinusoid with its own wavelength, ranging from short to very long across the \(d\) dimensions. This gives every position a distinct signature, and because shifting position by a fixed offset corresponds to a linear rotation of these sinusoids, the model can express relative position with linear operations. Adding \(PE\) to the token embeddings lets attention see where each token sits without changing the attention mechanism itself. Learned positional embeddings are a common alternative that replaces the fixed formula with a trainable table.

38.4.0.1 Why the offset is a linear map

The claim that a shift in position is a linear transformation can be derived directly. Fix a frequency \(\omega_i = 10000^{-2i/d}\) and group the pair of coordinates for that frequency into the vector \(u(\text{pos}) = (\sin(\omega_i\,\text{pos}),\, \cos(\omega_i\,\text{pos}))^\top\). For any fixed offset \(\Delta\), the angle-addition identities give

\[ u(\text{pos} + \Delta) = \underbrace{\begin{pmatrix} \cos(\omega_i \Delta) & \sin(\omega_i \Delta) \\ -\sin(\omega_i \Delta) & \cos(\omega_i \Delta) \end{pmatrix}}_{R(\omega_i \Delta)} \, u(\text{pos}) , \tag{38.5}\]

a rotation by angle \(\omega_i \Delta\) that depends on \(\Delta\) but not on \(\text{pos}\). Stacking the per-frequency rotations into a single block-diagonal matrix \(R_\Delta\) gives \(PE_{\text{pos}+\Delta} = R_\Delta\, PE_{\text{pos}}\) for all positions. Because a dot product of two encodings then depends only on the relative offset, \(\langle PE_{\text{pos}}, PE_{\text{pos}+\Delta} \rangle = \sum_i \cos(\omega_i \Delta)\), a linear query-key projection can read off relative position without ever learning absolute coordinates. Equation Equation 38.5 is also the seed of rotary position embeddings (RoPE), which apply \(R(\omega_i \,\text{pos})\) directly to the query and key vectors so that \(q_i^\top k_j\) depends only on \(i - j\).

38.5 The Transformer Block

The Transformer, introduced in “Attention Is All You Need” (Vaswani et al., 2017), stacks identical blocks built from a small set of pieces.

Each block has two sublayers: a multi-head attention sublayer and a position-wise feed-forward sublayer. Around each sublayer sits a residual connection followed by layer normalization. If a sublayer computes a function \(\text{Sublayer}(\cdot)\), the block produces

\[ \text{LayerNorm}\big(x + \text{Sublayer}(x)\big). \]

The residual connection4 lets the input pass through unchanged when the sublayer has little to add, which keeps gradients flowing through deep stacks. Layer normalization standardizes each position’s feature vector to stabilize training.

Note

The two ingredients play complementary roles. Residual connections keep a clean gradient highway through the depth of the network; layer normalization keeps the scale of activations stable so that highway does not drift. Together they are what let Transformers stack dozens of blocks.

The position-wise feed-forward sublayer applies the same two-layer network to each position independently,

\[ \text{FFN}(x) = \max(0,\, x W_1 + b_1)\, W_2 + b_2 , \]

with an inner width several times larger than \(d\). It adds nonlinear capacity on top of the mixing that attention provides.

38.5.0.1 Pre-norm versus post-norm

The placement of the normalization relative to the residual branch is not a cosmetic choice and is the single most important detail for training deep stacks. The original formula \(\text{LayerNorm}(x + \text{Sublayer}(x))\) is post-norm: the normalization sits on the residual path itself. Writing the stack as a composition, post-norm rescales the accumulated residual at every layer, so the identity “highway” is repeatedly renormalized and the effective gradient through \(\ell\) layers can grow or shrink geometrically. This is why the original Transformer needed a learning-rate warmup schedule to train at all. The now-standard pre-norm variant moves the normalization inside the branch,

\[ x \;\leftarrow\; x + \text{Sublayer}\big(\text{LayerNorm}(x)\big), \tag{38.6}\]

leaving the residual path in Equation 38.6 a clean sum of sublayer outputs. The gradient from the loss reaches every layer through an unobstructed identity term, so pre-norm trains stably without warmup and scales to hundreds of layers, at the cost of a slightly larger output variance that a final normalization corrects. Most modern large models use pre-norm.

38.5.1 Encoder and Decoder

The encoder is a stack of these blocks using self-attention, where every position attends freely in both directions. It maps an input sequence to a sequence of context-aware representations.

The decoder adds a cross-attention sublayer that reads the encoder output, and it uses causal masking in its self-attention. Causal masking sets the scores for future positions to \(-\infty\) before the softmax, so position \(i\) can attend only to positions \(1, \dots, i\).5 This keeps the model from peeking at tokens it has not yet generated, which is what allows the decoder to generate one token at a time at inference while still training in parallel.

Key idea

Encoder self-attention is bidirectional (every token sees the whole sequence), decoder self-attention is causal (every token sees only the past). That single difference is what separates an understanding model like BERT from a generative model like GPT.

38.6 A From-Scratch Implementation in Base R

The code below implements scaled dot-product attention and a small multi-head version using only base R matrix algebra on a toy sequence of five tokens. No deep learning framework is involved. The point is to make the equations concrete and to inspect the attention weights directly.

When to use this

This implementation is for learning, not production. Real models use optimized framework layers (shown at the end of the chapter). Read it to see exactly how the formula turns into matrix operations, then reach for a framework when you build something real.

We begin with the two building blocks: a numerically safe row-wise softmax, and the attention function itself. The softmax subtracts each row’s maximum before exponentiating, a standard trick that prevents overflow without changing the result.6

Show code
# Row-wise softmax with the standard max-subtraction trick for numerical safety.
row_softmax <- function(M) {
    M <- M - apply(M, 1, max)        # subtract row max before exponentiating
    E <- exp(M)
    E / rowSums(E)
}

# Scaled dot-product attention.
# Q: n_q x d_k, K: n_k x d_k, V: n_k x d_v
# Returns the output matrix and the attention weight matrix.
scaled_dot_product_attention <- function(Q, K, V, mask = NULL) {
    d_k <- ncol(K)
    scores <- (Q %*% t(K)) / sqrt(d_k)   # n_q x n_k raw similarities, scaled
    if (!is.null(mask)) {
        scores[mask] <- -Inf             # masked positions get -Inf before softmax
    }
    weights <- row_softmax(scores)       # each row sums to 1
    list(output = weights %*% V, weights = weights)
}

We build a tiny sequence. Five positions, each with a four-dimensional embedding, plus a sinusoidal positional encoding so order is represented.

Show code
set.seed(2017)

n <- 5    # sequence length
d <- 4    # embedding dimension

tokens <- c("the", "cat", "sat", "on", "mat")

# Random token embeddings standing in for a learned embedding table.
X_tok <- matrix(rnorm(n * d), nrow = n, ncol = d)

# Sinusoidal positional encoding, following the formula in the text.
positional_encoding <- function(n, d) {
    PE <- matrix(0, nrow = n, ncol = d)
    pos <- 0:(n - 1)
    for (i in 0:(d / 2 - 1)) {
        denom <- 10000^(2 * i / d)
        PE[, 2 * i + 1] <- sin(pos / denom)
        PE[, 2 * i + 2] <- cos(pos / denom)
    }
    PE
}

PE <- positional_encoding(n, d)
X  <- X_tok + PE          # inputs carry both content and position
rownames(X) <- tokens
round(X, 3)
#>       [,1]   [,2]   [,3]   [,4]
#> the  1.434  1.452  0.343  2.194
#> cat  0.764 -1.418  1.582  0.518
#> sat  1.648 -0.418 -0.727  2.318
#> on  -1.617 -1.255  0.337 -0.130
#> mat -0.827  0.910 -1.390  0.073

Single-head self-attention. The queries, keys, and values are linear projections of the same input \(X\), which is what makes this self-attention.

Show code
d_k <- 4

W_Q <- matrix(rnorm(d * d_k), d, d_k)
W_K <- matrix(rnorm(d * d_k), d, d_k)
W_V <- matrix(rnorm(d * d_k), d, d_k)

Q <- X %*% W_Q
K <- X %*% W_K
V <- X %*% W_V

attn <- scaled_dot_product_attention(Q, K, V)

A <- attn$weights
rownames(A) <- tokens
colnames(A) <- tokens
round(A, 3)
#>       the   cat   sat    on   mat
#> the 0.000 1.000 0.000 0.000 0.000
#> cat 0.052 0.001 0.000 0.002 0.944
#> sat 0.000 1.000 0.000 0.000 0.000
#> on  0.956 0.000 0.014 0.000 0.029
#> mat 0.001 0.054 0.944 0.001 0.000

Each row of A is the attention distribution for one query token over all key tokens. Reading across a row tells you where that token “looks.” With random projections these patterns are arbitrary, a trained model would learn meaningful ones, but the mechanics are identical. By construction every row is a probability distribution, so the rows sum to one. We check this explicitly rather than trusting it, since a single coding slip in the softmax would break the guarantee.

Show code
rowsum_check <- rowSums(A)
print(round(rowsum_check, 8))
#> the cat sat  on mat 
#>   1   1   1   1   1
stopifnot(all(abs(rowsum_check - 1) < 1e-8))
cat("All attention rows sum to 1.\n")
#> All attention rows sum to 1.

The output vectors are the attention-weighted blends of the value vectors, one row per token. Table 38.1 lists these output vectors for the toy sequence.

Show code
out <- attn$output
rownames(out) <- tokens
colnames(out) <- paste0("dim", seq_len(ncol(out)))

knitr::kable(
    round(out, 3),
    caption = "Self-attention output vectors, one row per input token."
)
Table 38.1: Self-attention output vectors, one row per input token.
dim1 dim2 dim3 dim4
the -2.871 -1.431 -1.750 -0.112
cat 1.710 0.491 1.361 0.276
sat -2.871 -1.431 -1.750 -0.112
on -2.388 3.451 1.082 2.614
mat -0.992 1.186 2.077 2.846

A heatmap makes the attention pattern easy to read. Figure Figure 38.1 shows the attention weight matrix, where brighter cells mean a query token (row) places more weight on a key token (column).

Show code
library(ggplot2)

A_df <- expand.grid(query = tokens, key = tokens)
A_df$weight <- as.vector(A)   # column-major matches expand.grid ordering

# Keep token order on the axes instead of alphabetical sorting.
A_df$query <- factor(A_df$query, levels = rev(tokens))
A_df$key   <- factor(A_df$key,   levels = tokens)

ggplot(A_df, aes(x = key, y = query, fill = weight)) +
    geom_tile(color = "white", linewidth = 0.5) +
    geom_text(aes(label = sprintf("%.2f", weight)), size = 3) +
    scale_fill_viridis_c(limits = c(0, 1)) +
    labs(
        x = "Key token (attended to)",
        y = "Query token (attending)",
        fill = "Weight",
        title = "Scaled dot-product attention weights"
    ) +
    coord_fixed()
Figure 38.1: Attention weight matrix from single-head self-attention on the toy sequence. Each row is a probability distribution over the key tokens.

Now a small multi-head version. We run two heads with their own projections, each of width \(d_k = 2\), concatenate the head outputs, and mix them with an output projection. This mirrors the multi-head formula from earlier.

Show code
multi_head_attention <- function(X, h, d_k, mask = NULL) {
    d <- ncol(X)
    heads <- vector("list", h)
    weights_list <- vector("list", h)
    for (i in seq_len(h)) {
        WQ <- matrix(rnorm(d * d_k), d, d_k)
        WK <- matrix(rnorm(d * d_k), d, d_k)
        WV <- matrix(rnorm(d * d_k), d, d_k)
        a  <- scaled_dot_product_attention(X %*% WQ, X %*% WK, X %*% WV, mask)
        heads[[i]] <- a$output
        weights_list[[i]] <- a$weights
    }
    concat <- do.call(cbind, heads)          # n x (h * d_k)
    W_O <- matrix(rnorm((h * d_k) * d), h * d_k, d)
    list(output = concat %*% W_O, weights = weights_list)
}

set.seed(7)
mha <- multi_head_attention(X, h = 2, d_k = 2)

cat("Head 1 attention rows sum to 1:",
    all(abs(rowSums(mha$weights[[1]]) - 1) < 1e-8), "\n")
#> Head 1 attention rows sum to 1: TRUE
cat("Head 2 attention rows sum to 1:",
    all(abs(rowSums(mha$weights[[2]]) - 1) < 1e-8), "\n")
#> Head 2 attention rows sum to 1: TRUE

mh_out <- mha$output
rownames(mh_out) <- tokens
colnames(mh_out) <- paste0("dim", seq_len(ncol(mh_out)))
round(mh_out, 3)
#>       dim1   dim2   dim3   dim4
#> the -1.406 -4.674 -0.573 -1.912
#> cat  0.054 -2.120  0.296 -1.767
#> sat -0.154 -0.697 -1.257 -0.058
#> on   1.024  1.931 -0.848  0.558
#> mat -1.275 -1.998 -3.874  0.231

We can also see causal masking in action. The mask blocks every position from attending to later positions, which produces a lower-triangular weight matrix.

Show code
# TRUE where a query (row) should NOT see a key (column): strictly future keys.
causal_mask <- upper.tri(matrix(0, n, n))

masked <- scaled_dot_product_attention(Q, K, V, mask = causal_mask)
Am <- masked$weights
rownames(Am) <- tokens
colnames(Am) <- tokens
round(Am, 3)
#>       the   cat   sat    on mat
#> the 1.000 0.000 0.000 0.000   0
#> cat 0.978 0.022 0.000 0.000   0
#> sat 0.000 1.000 0.000 0.000   0
#> on  0.984 0.001 0.015 0.000   0
#> mat 0.001 0.054 0.944 0.001   0

The upper triangle is exactly zero, so each token attends only to itself and the tokens before it, and each row still sums to one. The first token can attend only to itself, which is why its row puts all weight on the. This lower-triangular pattern is precisely the constraint a generative decoder needs.

To recap the implementation: a row-wise softmax plus three matrix products is the entire core of attention, multi-head attention is the same function run several times with different projections and stitched together, and a single boolean mask turns bidirectional attention into causal attention. Everything else in a Transformer (residual connections, normalization, the feed-forward sublayer) wraps around this core.

38.7 Efficient and Variant Attention

We flagged the quadratic cost at the start of the chapter; here is how the field addresses it. The \(O(n^2)\) cost of full attention becomes the binding constraint for long sequences: doubling the input quadruples the attention work and memory. Two broad families of variants reduce it.

Sparse attention restricts each query to a subset of keys instead of all of them. Patterns include local windows, strided or dilated connections, and a few global tokens that every position can reach. Limiting the number of key comparisons per query brings the cost below quadratic while preserving the ability to reach distant context through the global tokens.

Linear attention rewrites the computation to avoid forming the full \(n \times n\) matrix. By replacing the softmax with a feature map \(\phi(\cdot)\) and reordering the products as \(\phi(Q)\,(\phi(K)^\top V)\), the cost becomes linear in \(n\). The approximation trades some expressiveness for scalability on long inputs.

38.7.0.1 Deriving the linear-attention reordering

The key algebraic fact is associativity of matrix multiplication. Suppose the exponential kernel is replaced by a separable surrogate \(\exp(q_i^\top k_j) \approx \phi(q_i)^\top \phi(k_j)\) for some nonnegative feature map \(\phi: \mathbb{R}^{d_k} \to \mathbb{R}^{r}\). The (unnormalized) output at position \(i\) is then

\[ o_i = \sum_{j=1}^{n} \frac{\phi(q_i)^\top \phi(k_j)}{\sum_{m=1}^{n} \phi(q_i)^\top \phi(k_m)}\, v_j = \frac{\phi(q_i)^\top \left( \sum_{j=1}^{n} \phi(k_j)\, v_j^\top \right)} {\phi(q_i)^\top \left( \sum_{m=1}^{n} \phi(k_m) \right)} . \tag{38.7}\]

The two sums in Equation 38.7 do not depend on \(i\). The numerator inner sum \(\sum_j \phi(k_j) v_j^\top \in \mathbb{R}^{r \times d_v}\) and the denominator sum \(\sum_m \phi(k_m) \in \mathbb{R}^{r}\) can each be formed once in \(O(n r d_v)\) work, after which every query is answered by a single \(r\)-dimensional matrix-vector product. Total cost is \(O(n r d_v)\) time and \(O(r d_v)\) memory, linear in \(n\) rather than quadratic, with no \(n \times n\) matrix ever materialized. The price is that \(\phi(q)^\top \phi(k)\) is a rank-\(r\) similarity: it can only represent kernels that factor through an \(r\)-dimensional feature space, so the exact softmax (an infinite-dimensional RBF-type kernel) is recovered only in the limit \(r \to \infty\). For causal (decoder) attention the running sums become prefix sums, which is what makes linear attention equivalent to a linear recurrent state of fixed size \(r \times d_v\), recovering the constant-memory streaming behavior of an RNN while keeping parallel training.

The complexity tradeoffs of the main families are summarized in Table 38.2.

Table 38.2: Cost of attention variants for sequence length \(n\), width \(d\), window \(w\), and feature dimension \(r\). FlashAttention keeps the exact result but tiles the computation so the \(n \times n\) matrix is never stored.
Mechanism Time Memory Exact softmax?
Full attention \(O(n^2 d)\) \(O(n^2)\) yes
FlashAttention \(O(n^2 d)\) \(O(n)\) yes
Sparse (window \(w\)) \(O(n w d)\) \(O(n w)\) no
Linear / kernel \(O(n r d)\) \(O(r d)\) no
Warning

Sub-quadratic attention is an approximation, not a free lunch. Each variant assumes the important interactions follow a particular structure (local windows, a low-rank pattern, a chosen feature map). When that assumption fits the data it scales beautifully; when it does not, accuracy drops. Match the variant to the structure you actually expect in your sequences.

38.8 Choosing the Hyperparameters in Practice

The architecture exposes a small set of dials, and a few rules of thumb cover most cases. The model width \(d\) (also called \(d_{\text{model}}\)) and depth (number of blocks \(L\)) are the primary capacity knobs; common families fix the ratio so that the parameter count, dominated by \(4 d^2\) per attention block plus \(2 \cdot 4 d^2 = 8 d^2\) per feed-forward block, grows like \(L d^2\). The number of heads \(h\) should divide \(d\) so that \(d_k = d/h\) is an integer; typical values keep \(d_k\) in the range \(32\) to \(128\), since too-small heads lose the per-head expressiveness and too-large heads defeat the purpose of having several. The feed-forward inner width is almost always \(4d\). These are not tuned individually so much as inherited from a known scale.

For diagnostics, the attention weight matrix is directly inspectable, exactly as in the toy example above. Two pathologies are worth watching. Attention collapse, where every row concentrates on one or two positions (often a leading or delimiter token acting as a no-op sink), shows up as near-one-hot rows and signals that those heads have stopped routing information; the softmax-Jacobian argument above explains why such heads also receive almost no gradient. Rank collapse, where repeated self-attention without the feed-forward and residual branches drives all token representations toward a single vector, is the theoretical reason the feed-forward sublayer and residual connections are not optional: they are what keep the representations from degenerating as depth grows. A practical monitor is the entropy of each attention row, \(H_i = -\sum_j A_{ij} \log A_{ij}\), which should sit comfortably between zero (collapsed) and \(\log n\) (uniform, no selectivity).

Where the compute goes

For short sequences the cost is dominated by the feed-forward and projection matrices (\(O(n d^2)\)), not the attention matrix (\(O(n^2 d)\)). Attention only becomes the bottleneck once \(n \gtrsim d\). Below that crossover, reaching for a sub-quadratic attention variant buys little, the projections still cost the same, so profile before optimizing.

The variance argument behind the \(1/\sqrt{d_k}\) scaling is easy to confirm by simulation: the standard deviation of unscaled dot products should grow like \(\sqrt{d_k}\), and dividing by \(\sqrt{d_k}\) should restore unit scale.

Show code
set.seed(1)
check_sd <- function(d_k, reps = 20000) {
    s <- replicate(reps, sum(rnorm(d_k) * rnorm(d_k)))  # q^T k, unit-variance entries
    c(raw_sd = sd(s), scaled_sd = sd(s / sqrt(d_k)))
}
dims <- c(4, 16, 64, 256)
tab <- t(sapply(dims, check_sd))
rownames(tab) <- paste0("d_k=", dims)
round(tab, 3)
#>         raw_sd scaled_sd
#> d_k=4    2.010     1.005
#> d_k=16   4.017     1.004
#> d_k=64   8.070     1.009
#> d_k=256 16.089     1.006

The raw_sd column tracks \(\sqrt{d_k}\) (about 2, 4, 8, 16) while scaled_sd stays near one, which is exactly the invariance the scaling factor is designed to enforce.

38.9 Connection to BERT

The encoder stack described here is exactly what BERT is built from. BERT uses the bidirectional self-attention encoder with no causal mask, so every token attends to its full left and right context, and it is pretrained on large text corpora before being fine-tuned for downstream prediction tasks. The BERT chapter (Chapter 39) develops that pretraining and fine-tuning story in detail and shows how to use the resulting contextual embeddings as features for prediction.

When you reach for a deep learning framework, you will not write the matrix algebra by hand. Attention and Transformer blocks are available as ready-made layers, and your job is to wire them together. The snippet below sketches what a single Transformer block looks like in Keras, and you should be able to recognize every piece from this chapter: a multi-head attention layer, a residual add, layer normalization, the feed-forward sublayer, and a second residual-plus-norm. It is illustrative only and is not run here, since it requires the Keras/TensorFlow backend, which this book does not build against.

Show code
# Illustrative only. Requires TensorFlow/Keras, which is not run in this book.
library(keras)

inputs <- layer_input(shape = c(seq_len, d_model))

attn_out <- layer_multi_head_attention(
    num_heads = 8,
    key_dim   = d_model %/% 8
)(inputs, inputs)                              # self-attention: query = value

x <- layer_add(list(inputs, attn_out)) %>%     # residual connection
    layer_layer_normalization()

ffn <- x %>%
    layer_dense(units = 4 * d_model, activation = "relu") %>%
    layer_dense(units = d_model)

block_out <- layer_add(list(x, ffn)) %>%       # second residual connection
    layer_layer_normalization()

model <- keras_model(inputs, block_out)

38.10 Further Reading

Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., and Polosukhin, I. (2017). “Attention Is All You Need.” Advances in Neural Information Processing Systems. This paper introduced the Transformer and scaled dot-product attention.

Bahdanau, D., Cho, K., and Bengio, Y. (2015). “Neural Machine Translation by Jointly Learning to Align and Translate.” International Conference on Learning Representations. This earlier work introduced the attention mechanism in the context of recurrent encoder-decoder translation.


  1. A token is one unit of the input sequence, for example a word or a sub-word piece in text, or a time step in a time series.↩︎

  2. The softmax of a row \(s\) has entries \(\text{softmax}(s)_j = e^{s_j} / \sum_m e^{s_m}\), which are non-negative and sum to one, so each row is a genuine probability distribution.↩︎

  3. Permutation equivariant means that reordering the inputs reorders the outputs in the same way without changing their values. Contrast with permutation invariant, where reordering leaves the output completely unchanged.↩︎

  4. A residual connection adds the sublayer’s input back to its output, so the sublayer only has to learn a correction to the identity rather than the whole mapping. This is the same trick that made very deep image networks trainable.↩︎

  5. Why \(-\infty\) and not zero? The mask is applied to the scores before the softmax. Exponentiating \(-\infty\) gives exactly zero weight, while a raw score of zero would still receive non-trivial weight after the softmax.↩︎

  6. Subtracting a constant \(c\) from every score in a row leaves the softmax unchanged, because the factor \(e^{-c}\) cancels between numerator and denominator. Choosing \(c\) to be the row maximum keeps every exponent at most zero, so \(e^{\cdot}\) never overflows.↩︎