Feb 1, 2026
Here’s where we’re at:
We have an input \(X \in \R^{n \times d}\), where \(n\) is the sequence length, and \(d\) is the embedding dimension:
\[ X := \begin{bmatrix} t_{11} & t_{12} & \cdots & t_{1d} \\ t_{21} & t_{22} & \cdots & t_{2d} \\ \vdots & \vdots & \ddots & \vdots \\ t_{n1} & t_{n2} & \cdots & t_{nd} \end{bmatrix} \]
Here \(t_{ij}\) means the \(j^{th}\) component of the \(i^{th}\) position in the input.
And up to now, we basically have one tool at our disposal: the humble MLP. The problem is that the MLP has type \(\R^a \to \R^b\)… i.e. it takes vector input, not matrix input.
The naive solution is to simply apply the MLP row-wise to each position:
\[ \begin{bmatrix} t_{11} & t_{12} & \cdots & t_{1d} \\ t_{21} & t_{22} & \cdots & t_{2d} \\ \vdots & \vdots & \ddots & \vdots \\ t_{n1} & t_{n2} & \cdots & t_{nd} \end{bmatrix} \mapsto \begin{bmatrix} MLP(t_{11} & t_{12} & \cdots & t_{1d}) \\ MLP(t_{21} & t_{22} & \cdots & t_{2d}) \\ \vdots & \vdots & \ddots & \vdots \\ MLP(t_{n1} & t_{n2} & \cdots & t_{nd}) \end{bmatrix} \]
Of course this doesn’t actually work. We don’t even need to test it. Even if each row was sent through a different MLP, the network would only see one position. It’s not even sequence modeling. It’s literally just an expensive 1-gram model.
So then we try another approach: simply concatenate each row into one huge input vector and feed that in:
\[ \begin{bmatrix} t_{11} & t_{12} & \cdots & t_{1d} \\ t_{21} & t_{22} & \cdots & t_{2d} \\ \vdots & \vdots & \ddots & \vdots \\ t_{n1} & t_{n2} & \cdots & t_{nd} \end{bmatrix} \mapsto MLP(\mathrm{concat}(t_i)) \]
In theory this can work! But in practice (as we’ve seen in the last few posts in the series), this hits a wall pretty quickly. The MLP-concat-net is nice because it’s kind of the least amount of inductive bias we could reasonably have when designing a model, but in practice, it performs poorly.
It’s also grossly inefficient per-layer; \(O((nd)^2)\), compared to row-wise MLP: \(O(n \cdot d^2)\). (assuming hidden dimensions as wide or wider than the input)
So where do we stand? We know concatenated input hits a wall, and we know that row-wise MLP is nicer, but doesn’t work out-of-the-box because it can’t take in context information.
What if there was a way to ‘bake-in’ context to each position, before we ran the MLP? Something like:
\[ \mathrm{MLP}(\varphi(X)) \]
Where here \(\mathrm{MLP}\) means row-wise MLP (the same function to each input position/token), and \(\varphi : \R^{n \times d} \to \R^{n \times d}\) is some magical function that does this ‘baking-in’ of previous context. So \(\varphi_i(X)\) incorporates the information given by tokens \(t_1, t_2, \dots, t_{i-1}\).
Well, that’d be great but how do we do this?
First we’re going to need to mix the position information somehow. Note that if \(A\) is some \((n\times n)\) learned weight matrix, then the expression \(AX\) accomplishes this by mixing the rows of \(X\).
So first write \(\varphi(X) = A X\)
Next, observe that if \(A\) were a static weight matrix, this would amount to setting position \(i\) to some fixed combination of the other positions, regardless of their contents. In other words, \(A\) needs to be a function of \(X\), so set \(A = XB\) for some \(B \in \R^{d\times n}\). Then \(\varphi(X) = (XB)X\)
Again, observe that if \(B\) were a static weight matrix, the entry \(A_{ij} = X_i \cdot B_j\) depends on the content at position \(i\) but not position \(j\). We need \(A_{ij}\) to depend on both tokens, so again \(B\) must be data-dependent.
Since \(B\) must be data-dependent, and in \(\R^{d,n}\), the simplest choice is \(X^T\).
Then we have \(\varphi(X) = (XX^T) X\).
This does accomplish a thorough mixing of position information, but consider what this expression is really saying:
“we mix the tokens of \(X\) according to their similarity scores”… But this amounts to a questionable choice of inductive bias. We want the model to decide how to mix them, and currently the only way the model can do so, would be to choose an incredibly nice embedding space so that the \((XX^T) X\) mixer accomplishes anything useful.
Fortunately, there’s a clear fix: let the model project each term into a new space with learned matrices, so that the model can learn to combine the position information as it sees fit.
So let \(W_K, W_Q, W_V\) be our projection matrices, each in \(\R^{d \times d}\), and we have (making sure the types match):
\(\varphi(X) = (X W_Q)(W_K^T X^T) X W_V = (X W_Q)(X W_K)^T X W_V\)
We choose the subscripts because the post-hoc explainer goes something like this:
(for notation’s sake let \(Q := (X W_Q)\), \(K := X W_K\), and \(V := X W_V\))
The term \(QK^T\) is responsible for deciding which rows of \(V\) get combined; i.e. which tokens actually need to influence each other. So if the “keys” (\(K\)) and “queries” (\(Q\)) match, they get high scores in that term. The term \(V\) represents direction or “value” each token should push the others.
This is the one you’ve probably heard about or see in the original paper. Personally I prefer to think of it as just: \((XX^T)X\) is the nice mixer architecture and we need to pick arbitrary \(A, B, C\) projection matrices for each term to let the model decide how to use it.
There are two problems with this that i’ve ignored till now.
First: In the current formulation, the network is cheating.
Suppose \(QK^T\) is NOT lower-triangular (it will not be if the model has any say lol).
Think about what, for instance, a \(1\) in the first row and \(n^{th}\) column will do: It will pull the information from the \(n^{th}\) row of value matrix into the first row of the output.
In other words, the model can use positions from late in the sequence to predict tokens early in the sequence. But the whole task we’re trying to do is use tokens \(t_1, t_2, \dots, t_n\) to predict token \(t_{n+1}\).
So what we need to do is mask out the upper triangular part of the \(QK^T\) matrix. The full expression becomes:
\(\varphi(X) = \mathrm{mask}(QK^T)V\)
Now, I’m going to break the pedagogical purity of “we only do things that seem motivated” and tweak the magic mixer formula one more time. This time, we’ll have:
\(\varphi(X) = \mathrm{softmax}(\mathrm{mask}(QK^T))V\)
This is because:
Hopefully you can buy that the softmax is, at minimum, not a harmful inductive bias that just makes debugging easier, and at maximum, that it’s super nice for the optimizer, and so we’re doing it in either case.
Great. But now there’s one more problem:
The model doesn’t have access to positional information.
This is a subtle point. When you pass an input through the model, the first row of \(X\) gets transformed into the first row of the output; i.e. the matmul isn’t like ‘scrambling’ the input or anything. And the first token in, in some sense, “knows” its the first token because the rest of the tokens are masked in the attention matrix.
The issue is that the model can’t actually use position as a feature in the attention computation. Think about it: every position gets put through the same learned projection matrix, so the actual machinery of the attention mixer can’t reason about position.
Think of it this way:
Token “cat” at position 2 will produce the same output from e.g. \(X W_Q\) as the token “cat” at position, say, 10. To the learned matrices, tokens look the same no matter where they show up in the sequence. Each row gets the same projection, no matter where it is in the input.
the simplest fix is to add some learned \(n \times d\) matrix to \(X\) before it reaches the rest of the network, so the model has a chance to bake in positional information
So from now on, when we say \(X\), we implicitly mean \(X + W_{pos}\) where the latter term is a learned weight matrix.
Before we speculate on further changes to our new attention mechanism, let’s try it. The core code is here:
class MLP(nn.Module):
def __init__(self, d, d_hidden, d_out, n_layers):
super().__init__()
layers = []
for _ in range(n_layers):
layers.extend([nn.Linear(d, d_hidden), nn.ReLU()])
d = d_hidden
layers.append(nn.Linear(d_hidden, d_out))
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor):
return self.net(x)
class Attention(nn.Module):
def __init__(self, n, d):
super().__init__()
self.n = n
self.d = d
self.sqrt_d = torch.sqrt(torch.tensor(d))
self.last_entropy = float('-inf')
self.W_Q = nn.Parameter(torch.randn(d,d))
self.W_K = nn.Parameter(torch.randn(d,d))
self.W_V = nn.Parameter(torch.randn(d,d))
def forward(self, x: torch.Tensor):
n, sqrt_d = self.n, self.sqrt_d
W_Q, W_K, W_V = self.W_Q, self.W_K, self.W_V
# attention mask
attn = (torch.masked_fill(
(x @ W_Q) @ (x @ W_K).transpose(-2,-1),
torch.triu(
torch.ones(n, n, dtype=torch.bool, device=x.device),
diagonal=1
),
float('-inf')
)).softmax(dim=-1)
# track attn mask entropy for debugging
self.last_entropy = -(attn * (attn + 1e-10).log()).sum(dim=-1).mean().item()
return attn @ (x @ W_V)
class AttnNet(nn.Module):
def __init__(self, tokenizer, n, d, mlp_dh, mlp_depth):
super().__init__()
self.n = n
self.d = d
self.d_h = mlp_dh
self.tokenizer = tokenizer
self.vocab_size = tokenizer.vocab_size
self.W_pos = nn.Parameter(torch.randn(n, d))
self.embedding = nn.Embedding(self.vocab_size, d)
self.attention = Attention(n,d)
self.mlp = MLP(d, mlp_dh, self.vocab_size, mlp_depth)
# x: (b, n) |-> (b, n, d) (b=batch dimension)
def forward(self, x: torch.Tensor):
x = self.embedding(x) + self.W_pos # (b,n) |-> (b,n,d)
x = self.attention(x) # (b,n,d) |-> (b,n,d)
return self.mlp(x) # (b,n,d) |-> (b,n,vocab_size)And our hyperparams:
vocab_size = 2048
ctx_len = 32
batch_size = 4096
batch_log_interval = 100
grad_log_interval = 500
batch_val_interval = 1000
embedding_dim = 256
mlp_dh_coeff = 8
mlp_dh = embedding_dim * mlp_dh_coeff
mlp_depth = 2
optimizer = optim.SGD(
params = model.parameters(),
lr = 0.05,
nesterov=True,
momentum=0.9
)
And the results are… bad. Val perplexity is in the hundreds, generation quality is abysmal. The key thing to note is attn_entropy:
{'batch': '100/3548', 'epoch': '1/3', 'attn_entropy': '0.003', 'train_loss': 6.475}
{'batch': '200/3548', 'epoch': '1/3', 'attn_entropy': '0.003', 'train_loss': 6.472}
{'batch': '300/3548', 'epoch': '1/3', 'attn_entropy': '0.003', 'train_loss': 6.425}
{'batch': '400/3548', 'epoch': '1/3', 'attn_entropy': '0.003', 'train_loss': 6.366}
Recall that an entropy near zero means attention is essentially one-hot — the model is ignoring context.
Ok so our attention mechanism isn’t working. The model is only looking at one token. Why? Well let’s think about the QK^T matrix at the start of training.
We’re initializing \(Q,K,V\) as \(d \times d\) with entries from \(\mathcal{N}(0,1)\). Suppose \(X\) also has unit variance with mean zero. Then (wlog): \[\text{Var}((XQ)_{ij}) = \text{Var}\left(\sum_{k=1}^d X_{ik} Q_{kj}\right) = \sum_{k=1}^d \text{Var}(X_{ik}Q_{kj}) = d\]
The variance is huge; one entry is probably going to dominate the softmax output.
The fix is to normalize the attention down to unit variance.
\[\mathrm{softmax}(\mathrm{mask}(QK^T / \sqrt d))V\]
attn = (torch.masked_fill(
(x @ W_Q) @ (x @ W_K).transpose(-2,-1),
torch.triu(
torch.ones(n, n, dtype=torch.bool, device=x.device),
diagonal=1
),
float('-inf')
) / sqrt_d).softmax(dim=-1)Also it’s probably worth doing the same for our weight initialization:
self.W_Q = nn.Parameter(torch.randn(d,d) / self.sqrt_d)
self.W_K = nn.Parameter(torch.randn(d,d) / self.sqrt_d)
self.W_V = nn.Parameter(torch.randn(d,d) / self.sqrt_d)
Empirical Results:
1.492
which is healthy.0.089, which is degenerate
again…Recall the shape of our model: \(\mathcal{L} = \mathrm{CE}(\mathrm{MLP}(\mathrm{Attn}(X)), y)\) (CE is cross-entropy)
Note that the only way that the MLP sees the input \(X\) is through the attention block. So it’s only seeing \(SV\) where \(S\) is the attention mask, and \(V\) the value matrix.
Suppose \(S\) is uniform row-wise in the lower-triangular part. Then suddenly \(SV_j\) is a uniform mixture of all the value projections of tokens up to position \(j\). The MLP sees kind of a fuzzy picture.
So all the gradients are flowing through the MLP which is incentivized to push the attention mask to look one-hot, in order to clearly mark out the most important token.
It turns out also that the gradient flow for a degenerated attention matrix is quite bad; i.e. once the attention mask is only attending to one token, the gradient signal back to \(W_K\) and \(W_Q\) is effectively norm zero, so it’s ~unrecoverable. The gradient math is a lot and since this post is long enough I’ll leave demonstrating this as an exercise for the reader.
What’s the fix? Instead of making the MLP input solely a function of the attention transform, let attention compute a delta from the original input, and give the mlp the sum. Suddenly the row-wise MLP no longer has such a strong incentive to zero out all but one of the attention indices. Also the gradients being additive have much nicer dynamics. We can also do the same for the MLP, so really the entire network’s job is to compute a delta from our original vector, which we sum and project out to the vocab size.
\(\mathcal{L} = \mathrm{CE}(W_U(X + \mathrm{MLP}(X + \mathrm{Attn}(X))), y)\)
for some learned \(W_U\) unembedding projection back to our vocab size.
The code is probably clearer:
# in the MLP, make sure the output is the same shape is the input
class MLP(nn.Module):
def __init__(self, d, d_hidden, n_layers):
super().__init__()
layers = []
d_in = d
for _ in range(n_layers):
layers.extend([nn.Linear(d, d_hidden), nn.ReLU()])
d = d_hidden
layers.append(nn.Linear(d_hidden, d_in))
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor):
return self.net(x)
# (in AttnNet.__init__)
self.W_out = nn.Parameter(torch.randn(d, self.vocab_size) / math.sqrt(d))
# ... rest of init
# (in AttnNet.forward)
# x: (b, n) |-> (b, n, vocab_size)
def forward(self, x: torch.Tensor):
x = self.embedding(x) + self.W_pos # (b,n) |-> (b,n,d)
x = x + self.attention(x) # (b,n,d) |-> (b,n,d)
x = x + self.mlp(x) # (b,n,d) |-> (b,n,d)
return x @ self.W_out # (n d)(d vocab_size) = (n, vocab_size)Does it work?
Yep! Attention entropy stays consistent around 1, which means it’s attending to about 1/3 of our context size of 32. \((log(10) ~= 1)\) At least, if it were uniformly attending. It’s probably spikier in reality. But at any rate it’s not one-hot!
'batch': '3500/3548', 'epoch': '3/3', 'attn_entropy': '1.162', 'val_ppl': 68.284
Also, note the val ppl! Our model has 6.5M params and we’re competitive with our earlier 200M MLP model!
I think this is a long enough post for just attention. We have the core architecture down and it’s all reasonably motivated.
Next time, we’ll flesh out the network to the transformer architecture, which amounts basically to stacking a few of these in parallel and on top of each other and making sure variance doesn’t blow up again.
Also there’s a lot of stuff that you might be struggling to ‘typecheck’ here. Like, how do we compute cross-entropy loss from N predictions at once? How do we actually make the predictions? With a sliding window or what?
First, I’d suggest checking out the linked repo, because it has all of that. And if that’s not enough, the next post will have a super thorough “wtf is going on and what are the mathematical types” section at the end.