The Stability-Efficiency Dilemma: Investigating Sequence Length Warmup for Training GPT Models

  • Conglong Li ,
  • Minjia Zhang ,
  • Yuxiong He

2022 Neural Information Processing Systems |

Recent works have demonstrated great success in pre-training large-scale autoregressive language models on massive GPUs. To reduce the wall-clock training time, a common practice is to increase the batch size and learning rate. However, such practice is often brittle and leads to a so-called stability-efficiency dilemma: increasing the batch sizes and learning rates leads to better training efficiency but can also result in training instability, leading to poor generalization accuracy or failed runs. To better understand this phenomenon, we conduct an in-depth analysis on large-scale pre-training experiments replicating the GPT-2 model. We find that there is a strong correlation between training instability and extreme values of gradient variance, and that samples with long sequence lengths contribute to these extreme gradient variance values, especially at the beginning of the training, indicating that long sequence length can be a main source of training instability. Based on the analysis, we present a Sequence Length Warmup method that aims to solve the training stability-efficiency dilemma. Experiments replicating GPT-2 models show that our approach enables stable training with 8x larger batch size and 4x larger learning rate, whereas the baseline approach struggles with training instability. To achieve the same or better zero-shot evaluation results, our method reduces the required number of training tokens and wall clock time by up to 2.2x and 3.7x, respectively. Experiments replicating GPT-3 model (125M) show that our approach enables stable training with 8x larger batch size and 40x larger learning rate, and retains 99% of the zero-shot accuracy on 11 tasks using 10x less data and 17x less time compared to the original GPT-3 training recipe, while the baseline diverges under the same settings and only retain 95% of accuracy under lower learning rate.