Home

Awesome

EvalProtoPNet

Introduction

This is the official implementation of paper "Evaluation and Improvement of Interpretability for Self-Explainable Part-Prototype Networks".

Required Packages

Our required python packages are listed as below:

pytorch==1.12.1+cu113
Augmentor==0.2.9
torchvision==0.13.1+cu113
pillow==8.4.0 (9.3.0 for train)
timm==0.5.4
opencv-python==4.6.0.66
tensorboard==2.9.1
scipy==1.8.1
pandas==1.4.3
matplotlib==3.5.2
scikit-learn==1.1.1
numpy==1.22.0
tqdm==4.66.1

Dataset Preparation

Model Weights

Most model weights will be downloaded automatically, except the ResNet50 backbone on iNaturalist2017. It (BBN.iNaturalist2017.res50.90epoch.best_model.pth) can be downloaded from here and put into the folder pretrained_models.

Training Instructions

Use scripts/train.sh for training:

sh scripts/train.sh $model $num_gpus

Here, $model is the name of backbone chosen from resnet34, resnet152, vgg19, densenet121, densenet161, $num_gpus is the number of GPUs for training (2 GPUs is recommended).

For example, the instruction for training a ResNet34 model with 2 GPUs is as below:

sh scripts/train.sh resnet34 2

Note that when running two scripts at the same time, the variable use_port in scripts/train.sh should be different for them.

Evaluate Consistency Score

The instruction for evaluating the consistency score of a ResNet34 model with checkpoint path $ckpt_path:

python eval_consistency.py \
--base_architecture resnet34 \
--resume $ckpt_path

Evaluate Stability Score

The instruction for evaluating the stability score of a ResNet34 model with checkpoint path $ckpt_path:

python eval_stability.py \
--base_architecture resnet34 \
--resume $ckpt_path

Visualization

The instruction for visualizing a ResNet34 model with checkpoint path $ckpt_path for the images in category $class:

python local_analysis_vis.py \
--base_architecture resnet34 \
--resume $ckpt_path \
--imgclass $class

For example, the instruction for evaluating a ResNet34 model with checkpoint path output_cosine/CUB2011/resnet34/1028-1e-4-adam-12-train/checkpoints/save_model.pth for the images in category 15:

python local_analysis_vis.py \
--base_architecture resnet34 \
--resume output_cosine/CUB2011/resnet34/1028-1e-4-adam-12-train/checkpoints/save_model.pth \
--imgclass 15

The results will be saved in output_view directory.