Awesome
<p align="center"> <img src="https://github.com/user-attachments/assets/53a09bd1-c8ac-43c0-80ae-03ba284c94ad" width="150" style="margin-bottom: 0.2;"/> <p> <h3 align="center"><a href="https://arxiv.org/abs/2410.17243"> Breaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss</a></h3> <h5 align="center"> If our project helps you, please give us a star ⭐ on GitHub to support us. 🙏🙏 </h2> <h5 align="center"> </h5> <div align="center"><img src="https://github.com/user-attachments/assets/2c19838b-43d8-4145-b28c-903f3d76f8ab" width="800" /></div> <details open><summary>💡 Some other multimodal foundation model projects from our team may interest you ✨. </summary><p> <!-- may -->VCD: Mitigating Object Hallucinations in Large Vision-Language Models through Visual Contrastive Decoding <br> Sicong Leng, Hang Zhang, Guanzheng Chen, Xin Li, Shijian Lu, Chunyan Miao, Lidong Bing <br> <br>
VideoLLaMA 2: Advancing Spatial-Temporal Modeling and Audio Understanding in Video-LLMs <br> Zesen Cheng, Sicong Leng, Hang Zhang, Yifei Xin, Xin Li, Guanzheng Chen, Yongxin Zhu, Wenqi Zhang, Ziyang Luo, Deli Zhao, Lidong Bing <br> <br>
</p></details>The Curse of Multi-Modalities: Evaluating Hallucinations of Large Multimodal Models across Language, Visual, and Audio <br> Sicong Leng, Yun Xing, Zesen Cheng, Yang Zhou, Hang Zhang, Xin Li, Deli Zhao, Shijian Lu, Chunyan Miao, Lidong Bing <br> <br>
📰 News
- [2024.10.18] Release training and evaluation codes of Inf-CLIP.
🛠️ Requirements and Installation
Basic Dependencies:
- Python >= 3.8
- Pytorch >= 2.0.0
- CUDA Version >= 11.8
[Remote] Install Inf-CL:
# remote installing
pip install inf_cl -i https://pypi.org/simple
[Local] Install Inf-CL:
pip install -e .
Install required packages:
git clone https://github.com/DAMO-NLP-SG/Inf-CLIP
cd Inf-CLIP
pip install -r requirements.txt
⭐ Features
inf_cl
is the triton implementation of Inf-CL loss:
inf_clip
is the CLIP training codebase with Inf-CL loss and other training features:
🔑 Usage
A simple example about how to adopt our Inf-CL loss for contrastive learning. Using such command for attempting:
torchrun --nproc_per_node 2 tests/example.py
import torch
import torch.nn.functional as F
import torch.distributed as dist
import numpy as np
from inf_cl import cal_inf_loss
def create_cl_tensors(rank, world_size):
# Parameters
dtype = torch.float32
num_heads = 3 # Number of attention heads
seq_length_q = 32768 # Sequence length
seq_length_k = 32768
d_model = 256 # Dimension of each head (must be 16, 32, 64, or 128)
# Randomly initialize inputs
q = torch.rand((seq_length_q // world_size, num_heads * d_model), dtype=dtype, device=f"cuda:{rank}")
k = torch.rand((seq_length_k // world_size, num_heads * d_model), dtype=dtype, device=f"cuda:{rank}")
l = torch.ones([], dtype=dtype, device=f"cuda:{rank}") * np.log(1 / 0.07)
q = F.normalize(q, p=2, dim=-1).requires_grad_() # Query
k = F.normalize(k, p=2, dim=-1).requires_grad_() # Key
l = l.requires_grad_() # Logit scale
return q, k, l
if __name__ == "__main__":
# Assume that the distributed environment has been initialized
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(rank)
# Exampled by Image-Text Contrastive Learning, q is the global image features,
# k is the text features, and l is the logit scale.
q, k, l = create_cl_tensors(rank, world_size)
# labels are diagonal elements by default.
# labels = torch.arange(q.shape[0])
loss = cal_inf_loss(q, k, scale=l.exp())
print(loss)
🚀 Main Results
Memory Cost
<p><img src="https://github.com/user-attachments/assets/05dd3fea-0a93-4716-b321-0a94965e1fbe" width="800" "/></p>* denotes adopting "data offload" strategy.
Max Supported Batch Size
<p><img src="https://github.com/user-attachments/assets/eb38fb90-3b7e-4696-b078-b7766893f758" width="800" "/></p>Speed
<p><img src="https://github.com/user-attachments/assets/da72e99b-508b-450a-b12e-401d4991291a" width="800" "/></p>Batch Size Scaling
<p><img src="https://github.com/user-attachments/assets/5b55fa98-6558-4509-9b66-e290ecf77b41" width="800" "/></p>Training with larger data scale needs larger batch size.
🗝️ Training & Evaluation
Quick Start
To facilitate further development on top of our codebase, we provide a quick-start guide on how to use Inf-CLIP to train a customized CLIP and evaluate the trained model on the mainstream clip benchmarks.
- Training Data Structure:
Inf-CLIP
├── datasets
│ ├── cc3m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md
| | ├── 0000.tar
| | ├── 0001.tar
| | ├── ...
| | └── 0301.tar
│ ├── cc12m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc12m.md
| | ├── 0000.tar
| | ├── 0001.tar
| | ├── ...
| | └── 1044.tar
│ ├── laion400m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/laion400m.md
| | ├── 00000.tar
| | ├── 00001.tar
| | ├── ...
| | └── 41407.tar
- Command:
bash scripts/cc3m/lit_vit-b-32_bs16k.sh
bash scripts/cc12m/lit_vit-b-32_bs32k.sh
bash scripts/laion400m/lit_vit-b-32_bs256k.sh
- Evaluation Data Structure:
Inf-CLIP
├── datasets
│ ├── imagenet-1k/ # download val_images.tar.gz of imagenet
| | └── val/
| | | ├── n01440764
| | | ├── n01443537
| | | ├── ...
| | | └── n15075141
│ ├── clip-benchmark/ # bash datasets/benchmarks_download.sh
| | ├── wds_mscoco_captions
| | ├── wds_flickr8k
| | ├── wds_flickr30k
| | ├── wds_imagenet1k
| | ├── wds_imagenetv2
| | ├── wds_imagenet_sketch
| | ├── wds_imagenet-a
| | ├── wds_imagenet-r
| | ├── wds_imagenet-o
| | └── wds_objectnet
- Command:
# imagenet evaluation
bash scripts/imagenet_eval.sh
# overall evaluation
bash scripts/benchmarks_eval.sh
📑 Citation
If you find Inf-CLIP useful for your research and applications, please cite using this BibTeX:
@article{damovl2024infcl,
title={Breaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss},
author={Zesen Cheng, Hang Zhang, Kehan Li, Sicong Leng, Zhiqiang Hu, Fei Wu, Deli Zhao, Xin Li, Lidong Bing},
journal={arXiv preprint arXiv:2410.17243},
year={2024},
url={https://arxiv.org/abs/2410.12787}
}
👍 Acknowledgement
The codebase of Inf-CLIP is adapted from OpenCLIP. We are also grateful for the following projects our Inf-CL arose from:
🔒 License
This project is released under the Apache 2.0 license as found in the LICENSE file. The service is a research preview intended for non-commercial use ONLY, subject to the model Licenses of CLIP, Terms of Use of the data generated by OpenAI, and Laion. Please get in touch with us if you find any potential violations.