Awesome
PyTorch implementation for "Learning to Rematch Mismatched Pairs for Robust Cross-Modal Retrieval (CVPR 2024)"
Learning to Rematch Mismatched Pairs for Robust Cross-Modal Retrieval
Requirements
- Python 3.8
- torch 1.12
- numpy
- scikit-learn
- pomegranate Install. Note that pomegranate requires
Cython=0.29
,NumPy
,SciPy
,NetworkX
, andjoblib
. Then you can runpython setup.py build
andpython setup.py install
to install it.) - Punkt Sentence Tokenizer:
import nltk
nltk.download()
> d punkt
Datasets
We follow NCR to obtain image features and vocabularies.
Noise (Mismatching) Index
We use the same noise index settings as DECL and RCL, which could be found in noise_index
. The mismatching ratio (noise ratio) is set as 0.2, 0.4, 0.6, and 0.8.
Training and Evaluation
Training new models
Modify some necessary parameters and run it.
For Flickr30K:
sh train_f30k.sh
For MS-COCO:
sh train_coco.sh
For CC152K:
sh train_cc152k.sh
Evaluation
Modify some necessary parameters and run it.
python main_testing.py
Pre-trained L2RM models
The pre-trained models are available here.
License
Acknowledgements
The code is based on SCAN, SGRAF, NCR, DECL, and KPG-RL licensed under Apache 2.0.