Awesome
Domain Generalization via Gradient Surgery
This repository contains the source code corresponding to the paper "Domain Generalization via Gradient Surgery" (ICCV 2021). You can check out our paper here: https://arxiv.org/abs/2108.01621.
<p align="center"><img width="550" src="grad_surgery.png"></p>Instructions
This project uses Python 3.8.10 and PyTorch 1.10.0.
Data:
- Download the PACS (Li et al., 2017), VLCS (Fang et al., 2013) and Office-Home (Venkateswara et al., 2017) datasets and put them in
data/raw/
. - Resize images and generate training, validation and test splits. Run
./00_prepare_data.sh
after installing the project environment (instructions below).
Project environment:
- Create and activate virtual environment: 1)
python3 -m venv env
, 2)source env/bin/activate
- Install required packages:
pip install -r requirements.txt
- Install project modules (src):
pip install -e .
Simulations:
To run simulations across all datasets (PACS, VLCS and Office-Home) and methods (Deep-All, Agr-Sum, Agr-Rand and PCGrad), execute ./01_run_trials.sh
.
If you want to run a particular combination of dataset and method, use the train_model.py
script. For example, the following instruction:
python scripts/train_model.py \
--data_dir=data/processed \
--results_dir=results/train \
--dataset=PACS \
--method=deep-all
will run Deep-All on PACS and save the results in results/train
.
Reference
- Mansilla, L., Echeveste, R., Milone, D. H., & Ferrante, E. (2021). Domain generalization via gradient surgery. In Proceedings of the IEEE/CVF International Conference on Computer Vision (pp. 6630-6638).