ICML'20 Improving Transformer Optimization Through Better Initialization
Authors: Xiao Shi Huang, Felipe Perez, Jimmy Ba, Maksims Volkovs
<a name="intro"/>Introduction
This repository contains a full implementation of the T-Fixup algorithm implemented with the fairseq library, and includes both training and evaluation routines on the IWSLT'14 De-En dataset.
T-Fixup was used by Javier Martin and Andres Torrubia in their 3'rd place solution (out of 3395 teams) for the "Riiid Answer Correctness Prediction" Kaggle challenge. See this blogpost.
<a name="env"/>Environment
The python code is developed and tested on the following environment:
- Python 3.7
- Pytorch 1.2.0
Experiments on IWSLT'14 De-En and En-De datasets were run on NVIDIA V100 GPU with 32GB GPU memory; all other experiments were run on an IBM server with 160 POWER9 CPUs, 600GB RAM and 4 Tesla V100 GPUs
<a name="dataset"/>Dataset
The example execution script
builds the IWSLT'14 De-En dataset; for the WMT'14 En-De and WMT'17 En-De datasets refer to the fairseq's instructions here
Running The Code
- (Optionally) launch tensorboard to monitor progress by
tensorboard --logdir=<log_path>
This script runs the small 512-1024-4 Transformer encoder-decoder model (see paper for details) with both layer normalization and learning rate warmup removed. Starting learning rate is set to the post warmup value of 0.0005 (vs 1e-07 with warmup). By default all avialable GPUs are used, but parameters such as batchsize are set for for 1 GPU. If multiple GPUs are avaialbe, either point the script to only one GPU or adjust model parameters accordingly.
Validation Curves
Training and validation loss curves for a Transformer model trained with T-Fixup on the IWSLT'14 De-En dataset during the first 300 epochs. One epoch is around 1100 updates and we checkpoint the model after each epoch.
<p align="center"> <img src="" width="500"> </p> BLEU score, evaluated using the average of 10 checkpoints, reaches 35.73 at epochs 278-287