Awesome
Latent Alignment and Variational Attention
This is a Pytorch implementation of the paper Latent Alignment and Variational Attention from a fork of OpenNMT.
Dependencies
The code was tested with python 3.6
and pytorch 0.4
.
To install the dependencies, run
pip install -r requirements.txt
Running the code
All commands are in the script va.sh
.
Preprocessing the data
To preprocess the data, run
source va.sh && preprocess_bpe
The raw data in data/iwslt14-de-en
was obtained from the
fairseq repo
with BPE_TOKENS=14000
.
Training the model
To train a model, run one of the following commands:
- Soft attention
source va.sh && CUDA_VISIBLE_DEVICES=0 train_soft_b6
- Categorical attention with exact evidence
source va.sh && CUDA_VISIBLE_DEVICES=0 train_exact_b6
- Variational categorical attention with exact ELBO
source va.sh && CUDA_VISIBLE_DEVICES=0 train_cat_enum_b6
- Variational categorical attention with REINFORCE
source va.sh && CUDA_VISIBLE_DEVICES=0 train_cat_sample_b6
- Variational categorical attention with Gumbel-Softmax
source va.sh && CUDA_VISIBLE_DEVICES=0 train_cat_gumbel_b6
- Variational categorical attention using Wake-Sleep algorithm (Ba et al 2015)
source va.sh && CUDA_VISIBLE_DEVICES=0 train_cat_wsram_b6
Checkpoints will be saved to the project's root directory.
Evaluating on test
The exact perplexity of the generative model can be obtained by running
the following command with $model
replaced with a saved checkpoint.
source va.sh && CUDA_VISIBLE_DEVICES=0 eval_cat $model
The model can also be used to generate translations of the test data:
source va.sh && CUDA_VISIBLE_DEVICES=0 gen_cat $model
sed -e "s/@@ //g" $model.out | perl tools/multi-bleu.perl data/iwslt14-de-en/test.en
Trained Models
Models with the lowest validation PPL were selected for evaluation on test. Numbers are slightly different from those reported in the paper since this is a re-implementation.
Model | Test PPL | Test BLEU |
---|---|---|
Soft Attention | 7.17 | 32.77 |
Exact Marginalization | 6.34 | 33.29 |
Variational Attention + Enumeration | 6.08 | 33.69 |
Variational Attention + Sampling | 6.17 | 33.30 |