Home

Awesome

Latent Alignment and Variational Attention

This is a Pytorch implementation of the paper Latent Alignment and Variational Attention from a fork of OpenNMT.

Dependencies

The code was tested with python 3.6 and pytorch 0.4. To install the dependencies, run

pip install -r requirements.txt

Running the code

All commands are in the script va.sh.

Preprocessing the data

To preprocess the data, run

source va.sh && preprocess_bpe

The raw data in data/iwslt14-de-en was obtained from the fairseq repo with BPE_TOKENS=14000.

Training the model

To train a model, run one of the following commands:

source va.sh && CUDA_VISIBLE_DEVICES=0 train_soft_b6
source va.sh && CUDA_VISIBLE_DEVICES=0 train_exact_b6
source va.sh && CUDA_VISIBLE_DEVICES=0 train_cat_enum_b6
source va.sh && CUDA_VISIBLE_DEVICES=0 train_cat_sample_b6
source va.sh && CUDA_VISIBLE_DEVICES=0 train_cat_gumbel_b6
source va.sh && CUDA_VISIBLE_DEVICES=0 train_cat_wsram_b6

Checkpoints will be saved to the project's root directory.

Evaluating on test

The exact perplexity of the generative model can be obtained by running the following command with $model replaced with a saved checkpoint.

source va.sh && CUDA_VISIBLE_DEVICES=0 eval_cat $model

The model can also be used to generate translations of the test data:

source va.sh && CUDA_VISIBLE_DEVICES=0 gen_cat $model
sed -e "s/@@ //g" $model.out | perl tools/multi-bleu.perl data/iwslt14-de-en/test.en

Trained Models

Models with the lowest validation PPL were selected for evaluation on test. Numbers are slightly different from those reported in the paper since this is a re-implementation.

ModelTest PPLTest BLEU
Soft Attention7.1732.77
Exact Marginalization6.3433.29
Variational Attention + Enumeration6.0833.69
Variational Attention + Sampling6.1733.30