Home

Awesome

SSProbing

Code for the paper: Trust, but Verify: Using Self-Supervised Probing to Improve Trustworthiness (ECCV'22)

Demo Visualization

We have provided a notebook vis_demo.ipynb(link) for a demo visualization of self-supervised probing confidence scores.

Environment Setup

Setup environment and install with the packages:

Our environment is partly based on the ConfidNet.

$ git clone https://github.com/valeoai/ConfidNet
$ pip install -e ConfidNet

After installing above packages, install packages in the requirements.txt

    pip install -r requirements.txt

The Juypter Lab is not installed in the requirements. If you would like to use it, you could install it manually.

Data

All the datasets should be placed under data/ directory.

    mkdir data/

OOD related datasets:

Pre-trained models

The pre-trained model weights we used in the paper can be accessed via this link.

Training the base model

If you perfer to train a new base model and test it out, you could follow the following guidance.

As we use MCDropout as one of our baselines, the base models require Dropout components. In detail, the model implementation can refer to VGG16 and our model file under the train_base/models/.

For the VGG16 model training, please refer to ConfidNet.

For the resnet18 model training, you could follow the below example:

    $ cd train_base
    $ python train_cifar10.py -e 300

After the training, the models should be saved in train_base/snapshots/xxx

Applying SSProbing on the pre-trained models

The following commands or examples are applicable if you have downloaded the pre-trained models and unzipped it in the repo directory. Otherwise, you could slightly modify the commands according your actual demand, e.g. the config path argument to your actual saved or trained model path.

    # create the res output directory
    mkdir res_dir/

Misclassification Detection

Example:

    python -u mis_detect.py -c snapshots/cifar10_resnet18/cifar10_resnet18_dp_baseline_epoch_299.pt -m mcp -t ./task_configs/rot4_trans5.yaml -sf cifar10_mis_res.txt -se 5

OOD Detection

Example:

    python -u ood_detect.py -c snapshots/cifar10_resnet18/cifar10_resnet18_dp_baseline_epoch_299.pt -m mcp -se 5 -sf cifar10_ood_res.text

Calibration

Example:

    python -u cal.py -c snapshots/cifar10_resnet18/cifar10_resnet18_dp_baseline_epoch_299.pt -sf cifar10_cal_res.txt

For more details, you could also refer to run_mis.sh / run_ood.sh / run_cal.sh.

Citation

If you find this repo or our work useful for your research, please consider citing the paper

@inproceedings{deng2022trust,
  title={Trust, but Verify: Using Self-supervised Probing to Improve Trustworthiness},
  author={Deng, Ailin and Li, Shen and Xiong, Miao and Chen, Zhirui and Hooi, Bryan},
  booktitle={European Conference on Computer Vision},
  pages={361--377},
  year={2022},
  organization={Springer}
}

Acknowledgement