Awesome
[NeurIPS 2023] Model-enhanced Vector Index (Paper)
Environment
[Option 1] Create conda environment:
conda env create -f environment.yml
conda activate mevi
[Option 2] Use docker:
docker pull hugozhl/nci:latest
MSMARCO Passage
Data Process
[1] Download and preprocess:
bash dataprocess/msmarco_passage/download_data.sh
python dataprocess/msmarco_passage/prepare_origin.py \
--data_dir data/marco --origin
[2] Tokenize documents:
# tokenize for T5-ANCE and AR2
# T5-ANCE
python dataprocess/msmarco_passage/prepare_passage_tokenized.py \
--output_dir data/marco/ance \
--document_path data/marco/raw/corpus.tsv \
--dataset marco --model ance
rm data/marco/ance/all_document_indices_*.pkl
rm data/marco/ance/all_document_tokens_*.bin
rm data/marco/ance/all_document_masks_*.bin
# AR2
python dataprocess/msmarco_passage/prepare_passage_tokenized.py \
--output_dir data/marco/ar2 \
--document_path data/marco/raw/corpus.tsv \
--dataset marco --model ar2
rm data/marco/ar2/all_document_indices_*.pkl
rm data/marco/ar2/all_document_tokens_*.bin
rm data/marco/ar2/all_document_masks_*.bin
[3] Query generation for augmentation:
We used the docT5query checkpoint as in NCI. The QG data is only for training.
Please download the finetuned docT5query ckpt to data/marco/ckpts/doc2query-t5-base-msmarco
# MUST download the finetuned docT5query ckpt before running the scripts
python dataprocess/msmarco_passage/doc2query.py --data_dir data/marco
# if the qg data has bad quality, e.g. empty query or many duplicate queries, add another script below
python dataprocess/msmarco_passage/complement_qg10.py --data_dir data/marco # Optional
[4] Generate document embeddings and construct RQ
For T5-ANCE, please download T5-ANCE checkpoint to data/marco/ckpts/t5-ance
.
For AR2, please download AR2 checkpoint to data/marco/ckpts/ar2g_marco_finetune.pkl
and coCondenser checkpoint to data/marco/ckpts/co-condenser-marco-retriever
# MUST download the checkpoints before running the scripts
export DOCUMENT_ENCODER=ance
# export DOCUMENT_ENCODER=ar2 # use this line for ar2
bash MEVI/marco_generate_embedding_n_rq.sh
Training
Train the RQ-based NCI.
export DOCUMENT_ENCODER=ance
# export DOCUMENT_ENCODER=ar2 # use this line for ar2
export WANDB_TOKEN="your wandb token"
bash MEVI/marco_train_nci_rq.sh
Twin-tower Model Evaluation
First generate query embeddings.
# for T5-ANCE
python MEVI/generate.py \
--query_file data/marco/origin/dev_mevi_dedup.tsv \
--model_path data/marco/ckpts/t5-ance \
--tokenizer_path data/marco/ckpts/t5-ance \
--query_embedding_path data/marco/ance/query_emb.bin \
--gpus 0,1,2,3,4,5,6,7 --gen_query
# for AR2
python MEVI/generate.py \
--query_file data/marco/origin/dev_mevi_dedup.tsv \
--model_path data/marco/ckpts/ar2g_marco_finetune.pkl \
--tokenizer_path bert-base-uncased \
--query_embedding_path data/marco/ar2/query_emb.bin \
--gpus 0,1,2,3,4,5,6,7 --gen_query
Then use faiss for ANN search.
# for T5-ANCE; if for AR2, change the ance directory to ar2 directory
python MEVI/faiss_search.py \
--query_path data/marco/ance/query_emb.bin \
--doc_path data/marco/ance/docemb.bin \
--output_path data/marco/ance/hnsw256.txt \
--raw_query_path data/marco/origin/dev_mevi_dedup.tsv \
--param HNSW256
Sequence-to-sequence Model Evaluation
Please download our checkpoint for MSMARCO Passage or train from scratch before evaluation, and put the checkpoint in data/marco/ckpts
. If using the downloaded checkpoint, please also download the corresponding RQ files.
# MUST download or train a ckpt before running the scripts
export DOCUMENT_ENCODER=ance
# export DOCUMENT_ENCODER=ar2 # use this line for ar2
bash MEVI/marco_eval_nci_rq.sh
Ensemble
Ensemble the results from the twin-tower model and the sequence-to-sequence model.
export DOCUMENT_ENCODER=ance
# export DOCUMENT_ENCODER=ar2 # use this line for ar2
bash MEVI/marco_ensemble.sh
Natural Questions (DPR version)
Data Process
[1] Download and preprocess:
bash dataprocess/NQ_dpr/download_data.sh
python dataprocess/NQ_dpr/preprocess.py --data_dir data/nq_dpr
[2] Tokenize documents:
# use AR2
python dataprocess/NQ_dpr/tokenize_passage_ar2.py \
--output_dir data/nq_dpr \
--document_path data/nq_dpr/corpus.tsv
rm data/nq_dpr/all_document_indices_*.pkl
rm data/nq_dpr/all_document_tokens_*.bin
rm data/nq_dpr/all_document_masks_*.bin
[3] Query generation for augmentation:
We used the docT5query checkpoint as in NCI. The QG data is only for training. Please refer to the QG section for MSMARCO Passage.
# download finetuned docT5query ckpt to data/marco/ckpts/doc2query-t5-base-msmarco
python dataprocess/NQ_dpr/doc2query.py \
--data_dir data/nq_dpr --n_gen_query 1 \
--ckpt_path data/marco/ckpts/doc2query-t5-base-msmarco
[4] Generate document embeddings and construct RQ
Please download AR2 checkpoint to data/marco/ckpts/ar2g_nq_finetune.pkl
and ERNIE checkpoint to data/marco/ckpts/ernie-2.0-base-en
# MUST download the checkpoints before running the scripts
bash MEVI/nqdpr_generate_embedding_n_rq.sh
[5] Tokenize query
Since NQ has too many augmented queries, to eliminate runtime memory usage, we tokenize query to enable memmap.
python dataprocess/NQ_dpr/tokenize_query.py \
--output_dir data/nq_dpr \
--tok_train 1 --tok_corpus 1 --tok_qg 1
[6] Get answers
We sort the answers for fast evaluation. (Time-consuming! Please download the processed binary files if necessary.)
python dataprocess/NQ_dpr/get_answers.py \
--data_dir data/nq_dpr \
--dev 1 --test 1
python dataprocess/NQ_dpr/get_inverse_answers.py \
--data_dir data/nq_dpr \
--dev 1 --test 1
Training
Train the RQ-based NCI.
export WANDB_TOKEN="your wandb token"
bash MEVI/nqdpr_train_nci_rq.sh
Twin-tower Model Evaluation
First generate query embeddings.
python MEVI/generate.py \
--query_file data/nq_dpr/nq-test.qa.csv \
--model_path data/marco/ckpts/ar2g_nq_finetune.pkl \
--tokenizer_path bert-base-uncased \
--query_embedding_path data/nq_dpr/query_emb.bin \
--gpus 0,1,2,3,4,5,6,7 --gen_query
Then use faiss for ANN search.
python MEVI/faiss_search.py \
--query_path data/nq_dpr/query_emb.bin \
--doc_path data/nq_dpr/docemb.bin \
--output_path data/nq_dpr/hnsw256.txt \
--raw_query_path data/nq_dpr/nq-test.qa.csv \
--param HNSW256
Sequence-to-sequence Model Evaluation
Please download our checkpoint for NQ or train from scratch before evaluation, and put the checkpoint in data/marco/ckpts
. If using the downloaded checkpoint, please also download the corresponding RQ files.
# MUST download or train a ckpt before running the scripts
bash MEVI/nqdpr_eval_nci_rq.sh
Ensemble
Ensemble the results from the twin-tower model and the sequence-to-sequence model.
bash MEVI/nqdpr_ensemble.sh
Citation
If you find this work useful, please cite our paper.
Acknowledgement
We learned a lot and borrowed some codes from the following projects when building MEVI.