Efficient Training: Scaling Griffin Models for Large-Scale AI on TPUs

cover
14 Jan 2025

Authors:

(1) Soham De, Google DeepMind and with Equal contributions;

(2) Samuel L. Smith, Google DeepMind and with Equal contributions;

(3) Anushan Fernando, Google DeepMind and with Equal contributions;

(4) Aleksandar Botev, Google DeepMind and with Equal contributions;

(5) George Cristian-Muraru, Google DeepMind and with Equal contributions;

(6) Albert Gu, Work done while at Google DeepMind;

(7) Ruba Haroun, Google DeepMind;

(8) Leonard Berrada, Google DeepMind;

(9) Yutian Chen, Google DeepMind;

(10) Srivatsan Srinivasan, Google DeepMind;

(11) Guillaume Desjardins, Google DeepMind;

(12) Arnaud Doucet, Google DeepMind;

(13) David Budden, Google DeepMind;

(14) Yee Whye Teh, Google DeepMind;

(15) David Budden, Google DeepMind;

(16) Razvan Pascanu, Google DeepMind;

(17) Nando De Freitas, Google DeepMind;

(18) Caglar Gulcehre, Google DeepMind.

1 Introduction

2 Model Architecture

3 Recurrent Models Scale as Efficiently as Transformers

3.1. Scaling curves

3.2. Evaluation on downstream tasks

4 Training Recurrent Models Efficiently on Device and 4.1. Model parallelism for large scale training

4.2. Efficient linear recurrences on device

4.3. Training speed on longer sequences

5. Inference Speed

5.1. A simple model of the decode step

5.2. Results

6. Long Context Modeling and 6.1. Improving next token prediction with longer contexts

6.2. Copy and retrieval capabilities

7. Related Works

8. Conclusion, Acknowledgements, and References

A. RG-LRU Recurrence Gate

B. Complex-Gated Linear Recurrent Unit (CG-LRU)

C. Model Scale Hyper-Parameters

D. Efficient Linear Recurrences on Device

E. The Local Attention Window Size of Griffin

F. Inference Speeds

G. Improving Next Token Prediction with Longer Contexts: Additional Results

H. Additional Details of the Copy and Retrieval Tasks

4. Training Recurrent Models Efficiently on Device

We encountered two main engineering challenges when developing and scaling our models. First, how to efficiently shard our models across multiple devices. Second, how to efficiently implement linear recurrences to maximize training efficiency on TPUs. We address both of these challenges in this section, before providing an empirical comparison of the training speed of Griffin and our MQA baseline.

4.1. Model parallelism for large scale training

As our model increases in size, we cannot fit the model on a single device during training, even with a batch size of 1 per-device. We therefore use model parallelism to shard our large models across devices during training. Since communication costs across different training devices are expensive, efficiently sharding the model is critical for fast training at scale.

MLP and MQA block For our gated-MLP block we use Megatron-style sharding (Shoeybi et al., 2019), which requires a single all-reduce operation in both the forward and the backward pass. Similarly, we apply the same strategy to the linear layers in the attention block, and additionally shard the attention mechanism over its heads (Narayanan et al., 2021).

Other considerations Optimizer states can consume significant memory, exceeding the size of the model parameters themselves. To address this, we employ ZeRO parallelism (Rajbhandari et al., 2020), distributing both optimizer states and model parameters across the batch shards. We also use bfloat16 representation for model parameters and activations, minimizing any data transfer overhead.

This paper is available on arxiv under CC BY 4.0 DEED license.