Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models
Even better than Mamba!
Authors: Soham De, Samuel L. Smith, Anushan Fernando, Aleksandar Botev, George Cristian-Muraru, Albert Gu, Ruba Haroun, Leonard Berrada, Yutian Chen, Srivatsan Srinivasan, Guillaume Desjardins, Arnaud Doucet, David Budden, Yee Whye Teh, Razvan Pascanu, Nando De Freitas, Caglar Gulcehre
Paper: https://arxiv.org/abs/2402.19427
Recently, RecurrentGemma (https://arxiv.org/abs/2404.07839), built on the Griffin architecture, was published. Griffin itself was released by DeepMind in late February 2024. Let’s delve into it.
RNNs strike back!
This work revolves around a new recurrent block, RG-LRU, on which the architectures Hawk (alternating RG-LRU and MLP) and Griffin (alternating MLP with a mix of RG-LRU and local attention) are built. Hawk outperforms a similarly-sized Mamba, while Griffin surpasses Llama-2, training on six times less data.
The architecture is built on repeating residual blocks, similar to those used in pre-norm transformers: (RMSNorm + Temporal mixing block) and (RMSNorm + MLP block), both with a residual connection on top.
The MLP block employs a gated block similar to Noam Shazeer's GeGLU (https://arxiv.org/abs/2002.05202), now called GeGeLU in this work: it has two branches, each with dimensionality M*D (with M=3 chosen in this work, meaning embeddings are expanded), where one branch contains the GeLU nonlinearity, and the other calculates coefficients for element-wise multiplication, after which the merged branches are processed by another linear layer.
The most interesting and variable part is the Temporal mixing block. There are three versions:
1) global Multi-Query Attention (MQA)
2) local (sliding-window) MQA
3) a new RG-LRU recurrent block.
Version 1 (MQA, also thanks to Noam Shazeer, https://arxiv.org/abs/1911.02150) replaces the classic Multi-Head Attention (MHA), where K and V are shared across all heads. It uses RoPE positional embeddings.

Version 2, with local attention (also called sliding window attention), is similar to the local attention in Longformer. The local attention window is set to 1024 tokens.

Finally, version 3 resembles a block from Mamba, with two branches, one of which uses the same GeLU as in MLP, and the other uses one-dimensional convolution + RG-LRU layer.
The RG-LRU (Real-Gated Linear Recurrent Unit) is an advancement of LRU (https://arxiv.org/abs/2303.06349) with two added gates that do not depend on the previous recurrent state, only on the input.
The input gate 𝑖_t is similar to that in LSTM, filtering or scaling the input. The second gate, recurrence gate 𝑟_t, is new and can approximate interpolation between the standard LRU update from the original work and the previous hidden state, thus discarding input data and preserving information from the past. Appendix A further explores the behavior of the recurrent gate.
For initializing RG-LRU, neither polynomials like HiPPO nor discretization like SSM are used. Nor are complex numbers used in recurrence, as was done in LRU. A complex-valued variant called CG-LRU (Complex-Gated Linear Recurrent Unit) is also considered in Appendix B. It is more expressive but doesn’t aid practical language modeling.
Results
The results are interesting. Three variants are considered:
MQA-Transformer as the baseline.
Hawk with the same residual and MLP as the transformer baseline but with the recurrent block with RG-LRU as the temporal mixing block.
Griffin with the same residual and MLP as the transformer baseline but with a mix of recurrent and local MQA blocks (alternating every two residual blocks with RG-LRU, one block with local MQA).
The models scaled from 100M to 14B parameters, and the number of training tokens was scaled following Chinchilla recipes.
For evaluations on various tasks, models were trained on 300B tokens. All models demonstrate a beautiful power-law dependence between loss and training FLOPs. Griffin’s losses are consistently slightly lower than the transformer baseline at the same budget. Hawk’s are higher, but with a trend towards reduction as the budget increases.
External baselines included Mamba-3B and Llama-2 (7B, 13B), trained on larger (600B/2T) and different datasets. Hawk and Griffin perform very well, beating Mamba, despite being trained on smaller datasets.
For training large models on a set of devices, model parallel training was implemented via layer sharding. A separate challenge is the efficient implementation of recurrences on devices, as unlike classic architectures, they operate in a low FLOPs-to-byte ratio mode, making calculations memory bound. Custom kernels were written using Pallas, a special extension of JAX. You can see how this looks in the RecurrentGemma repo here. Using a linear scan made it three times faster than the native implementation. Using associative scan (used in S5, https://arxiv.org/abs/2208.04933) is slower, and convolutions aren't feasible, as the gating mechanism of RG-LRU is incompatible with convolutional representation.
As sequence lengths increase, Griffin trains faster than the transformer. This difference is especially noticeable when the sequence length is significantly greater than the model width and attention computation occupies a significant portion of the total time.
In terms of latency on inference, Hawk and Griffin are faster than the MQA transformer (which is in turn faster than the classic MHA). A noticeable difference occurs at larger lengths, mainly after 2048 tokens. Throughput (Figure 1b above) is also better in the new models (especially Hawk), partly due to better latency, partly due to smaller cache sizes and the ability to fit a larger batch on the same device. Griffin is slower than Hawk due to its growing local attention cache with increasing batch size.
On next-token prediction in long sequences, the new models outperform transformers and extrapolate to much longer sequences (at least 4x) than were seen in training.
An interesting observation is that models trained on shorter lengths (2k vs 8k) perform better on short lengths. Therefore, it is important to choose the sequence length during training for future tasks.
A recent work “Repeat After Me: Transformers are Better than State Space Models at Copying” (https://arxiv.org/abs/2402.01032) showed that transformers perform better on tasks like copying or retrieval than SSMs.
Authors tested the new models on Selective Copying and Induction Heads tasks (as in the work on Mamba). All three models can perfectly solve the copying task (although Hawk trains slower). On induction heads, all three solve the task up to a certain length limit, beyond which the transformer fails, unable to extrapolate. These tasks were also successful with Mamba.
In the mentioned work "Repeat After Me: Transformers are Better than State Space Models at Copying," a retrieval task was proposed using a synthetic phonebook where one must select a phone number by name. The prompt contains a phonebook followed by two examples and the name for which the phone number needs to be retrieved. On this task, Hawk quickly degrades to zero as the length of the book increases, a behavior similar to that of Mamba, which is not surprising given its smaller state size. The transformer holds up to lengths familiar from training and then degrades to zero. Griffin perfectly maintains performance up to the length of the local attention context and then begins to degrade but extrapolates further than the transformer.
An interesting development indeed!