Home

Awesome

CsaNMT

PyTorch implementation of CsaNMT: "Learning to Generalize to More: Continuous Semantic Augmentation for Neural Machine Translation" by Xiangpeng Wei.

Requirements and Installation

Link to paper

Prepare Data

1. Get dataset: WMT14 En-De and WMT14 En-Fr

2. Preprocessed dataset

bash preprocess_dataset_for_nmt.sh

Train

bash train_alitranx_csanmt.sh

1. The first training phase to optimize the semantic encoder

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python train.py data-bin/$data_dir \
  --distributed-world-size 8 -s en -t de \
  --arch transformer_encoder_alitranx \
  --optimizer adam --adam-betas '(0.98, 0.998)' \
  --clip-norm 0.0 \
  --lr 1e-4
  --max-tokens 4096 \
  --max-epoch 20 \
  --log-interval 100 \
  --save-interval-updates 1000 \
  --keep-interval-updates 10 \
  --seed 1234

2. The second training phase to optimize the NMT model

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
nohup python train.py data-bin/$data_dir
  --distributed-world-size 8-s en -t de \
  --arch transformer_t2t_alitranx \
  --optimizer adam --adam-betas '(0.9, 0.997)' \
  --clip-norm 0.0 \ 
  --encoder-path $encoder_path \
  --share-decoder-input-output-embed \
  --lr-scheduler inverse_sqrt \
  --warmup-init-lr 1e-07 --warmup-updates 8000 \
  --lr 0.001 --min-lr 1e-09 --weight-decay 0.0 \
  --encoder-lr 0.0 \
  --reset-dataloader \
  --no-progress-bar \
  --criterion label_smoothed_cross_entropy \
  --label-smoothing 0.1 \
  --max-tokens 4096 \
  --max-epoch 10 \
  --update-freq 2 \ 
  --log-interval 100 \
  --save-interval-updates 1000 \
  --keep-interval-updates 10 \
  --ddp-backend no_c10d \ 
  --seed 1234 \
  --eval-bleu \
  --eval-bleu-args '{"beam": 5, "max_len_a": 1.2, "max_len_b": 10}' \
  --eval-bleu-detok moses \
  --eval-bleu-remove-bpe \
  --best-checkpoint-metric bleu --maximize-best-checkpoint-metric

3. The third training phase to optimize both the NMT model and the semantic encoder

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
nohup python train.py data-bin/$data_dir
  ...
  --encoder-lr 1e-5 \
  ...

Results

ModelEn-DeEn-Fr
Transformer (base)27.238.9
CsaNMT (base)29.340.6
Transformer (big)28.140.7
CsaNMT (big)29.842.1

--SacreBLEU Signature: nrefs:1|case:mixed|eff:no|tok:13a|smooth:exp|version:2.0.0

Citation

Please cite as:

@inproceedings{wei-etal-2022-learning,
  title = {Learning to Generalize to More: Continuous Semantic Augmentation for Neural Machine Translation},
  author = {Xiangpeng Wei and Heng Yu and Yue Hu and Rongxiang Weng and Weihua Luo and Rong Jin},
  booktitle = {Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics, ACL 2022},
  year = {2022},
}