Home

Awesome

Diverse Weight Averaging for Out-of-Distribution Generalization, NeurIPS 2022

Official PyTorch implementation of DiWA | paper, openreview

Alexandre Ramé, Matthieu Kirchmeyer, Thibaud Rahier, Alain Rakotomamonjy, Patrick Gallinari, Matthieu Cord

TL;DR

To improve out-of-distribution generalization, we average diverse weights obtained from different training runs; this strategy is motivated by an extension of the bias-variance theory to weight averaging and is state-of-the-art on DomainBed.

diwa

Abstract

Standard neural networks struggle to generalize under distribution shifts. For out-of-distribution generalization in computer vision, the best current approach averages the weights along a training run. In this paper, we propose Diverse Weight Averaging (DiWA) that makes a simple change to this strategy: DiWA averages the weights obtained from several independent training runs rather than from a single run. Perhaps surprisingly, averaging these weights performs well under soft constraints despite the network's nonlinearities. The main motivation behind DiWA is to increase the functional diversity across averaged models. Indeed, models obtained from different runs are more diverse than those collected along a single run thanks to differences in hyperparameters and training procedures. We motivate the need for diversity by a new bias-variance-covariance-locality decomposition of the expected error, exploiting similarities between DiWA and standard functional ensembling. Moreover, this decomposition highlights that DiWA succeeds when the variance term dominates, which we show happens when the marginal distribution changes at test time. Experimentally, DiWA consistently improves the state of the art on the competitive DomainBed benchmark without inference overhead.

DomainBed

Our code is adapted from the open-source DomainBed github, which is a PyTorch benchmark including datasets and algorithms for Out-of-Distribution generalization. It was introduced in In Search of Lost Domain Generalization, ICLR 2021.

In addition to the newly-added domainbed/scripts/diwa.py and domainbed/algorithms_inference.py files, we made only few modifications to this codebase, all preceded by ## DiWA ##.

Then you should be able to reproduce our main experiment (Table 1) on the DomainBed benchmark.

Requirements

Datasets

We ran DiWA on the following datasets:

You can download the datasets with following command:

python3 -m domainbed.scripts.download --data_dir=/my/data/dir

DiWA Procedure Details

Our training procedure is in three stages.

Set the initialization

First, we need to fix the initialization.

python3 -m domainbed.scripts.train\
       --data_dir=/my/data/dir/\
       --algorithm ERM\
       --dataset OfficeHome\
       --test_env ${test_env}\
       --init_step\
       --path_for_init ${path_for_init}\
       --steps ${steps}\

In the paper, we proposed $2$ initialization procedures:

The initialization is then saved at ${path_for_init}, to be used in the subsequent sweep.

Launch ERM training

Second, we launch several ERM runs following the hyperparameter distributions from here, as defined in Table 5 from Appendix F.1. To do so, we leverage the native sweep script from DomainBed.

python -m domainbed.scripts.sweep launch\
       --data_dir=/my/data/dir/\
       --output_dir=/my/sweep/output/path\
       --command_launcher multi_gpu\
       --datasets OfficeHome\
       --test_env ${test_env}\
       --path_for_init ${path_for_init}\
       --algorithms ERM\
       --n_hparams 20\
       --n_trials 3

Average the diverse weights

Finally, we average the weights obtained from this grid search.

python -m domainbed.scripts.diwa\
       --data_dir=/my/data/dir/\
       --output_dir=/my/sweep/output/path\
       --dataset OfficeHome\
       --test_env ${test_env}\
       --weight_selection ${weight_selection}
       --trial_seed ${trial_seed}

In the paper, we proposed $3$ different procedures:

Weight averaging from a single run

You can reproduce the Moving Average (MA) baseline by replacing ERM by MA as the algorithm argument.

python -m domainbed.scripts.sweep launch\
       --data_dir=/my/data/dir/\
       --output_dir=/my/sweep/output/path\
       --command_launcher multi_gpu\
       --datasets OfficeHome\
       --test_env ${test_env}\
       --algorithms MA\
       --n_hparams 20\
       --n_trials 3

Then to view the results of your sweep:

python -m domainbed.scripts.collect_results --input_dir=/my/sweep/output/path

Results

DiWA sets a new state of the art on DomainBed.

AlgorithmWeight selectionInitPACSVLCSOfficeHomeTerraIncDomainNetAvg
ERMN/ARandom85.577.566.546.140.963.3
CoralN/ARandom86.278.868.747.641.564.6
SWADOverfit-awareRandom88.179.170.650.046.566.9
MAUniformRandom87.578.270.650.346.066.5
---------------------------
ERMN/ARandom85.577.667.448.344.164.6
DiWARestrictedRandom87.979.270.550.546.767.0
DiWAUniformRandom88.879.171.048.946.166.8
DiWA$^{\dagger}$UniformRandom89.079.471.649.046.367.1
---------------------------
ERMN/ALP85.978.169.450.444.365.6
DiWARestrictedLP88.078.571.551.647.767.5
DiWAUniformLP88.778.472.151.447.467.6
DiWA$^{\dagger}$UniformLP89.078.672.851.947.768.0

Citation

If you find this code useful for your research, please consider citing our work:

@inproceedings{rame2022diwa,
  title   = {Diverse Weight Averaging for Out-of-Distribution Generalization},
  author  = {Rame, Alexandre and Kirchmeyer, Matthieu and Rahier, Thibaud and Rakotomamonjy, Alain and Gallinari, Patrick and Cord, Matthieu},
  year    = {2022},
  booktitle = {NeurIPS}
}

Correspondence to alexandre.rame at sorbonne-universite dot fr