Home

Awesome

PWC PWC

PWC PWC

šŸ“ SURE (CVPR 2024 & ECCV 2024 OOD-CV Challenge Winner)

Introduction

This is the official implementation of our CVPR 2024 paper "SURE: SUrvey REcipes for building reliable and robust deep networks". Our recipes are powerful tools in addressing real-world challenges, such as long-tailed classification, learning with noisy labels, data corruption and out-of-distribution detection.

arXiv Winner

Project Page Google Drive Poster

News

<p align="center"> <img src="img/Teaser.png" width="1000px" alt="teaser"> </p>

Table of Content

1. Overview of recipes

<p align="center"> <img src="img/recipes.png" width="1000px" alt="method"> </p>

2. Visual Results

<p align="center"> <img src="img/confidence.png" width="1000px" alt="method"> </p> <p align="center"> <img src="img/ood.png" width="650px" alt="method"> </p>

3. Installation

3.1. Environment

Our model can be learnt in a single GPU RTX-4090 24G

conda env create -f environment.yml
conda activate u

The code was tested on Python 3.9 and PyTorch 1.13.0.

3.2. Datasets

3.2.1 CIFAR and Tiny-ImageNet

cd data
bash download_cifar.sh

The structure of the file should be:

./data/CIFAR10/
ā”œā”€ā”€ train
ā”œā”€ā”€ val
ā””ā”€ā”€ test

3.2.2 ImageNet1k and ImageNet21k

3.2.3 Animal-10N and Food-101N

./data/Animal10N/
ā”œā”€ā”€ train
ā””ā”€ā”€ test
./data/Food-101N/
ā”œā”€ā”€ train
ā””ā”€ā”€ test

3.2.4 CIFAR-LT

./data/CIFAR10_LT/
ā”œā”€ā”€ train
ā””ā”€ā”€ test

3.2.5 CIFAR10-C

./data/CIFAR-10-C/
ā”œā”€ā”€ brightness.npy
ā”œā”€ā”€ contrast.npy
ā”œā”€ā”€ defocus_blur.npy
...

3.2.6 Stanford CARS

./data/CARS/
ā”œā”€ā”€ train
ā””ā”€ā”€ test 
...

4. Quick Start

4.1 Failure Prediction

<details> <summary> Take a example in run/CIFAR10/wideresnet.sh: </summary> <details> <summary> MSP </summary>
  python3 main.py \
  --batch-size 128 \
  --gpu 0 \
  --epochs 200 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name baseline \
  --crl-weight 0 \
  --mixup-weight 0 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
  
  python3 test.py \
  --batch-size 128 \
  --gpu 0 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name baseline \
  --crl-weight 0 \
  --mixup-weight 0 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
</details> <details> <summary> RegMixup </summary>
  python3 main.py \
  --batch-size 128 \
  --gpu 0 \
  --epochs 200 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name baseline \
  --crl-weight 0 \
  --mixup-weight 0.5 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
  
  python3 test.py \
  --batch-size 128 \
  --gpu 0 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name baseline \
  --crl-weight 0 \
  --mixup-weight 0.5 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
</details> <details> <summary> CRL </summary>
  python3 main.py \
  --batch-size 128 \
  --gpu 0 \
  --epochs 200 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name baseline \
  --crl-weight 0.5 \
  --mixup-weight 0 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
  
  python3 test.py \
  --batch-size 128 \
  --gpu 0 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name baseline \
  --crl-weight 0.5 \
  --mixup-weight 0 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
</details> <details> <summary> SAM </summary>
  python3 main.py \
  --batch-size 128 \
  --gpu 0 \
  --epochs 200 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name sam \
  --crl-weight 0 \
  --mixup-weight 0 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
  
  python3 test.py \
  --batch-size 128 \
  --gpu 0 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name sam \
  --crl-weight 0 \
  --mixup-weight 0 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
</details> <details> <summary> SWA </summary>
  python3 main.py \
  --batch-size 128 \
  --gpu 0 \
  --epochs 200 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name swa \
  --crl-weight 0 \
  --mixup-weight 0 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
  
  python3 test.py \
  --batch-size 128 \
  --gpu 0 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name swa \
  --crl-weight 0 \
  --mixup-weight 0 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
