Awesome
<div align="center"> <h1> Contrastive Mean-Shift Learning for Generalized Category Discovery </h1> </div> <div align="center"> <h3><a href=http://sua-choi.github.io>Sua Choi</a> <a href=http://dahyun-kang.github.io>Dahyun Kang</a> <a href=https://cvlab.postech.ac.kr/~mcho>Minsu Cho</a> <h4> Pohang University of Science and Technology (POSTECH) <h4>[<a href=http://arxiv.org/abs/2404.09451>Paper</a>] [<a href=https://cvlab.postech.ac.kr/research/cms>Project page</a>] </div> <br /> </div> <br /> <div align="center"> <img src="data/assets/overview.png" alt="result" width="80%"/> </div> <!-- <p align="center"> <img src="data/assets/teaser.png" alt="result" width="40%" align="left"/> <img src="data/assets/overview.png" alt="result" width="60%" align="right"/> </p> -->Environmnet installation
This project is built upon the following environment:
The package requirements can be installed via requirements.txt
,
pip install -r requirements.txt
Datasets
We use fine-grained benchmarks in this paper, including:
We also use generic object recognition datasets, including:
Please follow this repo to set up the data.
Download the datasets, ssb splits, and pretrained backbone by following the file structure below and set DATASET_ROOT={YOUR DIRECTORY}
in config.py
.
DATASET_ROOT/
├── cifar100/
│ ├── cifar-100-python\
│ │ ├── meta/
│ ├── ...
├── CUB_200_2011/
│ ├── attributes/
│ ├── ...
├── ...
CMS/
├── data/
│ ├── ssb_splits/
├── models/
│ ├── dino_vitbase16_pretrain.pth
├── ...
Training
bash bash_scripts/contrastive_meanshift_training.sh
Example bash commands for training are as follows:
# GCD
python -m methods.contrastive_meanshift_training \
--dataset_name 'cub' \
--lr 0.05 \
--temperature 0.25 \
--wandb
# Inductive GCD
python -m methods.contrastive_meanshift_training \
--dataset_name 'cub' \
--lr 0.05 \
--temperature 0.25 \
--inductive \
--wandb
Evaluation
bash bash_scripts/meanshift_clustering.sh
Example bash command for evaluation is as follows. It will require changing model_name
.
python -m methods.meanshift_clustering \
--dataset_name 'cub' \
--model_name 'cub_best' \
Results and checkpoints
Experimental results on GCD task.
<table> <tr> <td> </td> <td>All</td> <td>Old</td> <td>Novel</td> <td>Checkpoints</td> </tr> <tr> <td>CIFAR100</td> <td align="center">82.3</td> <td align="center">85.7</td> <td align="center">75.5</td> <td align="center"><a href="https://drive.google.com/drive/folders/1LK9dSV4lDdu9kdvCkETGV_JjxhyP3tvp?usp=drive_link">link</a></td> </tr> <tr> <td>ImageNet100</td> <td align="center">84.7</td> <td align="center">95.6</td> <td align="center">79.2</td> <td align="center"><a href="https://drive.google.com/drive/folders/1ODWtXTNoTch-hjPgZ20d5JQEprS9WYS5?usp=drive_link">link</a></td> </tr> <tr> <td>CUB</td> <td align="center">68.2</td> <td align="center">76.5</td> <td align="center">64.0</td> <td align="center"><a href="https://drive.google.com/drive/folders/1YxrZJkmMf_QG4ELCTUP5EwlZk86eGPYc?usp=drive_link">link</a></td> </tr> <tr> <td>Stanford Cars</td> <td align="center">56.9</td> <td align="center">76.1</td> <td align="center">47.6</td> <td align="center"><a href="https://drive.google.com/drive/folders/1ci5LpkjGOvwUxYW28CV4nurU_u3liiqe?usp=drive_link">link</a></td> </tr> <tr> <td>FGVC-Aircraft</td> <td align="center">56.0</td> <td align="center">63.4</td> <td align="center">52.3</td> <td align="center"><a href="https://drive.google.com/drive/folders/1Sf7es_0O2UIeaZDaQv_UTjPVaVfELvVi?usp=drive_link">link</a></td> </tr> <tr> <td>Herbarium19</td> <td align="center">36.4</td> <td align="center">54.9</td> <td align="center">26.4</td> <td align="center"><a href="https://drive.google.com/drive/folders/1oDpp7bjLyRcA620xuuk5rv6JXGGDwYbX?usp=drive_link">link</a></td> </tr> </table>Experimental results on inductive GCD task.
<table> <tr> <td> </td> <td>All</td> <td>Old</td> <td>Novel</td> <td>Checkpoints</td> </tr> <tr> <td>CIFAR100</td> <td align="center">80.7</td> <td align="center">84.4</td> <td align="center">65.9</td> <td align="center"><a href="https://drive.google.com/drive/folders/1nBsLIE6kxMMtq5tNArqdIm4xYPQ7cPtY?usp=drive_link">link</a></td> </tr> <tr> <td>ImageNet100</td> <td align="center">85.7</td> <td align="center">95.7</td> <td align="center">75.8</td> <td align="center"><a href="https://drive.google.com/drive/folders/1ZtE3CwJOj1dEZ_0Vch5ocBzd4YFCWxXH?usp=drive_link">link</a></td> </tr> <tr> <td>CUB</td> <td align="center">69.7</td> <td align="center">76.5</td> <td align="center">63.0</td> <td align="center"><a href="https://drive.google.com/drive/folders/1iea01VbkCymHeXrn-gtGbllY9nbRFBE1?usp=drive_link">link</a></td> </tr> <tr> <td>Stanford Cars</td> <td align="center">57.8</td> <td align="center">75.2</td> <td align="center">41.0</td> <td align="center"><a href="https://drive.google.com/drive/folders/1BJPWYHlILZrYgyyoiY-Hi4VYJ2Cplkvc?usp=drive_link">link</a></td> </tr> <tr> <td>FGVC-Aircraft</td> <td align="center">53.3</td> <td align="center">62.7</td> <td align="center">43.8</td> <td align="center"><a href="https://drive.google.com/drive/folders/1dsm2ICGnSS4QZJN5Pda9fqGYzpZVValA?usp=drive_link">link</a></td> </tr> <tr> <td>Herbarium19</td> <td align="center">46.2</td> <td align="center">53.0</td> <td align="center">38.9</td> <td align="center"><a href="https://drive.google.com/drive/folders/1EjSwrX06f964TF9utvLQewSivOvJEsxf?usp=drive_link">link</a></td> </tr> </table>Citation
If you find our code or paper useful, please consider citing our paper:
@inproceedings{choi2024contrastive,
title={Contrastive Mean-Shift Learning for Generalized Category Discovery},
author={Choi, Sua and Kang, Dahyun and Cho, Minsu},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2024}
}
Related Repos
The codebase is largely built on Generalized Category Discovery and PromptCAL.
Acknowledgements
This work was supported by the NRF grant (NRF-2021R1A2C3012728 (50%)) and the IITP grants (2022-0-00113: Developing a Sustainable Collaborative Multi-modal Lifelong Learning Framework (45%), 2019-0-01906: AI Graduate School Program at POSTECH (5%)) funded by Ministry of Science and ICT, Korea.