The things that still confuse me about transformers is:
1. Why do we _add_ the positional embedding to the semantic embedding? It seems like it means certain semantic directions are irreversibly with certain positions.
2. I don't understand why the attention head (which I can implement and follow the math of) is described as "key query value lookup". Specifically, the Q and K matrices aren't structurally distinct – the projections into them will learn different weights, but one doesn't start out biased key-ward and the other query-ward.
The first one: transformers are "permutation invariant" by nature, so if you permute the input and apply the opposite permutation to the output you get the exact same thing. The transformer itself has no positional information. RNNs by comparison have positional information by design, they go token by token, but the transformer is parallel and all tokens are just independent "channels". So what can be done? You put positional embeddings in it - either by adding them to the tokens (concatenation was also ok, but less efficient) or by inserting relative distance biases in the attention matrix. It's a fix to make it understand time. It's still puzzling this works, because mixing text tokens with position tokens seems to cause a conflict, but it doesn't in practice. The model will learn to use the embedding vector for both, maybe specialising a part for semantics and another for position.
The second question. Neural nets find a way to differentiate the keys from queries by simply doing gradient descent. If we tell the model it should generate a specific token here, then it needs to fix the keys and queries to make it happen. The architecture is pretty dumb, the secret is the training data - everything the transformer learns comes from the training set. We should think about the training data when we marvel at what transformers can do. The architecture doesn't tell us why they work so well.
With regards to the "It's still puzzling this works" wrt positional encoding, I have developed an intuition (that may be very wrong ;-). If you take the fourier transform of a linear or sawtooth function (akin to the the progress of time), I think you get something that resembles the positional encoding in the original transformer. EDIT: fixed typo
This is a good intuition. At times it reminds me of old school hand rolled feature engineering used in time series modelling: assuming that the signal is made up of a stationary component and a sine wave. Though haven't managed to mathematically figure out if the two are equivalent.
> The architecture is pretty dumb, the secret is the training data
If this were true, we could throw the same training data at any other "dumb" architecture and it would learn language at least as well/fast as transformers do. But we don't see that happening, so the architecture must be smartly designed for this purpose.
Actually there are alternatives by the hundreds, with similar results. Reformer, Linformer, Performer, Longformer... none is better than vanilla overall, they all have an edge in some use case.
And then we have MLP-mixer which just doesn't do "attention" at all, MLP is all you need. A good solution for edge models.
Other dumb architectures don't parallelize as well. Other architectures that parallelize at similar levels (RNN-RWKV, H3, S4, etc.) do perform well at similar parameter counts and data sizes.
Regarding the positional encoding, why not include a scalar in the range (0..1) with every token where the scalar encodes the position of the token? This adds a small amount of complexity to the network, but it could aid comprehensibility which to me seems preferable if you're still doing research on these networks.
I'm still not clear on the second question. If lalaithion's original statement "the Q and K matrices aren't structurally distinct" is true, then once the neural network is trained, how can we look at the two matrices and confidently say that one is the query matrix instead of it being the key matrix (or vice versa)? To put it another way: is the distinction between query and key roles "real" or is it just an analogy for humans?
I am not an expert, but I think that they are structurally identical only in decoder only transformers like GPT. The original transformers were used for translation, and so the encoder-decoder layers use Q from the decoder layer and K from the encoder layer. The attention is all you need paper has an explanation:
> In "encoder-decoder attention" layers, the queries come from the previous decoder layer, and the memory keys and values come from the output of the encoder. This allows every position in the decoder to attend over all positions in the input sequence. This mimics the typical encoder-decoder attention mechanisms in sequence-to-sequence models such as...
Would this not imply that if I encrypt the input and then decrypt the output I would get the correct result (i.e. what I would have gotten if I used the plaintext input)?
I recently had the same questions and here is how I understand it:
1. You could concatenate the positional embedding and the semantic embedding and that way isolate them from each other. But if that separation is necessary, the model can learn the separation itself as well (it can make positional embeddings and semantic embeddings orthogonal to each other), so using addition is strictly more general.
2. My sense is that you could merge the Q and K matrices and everything would work mostly the same, but with multi-headed attention this will typically result in a larger matrix than the combined sizes of Q and K. It's basically a more efficient matrix factorization.
Curious to see if I got this right and if there is more to it.
One advantage of summing is that the lower frequency terms hardly change for a small text, so effectively there is more capacity for embeddings with short texts, while still encoding order in long texts.
1. High dimensional embedding space is way more vast than you'd think, so adding two vectors together doesn't really destroy information in the same way as addition does in low dimensional cartesian space - the semantic and position information remains separable.
2. I find the QKV nomenclature unintuitive too. Cross attention explains it a bit, where the Q and K come from different places. For self attention they are the same, but the terminology stuck.
1. It works, the direct alternative (concatenation) allocates a smaller dimension to the initial embedding, and also added positional embeddings are no longer commonly used in newer Transformers. Schemes like RoPE and ALiBi are more common.
2. I'm not 100% sure I understand your question. The Ks correspond to the Vs, and so is used to compute the weighted sum over Vs. This is easiest to understand when you think of an encoder-decoder model (Qs come from the decoder, KVs come from the encoder), or decoding in a decoder (there is 1Q and multiple KVs)
One aspect of the specific positional embedding used there is that it explicitly encodes a signal that the very first attention layer can directly use for both relative and absolute position - i.e. that there can be a trivial set of wights for "pay attention only to the token two tokens to the right from the target token" and also a trivial set of weights saying "pay attention to the first token in the sequence" and also "pay attention to this aspect of all the words weighed by the distance from the target token". As by default the transformer architecture is effectively position-blind, having the flexibility to learn all these different types of relations is important; and many possible simple, clean, efficient position encodings make it easy to represent some relations but very difficult for others, perhaps theoretically possible but needing extra layers and/or hard to learn by gradient descent.
To answer (2): You are token i. In order to see how much of a token j's value v_j you update yourself with, you compare your query q_i with token j's key k_j. This gives you the asymmetry between queries and keys.
This is even more apparent in a cross-attention setting, where one stream of tokens will have only queries associated with it and the other will have only keys/values.
1. Why do we _add_ the positional embedding to the semantic embedding? It seems like it means certain semantic directions are irreversibly with certain positions.
2. I don't understand why the attention head (which I can implement and follow the math of) is described as "key query value lookup". Specifically, the Q and K matrices aren't structurally distinct – the projections into them will learn different weights, but one doesn't start out biased key-ward and the other query-ward.