Were RNNs All We Needed?
Maybe, at least parallelizable ones. But we forgot about many of them 😢
Authors: Leo Feng, Frederick Tung, Mohamed Osama Ahmed, Yoshua Bengio, Hossein Hajimirsadegh
Paper: https://arxiv.org/abs/2410.01201
The resurgence of recurrent networks continues. This time, it's back to the classics (RNN/LSTM/GRU), not the more recent Structured State Space Models (SSM), which are not equivalent to RNNs and belong to a simpler complexity class. (For more details, see this video).
RNNs have fundamental advantages in terms of memory requirements. They are linear O(n) (with respect to sequence length) during training and constant O(1) during inference. This is not the case for vanilla transformers with quadratic O(n2) and linear O(n) complexities, respectively.
However, RNNs had one major drawback: training couldn't be parallelized. They were trained sequentially using backpropagation through time (BPTT), which made it extremely slow for long sequences. Here, transformers' parallelization capabilities, despite their shortcomings in complexity, provided a crucial advantage, enabling scalable training and leading to transformers' dominance over RNNs in almost every application.
Over the past few years, new works have addressed this limitation of RNNs, giving rise to models like LRU, Griffin, RWKV, Mamba 1 and 2, and others. I recently gave a talk about non-transformer architectures at Yerevan DataFest 2024; here are the slides.
This modern diversity of architectures can all be efficiently parallelized using the same algorithm—parallel prefix scan, also known as parallel prefix sum (reference).
The authors of this paper adapt parallel scan to classic LSTM/GRU models, eliminating the dependency of input, forget, and update gates on the hidden state (H). They also remove the tanh non-linearity (similar to LRU). They do not consider vanilla RNNs, citing issues with vanishing and exploding gradients (but recall LRU from DeepMind, which was actually a variation of a vanilla RNN!).
LSTM, by the way, also has 100,500 different variations, such as peephole connections that introduce additional dependencies of gates on the content of the memory cell. Remember that LSTM has two state variables: the internal cell state (C, not visible externally) and the hidden state (H, which is visible externally).
In my view, there are two fundamental sources of information about LSTMs beyond the original papers. One is the PhD dissertation of Felix Gers (link), who introduced the forget gate into the architecture (initially, only the other two gates were present) and also added peephole connections. The other is the PhD dissertation of Alex Graves (link), who came up with CTC loss and multidimensional RNNs. The power of good PhDs—what can I say?
The authors achieve minimalist versions of LSTM and GRU (minLSTM and minGRU, respectively), which require fewer parameters, can be parallelized during training, and achieve good performance. It’s worth remembering that there have been many other attempts in history to enable fast parallel training for recurrent networks, such as QRNN (which differs due to the presence of convolutions) or SRU.
The authors looked at the original LSTM and GRU architectures and removed elements that hindered their training using parallel scan.
In GRU, they removed the dependency of the update gate (z) and hidden state (h) on the previous value of h. The reset gate was eliminated entirely. Then, they removed the tanh non-linearity when computing the final value of h. Now, minGRU requires O(2*d_h*d_x) parameters instead of the O(3*d_h(d_x + d_h)) in the original GRU.
In LSTM, they removed the dependency on the previous state h in the forget and input gates, as well as in the content of the memory cell c. The tanh operation was also removed from the computation of c, resulting in the elimination of the output gate and the memory cell c itself, leaving only h. minLSTM requires O(3*d_h*d_x) parameters instead of O(4*d_h(d_x + d_h)) in LSTM.
In terms of computational speed, the new minLSTM and minGRU models are comparable to Mamba. For sequence lengths of 512 elements, they are 235x and 175x faster than the original LSTM/GRU, respectively. For longer sequences, the improvement is even more significant.
More memory is required (up to 88% more), as a large computation graph is needed for parallelization. Mamba requires 56% more memory than minGRU.
In the Selective Copy task from the Mamba paper, the minLSTM, minGRU, and Mamba (S6) models successfully solve the task, while S4, H3, and Hyena only partially succeed (based on the Mamba paper results).
In RL MuJoCo locomotion tasks (HalfCheetah, Hopper, Walker) from the D4RL benchmark, the models were compared against Decision Transformer variants, including Decision S4, Decision Mamba, and (Decision) Aaren. minLSTM and minGRU outperformed Decision S4 and were comparable to Decision Transformer, Aaren, and Mamba.
For language modeling, the authors used character-level GPT (nanoGPT) trained on Shakespeare’s works. The test loss for minGRU, minLSTM, Mamba, and Transformers were close. Mamba performed slightly worse than others but trained the fastest (400 steps). minGRU/minLSTM reached optimum in 575/625 steps, while the transformer required 2000.
It’s peculiar that they compare minLSTM/minGRU with transformers and SSMs but not with optimized RNNs like LRU, SRU, or QRNN—it wouldn’t surprise me if they performed comparably. It would also be interesting to see comparisons with the recently introduced xLSTM, for which official code is now available (GitHub).
Overall, it seems that a comprehensive review paper comparing the many known parallel RNNs is sorely needed. Any takers?