Home

Awesome

TRIME: Training Language Models with Memory Augmentation

This is the repository for our EMNLP2022 paper Training Language Models with Memory Augmentation, by Zexuan Zhong, Tao Lei, and Danqi Chen.

Updates

Quick links

Overview

<img src="images/method.png" width="600">

We propose a new training objective TRIME for language modeling, which aligns model outputs with both token embeddings and in-batch memories. We also devise novel ways for data batching and constructing training memories, so that our models can leverage long-range contexts and external datastore effectively.

Please find more details of this work in our paper.

Setup

Requirements and dependencies

The code is based on the following requirements/dependencies (we specify the version we used in our experiments in brackets):

You can install this project (based on Fairseq) as follow:

pip install --editable .

Datasets

We conduct experiments on the Wikitext-103 and enwik8 datasets. Please use get_data.sh to download and preprocess the datasets.

bash get_data.sh {wikitext-103 | enwik8}

The processed datasets will be stored in data-bin/wikitext-103 and data-bin/enwik8.

Run Pre-Trained Models

We show the examples of running pre-trained models on Wikitext-103 with model size = 247M and segment length = 3072. For other experiments (e.g., with different datasets or models), we refer to run_pretrained_models.md for the scripts on all experimental settings.

TrimeLM (local memory)

TrimeLM uses only the local memory (constructed using tokens in the input). It can be viewed as a lightweight replacement for vanilla langauge models.

# download the pre-trained TrimeLM
mkdir pretrained_models; cd pretrained_models
wget https://nlp.cs.princeton.edu/projects/trime/pretrained_models/wiki103-247M-trime.zip;
unzip wiki103-247M-trime.zip; rm -f wiki103-247M-trime.zip
cd ..

# run evaluation
python eval_lm-trime.py data-bin/wikitext-103 \
    --path pretrained_models/wiki103-247M-trime/checkpoint_best.pt \
    --sample-break-mode complete --max-tokens 3072 --context-window 2560 \
    --softmax-batch 1024 --gen-subset valid --fp16 \
    --max-sentences 1 --knn-keytype last_ffn_input \
    --use-local --softmax-temp 1.17

# the following output is expected:
# Loss (base 2): 4.0962, Perplexity: 17.10

Arguments:

TrimeLM_long (local + long-term memory)

TrimeLM_long uses local memory and long-term memory during inference. The model is able to leverage long contexts, although it is trained with shorter ones.

# download the pre-trained TRIME_long
mkdir pretrained_models; cd pretrained_models
wget https://nlp.cs.princeton.edu/projects/trime/pretrained_models/wiki103-247M-trime_long.zip;
unzip wiki103-247M-trime_long.zip; rm -f wiki103-247M-trime_long.zip
cd ..

# run evaluation
python eval_lm-trime.py data-bin/wikitext-103 \
    --path pretrained_models/wiki103-247M-trime_long/checkpoint_best.pt \
    --sample-break-mode complete --max-tokens 3072 --context-window 2560 \
    --softmax-batch 1024 --gen-subset valid --fp16 \
    --max-sentences 1 --knn-keytype last_ffn_input \
    --use-local --use-long --mem-size 12288 --softmax-temp 1.22

# the following output is expected:
# Loss (base 2): 4.0879, Perplexity: 17.01

Arguments:

TrimeLM_ext (local + long-term + external memory)

TrimeLM_ext uses local memory, long-term memory, and external memory. During inference, we run the model on the training set to build the external memory and use Faiss library to build index for retrieving top-K nearest neighbors the external memory. We also calibrate a separated distribution over the memory and interpolate the output distribution and the memory distribution, similarly to kNN-LM (see details in the paper).

We first download the pre-trained TrimeLM_ext:

mkdir pretrained_models; cd pretrained_models
wget https://nlp.cs.princeton.edu/projects/trime/pretrained_models/wiki103-247M-trime_ext.zip;
unzip wiki103-247M-trime_ext.zip; rm -f wiki103-247M-trime_ext.zip
cd ..

Then, we generate the external memory (keys and values) using the training set and then build the Faiss index:

MODEL_PATH=pretrained_models/wiki103-247M-trime_ext

# generate the external memory (keys and values) using the training set
python eval_lm.py data-bin/wikitext-103 \
    --path ${MODEL_PATH}/checkpoint_best.pt \
    --sample-break-mode none --max-tokens 3072 \
    --softmax-batch 1024 --gen-subset train \
    --context-window 2560 --tokens-per-sample 512 \
    --dstore-mmap ${MODEL_PATH}/dstore --knn-keytype last_ffn_input \
    --dstore-size 103224461 \
    --save-knnlm-dstore --fp16 --dstore-fp16


# build Faiss index
python build_dstore.py \
    --dstore_mmap ${MODEL_PATH}/dstore \
    --dstore_size 103224461 --dimension 1024 \
    --faiss_index ${MODEL_PATH}/knn.index \
    --num_keys_to_add_at_a_time 500000 \
    --starting_point 0  --dstore_fp16  --dist ip

Now, we are ready to evaluate the model:

MODEL_PATH=pretrained_models/wiki103-247M-trime_ext

python eval_lm-trime.py data-bin/wikitext-103 \
    --path ${MODEL_PATH}/checkpoint_best.pt \
    --sample-break-mode complete --max-tokens 3072 --context-window 2560 \
    --softmax-batch 1024 --gen-subset valid --fp16 \
    --max-sentences 1 --knn-keytype last_ffn_input \
    --use-local --use-long --mem-size 12288 --softmax-temp 1.25 \
    --use-external --dstore-filename ${MODEL_PATH}/dstore --indexfile ${MODEL_PATH}/knn.index.ip \
    --probe 32 --dstore-fp16 --faiss-metric-type ip --no-load-keys --k 1024 \
    --use-interp --interp-temp 10.5 --lmbda 0.3 

