Awesome
MASF
<p align="left"> <img src="assets/method.png" width="85%" height="85%"> </p>Domain Generalization via Model-Agnostic Learning of Semantic Features
We study the challenging problem of domain generalization, i.e., training a model on multi-domain source data such that it can directly generalize to unseen target domains. We adopta model-agnostic learning paradigm with gradient-based meta-train and meta-testprocedures to expose the optimization to domain shift. Further, we introduce two complementary losses which explicitly regularize the semantic structure ofthe feature space. Globally, we align a derived soft confusion matrix to preservegeneral knowledge about inter-class relationships. Locally, we promote domain-independent class-specific cohesion and separation of sample features with ametric-learning component.
This is the reference implementation of the domain generalization method described in our paper:
@inproceedings{dou2019domain,
author = {Qi Dou and Daniel C. Castro and Konstantinos Kamnitsas and Ben Glocker},
title = {Domain Generalization via Model-Agnostic Learning of Semantic Features},
booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
year = {2019},
}
If you make use of the code, please cite the paper in any resulting publications.
Setup
Check dependencies in requirements.txt, and necessarily run
pip install -r requirements.txt
Running MASF
Download PACS dataset from here, put in dataroot /path/to/PACS_dataset, put the .txt files in '/path/to/image/filelist' <br> Download the ImageNet pretrained AlexNet weights bvlc_alexnet.npy from here. <br> To run masf with target domain as art_painting
python main.py --dataset pacs --target_domain art_painting --inner_lr 1e-5 --outer_lr 1e-5 --metric_lr 1e-5 --margin 20
Monitoring training with Tensorboard
Tensorboard logs of losses and gradients are stored in /log/, to observe it run
tensorboard --logdir {/log/}
Running on medical data
To run on medical dataset, replace functions of construct_alexnet_weights() and forward_alexnex() to construct_unet_weights() and forward_unet() demoed in the medical folder