Home

Awesome

Investigating and Simplifying Masking-based Saliency Methods for Model Interpretability

This repository contains code for running and replicating the experiments from Investigating and Simplifying Masking-based Saliency Methods for Model Interpretability. It is a modified fork of Classifier-Agnostic Saliency Map Extraction, and contains the code originally forked from the ImageNet training in PyTorch.

<p align="center"> <br> <img src="./saliency_overview.png"/> <br> <p>

(A) Overview of the training setup for our final model. The masker is trained to maximize masked-in classification accuracy and masked-out prediction entropy.
(B) Masker architecture. The masker takes as input the hidden activations of different layers of the ResNet-50 and produces a mask of the same resolution as the input image.
(C) Few-shot training of masker. Performance drops only slightly when trained on much fewer examples compared to the full training procedure.

Software requirements

pytorch==1.4.0
torchvision==0.5.0
opencv-python==4.1.2.30
beautifulsoup4==4.8.1
tqdm==4.35.0
pandas==0.24.2
scikit-learn==0.20.2
scipy==1.3.0 

In addition, git clone https://github.com/zphang/zutils and add it to your PYTHONPATH

Additional requirements

Data requirements

Running the code

We will assume that experiments will be run in the following folder:

export EXP_DIR=/path/to/experiments

Data Preparation

To facilitate easy subsetting and label shuffling for the ImageNet training set, we write a JSON files containing the paths to the example images, and their corresponding labels. These will be consumed by a modified ImageNet PyTorch Dataset.

Run the following command:

python casme/tasks/imagenet/preproc.py \
    --train_path ${IMAGENET_PATH}/train \
    --val_path ${IMAGENET_PATH}/val \
    --val_annotation_path ${IMAGENET_ANN}/val \
    --output_base_path ${EXP_DIR}/metadata

This script does several things:

Training

To train a FIX or CA model, you can run:

python train_casme.py \
    --train_json ${EXP_DIR}/metadata/train.json \
    --val_json ${EXP_DIR}/metadata/val.json \
    --ZZsrc ./assets/fix.json \
    --masker_use_layers 3,4 \
    --output_path ${EXP_DIR}/runs/ \
    --epochs 60 --lrde 20 \
    --name fix

python train_casme.py \
    --train_json ${EXP_DIR}/metadata/train.json \
    --val_json ${EXP_DIR}/metadata/val.json \
    --ZZsrc ./assets/ca.json \
    --masker_use_layers 3,4 \
    --output_path ${EXP_DIR}/runs/ \
    --epochs 60 --lrde 20 \
    --name ca

Evaluation

To evaluate the model on WSOL metrics and Saliency Metric, run:

python casme/tasks/imagenet/score_bboxes.py \
    --val_json ${EXP_DIR}/metadata/val.json \
    --mode casme \
    --bboxes_path ${EXP_DIR}/metadata/val_bboxes.json \
    --casm_path ${EXP_DIR}/runs/ca/epoch_XXX.chk \
    --output_path ${EXP_DIR}/runs/ca/metrics/scores.json

where epoch_XXX.chk corresponds to the model checkpoint you want to evaluate. Chain the val_json and bboxes_path paths to evaluate on the Train-Validation or Validation sets respectively. Note that the mode should be casme regardless of whether you are using FIX or CA models.

The output JSON looks something like this:

{
  "F1": 0.6201832851563015,
  "F1a": 0.5816041554785251,
  "OM": 0.48426,
  "LE": 0.35752,
  "SM": 0.523097248590095,
  "SM1": -0.5532185246243142,
  "SM2": -1.076315772478443,
  "top1": 75.222,
  "top5": 92.488,
  "sm_acc": 74.124,
  "binarized": 0.4486632848739624,
  "avg_mask": 0.44638757080078123,
  "std_mask": 0.1815464876794815,
  "entropy": 0.034756517103545534,
  "tv": 0.006838996527194977
}

To evaluate the model on PxAP, run:

python casme/tasks/imagenet/wsoleval.py \
    --cam_loader casme \
    --casm_base_path ${EXP_DIR}/runs/ca/epoch_XXX.chk \
    --casme_load_mode specific \
    --dataset OpenImages \
    --dataset_split test \
    --dataset_path ${WSOLEVAL_PATH}/dataset \
    --metadata_path ${WSOLEVAL_PATH}/metadata \
    --output_base_path ${EXP_DIR}/runs/ca/metrics/scores.json

where WSOLEVAL_PATH is the location where wsolevaluation has been cloned to, and after running the relevant dataset downloading scripts.

Pretrained Checkpoints

Reference

If you found this code useful, please cite the following paper:

Jason Phang, Jungkyu Park, Krzysztof J. Geras "Investigating and Simplifying Masking-based Saliency Methods for Model Interpretability." arXiv preprint arXiv:2010.09750 (2020).

@article{phang2020investigating,
  title={Investigating and Simplifying Masking-based Saliency Methods for Model Interpretability},
  author={Phang, Jason and Park, Jungkyu and Geras, Krzysztof J},
  journal={arXiv preprint arXiv:2010.09750},
  year={2020}
}