Awesome
Born-Identity-Network
Tensorflow implementation of Born Identity Network: Multi-way Counterfactual Map Generation to Explain a Classifier's Decision.
Overall framework
- The goal of Born-Identity-Netwok (BIN) is to induce counterfactual reasoning dependent on the target condition from a pre-trained model.
- There are two major components of BIN: Counterfactual Map Generator (CMG) and Target Attribution Network (TAN).
- The CMG synthesized a counterfactual map conditioned on arbitrary target label, while the TAN work towards enforcing target label attributes to the synthesized map.
Results
Counterfactual visual explanations
Extra interpolation using 3D Shapes
Requirements
tensorflow (2.2.0)
tensorboard (2.2.2)
tensorflow-addons (0.11.0)
tqdm (4.48.0)
matplotlib (3.3.0)
numpy (1.19.0)
scikit-learn (0.23.2)
Datasets
Place them into "data_path" on each Config.py
- HandWritten digits data (MNIST)
- 3D Geometric shape data
- Alzheimer’s Disease Neuroimaging Initiative (ADNI)
How to run
Mode:
#0 Pre-training a classifier
#1 Training the counterfactual map generator
- Pre-training a classifier
training.py --mode=0
- Training the counterfactual map generator
- Set the classifier and encoder weight for training (freeze)
- Change the mode from 0 to 1 on Config.py
training.py --mode=1
Config.py of each dataset
data_path = Raw dataset path
save_path = Storage path to save results such as tensorboard event files, model weights, etc.
cls_weight_path = Pre-trained classifier weight path obtained in mode#0 setup
enc_weight_path = Pre-trained encoder weight path obtained in mode#0 setup