Home

Awesome

CyCLIP — Official PyTorch Implementation

Python 3.10 PyTorch 1.10 License CC BY-NC

<h1 align="center"><img src="./docs/images/intro.png" width="75%"></h1>

This repository contains the official PyTorch implementation of the following paper:

CyCLIP: Cyclic Contrastive Language-Image Pretraining<br> Shashank Goel (UCLA), Hritik Bansal (UCLA), Sumit Bhatia (MDSR Lab, Adobe Systems), Ryan A. Rossi (Adobe Research), Vishwa Vinay (Adobe Research), Aditya Grover (UCLA)<br> https://arxiv.org/abs/2205.14459

Abstract: Recent advances in contrastive representation learning over paired image-text data have led to models such as CLIP that achieve state-of-the-art performance for zero-shot classification and distributional robustness. Such models typically require joint reasoning in the image and text representation spaces for downstream inference tasks. Contrary to prior beliefs, we demonstrate that the image and text representations learned via a standard contrastive objective are not interchangeable and can lead to inconsistent downstream predictions. To mitigate this issue, we formalize consistency and propose CyCLIP, a framework for contrastive representation learning that explicitly optimizes for the learned representations to be geometrically consistent in the image and text space. In particular, we show that consistent representations can be learned by explicitly symmetrizing (a) the similarity between the two mismatched image-text pairs (cross-modal consistency); and (b) the similarity between the image-image pair and the text-text pair (in-modal consistency). Empirically, we show that the improved consistency in CyCLIP translates to significant gains over CLIP, with gains ranging from 10%-24% for zero-shot classification accuracy on standard benchmarks (CIFAR-10, CIFAR-100, ImageNet1K) and 10%-27% for robustness to various natural distribution shifts

Acknowledgements

Some portions of the code in this repository are adaptations from the following repositories: mlfoundations and openai.

Licenses

You can use, redistribute, and adapt the material for non-commercial purposes, as long as you give appropriate credit by citing our paper and indicating any changes that you've made.

Requirements

Setup Environment and Install dependencies

Clone the repository

git clone git@github.com:goel-shashank/CyCLIP.git
cd CyCLIP

Conda (recommended)

Please follow the instructions at the following link to set up anaconda: Anaconda Setup

The following commands create a conda environment inside the repository with the dependencies.

conda env create --prefix ./env -f environment.yml
source activate ./env

Pip

The requirements can be directly installed without creating a conda environment.

pip install -r requirements.txt

Training

python -m src.main --name exp1 --train_data <path to train csv file> --validation_data <path to valid csv file>
--image_key <column name of the image paths in the train/validation csv file> --caption_key <column name of the captions
in the train/validation csv file> --device_ids 0 1 2 3 --distributed --cylambda1 0.25 --cylambda2 0.25 

Your train/validation csv/tsv file should have 2 columns containing captions and the path to corresponding images on the machine. this script does not download the images for the captions directly. To download the images from their URL for CC3M and/or CC12M, use our utils/download.py script.

Inference - ImageNet1K

python -m src.main --name <eval_imagenet_1k> --eval_data_type <dataset> --eval_test_data_dir data/ImageNet1K/validation/ --device_id 0 --checkpoint <ckpts/epoch_64.pt> 

For ImageNet1K: There should be a labels.csv in the test data directory that contains 2 columns -- image, label. image should have the location to the image in the local machine.

Pretrained Checkpoints

You can find the pre-trained checkpoints here.

CyCLIP has been added to EvalAI leaderboard under Image Classification Challenge. We highly recommend using this for benchmarking your pre-trained models. Link to their repo - https://github.com/Computer-Vision-in-the-Wild/Elevater_Toolkit_IC