Awesome
VisRAG: Vision-based Retrieval-augmented Generation on Multi-modality Documents
<p align="center">• <a href="#-introduction"> 📖 Introduction </a> • <a href="#-news">🎉 News</a> • <a href="#-visrag-pipeline">✨ VisRAG Pipeline</a> • <a href="#%EF%B8%8F-setup">⚙️ Setup</a> • <a href="#%EF%B8%8F-training">⚡️ Training</a> </p> <p align="center">• <a href="#-evaluation">📃 Evaluation</a> • <a href="#-usage">🔧 Usage</a> • <a href="#-license">📄 Lisense</a> • <a href="#-contact">📧 Contact</a> • <a href="#-star-history">📈 Star History</a> </p>📖 Introduction
VisRAG is a novel vision-language model (VLM)-based RAG pipeline. In this pipeline, instead of first parsing the document to obtain text, the document is directly embedded using a VLM as an image and then retrieved to enhance the generation of a VLM. Compared to traditional text-based RAG, VisRAG maximizes the retention and utilization of the data information in the original documents, eliminating the information loss introduced during the parsing process.
<p align="center"><img width=800 src="assets/main_figure.png"/></p>🎉 News
- 20241111: Released our VisRAG Pipeline on GitHub, now supporting visual understanding across multiple PDF documents.
- 20241104: Released our VisRAG Pipeline on Hugging Face Space.
- 20241031: Released our VisRAG Pipeline on Colab. Released codes for converting files to images which could be found at
scripts/file2img
. - 20241015: Released our train data and test data on Hugging Face which can be found in the VisRAG Collection on Hugging Face. It is referenced at the beginning of this page.
- 20241014: Released our Paper on arXiv. Released our Model on Hugging Face. Released our Code on GitHub.
✨ VisRAG Pipeline
VisRAG-Ret
VisRAG-Ret is a document embedding model built on MiniCPM-V 2.0, a vision-language model that integrates SigLIP as the vision encoder and MiniCPM-2B as the language model.
VisRAG-Gen
In the paper, We use MiniCPM-V 2.0, MiniCPM-V 2.6 and GPT-4o as the generators. Actually, you can use any VLMs you like!
⚙️ Setup
conda create --name VisRAG python==3.10.8
conda install nvidia/label/cuda-11.8.0::cuda-toolkit
cd VisRAG
pip install -r requirements.txt
pip install -e .
cd timm_modified
pip install -e .
cd ..
Note:
timm_modified
is an enhanced version of thetimm
library that supports gradient checkpointing, which we use in our training process to reduce memory usage.
⚡️ Training
VisRAG-Ret
Our training dataset of 362,110 Query-Document (Q-D) Pairs for VisRAG-Ret is comprised of train sets of openly available academic datasets (34%) and a synthetic dataset made up of pages from web-crawled PDF documents and augmented with VLM-generated (GPT-4o) pseudo-queries (66%).
bash scripts/train_retriever/train.sh 2048 16 8 0.02 1 true false config/deepspeed.json 1e-5 false wmean causal 1 true 2 false <model_dir> <repo_name_or_path>
Note:
- Our training data can be found in the
VisRAG
collection on Hugging Face, referenced at the beginning of this page. Please note that we have separated theIn-domain-data
andSynthetic-data
due to their distinct differences. If you wish to train with the complete dataset, you’ll need to merge and shuffle them manually. - The parameters listed above are those used in our paper and can be used to reproduce the results.
<repo_name_or_path>
can be any of the following:openbmb/VisRAG-Ret-Train-In-domain-data
,openbmb/VisRAG-Ret-Train-Synthetic-data
, the directory path of a repository downloaded fromHugging Face
, or the directory containing your own training data.- If you wish to train using your own datasets, remove the
--from_hf_repo
line from thetrain.sh
script. Additionally, ensure that your dataset directory contains ametadata.json
file, which must include alength
field specifying the total number of samples in the dataset. - Our training framework is modified based on OpenMatch.
VisRAG-Gen
The generation part does not use any fine-tuning, we directly use off-the-shelf LLMs/VLMs for generation.
📃 Evaluation
VisRAG-Ret
bash scripts/eval_retriever/eval.sh 512 2048 16 8 wmean causal ArxivQA,ChartQA,MP-DocVQA,InfoVQA,PlotQA,SlideVQA <ckpt_path>
Note:
- Our test data can be found in the
VisRAG
Collection on Hugging Face, which is referenced at the beginning of this page. - The parameters listed above are those used in our paper and can be used to reproduce the results.
- The evaluation script is configured to use datasets from Hugging Face by default. If you prefer to evaluate using locally downloaded dataset repositories, you can modify the
CORPUS_PATH
,QUERY_PATH
,QRELS_PATH
variables in the evaluation script to point to the local repository directory.
VisRAG-Gen
There are three settings in our generation: text-based generation, single-image-VLM-based generation and multi-image-VLM-based generation. Under single-image-VLM-based generation, there are two additional settings: page concatenation and weighted selection. For detailed information about these settings, please refer to our paper.
python scripts/generate/generate.py \
--model_name <model_name> \
--dataset_name <dataset_name> \
--rank <process_rank> \
--world_size <world_size> \
--use_positive_sample <use_positive_sample> \
--topk <number of docs retrieved for generation> \
--results_root_dir <retrieval_results_dir> \
--task_type <task_type> \
--concatenate_type <image_concatenate_type> \
Note:
use_positive_sample
indicates whether to use retrieved documents or just the positive document for the query.topk
andresults_root_dir
are only needed whenuse_positive_sample
is set to 0. Theresults_root_dir
should be organized as follows:results_root_dir/dataset_name/*.trec
.concatenate_type
is needed only whentask_type
is set topage_concatenation
. It specifies the type of concatenation used to combine several images.
🔧 Usage
VisRAG-Ret
Model on Hugging Face: https://huggingface.co/openbmb/VisRAG-Ret
from transformers import AutoModel, AutoTokenizer
import torch
import torch.nn.functional as F
from PIL import Image
import os
def weighted_mean_pooling(hidden, attention_mask):
attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
s = torch.sum(hidden * attention_mask_.unsqueeze(-1).float(), dim=1)
d = attention_mask_.sum(dim=1, keepdim=True).float()
reps = s / d
return reps
@torch.no_grad()
def encode(text_or_image_list):
if (isinstance(text_or_image_list[0], str)):
inputs = {
"text": text_or_image_list,
'image': [None] * len(text_or_image_list),
'tokenizer': tokenizer
}
else:
inputs = {
"text": [''] * len(text_or_image_list),
'image': text_or_image_list,
'tokenizer': tokenizer
}
outputs = model(**inputs)
attention_mask = outputs.attention_mask
hidden = outputs.last_hidden_state
reps = weighted_mean_pooling(hidden, attention_mask)
embeddings = F.normalize(reps, p=2, dim=1).detach().cpu().numpy()
return embeddings
model_name_or_path = "openbmb/VisRAG-Ret"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, trust_remote_code=True)
model.eval()
script_dir = os.path.dirname(os.path.realpath(__file__))
queries = ["What does a dog look like?"]
passages = [
Image.open(os.path.join(script_dir, 'test_image/cat.jpeg')).convert('RGB'),
Image.open(os.path.join(script_dir, 'test_image/dog.jpg')).convert('RGB'),
]
INSTRUCTION = "Represent this query for retrieving relevant documents: "
queries = [INSTRUCTION + query for query in queries]
embeddings_query = encode(queries)
embeddings_doc = encode(passages)
scores = (embeddings_query @ embeddings_doc.T)
print(scores.tolist())
VisRAG-Gen
For VisRAG-Gen
, you can explore the VisRAG Pipeline
on Google Colab which includes both VisRAG-Ret
and VisRAG-Gen
to try out this simple demonstration.
📄 License
- The code in this repo is released under the Apache-2.0 License.
- The usage of VisRAG-Ret model weights must strictly follow MiniCPM Model License.md.
- The models and weights of VisRAG-Ret are completely free for academic research. After filling out a "questionnaire" for registration, VisRAG-Ret weights are also available for free commercial use.
📧 Contact
- Shi Yu: yus21@mails.tsinghua.edu.cn
- Chaoyue Tang: tcy006@gmail.com