Awesome
Your Out-of-Distribution Detection Method is Not Robust!
This repository contains the code for the paper "Your Out-of-Distribution Detection Method is Not Robust!". Out-of-distribution (OOD) detection has recently gained substantial attention due to the importance of identifying out-of-domain samples in reliability and safety. Although OOD detection methods have advanced by a great deal, they are still susceptible to adversarial examples. To mitigate this issue, we propose the Adversarially Trained Discriminator (ATD), which utilizes a pre-trained robust model to extract robust features, and a generator model to create OOD samples. This method could significantly outperform previous methods.
Illustration
Experimental Results
<p align="center" > <img src="images/performance.png" /> </p>Preliminaries
It is tested under Ubuntu Linux 20.04.3 LTS and Python 3.8.10 environment, and requires some packages to be installed:
Downloading In-distribution Datasets
- CIFAR: Included in TorchVision.
- TinyImageNet: Download and extract to
data/tiny-imagenet-200
folder.
Downloading Auxiliary OOD Datasets
Downloading Out-of-distribution Test Datasets
links and instructions to download each dataset are provided below:
- MNIST: Included in PyTorch.
- TinyImageNet: Download and extract to
data/tiny-imagenet-200
folder. - Places365: Download and extract to
data/val_256
. These two files (places365_val.txt, categories_places365.txt) should also be downloaded and put in thedata
folder. - LSUN: Download and extract to
data/LSUN_resize
. - iSUN: Download and extract to
data/iSUN
. - Birds: Download and extract to
data/images
. - Flowers: Download and extract to
data/flowers
. You should add another folder in this directory and move at least one image to it to avoid TorchVision error. - COIL-100: Download and extract to
data/coil
.
For example, run the following commands in the root directory to download LSUN:
cd data
wget https://www.dropbox.com/s/moqh2wh8696c3yl/LSUN_resize.tar.gz
tar -xvzf LSUN_resize.tar.gz
Downloading Pre-Trained Models
We provide pre-trained ATD models on CIFAR-10, CIFAR-100, and TinyImageNet as the in-distribution datasets, which can be downloaded from Google Drive.
To test the checkpoint for TinyImageNet, you should also download weights-best-TI.pt
from the above Google Drive folder to the models
folder.
Overview of the Code
Training Option and Description
The options for the training and evaluation code is as follows:
run_name
: This is used in checkpoint name.model_type
: {pix
: Feature extractor not used.fea
: Feature extractor is added to the model. }training_type
: {clean
: Standard training.adv
: Adversarial training. }in_dataset
:{cifar10
,cifar100
,TI
}. CIFAR-10, CIFAR-100, and TinyImageNet are considered as in-distribution datasets.alpha
: α coefficient in the equation 9.batch_size
: Batch size.num_epochs
: Number of training epochs.eps
: Attack perturbation budget.attack_iters
: Number of iterations in PGD attack.seed
: Seed used to make code behaviour deterministic.out_datasets
: OOD Datasets used for evaluation.
ATD Training
python train_ATD.py --run_name cifar10 --model_type "fea" --training_type "adv" --in_dataset cifar10 --alpha 0.5 --batch_size 128 --num_epochs 20 --eps 0.0313 --attack_iters 10
ATD Evaluation
python test_ATD.py --run_name cifar10 --model_type "fea" --in_dataset cifar10 --batch_size 128 --eps 0.0313 --attack_iters 100 --out_datasets 'mnist' 'tiny_imagenet' 'LSUN'
Acknowledgements
Part of this code is inspired by OpenGAN, ATOM, RobustBench, RobustOverfitting, and HAT.
Citation
Please cite our work if you use the codebase:
@inproceedings{
azizmalayeri2022your,
title={Your Out-of-Distribution Detection Method is Not Robust!},
author={Mohammad Azizmalayeri and Arshia Soltani Moakar and Arman Zarei and Reihaneh Zohrabi and Mohammad Taghi Manzuri and Mohammad Hossein Rohban},
booktitle={Advances in Neural Information Processing Systems},
editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
year={2022},
url={https://openreview.net/forum?id=YUEP3ZmkL1}
}
License
Please refer to the LICENSE.