Awesome
Introduction
This repository contains the official implementation accompanying our
- EMNLP'21 Findings paper R2-D2: A Modular Baseline for Open-Domain Question Answering
- preprint Pruning the Index Contents for Memory Efficient Open-Domain QA.
The sources present in this repository can be used to train new models. Please note our paper is accompanied with two repositories. If you are interested in run model inference in pipeline instead, check the R2-D2-pipeline repository.
If you use R2-D2, please cite our paper:
@inproceedings{fajcik-etal-2021-r2-d2,
title = "{R2-D2}: {A} {M}odular {B}aseline for {O}pen-{D}omain {Q}uestion {A}nswering",
author = "Fajcik, Martin and
Docekal, Martin and
Ondrej, Karel and
Smrz, Pavel",
booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2021",
month = nov,
year = "2021",
address = "Punta Cana, Dominican Republic",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2021.findings-emnlp.73",
pages = "854--870",
abstract = "This work presents a novel four-stage open-domain QA pipeline R2-D2 (Rank twice, reaD twice). The pipeline is composed of a retriever, passage reranker, extractive reader, generative reader and a mechanism that aggregates the final prediction from all system{'}s components. We demonstrate its strength across three open-domain QA datasets: NaturalQuestions, TriviaQA and EfficientQA, surpassing state-of-the-art on the first two. Our analysis demonstrates that: (i) combining extractive and generative reader yields absolute improvements up to 5 exact match and it is at least twice as effective as the posterior averaging ensemble of the same models with different parameters, (ii) the extractive reader with fewer parameters can match the performance of the generative reader on extractive QA datasets.",
}
If you use our corpus pruning approach from our pre-print, please cite our preprint:
@article{fajcik2021pruning,
title={{P}runing the {I}ndex {C}ontents for {M}emory {E}fficient {O}pen-{D}omain {QA}},
author={Fajcik, Martin and Docekal, Martin and Ondrej, Karel and Smrz, Pavel},
journal={arXiv preprint arXiv:2102.10697},
year={2021}
}
Table of Contents
Prerequisites
Installation
Set your system's locale.
export LANG=en_US.UTF-8 LC_ALL=en_US.UTF-8
Install this package using python3.6.
git clone https://github.com/KNOT-FIT-BUT/scalingQA.git
cd scalingQA; python -m pip install -r requirements.txt; python setup.py install
Data
Datasets
Following hyperlinks contain preprocessed datasets required for training
- NQ-open processed via DPR retriever (for reranker/reader training)
- TriviaQA-open processed via DPR multiset retriever (for reranker/reader training)
These files were created using DPR retrieval over all 21M passages of Wikipedia.
Additionaly, we also release original files we used as inputs to DPR to generate the preprocessed datasets for readers and reranker.
If you would like to process your custom data, follow "Retrieving the results via DPR (optional)" guide at the end of this README.
Index
SQLite database of 21M passages is
available here.
Embedding matrix for full 21M passages trained on NQ is
available here.
Embedding matrix for full 21M passages trained on multiple datasets (used for Trivia experiments in the paper) is
available here.
Training R2-D2 models
Passage Reranker
Data Pre-processing
The datasets mentioned above comprise a set of the best-retrieved passages and one ground truth passage if it exists. For several samples, no retriever passage contains an answer, and the ground truth is unknown. Those samples should be removed from reranker training data using the following command:
grep -v '"gt_index": -1, "hit_rank": -1,' [INPUT] > [FILTERED_OUTPUT]
Training the Model
The scripts for passage reranker training can be found in the folder scalingqa/reranker/training
. See help for more information about training configuration.
python -m scalingqa.reranker.training.train_reranker --help
Our results should be easily replicable using several ready-made scripts, e.g. for the NQ dataset:
python -m scalingqa.reranker.training.train_reranker_nq
Note that the GPU with at least that 12 GB of RAM (tested on GeForce RTX 2080Ti) is required for training.
Reranker Outside the Pipeline
The passage ranker can be run separately on input in the same format as training data. See help for more information:
python -m scalingqa.reranker.run_reranker --help
Extractive Reader
Data Pre-processing
The extractive reader always expects at least one answer span per a training example. To ensure this run:
python -m scalingqa.extractivereader.run_extractive_reader_filter your_config.py
The filtering script can be configured. An example of a configuration file for the filter is:
toy_examples/extractive_reader/filter_dataset/run_config.py
Training the Model
To train you own model use:
python -m scalingqa.extractivereader.run_extractive_reader_train your_config.py
An example of a configuration file for the training script is:
toy_examples/extractive_reader/train/run_config.py
If you want to learn more about the usage of our scripts, read descriptions in configuration files. There are also ready to run toy examples in
toy_examples/extractive_reader/
Replicate
To replicate training of our model for NaturalQuestions-Open run:
./scalingqa/extractivereader/replicate/replicate_nq.sh
for TriviaQA-Open:
./scalingqa/extractivereader/replicate/replicate_trivia.sh
The scripts expect that all data files are already in the .data folder (see configurations in scalingqa/extractivereader/replicate). They also run the filtering.
Generative Reader
Training the Model
The run-files for replicating our results on NQ and Trivia are available in
folder scalingqa/generative_reader/training
. To run the training, adjust the config
dictionary right inside the file
(you will probably want to set the paths to your data and to output directories).
config = {
"save_dir": ".saved", # where the checkpoints will be saved
"results": ".results", # where validation results will be saved
"validate_after_steps": 500, # validation period, divided by 2 after 2/3 of training
###############################
# Data
###############################
"data_cache_dir": ".data/reader/NQ/ranked/", # where the preprocessed datafiles will be cached
"train_data": ".data/reader/NQ/ranked/NQ-open_TRAINING_maxlen_5_ms_with_dpr_annotation.jsonl_dpr_official_nqsingle_of_impossible.jsonl",
"val_data": ".data/reader/NQ/ranked/NQ-open_DEV_maxlen_5_ms_with_dpr_annotation.json_dpr_official_nqsingle_of_impossible.jsonl",
"test_data": ".data/reader/NQ/ranked/NQ-open_TEST.jsonl_nq-open_dpr_official_nqsingle_of_impossible.jsonl",
"pass_database": ".index/wiki2018_dpr_blocks.db", # database of passages and titles
# number of passages encoded from mini-batch
# for training dataset there is always the ground truth passage and the rest is filled with the others recommended by retriever
# for validation dataset only the passages from retriever are used
"context_length": 25, # number of passages at the input of FiD
# ...
}
Afterwards simply run the module to e.g. replicate the results of FiD-large on NQ
python -m scalingqa.generative_reader.training.run_train_fid_large_nq
Note that training is expected to run with on-hardware-batch size 1. FiD-large on NQ takes about 9 days to converge on the single RTX 8000 48GB GPU.
Common Use-Cases
To evaluate some checkpoint on the test data, add its path into config
dictionary under
"pre_initialize"
key and set "test_only"
to True:
config = {
"pre_initialize": PATH_TO_CHECKPOINT,
"test_only": True,
# ...
}
To resume training from some checkpoint, use "resume_training"
and
"resume_checkpoint"
in analogously to previous example.
config = {
"resume_checkpoint": PATH_TO_CHECKPOINT,
"resume_training": True,
# ...
}
You can also train system in mixed precision (see flag "fp16"
). Note that while the system seems to converge after
initial updates, we have never fully trained it, and thus cannot guarantee that it works as intended.
To "try out, if it works", you can try out toy-example run-file run_train_fid_base_nq_toy.py
, which runs the
FiD-base training using just 2 retrieved passages (runs on 12 GB GPU).
Exporting the Checkpoint for R2-D2 Pipeline
To use the trained checkpoint in R2-D2 pipeline, the checkpoint needs to be resctructured so it contains just a state
dictionary and a model configuration. This can be done via
script scalingqa/generative_reader/training/export_checkpoint.py
.
python -m scalingqa.generative_reader.training.export_checkpoint INPUT_FILE OUTPUT_FILE [fp16]
You can use option fp16
to save checkpoint in 16-bit precision.
Retrieving the Data via DPR (Optional) <a name="retrievingviadpr"></a>
Here we describe how to process your custom dataset which follows the same format
as NQ-open
or TriviaQA-open via retriever.
Firstly, you will need to adjust the configuration in scalingqa/retriever/extract_DPR_predictions.py
script. You will
need to change the contents of config
dictionary at the start of the file. Here is an example, how this configuration
might look:
config = {
# Omit option, if you do not have the file in your split
# (e.g. if you have only training/test split, comment-out "test_data_file" option here
# Path to your training data
"training_data_file": ".data/nqopen/nq-open_train_short_maxlen_5_ms_with_dpr_annotation.jsonl",
# Path to your validation data
"validation_data_file": ".data/nqopen/nq-open_dev_short_maxlen_5_ms_with_dpr_annotation.jsonl",
# Path to your test data
"test_data_file": ".data/nqopen/NQ-open-test.jsonl",
# Output directory, where to save files with retrievan information
"output_directory": "retrieved_data",
# Path to your passage embeddings
"embeddings": ".embeddings/DPR_nqsingle_official.h5",
# Path to databse containing passages
"db_path": ".wikipedia/wiki2018_dpr_blocks.db",
# Path to retriever model
"model_path": ".checkpoints/dpr_official_questionencoder_nq.pt",
# How many top-K passage indices to save into the output file
"topK_extract": 400,
# ...
}
- Note Trivia files also contain
"human_answer"
entry for each example, which is used to supervise the FiD reader. - This code does exact retrieval (dot-product with the embedding matrix). Therefore if you use full matrix of 21M passages in this step, you will need to fit it into your RAM (~65GB).
- You can find download urls to compressed index/database/retriever in
every R2-D2-pipeline configuration (for example,
check
configurations/pipeline/NQ/r2d2_full.json
to get files needed to run this code snippet).
Afterwards simply run the module to extract the DPR's predictions.
python -m scalingqa.retriever.extract_DPR_predictions
Pruning the Index Contents
1. Constructing Golden Dataset (dataset with relevant and irrelevant passages) <a name="gpconstruction"></a>
For building NQ-Golden set, run script scalingqa/index_pruning/dataset/NQ/build_dataset.py
.
python -m scalingqa.index_pruning.dataset.NQ.build_dataset
The script works with 4 arguments. They are not passed, please edit them directly in the script's main()
function
raw_passages_db = ".index/wiki2018_dpr_blocks.db"
output_folder = ".data/nq_corpus_pruning"
training_data_source = ".data/nqopen/nq-open_train_short_maxlen_5_ms_with_dpr_annotation.jsonl"
validation_data_source = ".data/nqopen/nq-open_dev_short_maxlen_5_ms_with_dpr_annotation.jsonl"
Note the data here are the same as inputs to DPR use to generate data for reranker and reader training. Validation and Test sets for this task are build from nq-open's validation set. You should end up with 176,628 examples for training, 4,332 examples for validation, and examples 8,698 for testing on NQ.
Similarly, you can use scalingqa/index_pruning/dataset/Trivia/build_dataset.py
to build Trivia-Golden dataset.
2. Training the Irrelevant Passage Classifier (Pruner) <a name="prunertraining"></a>
Run
python -m scalingqa.index_pruning.training.run_irrelevant_doc_classifier_training_[NQ|TRIVIA]
Adjust the parameters in the config if needed; in particular, you might be interested in setting paths to your data. For
example, the defaults for NQ
dataset are:
"data_cache_dir": '.data/nq_corpus_pruning',
"training_data": "train.jsonl",
"validation_data": "val.jsonl",
"test_data": "test.jsonl",
The training takes about 1.5h on 2080Ti 12 GB GPU for both datasets. In the paper we use the following checkpoints.
3. Inferring Irrelevant Passage's Probabilities <a name="prunerinference"></a>
Now when the model is training, the next step is to extract the irrelevance probability for each passage. Extract probabilities for each passage into h5 matrix via:
python -m scalingqa.index_pruning.inference.run_irrelevant_doc_predictor
The parameters can be again adjusted inside runfile's config:
"passage_source": ".data/index/psgs_w100.tsv", # all passages from DPR
"prob_file": ".pruning/psgs_w100_pruneprobs.h5", # output file
"cls_checkpoint": ".saved/irrelevant_doc_cls_google_electra-base-discriminator_acc_0.9049_2020-12-26_23:51.pt" # checkpoint from training
This is usually the longest step. For 21M passages, it takes about 24h to extract the probabilities. To get the wikipedia passages, you can use this link available in the official DPR implementation.
You can get the extracted probabilities we used in the paper from the following links:
4. Choosing the Relevant Documents
Now, prune the index (manually) via jupyter-notebook file scalingqa/index_pruning/inference/get_pruning_index.ipynb
.
There, you can select the number of passages or manually adjust the threshold for pruner. Running the notebook will
create a file containing set of all passage indices to keep in the index.
5. Dumping the Pruned Index
Finally, the embedding index and the database can be pruned. You can use index_pruning/inference/prune_embeddings.py
to prune embedding matrix. Adjust paths to full embeddings (FULL_EMBEDDINGS
) and file from previous
step (PRUNE_FILE
) directly in the file.
python -m scalingqa.index_pruning.inference.prune_embeddings
Analogously, use index_pruning/inference/prune_db.py
to prune the SQLite database. There adjust path to
databse (FULL_DB_PATH
) and PRUNE_FILE
.
python -m scalingqa.index_pruning.inference.prune_db
See any of the configurations/pipeline/[NQ|Trivia]/*_pruned.json
files in R2-D2-pipeline for links to pruned versions of NQ/Trivia index we used in the paper.