Authors:
(1) Soham De, Google DeepMind and with Equal contributions;
(2) Samuel L. Smith, Google DeepMind and with Equal contributions;
(3) Anushan Fernando, Google DeepMind and with Equal contributions;
(4) Aleksandar Botev, Google DeepMind and with Equal contributions;
(5) George Cristian-Muraru, Google DeepMind and with Equal contributions;
(6) Albert Gu, Work done while at Google DeepMind;
(7) Ruba Haroun, Google DeepMind;
(8) Leonard Berrada, Google DeepMind;
(9) Yutian Chen, Google DeepMind;
(10) Srivatsan Srinivasan, Google DeepMind;
(11) Guillaume Desjardins, Google DeepMind;
(12) Arnaud Doucet, Google DeepMind;
(13) David Budden, Google DeepMind;
(14) Yee Whye Teh, Google DeepMind;
(15) David Budden, Google DeepMind;
(16) Razvan Pascanu, Google DeepMind;
(17) Nando De Freitas, Google DeepMind;
(18) Caglar Gulcehre, Google DeepMind.
Table of Links
3 Recurrent Models Scale as Efficiently as Transformers
3.2. Evaluation on downstream tasks
4.2. Efficient linear recurrences on device
4.3. Training speed on longer sequences
5.1. A simple model of the decode step
6. Long Context Modeling and 6.1. Improving next token prediction with longer contexts
6.2. Copy and retrieval capabilities
8. Conclusion, Acknowledgements, and References
B. Complex-Gated Linear Recurrent Unit (CG-LRU)
C. Model Scale Hyper-Parameters
D. Efficient Linear Recurrences on Device
E. The Local Attention Window Size of Griffin
G. Improving Next Token Prediction with Longer Contexts: Additional Results
H. Additional Details of the Copy and Retrieval Tasks
7. Related Works
The Transformer architecture has become a more scalable alternative to RNNs. Transformers achieve superior scalability through fully parallelized training, contrasting with the inherent limitations of RNNs. Due to their sequential processing structure, classical RNNs suffer from slow training speeds during both forward and backward propagation (Werbos, 1990). To mitigate this issue, researchers have explored alternative RNN-based methods. Notable examples include Quasi-RNNs (Bradbury et al., 2016), which combine convolutions and linear RNNs for greater parallelization, and the use of input-based gating mechanisms to parallelize linear RNN training (Martin and Cundy, 2017).
State-space Models (SSMs) have recently emerged as a powerful tool for modeling long input sequences. They demonstrated strong performance on tasks from the long-range arena benchmark (Tay et al., 2020), and audio generation (Goel et al., 2022). SSMs successfully integrate concepts from classical state-space models (Kalman, 1960) with those of RNNs. Their reliance on linear recurrences allows for efficient hidden state computation, either through parallel scan operations or convolutions, resulting in training speeds comparable to Transformer models. The S4 (Gu et al., 2021a) model proposed a sophisticated parameterization called normal plus low-rank to diagonalize the recurrence computation. The S4D parametrized the SSM directly with a diagonal state matrix and showed that it performed just as well while being much simpler (Gu et al., 2022). S5 also diagonalized the recurrence, and showed that the recurrence can be computed using the associative scan (Smith et al., 2022). The H3 model (Dao et al., 2022b) generalizes the recurrent interpretation of linear attention (Katharopoulos et al., 2020). Hyena (Poli et al., 2023) uses a similar architecture, but replaces the S4D layer with a global convolution kernel parametrized by an MLP. RetNet (Sun et al., 2023) uses a simpler SSM design with a gating mechanism which allows them to parallelize the computation using a variant of multi-head attention. Orvieto et al. (2023b) systematically analyzed and ablated multiple modifications to standard RNNs. Their finding showed that through better parameterization and initialization simplified linear RNNs (the LRU), perform just as well as other SSMs variants on various long-range tasks. RWKV (Peng et al., 2023) is a recent RNN, shown to be competitive on language modeling tasks, based on another linear attention approximation inspired by the attention-free Transformer (Zhai et al., 2021). Concurrent to our work Gu and Dao (2023) developed an SSM architecture called Mamba with an input dependant selection mechanism and showed that it achieves performance comparable to Transformers with efficient inference. Several extensions of Mamba have been proposed (Wang et al., 2024; Zhu et al., 2024) for different applications. An input-dependent gating similar to Mamba was also proposed by Gateloop (Katsch, 2023).
Linear attention (Katharopoulos et al., 2020) offers a computationally efficient approximation of the self-attention mechanism by linearizing the attention, which can be computed recurrently as a linear RNN. While this approach significantly reduces computational cost compared to full attention, it often comes with a trade-off in model performance. Flash Attention (Dao et al., 2022a) improves the training speed of attention on GPUs by making efficient use of the memory hierarchy. Another approach to reducing the computational cost of global attention, which is becoming increasingly more popular, is using sparse-local attention (Child et al., 2019) or sliding window attention (Jiang et al., 2023).
This paper is available on arxiv under CC BY 4.0 DEED license.