Authors:
(1) Soham De, Google DeepMind and with Equal contributions;
(2) Samuel L. Smith, Google DeepMind and with Equal contributions;
(3) Anushan Fernando, Google DeepMind and with Equal contributions;
(4) Aleksandar Botev, Google DeepMind and with Equal contributions;
(5) George Cristian-Muraru, Google DeepMind and with Equal contributions;
(6) Albert Gu, Work done while at Google DeepMind;
(7) Ruba Haroun, Google DeepMind;
(8) Leonard Berrada, Google DeepMind;
(9) Yutian Chen, Google DeepMind;
(10) Srivatsan Srinivasan, Google DeepMind;
(11) Guillaume Desjardins, Google DeepMind;
(12) Arnaud Doucet, Google DeepMind;
(13) David Budden, Google DeepMind;
(14) Yee Whye Teh, Google DeepMind;
(15) David Budden, Google DeepMind;
(16) Razvan Pascanu, Google DeepMind;
(17) Nando De Freitas, Google DeepMind;
(18) Caglar Gulcehre, Google DeepMind.
Table of Links
3 Recurrent Models Scale as Efficiently as Transformers
3.2. Evaluation on downstream tasks
4.2. Efficient linear recurrences on device
4.3. Training speed on longer sequences
5.1. A simple model of the decode step
6. Long Context Modeling and 6.1. Improving next token prediction with longer contexts
6.2. Copy and retrieval capabilities
8. Conclusion, Acknowledgements, and References
B. Complex-Gated Linear Recurrent Unit (CG-LRU)
C. Model Scale Hyper-Parameters
D. Efficient Linear Recurrences on Device
E. The Local Attention Window Size of Griffin
G. Improving Next Token Prediction with Longer Contexts: Additional Results
H. Additional Details of the Copy and Retrieval Tasks
Recurrent neural networks (RNNs) have fast inference and scale efficiently on long sequences, but they are difficult to train and hard to scale. We propose Hawk, an RNN with gated linear recurrences, and Griffin, a hybrid model that mixes gated linear recurrences with local attention. Hawk exceeds the reported performance of Mamba on downstream tasks, while Griffin matches the performance of Llama-2 despite being trained on over 6 times fewer tokens. We also show that Griffin can extrapolate on sequences significantly longer than those seen during training. Our models match the hardware efficiency of Transformers during training, and during inference they have lower latency and significantly higher throughput. We scale Griffin up to 14B parameters, and explain how to shard our models for efficient distributed training.
1. Introduction
Recurrent neural networks (RNNs) played a central role in the early days of deep learning and NLP research (Elman, 1990; Siegelmann and Sontag, 1991; Hochreiter and Schmidhuber, 1997; Mikolov et al., 2010; Bahdanau et al., 2014; Sutskever et al., 2014), and achieved practical success in many applications, including Google’s first end to end machine translation system (Wu et al., 2016). However in recent years, both deep learning and NLP have been dominated by the Transformer architecture (Vaswani et al., 2017), which interleaves multi-layer perceptrons (MLPs) and multi-head attention (MHA). Transformers achieve better performance than RNNs in practice and are also very efficient at utilizing modern hardware (Kaplan et al., 2020). Transformer-based large language models trained on massive datasets collected from the web have achieved remarkable success (Brown et al., 2020; Rae et al., 2021; Hoffmann et al., 2022; Touvron et al., 2023; Achiam et al., 2023; Gemini Team Google, 2023).
Despite their successes, Transformers are difficult to scale efficiently to long sequences due to the quadratic complexity of global attention. Additionally, the linear growth of the Key-Value (KV) cache with the sequence length makes Transformers slow during inference. Although Multi-Query Attention (MQA) (Shazeer, 2019) partially mitigates this issue by reducing the cache size by a constant factor, the cache still grows linearly in sequence length. Recurrent language models present a compelling alternative as they compress the entire sequence into a fixed-sized hidden state which is updated iteratively. However to replace Transformers, new RNN models must demonstrate not only comparable performance at scale but also achieve similar hardware efficiency (Gu et al., 2021a; Mehta et al., 2022; Smith et al., 2022; Orvieto et al., 2023b; Dao et al., 2022b; Poli et al., 2023; Gu and Dao, 2023).
In this work, we propose the RG-LRU layer, a novel gated linear recurrent layer, around which we design a new recurrent block to replace MQA. We build two new models using this recurrent block: Hawk, a model which interleaves MLPs with recurrent blocks, and Griffin, a hybrid model which interleaves MLPs with a mixture of recurrent blocks and local attention (Beltagy et al., 2020). We show that:
-
Hawk and Griffin exhibit power law scaling between held-out loss and training FLOPs, up to and beyond 7B parameters (Figure 1(a)), as previously observed for Transformers (Kaplan et al., 2020).
-
Griffin achieves slightly lower held-out loss than strong Transformer baselines at all model scales.
-
We overtrain Hawk and Griffin on 300B tokens at a range of model scales. Hawk-3B exceeds the reported performance of Mamba-3B (Gu and Dao, 2023) on downstream tasks, despite being trained on half as many tokens. Griffin-7B and Griffin-14B match the performance of Llama-2 (Touvron et al., 2023) despite being trained on roughly 7 times fewer tokens (Section 3.2).
-
Both Hawk and Griffin achieve comparable training efficiency to Transformers on TPU-v3. Since diagonal RNN layers are memory bound, we achieve this with a kernel for the RG-LRU layer, implemented in Pallas (Bradbury et al., 2018), that minimizes memory transfers (Section 4).
-
During inference, both Hawk and Griffin achieve significantly higher throughput than MQA Transformers (Figure 1(b)), and they achieve lower latency when sampling long sequences (Section 5).
-
Griffin performs better than Transformers when evaluated on sequences longer than those seen during training, and can also efficiently learn copying and retrieval tasks from training data (Section 6). However, Hawk and Griffin perform less well than Transformers when we evaluate pre-trained models on copying and exact-retrieval tasks without fine-tuning.
This paper is available on arxiv under CC BY 4.0 DEED license.