A (57:20)
And this is actually kind of an interesting one, I gotta say. So we need to talk about the way that a standard transformer works, basically in the MLP section of the transformer. So usually you have attention, your attention mechanism, which is figuring out what parts of the input to pay attention to. Then you have your mlp, it's gonna like chew on that data that you get and then you pass it on through the residual stream into the next layer. Right? Okay. So typically the MLP part of that process of those layers has kind of three steps, right? So so typically you're first going to project your input. So you have your, your input that comes in from the previous layer. And this is going to be some, some vector, some list of numbers. You're going to multiply it by matrix like projected upwards to a higher dimensional space to create a longer vector. And then you're going to do some interesting processing on it, and then you're going to project it back down to the residual dimensional space and off it goes. And normally, at least recently, the way people have started to think about this is that the first step where you are blowing up the size of that vector is doing an operation that's akin to generating a key in the sense of keys and values. So keys tell you, hey, I'm here. Here is all the information that I have on offer. Basically this is the information the token is broadcasting to the world. Here is the information that I can share. And then you do a bunch of processing on it. And then when you project down back to the residual dimension space, usually that's interpreted as Retrieval of the corresponding values from that key. So basically you're kind of saying, hey, here's the key, you blow it up. Here's what I have to share. I'm a token, here's the information I contain. What do you want to do with that? And then the process of chugging on that and spitting out something interesting. That's that last step where you, where you kind of generate the values from that, retrieve the values. And so the, the life of a token going through this process is you come through the last layer, you go through attention, enter the MLP now. So instead of multiplying that input by a big matrix, which is the usual way this works, again blowing you up and essentially doing this key retrieval process where you're saying, hey, here I am, here's the information I contain. That's computationally expensive. That involves matrix math. Ugh, matrix math. You don't like it. GPUs don't like it either. It's like, let's not do that. That takes a lot of time. So instead what we're going to do is we're going to assign an ID that's unique to every token as it's being fed in. So the has an id, eating has an id, every chunk of text has a unique id. And you're going to feed it into the bottom. And instead of multiplying the vector that represents that token at that level of processing by a matrix to blow it up and generate the key instead you're going to go, what's that token's id? Let me look it up in a lookup table where you actually have an embedding that is trained over time, but just like a straight up lookup into a table for this embedding, for that token at that layer, this is unique to that layer, and you pull in that embedding, you basically just do a memory retrieval instead of a vector kind of matrix operation. And that's much, much faster. So it means you basically have to cut out this whole matrix multiplication step. So now what you're going to do is you're going to proceed to the middle ground calculations that normally happen before projecting back down. Now, the key thing is normally that projection upwards that with that matrix multiplication, it's context aware. In other words, that matrix that you're using is a learned matrix. It learns to account for context in the overall prompt, not just the token you're looking for. The problem is when we do this token retrieval thing and we're just like looking up in a library a list of Embeddings for that token to swap in that process is not actually context dependence. You're losing context. And so what they realized was in most cases there's like a SWIGLU is used in between the generation of the key and then the down projection. And the SWIGLU part is context aware already. So there's kind of a redundant use of context in that process. So they went, you know what, we can throw out the matrix multiplication that is context aware that usually generates that key. We can swap it out with just retrieval from a database of this like, you know, kind of embedding. We pull that in and we know that the next step in the, in the processing SWIGGLU anyway is going to be context dependent. We're going to project down and then we're good to go. This has a really, really important consequence for the embedding, the interpretability of the embeddings that you get for those types of tokens when you multiply by a matrix. And it's context dependent in that way to generate your key vector in the usual way. The problem is you're mixing all that context in there. And so the kind of embedding you get or the vector is like this weird mangled Frankenstein monster that combines, yes, the token that you put in, but also the meaning of the sentence around it. When you just do a table lookup, you're exclusively looking at the meaning of just that token at that layer. And what this means is that the representations of all those tokens, over time, they can end up being much more cleanly resolved as you learn. Because as you learn, you're improving, you're iterating on the representations of those tokens that you're pulling out from that retrieval process. This means that those tokens end up being semantically much more distinct. The overlap between them, the cosine similarity is very, very low. And then this also has the advantage of allowing you to offload a lot of this work to the cpu, because that's basically the CPU can handle retrieval, it can't handle the matrix multiplication thing. And so there's a whole bunch more detail. This is actually a really, really interesting paper. Net result is they can cut down on the amount of compute required to reach a given level of performance fairly significant by about a third, precisely because the MLP layer has the dominant fraction of the parameters in the network. And you're functionally getting rid of a third of those not parameters, but of the processing involved in that, in that process. So really interesting paper, Great implications for a Whole bunch of things including even training stability. Check it out if you're are interested in that kind of thing.