Home

Awesome

RetGen: A Joint framework for Retrieval and Grounded Text Generation Modeling

This repository contains the source code and trained model for "Joint Retrieval and Generation Training for Grounded Text Generation". RetGen is a joint training framework that simultaneously optimizes a dense passage retriever and a knowledge-grounded text generator in an end-to-end fashion. It can be applied to scenarios including but not limited to conversational modeling, text generation and open-domain question answering. The code implementation is based on DialoGPT, Huggingface Transformers, DPR and ANCE. Our human evaluation results indicates that RetGen can generate more relevant, interesting and human-like text comparing to vanilla DialoGPT or GPT-2.

Screenshot Figure: RetGen overview.

If this repo is helpful to your research, please cite our paper:

@article{zhang2021joint,
  title={Joint Retrieval and Generation Training for Grounded Text Generation},
  author={Zhang, Yizhe and Sun, Siqi and Gao, Xiang and Fang, Yuwei and Brockett, Chris and Galley, Michel and Gao, Jianfeng and Dolan, Bill},
  journal={arXiv preprint arXiv:2105.06597},
  year={2021}
}

Enviroment

Conda

For cuda 10.0, run

conda env create -f RetGen.yml
conda activate RetGen
conda install pytorch=1.4.0 torchvision cudatoolkit=10.0 -c pytorch

, then install apex by (download apex to somewhere else)

git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

For cuda 10.1, simply run

conda install pytorch=1.5.0 torchvision cudatoolkit=10.1 -c pytorch

instead of

conda install pytorch=1.4.0 torchvision cudatoolkit=10.0 -c pytorch

Next, install Fairseq in somewhere else

git clone https://github.com/pytorch/fairseq
cd fairseq
pip install --editable ./

Docker

Container was built by

docker build -f dockerfile.txt -t gdpt .

Activate container by

docker run --gpus all --ipc=host --rm -it --mount src=/your_source_code_dir,dst=/code,type=bind --mount src=/gdpt,dst=/gdpt,type=bind intersun/gdpt

Preprocessing

The training corpus needs to be first compressed in to *.db file using command below

python dialogpt/prepro.py --corpus data/train.doc_ctx_rsp.txt --max_seq_len 512

Training

Example training command for reddit data with 8 GPUs (each with 32GB VRAM)

python -m torch.distributed.launch --nproc_per_node=8 joint_training.py
     --model_name_or_path configs
     --init_checkpoint models/reddit_generator.pkl
     --train_input_file data/reddit_train.db
     --eval_input_file data/reddit_test.txt 
     --output_dir output/joint_reddit
     --file_suffix joint_reddit
     --train_batch_size 4
     --gradient_accumulation_steps 2
     --eval_batch_size 2
     --num_optim_steps 16000
     --encoder_model_type ance_roberta
     --pretrained_model_cfg bert-base-uncased
     --model_file models/reddit_retriever.pkl
     --ctx_file data/wiki.txt
     --num_shards 8
     --batch_size 128
     --n_docs 4
     --encoding
     --load_trained_model

For single card, please use following command

CUDA_VISIBLE_DEVICES=0 python joint_training.py
     --model_name_or_path configs
     --init_checkpoint models/reddit_generator.pkl
     --train_input_file data/reddit_train.db
     --eval_input_file data/reddit_test.txt 
     --output_dir output/joint_reddit
     --file_suffix joint_reddit
     --train_batch_size 2
     --gradient_accumulation_steps 2
     --eval_batch_size 2
     --num_optim_steps 16000
     --encoder_model_type ance_roberta
     --pretrained_model_cfg bert-base-uncased
     --model_file models/reddit_retriever.pkl
     --ctx_file data/wiki.txt
     --num_shards 1
     --batch_size 128
     --n_docs 2
     --encoding
     --load_trained_model

Model checkpoints

We release model checkpoints which can be directly used or further fine-tuned with customized dataset.

ModelRedditarXiv
RetGen generator[link][link]
RetGen retriever[link][link]

For the generator, you can find the corresponding configuration files (merges.txt, config.json, vocab.json) in ./configs/*.

Data

The preprocessed wikipedia dump used in our work can be downloaded as in below. We also provide the raw text data for arXiv. For Reddit, due to copyright issue. We will release a script to automatically extract the training development and test data.

DataLink
Wiki (2.5GB)[link]
arXiv (2.4GB)[link]
RedditTBD

Inference

We note that even with properly filtered Reddit dataset, sometimes our model can still generate moderately toxic/inappropriate responses. Due to this reason, we are unable to provide the inference code at this time. We are currently working on a controlled decoding method to prevent this system from toxic generation. Please stay tuned.

Evaluation

Generation Evaluation

Please follow dialoGPT evaluation script in dialogpt/README.md.

Retrieval Evaluation

We provide example for evaluation retriever. The evaluation is done by estimating the recall@10,30,50 over 10K samples. The required files can be downloaded here

CUDA_VISIBLE_DEVICES=0 python eval_checkpoint.py \
        --eval_mode rank \
        --encoder_model_type ance_roberta \
        --pretrained_model_cfg bert-base-uncased \
        --model_file models/reddit_retriever.pkl \
        --qa_file data/2k_positive.txt \
        --ctx_file data/10k.txt \
        --n_docs 50 \
        --batch_size 64 \
        --shard_id 0 \
        --num_shards 1 \
        --load_trained_model \
        --encoding

If running correctly, the system should output results as in below:

Validation results: recall@10:0.617
Validation results: recall@30:0.703
Validation results: recall@50:0.744

<a name="human_eval"></a>Human evaluation

We further conduct human evaluations (500 examples for each methods, each example is evaluated by 3 human judges). The results show a strong evidence that our generation quality is better than vanilla DialoGPT/GPT-2, under this non-interactive Turing test:

Coherence:A and B, which is more relevant to, and coherent with the context?

DatasetSystem AA Wins (%)Ties (%)B Wins (%)System B
RedditRetGen 345M43.728.328.0DialoGPT 345M
arXivRetGen 345M32.141.726.3GPT-2 345M

Informativeness: A and B, which is more informative (usually more specific content)?

DatasetSystem AA Wins (%)Ties (%)B Wins (%)System B
RedditRetGen 345M44.527.827.7DialoGPT 345M
arXivRetGen 345M36.337.226.5GPT-2 345M

Human-likeness: A and B, which is more likely to be generated by human rather than a machine?

DatasetSystem AA Wins (%)Ties (%)B Wins (%)System B
RedditRetGen 345M36.434.029.6DialoGPT 345M
arXivRetGen 345M29.743.626.7GPT-2 345M

Related Project

Contact

Please contact DialoGPT@microsoft.com if you have any questions/suggestions. However, the response will be sporadic. Please expect delay.

Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution.

Disclaimer

This repository aims to facilitate research in large-scale pretraining for conversational data. This toolkit contains only part of the modeling machinery needed to actually produce a model weight file in a running dialog. On its own, this model provides only information about the weights of various text spans; in order for a researcher to actually use it, they will need to bring conversational data of their own and decode the response generation from the pretrained system. We are not responsible for any generation from the 3rd party utilization of the pretrained system.

Citation

If this repo is helpful to your research, please cite our paper:

@article{zhang2021joint,
  title={Joint Retrieval and Generation Training for Grounded Text Generation},
  author={Zhang, Yizhe and Sun, Siqi and Gao, Xiang and Fang, Yuwei and Brockett, Chris and Galley, Michel and Gao, Jianfeng and Dolan, Bill},
  journal={arXiv preprint arXiv:2105.06597},
  year={2021}
}