Awesome
ResAD: A Simple Framework for Class Generalizable Anomaly Detection (NeurIPS Spotlight, 2024)
PyTorch implementation for NeurIPS 2024 spotlight paper, ResAD: A Simple Framework for Class Generalizable Anomaly Detection (https://arxiv.org/abs/2410.20047).
<img src="./figures/framework.jpg" width="800">Intuitive illustration of class-generalizable anomaly detection and conceptual illustration of residual feature learning:
<img src="./figures/motivation.jpg" width="800">Update: The code in this repository is not exactly the same with the method in the NeurIPS conference paper. We further improved the method and open-source the improved code, which can achieve better performance and is also more robust. For more details, please see our journal version paper (due to some review reasons, the journal version paper is not currently open for access).
Installation
Install all packages (the same version with ours) by the following command:
$ pip3 install -r requirements.txt
Download Few-Shot Reference Samples
First, You need to download the few-shot reference normal samples. Please download the few-shot normal reference samples from Data and put the data in the ./data
directory.
Download Datasets
Please download MVTecAD dataset from MVTecAD dataset, VisA dataset from VisA dataset, BTAD dataset from BTAD dataset, and MVTec3D dataset from MVTec3D dataset, MPDD dataset from MPDD dataset, MVTecLOCO dataset from MVTecLOCO dataset, BraTS dataset from BraTS dataset.
Creating Reference Features
Please run the following code for extracting reference features used during testing as reference.
# For MVTecAD
python extract_ref_features.py --dataset mvtec --few_shot_dir ./data/4shot/mvtec --save_dir ./ref_features/w50/mvtec_4shot
# For BTAD
python extract_ref_features.py --dataset btad --few_shot_dir ./data/4shot/btad --save_dir ./ref_features/w50/btad_4shot
# For VisA
python extract_ref_features.py --dataset visa --few_shot_dir ./data/4shot/visa --save_dir ./ref_features/w50/visa_4shot
# For MVTec3D
python extract_ref_features.py --dataset mvtec3d --few_shot_dir ./data/4shot/mvtec3d --save_dir ./ref_features/w50/mvtec3d_4shot
Training and Evaluating
In this repository, we use wide_resnet50
as the feature extractor by default.
- Run code for training and evaluating MVTecAD
python main.py --setting visa_to_mvtec --train_dataset_dir /path/to/your/dataset --test_dataset_dir /path/to/your/dataset --test_ref_feature_dir ./ref_features/w50/mvtec_4shot --num_ref_shot 4 --device cuda:0
- Run code for training and evaluating BTAD
python main.py --setting mvtec_to_btad --train_dataset_dir /path/to/your/dataset --test_dataset_dir /path/to/your/dataset --test_ref_feature_dir ./ref_features/w50/btad_4shot --num_ref_shot 4 --device cuda:0
- Run code for training and evaluating VisA
python main.py --setting mvtec_to_visa --train_dataset_dir /path/to/your/dataset --test_dataset_dir /path/to/your/dataset --test_ref_feature_dir ./ref_features/w50/visa_4shot --num_ref_shot 4 --device cuda:0
- Run code for training and evaluating MVTec3D
python main.py --setting mvtec_to_mvtec3d --train_dataset_dir /path/to/your/dataset --test_dataset_dir /path/to/your/dataset --test_ref_feature_dir ./ref_features/w50/mvtec3d_4shot --num_ref_shot 4 --device cuda:0
Please note that the --num_ref_shot
should be less than or equal to 4, as we only extract reference features with 4-shot reference samples.
Normally, you can obtain the following results (under 4-shot setting; please run sufficiently, about 60 epochs):
Dataset | Image AUC | Pixel AUC |
---|---|---|
MVTecAD | 91.0 | 96.0 |
VisA | 89.4 | 96.8 |
BTAD | 94.7 | 97.2 |
MVTec3D | 72.0 | 97.7 |
We also provide a script main_all.py
for testing 2, 4, and 8 shot settings simultaneously. You need to extract reference features with 8-shot samples in the ./data/8shot directory. The command is:
# Extract reference features
python extract_ref_features.py --dataset mvtec --few_shot_dir ./data/8shot/mvtec --save_dir ./ref_features/w50/mvtec_8shot
# Training and evaluating
python main_all.py --setting visa_to_mvtec --train_dataset_dir /path/to/your/dataset --test_dataset_dir /path/to/your/dataset --test_ref_feature_dir ./ref_features/w50/mvtec_8shot --device cuda:0
Download ImageBind Checkpoint
For Imagebind as feature extractor, you can download the pre-trained ImageBind model from this link. After downloading, please put the downloaded file (imagebind_huge.pth) in ./pretrained_weights/imagebind/
directory. For creating reference features, please replace the main() function with main2() in extract_ref_features.py script, the running code is similar. For training and evaluating, please replace main.py with main_ib.py in above running command.
- Run code for training and evaluating MVTecAD
python main_ib.py --setting visa_to_mvtec --train_dataset_dir /path/to/your/dataset --test_dataset_dir /path/to/your/dataset --test_ref_feature_dir ./ref_features/ib/mvtec_4shot --num_ref_shot 4 --device cuda:0
Citation
If you find this repository useful, please consider citing our work:
@article{ResAD,
title={ResAD: A Simple Framework for Class Generalizable Anomaly Detection},
author={Xincheng Yao and Zixin Chen and Gao Chao and Guangtao Zhai and Chongyang Zhang},
year={2024},
booktitle={Thirty-Eighth Annual Conference on Neural Information Processing Systems, NeurIPS 2024},
url={https://arxiv.org/abs/2410.20047},
primaryClass={cs.CV}
}
If you are interested in our work, you can also follow our previous works: BGAD (CVPR2023), PMAD (AAAI2023), FOD (ICCV2023), HGAD (ECCV2024). Or, you can follow our github page xcyao00.