Awesome
LMEDR <img src="https://pytorch.org/assets/images/logo-dark.svg" width = "90" />
Code for AAAI 2023 paper: Learning to Memorize Entailment and Discourse Relations for Persona-Consistent Dialogues.
Requirements
Check the package requirements
- python==3.8
- torch==1.9.1
- transformers==4.14.1
- pytorch-ignite==0.4.9
Please install ParlAI, which can be done in the following ways
git clone https://github.com/Chenrj233/ParlAI.git
cd ParlAI
python setup.py install
Please replace eval_f1.py
and eval_hits.py
in /ParlAI/projects/convai2/
with the corresponding files in /other/
. Similarly, replace the generation_utils.py
in transformers/
with the corresponding files in /other/
, the file is in a path similar to
| -- python3.8
| -- site-packages
| -- transformers
| -- modeling_utils.py
| -- generation_utils.py
| -- ...
Data
The datasets used in the paper can be obtained from the following link:
Training
-
PersonaChat
Use the following script to train on the PersonaChat original dataset. If you want to train on the revised dataset, please add
--revised
python train_PersonaChat.py --lr 8e-6 \
--epochs 20 \
--train_batch_size 2 \
--valid_batch_size 2 \
--infer_batch_size 64
-
DSTC7-AVSD
For training on DSTC7-AVSD, it can be run as
python train_dstc.py --lr 8e-6 \
--epochs 20 \
--train_batch_size 2 \
--valid_batch_size 2 \
--infer_batch_size 10
Evaluation
-
PersonaChat
Model checkpoints can be obtained from persona_original, persona_revised.
- Hits@1
python evaluation_PersonaChat.py --model_checkpoint persona_original \
--eval_type hits@1
- F1
python evaluation_PersonaChat.py --model_checkpoint persona_original \
--eval_type f1 \
--beam 2 \
--max_history 7
- PPL
python train_PersonaChat.py --load_from persona_original \
--eval
-
C.Score
Please refer to PAML.
-
DSTC7-AVSD
First, we use
dstc_generate.py
to generate the predicted response, and then use dstc7avsd_eval to evaluate,model checkpoint can be obtained from dstc_model.
python dstc_generate.py --load_from dstc_model \
--beam 5
Results
We also provide the final generated texts, which can be found in /results/
.