</details> <details> <summary> FMFP </summary>
  python3 main.py \
  --batch-size 128 \
  --gpu 0 \
  --epochs 200 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name fmfp \
  --crl-weight 0 \
  --mixup-weight 0 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
  
  python3 test.py \
  --batch-size 128 \
  --gpu 0 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name fmfp \
  --crl-weight 0 \
  --mixup-weight 0 \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
</details> <details> <summary> SURE </summary>
  python3 main.py \
  --batch-size 128 \
  --gpu 0 \
  --epochs 200 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name fmfp \
  --crl-weight 0.5 \
  --mixup-weight 0.5 \
  --mixup-beta 10 \
  --use-cosine \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
  
  python3 test.py \
  --batch-size 128 \
  --gpu 0 \
  --nb-run 3 \
  --model-name wrn \
  --optim-name fmfp \
  --crl-weight 0.5 \
  --mixup-weight 0.5 \
  --use-cosine \
  --save-dir ./CIFAR10_out/wrn_out \
  Cifar10
</details> </details>

Note that :

<details> <summary> Take a example in run/CIFAR10/deit.sh: </summary> <details> <summary> MSP </summary>
  python3 main.py \
  --batch-size 64 \
  --gpu 5 \
  --epochs 50 \
  --lr 0.01 \
  --weight-decay 5e-5 \
  --nb-run 3 \
  --model-name deit \
  --optim-name baseline \
  --crl-weight 0 \
  --mixup-weight 0 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
  
  python3 test.py \
  --batch-size 64 \
  --gpu 5 \
  --nb-run 3 \
  --model-name deit \
  --optim-name baseline \
  --crl-weight 0 \
  --mixup-weight 0 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
</details> <details> <summary> RegMixup </summary>
  python3 main.py \
  --batch-size 64 \
  --gpu 5 \
  --epochs 50 \
  --lr 0.01 \
  --weight-decay 5e-5 \
  --nb-run 3 \
  --model-name deit \
  --optim-name baseline \
  --crl-weight 0 \
  --mixup-weight 0.2 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
  
  python3 test.py \
  --batch-size 64 \
  --gpu 5 \
  --nb-run 3 \
  --model-name deit \
  --optim-name baseline \
  --crl-weight 0 \
  --mixup-weight 0.2 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
</details> <details> <summary> CRL </summary>
 python3 main.py \
 --batch-size 64 \
 --gpu 5 \
 --epochs 50 \
 --lr 0.01 \
 --weight-decay 5e-5 \
 --nb-run 3 \
 --model-name deit \
 --optim-name baseline \
 --crl-weight 0.2 \
 --mixup-weight 0 \
 --mixup-beta 10 \
 --save-dir ./CIFAR10_out/deit_out \
 Cifar10
 
 python3 test.py \
 --batch-size 64 \
 --gpu 5 \
 --nb-run 3 \
 --model-name deit \
 --optim-name baseline \
 --crl-weight 0.2 \
 --mixup-weight 0 \
 --save-dir ./CIFAR10_out/deit_out \
 Cifar10
</details> <details> <summary> SAM </summary>
  python3 main.py \
  --batch-size 64 \
  --gpu 5 \
  --epochs 50 \
  --lr 0.01 \
  --weight-decay 5e-5 \
  --nb-run 3 \
  --model-name deit \
  --optim-name sam \
  --crl-weight 0 \
  --mixup-weight 0 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
  
  python3 test.py \
  --batch-size 64 \
  --gpu 5 \
  --nb-run 3 \
  --model-name deit \
  --optim-name sam \
  --crl-weight 0 \
  --mixup-weight 0 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
</details> <details> <summary> SWA </summary>
  python3 main.py \
  --batch-size 64 \
  --gpu 5 \
  --epochs 50 \
  --lr 0.01 \
  --weight-decay 5e-5 \
  --swa-epoch-start 0 \
  --swa-lr 0.004 \
  --nb-run 3 \
  --model-name deit \
  --optim-name swa \
  --crl-weight 0 \
  --mixup-weight 0 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
  
  python3 test.py \
  --batch-size 64 \
  --gpu 5 \
  --nb-run 3 \
  --model-name deit \
  --optim-name swa \
  --crl-weight 0 \
  --mixup-weight 0 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