# the following output is expected:
# Loss (base 2): 3.9580, Perplexity: 15.54

Arguments:

Performance of pre-trained models

We list the performance of the released pre-trained models on Wikitext-103 and enwik8, as well as their download links.

DatasetModelDevTestHyper-parameters
Wikitext-103TrimeLM<br/>(247M, L=3072)17.1017.76--softmax-temp 1.17
Wikitext-103TrimeLM_long<br/>(247M, L=3072)17.0117.64--softmax-temp 1.22 --mem-size 12288
Wikitext-103TrimeLM_ext<br/>(247M, L=3072)15.5415.46--softmax-temp 1.25 --mem-size 12288 --interp-temp 10.5 --lmbda 0.3
Wikitext-103TrimeLM<br/>(150M, L=150)24.4525.61--softmax-temp 1.03
Wikitext-103TrimeLM_long<br/>(150M, L=150)21.7622.62--softmax-temp 1.07 --mem-size 15000
enwik8TrimeLM<br/>(38M, L=512)1.141.12--softmax-temp 1.05
enwik8TrimeLM_long<br/>(38M, L=512)1.081.05--softmax-temp 1.10 --mem-size 24576

Train TrimeLM

Trime loss functions

We follow Fairseq's training recipe (e.g., optimizer, learning rate, batch size) to train TrimeLM. Differently, we use our own loss functions (specified by --criterion) and data batching methods.

<img src="images/batching.png" width="600">

We trained three varieties of TrimeLM by using different data batching and memory construction methods.

Training scripts

Here is an example of training a TrimeLM_ext model. You can find all training scripts we used in our experiments in train_scripts.

We train our models on 4 NVIDIA RTX3090 GPUs.

# download the results of bm25 batching
wget https://nlp.cs.princeton.edu/projects/trime/bm25_batch/wiki103-l3072-batches.json -P data-bin/wikitext-103/

python train.py --task language_modeling data-bin/wikitext-103 \
    --save-dir output/wiki103-247M-trime_ext \
    --arch transformer_lm_wiki103 \
    --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
    --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \
    --criterion trime_ext_loss --max-tokens 3072 --update-freq 6 --tokens-per-sample 3072 --seed 1 \
    --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --knn-keytype last_ffn_input --fp16 \
    --ce-warmup-epoch 9 --cross-sent-ratio 0.9 \
    --predefined-batches data-bin/wikitext-103/wiki103-l3072-batches.json

Important arguments:

BM25 batching

When training the TrimeLM_ext model with --criterion trime_ext_loss, we use BM25 scores to batch training data.

We use the Pyserini library to build BM25 index. The library can be installed via pip.

pip install pyserini

We first save all the segments from the training set into a .json file.

mkdir -p bm25/wiki103-l3072/segments
CUDA_VISIBLE_DEVICES=0 python train.py --task language_modeling \
    data-bin/wikitext-103 \
    --max-tokens 6144 --tokens-per-sample 3072 \
    --arch transformer_lm_wiki103 \
    --output-segments-to-file bm25/wiki103-l3072/segments/segments.json

# Modify --tokens-per-sample for different segment lengths

Then, we build the BM25 index using Pyserini.

python -m pyserini.index.lucene \
  --collection JsonCollection \
  --input bm25/wiki103-l3072/segments \
  --index bm25/wiki103-l3072/bm25_index \
  --generator DefaultLuceneDocumentGenerator --threads 1 \
  --storePositions --storeDocvectors --storeRaw

Next, for each training segment, we search the similar segments using the BM25 index we built above.

python bm25_search.py \
    --index_path bm25/wiki103-l3072/bm25_index/ \
    --segments_path bm25/wiki103-l3072/segments/segments.json \
    --results_path bm25/wiki103-l3072/bm25_results

# Use --num_shards and --shard_id; you can parallel the computation of NN search (e.g., --num_shards 20).

Finally, based on the retrieval results, we create batches by group similar segments.

python bm25_make_batches.py \
    --results_path bm25/wiki103-l3072/bm25_results \
    --batch_file data-bin/wikitext-103/wiki103-l3072-batches.json

The output file wiki103-l3072-batches.json contains a list of indices of training segments and adjacent segments are likely to be similar.

The batch file wiki103-l3072-batches.json can be used during the training of TrimeLM_ext, with the argument --predefined-batches. During training, we simply get training batches by taking sub-lists sequencitally from the file.

Machine Translation

For machine translation code and experiments, please check out the subdirectory.

Bugs or Questions?

If you have any questions related to the code or the paper, or you encounter any problems when using the code, feel free to email Zexuan Zhong (zzhong@cs.princeton.edu) or open an issue. Please try to specify the problem with details so we can help you better and quicker!

Citation

If you use our code in your research, please cite our work:

@inproceedings{zhong2022training,
   title={Training Language Models with Memory Augmentation},
   author={Zhong, Zexuan and Lei, Tao and Chen, Danqi},
   booktitle={Empirical Methods in Natural Language Processing (EMNLP)},
   year={2022}
}

Acknowledgement

Our repo is based on the fairseq, knnlm, and adaptive-knn-mt projects. We thank the authors for open-sourcing the great code!