Awesome
Diverse Weight Averaging for Out-of-Distribution Generalization, NeurIPS 2022
Official PyTorch implementation of DiWA | paper, openreview
Alexandre Ramé, Matthieu Kirchmeyer, Thibaud Rahier, Alain Rakotomamonjy, Patrick Gallinari, Matthieu Cord
TL;DR
To improve out-of-distribution generalization, we average diverse weights obtained from different training runs; this strategy is motivated by an extension of the bias-variance theory to weight averaging and is state-of-the-art on DomainBed.
Abstract
Standard neural networks struggle to generalize under distribution shifts. For out-of-distribution generalization in computer vision, the best current approach averages the weights along a training run. In this paper, we propose Diverse Weight Averaging (DiWA) that makes a simple change to this strategy: DiWA averages the weights obtained from several independent training runs rather than from a single run. Perhaps surprisingly, averaging these weights performs well under soft constraints despite the network's nonlinearities. The main motivation behind DiWA is to increase the functional diversity across averaged models. Indeed, models obtained from different runs are more diverse than those collected along a single run thanks to differences in hyperparameters and training procedures. We motivate the need for diversity by a new bias-variance-covariance-locality decomposition of the expected error, exploiting similarities between DiWA and standard functional ensembling. Moreover, this decomposition highlights that DiWA succeeds when the variance term dominates, which we show happens when the marginal distribution changes at test time. Experimentally, DiWA consistently improves the state of the art on the competitive DomainBed benchmark without inference overhead.
DomainBed
Our code is adapted from the open-source DomainBed github, which is a PyTorch benchmark including datasets and algorithms for Out-of-Distribution generalization. It was introduced in In Search of Lost Domain Generalization, ICLR 2021.
In addition to the newly-added domainbed/scripts/diwa.py
and domainbed/algorithms_inference.py
files, we made only few modifications to this codebase, all preceded by ## DiWA ##
.
- in
domainbed/hparams_registry.py
, to define our mild hyperparameter ranges. - in
domainbed/train.py
, to handle the shared initialization and save the weights of the epoch with the highest validation accuracy. - in
domainbed/algorithms.py
, to handle the shared initialization, the linear probing approach and implement the MA baseline. - in
domainbed/datasets.py
, to define the checkpoint frequency. - in
domainbed/scripts/sweep.py
, to be able to force the test env. - in
domainbed/lib/misc.py
, to include some tools.
Then you should be able to reproduce our main experiment (Table 1) on the DomainBed benchmark.
Requirements
- python == 3.7.10
- torch == 1.8.1
- torchvision == 0.9.1
- numpy == 1.20.2
Datasets
We ran DiWA on the following datasets:
- VLCS (Fang et al., 2013)
- PACS (Li et al., 2017)
- OfficeHome (Venkateswara et al., 2017)
- A TerraIncognita (Beery et al., 2018) subset
- DomainNet (Peng et al., 2019)
- Colored MNIST (Arjovsky et al., 2019)
You can download the datasets with following command:
python3 -m domainbed.scripts.download --data_dir=/my/data/dir
DiWA Procedure Details
Our training procedure is in three stages.
Set the initialization
First, we need to fix the initialization.
python3 -m domainbed.scripts.train\
--data_dir=/my/data/dir/\
--algorithm ERM\
--dataset OfficeHome\
--test_env ${test_env}\
--init_step\
--path_for_init ${path_for_init}\
--steps ${steps}\
In the paper, we proposed $2$ initialization procedures:
- random initialization, set
steps
to-1
: there will be no training. - Linear Probing, ICLR2022, set
steps
to0
: only the classifier will be trained.
The initialization is then saved at ${path_for_init}
, to be used in the subsequent sweep.
Launch ERM training
Second, we launch several ERM runs following the hyperparameter distributions from here, as defined in Table 5 from Appendix F.1.
To do so, we leverage the native sweep
script from DomainBed.
python -m domainbed.scripts.sweep launch\
--data_dir=/my/data/dir/\
--output_dir=/my/sweep/output/path\
--command_launcher multi_gpu\
--datasets OfficeHome\
--test_env ${test_env}\
--path_for_init ${path_for_init}\
--algorithms ERM\
--n_hparams 20\
--n_trials 3
Average the diverse weights
Finally, we average the weights obtained from this grid search.
python -m domainbed.scripts.diwa\
--data_dir=/my/data/dir/\
--output_dir=/my/sweep/output/path\
--dataset OfficeHome\
--test_env ${test_env}\
--weight_selection ${weight_selection}
--trial_seed ${trial_seed}
In the paper, we proposed $3$ different procedures:
- DiWA-restricted, set
weight_selection
torestricted
andtrial_seed
to an integer between0
and2
. - DiWA-uniform, set
weight_selection
touniform
andtrial_seed
to an integer between0
and2
. - DiWA$^\dagger$-uniform, set
weight_selection
touniform
andtrial_seed
to-1
.
Weight averaging from a single run
You can reproduce the Moving Average (MA) baseline by replacing ERM by MA as the algorithm argument.
python -m domainbed.scripts.sweep launch\
--data_dir=/my/data/dir/\
--output_dir=/my/sweep/output/path\
--command_launcher multi_gpu\
--datasets OfficeHome\
--test_env ${test_env}\
--algorithms MA\
--n_hparams 20\
--n_trials 3
Then to view the results of your sweep:
python -m domainbed.scripts.collect_results --input_dir=/my/sweep/output/path
Results
DiWA sets a new state of the art on DomainBed.
Algorithm | Weight selection | Init | PACS | VLCS | OfficeHome | TerraInc | DomainNet | Avg |
---|---|---|---|---|---|---|---|---|
ERM | N/A | Random | 85.5 | 77.5 | 66.5 | 46.1 | 40.9 | 63.3 |
Coral | N/A | Random | 86.2 | 78.8 | 68.7 | 47.6 | 41.5 | 64.6 |
SWAD | Overfit-aware | Random | 88.1 | 79.1 | 70.6 | 50.0 | 46.5 | 66.9 |
MA | Uniform | Random | 87.5 | 78.2 | 70.6 | 50.3 | 46.0 | 66.5 |
--- | --- | --- | --- | --- | --- | --- | --- | --- |
ERM | N/A | Random | 85.5 | 77.6 | 67.4 | 48.3 | 44.1 | 64.6 |
DiWA | Restricted | Random | 87.9 | 79.2 | 70.5 | 50.5 | 46.7 | 67.0 |
DiWA | Uniform | Random | 88.8 | 79.1 | 71.0 | 48.9 | 46.1 | 66.8 |
DiWA$^{\dagger}$ | Uniform | Random | 89.0 | 79.4 | 71.6 | 49.0 | 46.3 | 67.1 |
--- | --- | --- | --- | --- | --- | --- | --- | --- |
ERM | N/A | LP | 85.9 | 78.1 | 69.4 | 50.4 | 44.3 | 65.6 |
DiWA | Restricted | LP | 88.0 | 78.5 | 71.5 | 51.6 | 47.7 | 67.5 |
DiWA | Uniform | LP | 88.7 | 78.4 | 72.1 | 51.4 | 47.4 | 67.6 |
DiWA$^{\dagger}$ | Uniform | LP | 89.0 | 78.6 | 72.8 | 51.9 | 47.7 | 68.0 |
Citation
If you find this code useful for your research, please consider citing our work:
@inproceedings{rame2022diwa,
title = {Diverse Weight Averaging for Out-of-Distribution Generalization},
author = {Rame, Alexandre and Kirchmeyer, Matthieu and Rahier, Thibaud and Rakotomamonjy, Alain and Gallinari, Patrick and Cord, Matthieu},
year = {2022},
booktitle = {NeurIPS}
}
Correspondence to alexandre.rame at sorbonne-universite dot fr