</details> <details> <summary> FMFP </summary>
  python3 main.py \
  --batch-size 64 \
  --gpu 5 \
  --epochs 50 \
  --lr 0.01 \
  --weight-decay 5e-5 \
  --swa-epoch-start 0 \
  --swa-lr 0.004 \
  --nb-run 3 \
  --model-name deit \
  --optim-name fmfp \
  --crl-weight 0 \
  --mixup-weight 0 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
  
  python3 test.py \
  --batch-size 64 \
  --gpu 5 \
  --nb-run 3 \
  --model-name deit \
  --optim-name fmfp \
  --crl-weight 0 \
  --mixup-weight 0 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
</details> <details> <summary> SURE </summary>
  python3 main.py \
  --batch-size 64 \
  --gpu 5 \
  --epochs 50 \
  --lr 0.01 \
  --weight-decay 5e-5 \
  --swa-epoch-start 0 \
  --swa-lr 0.004 \
  --nb-run 3 \
  --model-name deit \
  --optim-name fmfp \
  --crl-weight 0 \
  --mixup-weight 0.2 \
  --mixup-beta 10 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
  
  python3 test.py \
  --batch-size 64 \
  --gpu 5 \
  --nb-run 3 \
  --model-name deit \
  --optim-name fmfp \
  --crl-weight 0 \
  --mixup-weight 0.2 \
  --save-dir ./CIFAR10_out/deit_out \
  Cifar10
</details> </details> <details> <summary> The results of failure prediction. </summary> <p align="center"> <img src="img/main_results.jpeg" width="1000px" alt="method"> </p> </details>

4.2 Long-tailed classification

<details> <summary> Take a example in run/CIFAR10_LT/resnet32.sh: </summary> <details> <summary> Imbalance factor=10 </summary>
  python3 main.py \
  --batch-size 128 \
  --gpu 0 \
  --epochs 200 \
  --nb-run 3 \
  --model-name resnet32 \
  --optim-name fmfp \
  --crl-weight 0 \
  --mixup-weight 1 \
  --mixup-beta 10 \
  --use-cosine \
  --save-dir ./CIFAR10_LT/res32_out \
  Cifar10_LT
  
  python3 test.py \
  --batch-size 128 \
  --gpu 0 \
  --nb-run 3 \
  --model-name resnet32 \
  --optim-name fmfp \
  --crl-weight 0 \
  --mixup-weight 1 \
  --use-cosine \
  --save-dir ./CIFAR10_LT/res32_out \
  Cifar10_LT
</details> <details> <summary> Imbalance factor = 50 </summary>
  python3 main.py \
  --batch-size 128 \
  --gpu 0 \
  --epochs 200 \
  --nb-run 3 \
  --model-name resnet32 \
  --optim-name fmfp \
  --crl-weight 0 \
  --mixup-weight 1 \
  --mixup-beta 10 \
  --use-cosine \
  --save-dir ./CIFAR10_LT_50/res32_out \
  Cifar10_LT_50
  
  python3 test.py \
  --batch-size 128 \
  --gpu 0 \
  --nb-run 3 \
  --model-name resnet32 \
  --optim-name fmfp \
  --crl-weight 0 \
  --mixup-weight 1 \
  --use-cosine \
  --save-dir ./CIFAR10_LT_50/res32_out \
  Cifar10_LT_50
  
</details> <details> <summary> Imbalance factor = 100 </summary>
python3 main.py \
--batch-size 128 \
--gpu 0 \
--epochs 200 \
--nb-run 3 \
--model-name resnet32 \
--optim-name fmfp \
--crl-weight 0 \
--mixup-weight 1 \
--mixup-beta 10 \
--use-cosine \
--save-dir ./CIFAR10_LT_100/res32_out \
Cifar10_LT_100

python3 test.py \
--batch-size 128 \
--gpu 0 \
--nb-run 3 \
--model-name resnet32 \
--optim-name fmfp \
--crl-weight 0 \
--mixup-weight 1 \
--use-cosine \
--save-dir ./CIFAR10_LT_100/res32_out \
Cifar10_LT_100
</details> </details>

You can conduct second stage uncertainty-aware re-weighting by :

