Home

Awesome

PAG

This repo provides the source code and checkpoints for our paper Planning Ahead in Generative Retrieval: Guiding Autoregressive Generation through Simultaneous Decoding. We introduces PAG–a novel optimization and decoding approach that guides autoregressive generation of document identifiers in generative retrieval models through simultaneous decoding. To this aim, PAG constructs a set-based and sequential identifier for each document. Motivated by the bag-of-words assumption in information retrieval, the set-based identifier is built on lexical tokens. The sequential identifier, on the other hand, is obtained via quantizing relevance-based representations of documents. Extensive experiments on MSMARCO and TREC Deep Learning Track data reveal that PAG outperforms the state-of-the-art generative retrieval model by a large margin (e.g., 15.6% MRR improvements on MS MARCO), while achieving 22x speed up in terms of query latency.

<p align="center"> <img align="center" src="./pag_arch.png" width="850" /> </p>

Package installation

Download Files

All necessary files and checkpoint are in PAG-data. If only want to do inference, you only need to download the following files or folders.

Inference

We use a single 80GB A100 to run the script. Feel free to use other types of GPUs, such as V100, but it would be slower. Make sure that the task variable in line 1 of full_scripts/full_lexical_ripor_evaluate.sh is set to lexical_constrained_retrieve_and_rerank, then run:

bash full_scripts/full_lexical_ripor_evaluate.sh

Training

All experiments are conducted 8x 40GB A100 GPUs. The whole training pipeline contains three stages: (1) Generative retrieval (GR) model for set-based DocIDs. (2) GR model for sequence-based DocIDs. (3) Unified GR model for set-based & sequence-based DocIDs. Stages (1) and (2) can be train in parallel.

Stage 1: GR model for set-based DocIDs

The stage contains 2 phases: pre-training and fine-tunining. For pre-training, we train the GR model as a sparse encoder, then we select the top m words from the sparse vector for each document $d$ to form the set-based DocID, and we term it as ${t^d_1, \ldots t^d_m }$. For fine-tuning phase, we train the GR model for set-based DocID prediction.

Pre-training:

Run script for training the sparse encoder:

bash full_scripts/t5_splade_train.sh 

Once model trained, run the following script to get the set-based DocIDs:

bash full_scripts/t5_splade_get_bow_rep.sh

Fine-tuning:

We apply the two-step fine-tuning stragegy. The negatives for the step 1 is from BM25, and the negatives for the step 2 is from the step 1 model itself. For step 1 training:

full_scripts/t5_full_term_encoder_train.sh

For step 2 training:

bash full_scripts/t5_full_term_encoder_evaluate.sh
python t5_pretrainer/full_preprocess/add_qrel_to_rerank_run.py
bash full_scripts/t5_full_term_encoder_train.sh

Stage 2: GR model for sequence-basded DocIDs

The stage also contains pre-training and fine-tuning phases. And the training pipline is the same as RIPOR [https://arxiv.org/pdf/2311.09134.pdf] except we don't use progressive training (We found the progressive training requires too much training time, and do not significantly influence the model effectiveness).

pre-training

We treat the GR mdoel as a dense encoder and apply the two-step training strategy:

For step 1:

bash full_scripts/t5_full_dense_train.sh

For step 2:

bash full_scripts/t5_full_dense_evaluate.sh 
bash full_scripts/rerank_for_create_trainset.sh
python t5_pretrainer/full_preprocess/add_qrel_to_rerank_run.py
bash full_scripts/t5_full_dense_train.sh
bash full_scripts/t5_full_dense_evaluate.sh 
bash full_scripts/full_ripor_initial_train.sh

fine-tuning

Let us apply the rank-oriented fine-tuning in this stage. Please run the script:

bash full_scripts/full_ripor_direct_lng_knp_train.sh

Stage 3: Unified GR model for set-based & sequence-based DocIDs

We need to the merge weights of the above two trained GR models.

First, move your terminal current directory to t5_pretrainer:

cd t5_pretrainer

Second, run the following code to merge weiths:

python -m full_preprocess.merge_model_weights

Third, move your terminal current directory back:

cd ..

Now, we can finally fine-tune the model. Run the following script:

bash full_scripts/full_lexical_ripor_train.sh