Home

Awesome

RetroMAE

Codebase for RetroMAE and beyond.

What's New

Released Models

We have uploaded some checkpoints to Huggingface Hub.

ModelDescriptionLink
RetroMAEPre-trianed on the wikipedia and bookcorpusShitao/RetroMAE
RetroMAE_MSMARCOPre-trianed on the MSMARCO passageShitao/RetroMAE_MSMARCO
RetroMAE_MSMARCO_finetuneFinetune the RetroMAE_MSMARCO on the MSMARCO passage dataShitao/RetroMAE_MSMARCO_finetune
RetroMAE_MSMARCO_distillFinetune the RetroMAE_MSMARCO on the MSMARCO passage data by minimizing the KL-divergence with the cross-encoder Shitao/RetroMAE_MSMARCO_distill
RetroMAE_BEIRFinetune the RetroMAE on the MSMARCO passage data for BEIR (use the official negatives provided by BEIR) Shitao/RetroMAE_BEIR

You can load them easily using the identifier strings. For example:

from transformers import AutoModel
model = AutoModel.from_pretrained('Shitao/RetroMAE')

State of the Art Performance

RetroMAE can provide a strong initialization of dense retriever; after fine-tuned with in-domain data, it gives rise to a high-quality supervised retrieval performance in the corresponding scenario. Besides, It substantially improves the pre-trained model's transferability, which helps to result in superior zero-shot performances on out-of-domain datasets.

MSMARCO Passage

ModelMRR@10Recall@1000
Bert0.3460.964
RetroMAE0.3820.981
ModelMRR@10Recall@1000
coCondenser0.3820.984
RetroMAE0.3930.985
RetroMAE(distillation)0.4160.988

BEIR Benchemark

ModelAvg NDCG@10 (18 datasets)
Bert0.371
Condenser0.407
RetroMAE0.452
RetroMAE v20.491

Installation

git clone https://github.com/staoxiao/RetroMAE.git
cd RetroMAE
pip install .

For development, install as editable:

pip install -e .

Workflow

This repo includes two functions: pre-train and finetune. Firstly, train the RetroMAE on general dataset (or downstream dataset) with mask language modeling loss. Then finetune the RetroMAE on downstream dataset with contrastive loss. To achieve a better performance, you also can finetune the RetroMAE by distillation the scores provided by cross-encoder. Detailed workflow please refer to our examples.

Pretrain

torchrun --nproc_per_node 8 \
  -m pretrain.run \
  --output_dir {path to save ckpt} \
  --data_dir {your data} \
  --do_train True \
  --model_name_or_path bert-base-uncased \
  --pretrain_method {retromae or dupmae}

Finetune

torchrun --nproc_per_node 8 \
-m bi_encoder.run \
--output_dir {path to save ckpt} \
--model_name_or_path Shitao/RetroMAE \
--do_train  \
--corpus_file ./data/BertTokenizer_data/corpus \
--train_query_file ./data/BertTokenizer_data/train_query \
--train_qrels ./data/BertTokenizer_data/train_qrels.txt \
--neg_file ./data/train_negs.tsv 

Examples

Citation

If you find our work helpful, please consider citing us:

@inproceedings{RetroMAE,
  title={RetroMAE: Pre-Training Retrieval-oriented Language Models Via Masked Auto-Encoder},
  author={Shitao Xiao, Zheng Liu, Yingxia Shao, Zhao Cao},
  url={https://arxiv.org/abs/2205.12035},
  booktitle ={EMNLP},
  year={2022},
}