python3 finetune.py \
--batch-size 128 \
--gpu 5 \
--nb-run 1 \
--model-name resnet32 \
--optim-name fmfp \
--fine-tune-lr 0.005 \
--reweighting-type exp \
--t 1 \
--crl-weight 0 \
--mixup-weight 1 \
--mixup-beta 10 \
--fine-tune-epochs 50 \
--use-cosine \
--save-dir ./CIFAR100LT_100_out/51.60 \
Cifar100_LT_100
<details> <summary> The results of long-tailed classification. </summary> <p align="center"> <img src="img/long-tail.jpeg" width="600px" alt="method"> </p> </details>

4.3 Learning with noisy labels

<details> <summary> Animal-10N </summary>
 python3 main.py \
 --batch-size 128 \
 --gpu 0 \
 --epochs 200 \
 --nb-run 1 \
 --model-name vgg19bn \
 --optim-name fmfp \
 --crl-weight 0.2 \
 --mixup-weight 1 \
 --mixup-beta 10 \
 --use-cosine \
 --save-dir ./Animal10N_out/vgg19bn_out \
 Animal10N
 
 python3 test.py \
 --batch-size 128 \
 --gpu 0 \
 --nb-run 1 \
 --model-name vgg19bn \
 --optim-name baseline \
 --crl-weight 0.2 \
 --mixup-weight 1 \
 --use-cosine \
 --save-dir ./Animal10N_out/vgg19bn_out \
 Animal10N
</details> <details> <summary> Food-101N </summary>
 python3 main.py \
 --batch-size 64 \
 --gpu 0 \
 --epochs 30 \
 --nb-run 1 \
 --model-name resnet50 \
 --optim-name fmfp \
 --crl-weight 0.2 \
 --mixup-weight 1 \
 --mixup-beta 10 \
 --lr 0.01 \
 --swa-lr 0.005 \
 --swa-epoch-start 22 \
 --use-cosine True \
 --save-dir ./Food101N_out/resnet50_out \
 Food101N
 
 python3 test.py \
 --batch-size 64 \
 --gpu 0 \
 --nb-run 1 \
 --model-name resnet50 \
 --optim-name fmfp \
 --crl-weight 0.2 \
 --mixup-weight 1 \
 --use-cosine True \
 --save-dir ./Food101N_out/resnet50_out \
 Food101N
</details> <details> <summary> The results of learning with noisy labels. </summary> <p align="center"> <img src="img/label-noise.jpeg" width="600px" alt="method"> </p> </details>

4.4 Robustness under data corruption

if args.data_name == 'cifar10':
    cor_results_storage = test_cifar10c_corruptions(net, args.corruption_dir, transform_test,
                                                    args.batch_size, metrics, logger)
    cor_results = {corruption: {
                   severity: {
                   metric: cor_results_storage[corruption][severity][metric][0] for metric in metrics} for severity
                   in range(1, 6)} for corruption in data.CIFAR10C.CIFAR10C.cifarc_subsets}
    cor_results_all_models[f"model_{r + 1}"] = cor_results
<details> <summary> The results of failure prediction under distribution shift. </summary> <p align="center"> <img src="img/data-corruption.jpeg" width="1000px" alt="method"> </p> </details>

4.5 Out-of-distribution detection

<details> <summary> The results of out-of-distribution detection. </summary> <p align="center"> <img src="img/ood_results.png" width="800px" alt="method"> </p> </details>

5. Citation

If our project is helpful for your research, please consider citing :

@InProceedings{Li_2024_CVPR,
    author    = {Li, Yuting and Chen, Yingyi and Yu, Xuanlong and Chen, Dexiong and Shen, Xi},
    title     = {SURE: SUrvey REcipes for building reliable and robust deep networks},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2024},
    pages     = {17500-17510}
}

@article{Li2024sureood,
    author    = {Li, Yang and Sha, Youyang and Wu, Shengliang and Li, Yuting and Yu, Xuanlong and Huang, Shihua and Cun, Xiaodong and Chen,Yingyi and Chen, Dexiong and Shen, Xi},
    title     = {SURE-OOD: Detecting OOD samples with SURE},
    month     = {September}
    year      = {2024},
}

6. Acknowledgement

We refer to codes from FMFP and OpenMix. Thanks for their awesome works.