Recurrent Models Scale as Efficiently as Transformers

cover
13 Jan 2025

Authors:

(1) Soham De, Google DeepMind and with Equal contributions;

(2) Samuel L. Smith, Google DeepMind and with Equal contributions;

(3) Anushan Fernando, Google DeepMind and with Equal contributions;

(4) Aleksandar Botev, Google DeepMind and with Equal contributions;

(5) George Cristian-Muraru, Google DeepMind and with Equal contributions;

(6) Albert Gu, Work done while at Google DeepMind;

(7) Ruba Haroun, Google DeepMind;

(8) Leonard Berrada, Google DeepMind;

(9) Yutian Chen, Google DeepMind;

(10) Srivatsan Srinivasan, Google DeepMind;

(11) Guillaume Desjardins, Google DeepMind;

(12) Arnaud Doucet, Google DeepMind;

(13) David Budden, Google DeepMind;

(14) Yee Whye Teh, Google DeepMind;

(15) David Budden, Google DeepMind;

(16) Razvan Pascanu, Google DeepMind;

(17) Nando De Freitas, Google DeepMind;

(18) Caglar Gulcehre, Google DeepMind.

1 Introduction

2 Model Architecture

3 Recurrent Models Scale as Efficiently as Transformers

3.1. Scaling curves

3.2. Evaluation on downstream tasks

4 Training Recurrent Models Efficiently on Device and 4.1. Model parallelism for large scale training

4.2. Efficient linear recurrences on device

4.3. Training speed on longer sequences

5. Inference Speed

5.1. A simple model of the decode step

5.2. Results

6. Long Context Modeling and 6.1. Improving next token prediction with longer contexts

6.2. Copy and retrieval capabilities

7. Related Works

8. Conclusion, Acknowledgements, and References

A. RG-LRU Recurrence Gate

B. Complex-Gated Linear Recurrent Unit (CG-LRU)

C. Model Scale Hyper-Parameters

D. Efficient Linear Recurrences on Device

E. The Local Attention Window Size of Griffin

F. Inference Speeds

G. Improving Next Token Prediction with Longer Contexts: Additional Results

H. Additional Details of the Copy and Retrieval Tasks

3. Recurrent Models Scale as Efficiently as Transformers

Scaling studies provide important insights into how to tune the hyperparameters of a model and its behaviour at scale. Here, we define the models evaluated in our studies, and provide scaling curves up to and beyond 7B parameters. Finally, we assess the performance of our models on downstream tasks. We consider 3 model families in this work; (1) a MQA-Transformer baseline, (2) Hawk; our pure RNN model, and (3) Griffin; our hybrid model which mixes recurrent blocks with local attention. We define the key model hyper-parameters for models across a range of scales in Appendix C.

MQA Transformer baseline Our Transformer baseline uses the residual pattern and the gated MLP block described in Section 2, in combination with MQA (Shazeer, 2019) and RoPE (Su et al., 2021).

Griffin The key advantage of recurrent blocks over global attention is that they use a fixed state size to summarize the sequence, whereas the size of MQA’s KV cache grows proportional to sequence length. Since local attention (Section2.3) has the same property, mixing recurrent blocks with local attention preserves this benefit. We have found this combination extremely effective, since local attention accurately models the recent past, while the recurrent layers can transmit information across long sequences.

Griffin uses the same residual pattern and MLP block as our Transformer baseline. However unlike both our MQA Transformer baseline and the Hawk model, Griffin uses a mixture of recurrent blocks and MQA blocks. Specifically, we employ a layered structure by alternating two residual blocks with a recurrent block followed by one residual block which uses the local (MQA) attention block described in Section 2.3. Unless otherwise stated, the local attention window size is fixed to 1024 tokens.

This paper is available on arxiv under CC BY 4.0 DEED license.


[2] Note that we match parameters with MHA attention block, though our Transformer baseline and Griffin ended up relying on MQA attention in order to improve inference efficiency. This means that our recurrent blocks have slightly more parameters than the corresponding MQA blocks.