Transformers have been all the rage in NLP for the last few years, and there’s even been talk about whether LLMs have solved all the NLP problems, but S4 and Mamba models seem ready to take the world by storm.
So I got around to reading about them. They’re basically more advanced RNNs. The issues that we had with RNNs - the exploding/vanishing gradient problem, the limited context due to a fixed-length context vector, while ‘solved’ by transformers and Attention, introduced more problems in the form of the attention layer’s quadratic scaling with sequence length.
State Space Models, which take inspiration from the modelling of systems that we do in fields as varied as physics (I’ve seen examples explaining SSMs with spring systems), take RNNs and make them more robust by making them Linear Time Invariant. This means the matrices that are multiplied with the inputs are independent of the input and can be calculated in one fell swoop.
This obviously creates the problem that the model doesn’t differentiate between different tokens in the sequence and what to “remember” and “forget” based on tokens. Mamba remedies this by taking out the LTI. “But this just beats the point and makes them inefficient again”, I hear you say. They deal with this by adding a “selective scan” algorithm that’s hardware-aware.
What this means is that they use the prefix-scan-sum to make the calculation of the output tokens efficient on the GPU. (the prefix-scan-sum is a parallel algorithm that parallelizes certain calculations that might seem sequential in nature but are actually not) They also use the GPU SRAM (the faster, much smaller part of the GPU) to do the calculations as opposed to transferring data between the SRAM and the HBM. This is also the trick used in FlashAttention - the method that’s been used lately to “linearise” attention and make it more efficient during inference.
While Mamba models don’t seem to have the same level of ability for In-context learning (few-shot and zero-shot learning) that Transformers do, they add a TON of efficiency. People have realised this and mixed Mamba layers with Transformer layers to drastically reduce the memory and compute footprint and lengthen contexts. (e.g., Jamba)
This is exciting - while Mamba likely won’t completely replace Transformers, this means there are ways to deal with the inefficiencies of the latter.