Home

Awesome

RepBERT

RepBERT is is currently the state-of-the-art first-stage retrieval technique on MS MARCO Passage Ranking task. It represents documents and queries with fixed-length contextualized embeddings. The inner products of them are regarded as relevance scores. Its efficiency is comparable to bag-of-words methods. For more details, check out our paper:

MS MARCO Passage Ranking Leaderboard (Jun 28th 2020)CategoryEval MRR@10Latency
BM25 + BERT from (Nogueira and Cho, 2019)Cascade0.3583400 ms
RepBERT (this code)First-Stage0.29480 ms
BiLSTM + Co-Attention + self attention based document scorer (Alaparthi et al., 2019) (best non-BERT)Cascade0.291-
docTTTTTquery (Nogueira1 et al., 2019)First-Stage0.27264 ms
DeepCT (Dai and Callan, 2019)First-Stage0.23955 ms
doc2query (Nogueira et al., 2019)First-Stage0.21890 ms
BM25(Anserini)First-Stage0.18650 ms

Data and Trained Models

We make the following data available for download:

Download and verify the above files from the below table:

FileSizeMD5Download
repbert.dev.small.top1k.tsv127 MB0d08617b62a777c3c8b2d42ca5e89a8e[Google Drive]
repbert.eval.small.top1k.tsv125 MBb56a79138f215292d674f58c694d5206[Google Drive]
repbert.ckpt-350000.zip386 MBb59a574f53c92de6a4ddd4b3fbef784a[Google Drive]

Replicating Results with Provided Trained Model

We provide instructions on how to replicate RepBERT retrieval results using provided trained model.

First, make sure you already installed 🤗 Transformers:

pip install transformers
git clone https://github.com/jingtaozhan/RepBERT-Index
cd RepBERT-Index

Next, download collectionandqueries.tar.gz from MSMARCO-Passage-Ranking. It contains passages, queries, and qrels.

mkdir data
cd data
wget https://msmarco.blob.core.windows.net/msmarcoranking/collectionandqueries.tar.gz
mkdir msmarco-passage
tar xvfz collectionandqueries.tar.gz -C msmarco-passage

To confirm, collectionandqueries.tar.gz should have MD5 checksum of 31644046b18952c1386cd4564ba2ae69.

To reduce duplication of effort in training and testing, we tokenize queries and passages in advance. This should take some time (about 3-4 hours). Besides, we convert tokenized passages to numpy memmap array, which can greatly reduce the memory overhead at run time.

python convert_text_to_tokenized.py --tokenize_queries --tokenize_collection
python convert_collection_to_memmap.py

Please download the provided model repbert.ckpt-350000.zip, put it in ./data, and unzip it. You should see two files in the directory ./data/ckpt-350000, namely pytorch_model.bin and config.json.

Next, you need to precompute the representations of passages and queries.

python precompute.py --load_model_path ./data/ckpt-350000 --task doc
python precompute.py --load_model_path ./data/ckpt-350000 --task query_dev.small
python precompute.py --load_model_path ./data/ckpt-350000 --task query_eval.small

At last, you can retrieve the passages for the queries in the dev set (or eval set). multi_retrieve.py will use the gpus specified by --gpus argument and the representations of all passages are evenly distributed among all gpus. If your CUDA memory is limited, you can use --per_gpu_doc_num to specify the num of passages distributed to each gpu.

python multi_retrieve.py  --query_embedding_dir ./data/precompute/query_dev.small_embedding --output_path ./data/retrieve/repbert.dev.small.top1k.tsv --hit 1000 --gpus 0,1,2,3,4
python ms_marco_eval.py ./data/msmarco-passage/qrels.dev.small.tsv ./data/retrieve/repbert.dev.small.top1k.tsv

You can also retrieve the passages with only one GPU.

export CUDA_VISIBLE_DEVICES=0
python retrieve.py  --query_embedding_dir ./data/precompute/query_dev.small_embedding --output_path ./data/retrieve/repbert.dev.small.top1k.tsv --hit 1000 --per_gpu_doc_num 1800000
python ms_marco_eval.py ./data/msmarco-passage/qrels.dev.small.tsv ./data/retrieve/repbert.dev.small.top1k.tsv

The results should be:

#####################
MRR @10: 0.3038783713103188
QueriesRanked: 6980
#####################

Train RepBERT

Next, download qidpidtriples.train.full.tsv.gz from MSMARCO-Passage-Ranking.

cd ./data/msmarco-passage
wget https://msmarco.blob.core.windows.net/msmarcoranking/qidpidtriples.train.full.tsv.gz

Extract it and use shuf command to generate a smaller file (10%).

shuf ./qidpidtriples.train.full.tsv -o ./qidpidtriples.train.small.tsv -n 26991900

Start training. Note that the evaluaton result is about reranking.

python ./train.py --task train --evaluate_during_training