Awesome
On Feature Learning in the Presence of Spurious Correlations
This repository contains experiments for the NeurIPS 2022 paper On Feature Learning in the Presence of Spurious Correlations by Pavel Izmailov, Polina Kirichenko, Nate Gruver and Andrew Gordon Wilson.
Introduction
Deep classifiers are known to rely on spurious features — patterns which are correlated with the target on the training data but not inherently relevant to the learning problem, such as the image backgrounds when classifying the foregrounds. In this paper we evaluate the amount of information about the core (non-spurious) features that can be decoded from the representations learned by standard empirical risk minimization (ERM) and specialized group robustness training. Following recent work on Deep Feature Reweighting (DFR), we evaluate the feature representations by re-training the last layer of the model on a held-out set where the spurious correlation is broken. On multiple vision and NLP problems, we show that the features learned by simple ERM are highly competitive with the features learned by specialized group robustness methods targeted at reducing the effect of spurious correlations. Moreover, we show that the quality of learned feature representations is greatly affected by the design decisions beyond the training method, such as the model architecture and pre-training strategy. On the other hand, we find that strong regularization is not necessary for learning high quality feature representations. Finally, using insights from our analysis, we significantly improve upon the best results reported in the literature on the popular Waterbirds, CelebA hair color prediction and WILDS-FMOW problems, achieving 97%, 92% and 50% worst-group accuracies, respectively.
Please cite our paper if you find it helpful in your work:
@article{izmailov2022feature,
title={On Feature Learning in the Presence of Spurious Correlations},
author={Izmailov, Pavel and Kirichenko, Polina and Gruver, Nate and Wilson, Andrew Gordon},
journal={arXiv preprint arXiv:2210.11369},
year={2022}
}
File Structure
.
+-- data/
| +-- __init__.py
| +-- augmix_transforms.py (AugMix augmentation)
| +-- data_transforms.py (Preprocessing and data augmentation)
| +-- dataloaders.py (RWY and RWG samplers and MixUp)
| +-- datasets.py (Dataset classes)
+-- dataset_files/utils_glue.py (File to copy into the MultiNLI dataset directory)
+-- group_DRO/ (Group DRO codebase)
+-- models/
| +-- __init__.py (Most of the vision models)
| +-- preresnet.py (Preactivation ResNet; not used in the experiments)
| +-- text_models.py (BERT classifier model)
| +-- vissl_models.py (Contrastive models from vissl)
+-- optimizers/
| +-- __init__.py (SGD, AdamW and LR schedulers)
+-- utils/
| +-- __init__.py
| +-- common_utils.py (Common utilities used in different scripts)
| +-- logging_utils.py (Logging-related utilities)
| +-- supervised_utils.py (Utilities for supervised training)
+-- train_supervised.py (Train base models)
+-- dfr_evaluate_spurious.py (Tune and evaluate DFR for a given base model)
+-- dfr_evaluate_auroc.py (Tune and evaluate DFR on the CXR dataset)
Requirements
Data access
Waterbirds and CelebA
Please follow the instructions in the DFR repo to prepare the Waterbirds and CelebA datasets.
Civil Comments and MultiNLI
The Civil Comments dataset should be downloaded automatically when you run experiments, no manual preparation needed.
To run experiments on the MultiNLI dataset, please manually download and unzip the dataset from this link.
Further, please copy the dataset_files/utils_glue.py
to the root directory of the dataset.
WILDS-FMOW
To run experiments on the FMOW dataset, you first need to run wilds.get_dataset(dataset="fmow", download=False, root_dir=<ROOT DIR>)
from python console or in a jupyter notebook.
CXR
The chest drain labels for the CXR dataset are not publically available, so we cannot share the code for preparing this dataset.
Example commands
Waterbirds:
python3 train_supervised.py --output_dir=logs/waterbirds/erm_seed1 \
--num_epochs=100 --eval_freq=1 --save_freq=100 --seed=1 \
--weight_decay=1e-4 --batch_size=32 --init_lr=3e-3 \
--scheduler=cosine_lr_scheduler --data_dir=<DATA_DIR> \
--data_transform=AugWaterbirdsCelebATransform \
--dataset=SpuriousCorrelationDataset --model=imagenet_resnet50_pretrained
python3 dfr_evaluate_spurious.py --data_dir=<DATA_DIR> \
--data_transform=AugWaterbirdsCelebATransform \
--dataset=SpuriousCorrelationDataset --model=imagenet_resnet50_pretrained \
--ckpt_path=logs/waterbirds/erm_seed1/final_checkpoint.pt \
--result_path=wb_erm_seed1_dfr.pkl --predict_spurious
CelebA:
python3 train_supervised.py --output_dir=logs/celeba/erm_seed1 \
--num_epochs=20 --eval_freq=1 --save_freq=100 --seed=1 \
--weight_decay=1e-4 --batch_size=100 --init_lr=3e-3 \
--scheduler=cosine_lr_scheduler --data_dir=<DATA_DIR> \
--data_transform=AugWaterbirdsCelebATransform \
--dataset=SpuriousCorrelationDataset --model=imagenet_resnet50_pretrained
python3 dfr_evaluate_spurious.py --data_dir=<DATA_DIR> \
--data_transform=AugWaterbirdsCelebATransform \
--dataset=SpuriousCorrelationDataset --model=imagenet_resnet50_pretrained \
--ckpt_path=logs/celeba/erm_seed1/final_checkpoint.pt \
--result_path=celeba_erm_seed1_dfr.pkl --predict_spurious
FMOW
python3 train_supervised.py --output_dir=logs/fmow/erm_seed1 \
--num_epochs=20 --eval_freq=5 --save_freq=100 --seed=1 \
--weight_decay=1e-4 --batch_size=100 --init_lr=3e-3 \
--scheduler=cosine_lr_scheduler --data_dir=<DATA_DIR> \
--data_transform=AugWaterbirdsCelebATransform \
--dataset=WildsFMOW --model=imagenet_resnet50_pretrained
python3 dfr_evaluate_spurious.py --data_dir=<DATA_DIR> \
--data_transform=AugWaterbirdsCelebATransform \
--dataset=WildsFMOW --model=imagenet_resnet50_pretrained \
--ckpt_path=logs/fmow/erm_seed1/final_checkpoint.pt \
--result_path=fmow_erm_seed1_dfr.pkl --predict_spurious
CXR
python3 train_supervised.py --output_dir=logs/cxr/erm_seed1 \
--num_epochs=20 --eval_freq=5 --save_freq=100 --seed=1 \
--weight_decay=1e-4 --batch_size=100 --init_lr=3e-3 \
--scheduler=cosine_lr_scheduler --data_dir=<DATA_DIR> \
--data_transform=AugWaterbirdsCelebATransform \
--dataset=SpuriousCorrelationDataset \
--model=imagenet_densenet121_pretrained
python3 dfr_evaluate_auroc.py --data_dir=<DATA_DIR> \
--data_transform=AugWaterbirdsCelebATransform \
--dataset=WildsFMOW --model=imagenet_resnet50_pretrained \
--ckpt_path=logs/cxr/erm_seed1/final_checkpoint.pt \
--result_path=cxr_erm_seed1_dfr.pkl
Civil Comments
python3 train_supervised.py --output_dir=logs/civilcomments/erm_seed1 \
--num_epochs=10 --eval_freq=1 --save_freq=10 --seed=1 \
--weight_decay=1.e-4 --batch_size=16 --init_lr=1e-5 \
--scheduler=bert_lr_scheduler --data_dir=<DATA_DIR> \
--data_transform=BertTokenizeTransform \
--dataset=WildsCivilCommentsCoarse --model=bert_pretrained \
--optimizer=bert_adamw_optimizer
python3 dfr_evaluate_spurious.py --data_dir=<DATA_DIR> \
--data_transform=BertTokenizeTransform \
--dataset=WildsCivilCommentsCoarse --model=bert_pretrained \
--ckpt_path=logs/civilcomments/methods/erm_seed1/final_checkpoint.pt \
--result_path=civilcomments_methods_erm_seed1_dfr.pkl --predict_spurious
MultiNLI
python3 train_supervised.py --output_dir=logs/multinli/erm_seed1/ \
--num_epochs=10 --eval_freq=1 --save_freq=10 --seed=1 \
--weight_decay=1.e-4 --batch_size=16 --init_lr=1e-5 \
--scheduler=bert_lr_scheduler --data_dir=<DATA_DIR> \
--data_transform=None --dataset=MultiNLIDataset --model=bert_pretrained \
--optimizer=bert_adamw_optimizer
python3 dfr_evaluate_spurious.py --data_dir=<DATA_DIR> \
--data_transform=None --dataset=MultiNLIDataset --model=bert_pretrained \
--ckpt_path=logs/multinli/erm_seed1/final_checkpoint.pt \
--result_path=multinli_erm_seed1_dfr.pkl --predict_spurious
<DATA_DIR>
should be the root directory of the dataset, e.g. /datasets/fmow_v1.1/
. We provide example --output_dir
and --result_path
argument values, but you can change them to your convenience;
these arguments define the location of the logs and checkpoints, and the dfr results file respectively.
Other architectures and augmentations
You can specify the base model with the --model
flag.
For example:
- ResNet-50 pretrained on ImageNet:
--model=imagenet_resnet50_pretrained
- ResNet-50 initialized randomly:
--model=imagenet_resnet50
- ConvNext XLarge pretrained on ImageNet22k
--model=imagenet_convnext_xlarge_in22k_pretrained
- DenseNet-121 pretrained on ImageNet:
--model=imagenet_densenet121_pretrained
Note that for some of the models (some ConvNext models, MAE, DINO), you need to manually download the checkpoints and put them in the /ckpts/
directory, or, alternatively, change the paths in the models/__init__.py
file.
You can specify the data augmentation policy with the --data_transform
flag:
- No augmentation:
--data_transform=NoAugWaterbirdsCelebATransform
- Default augmentation:
--data_transform=AugWaterbirdsCelebATransform
- Random Erase augmentation:
--data_transform=ImageNetRandomErasingTransform
- AugMix augmentation:
--data_transform=ImageNetAugmixTransform
You can apply MixUp with any augmentation policy by using the --mixup
flag.
For a full list of models and augmentations available, run:
python3 train_supervised.py -h
RWY and RWG
To run the RWY method, add --reweight_classes
;
to run the RWG method, add --reweight_groups
.
You can then evaluate the saved model weights with DFR analogously to the commands for ERM training above.
Group-DRO
We modified the group-DRO code to be compatible with our model and dataset implementations.
You can run experiments with group-DRO with the standard group-DRO comands, as described in the group-DRO codebase, but adding the --dfr_data --dfr_model
flags.
With these flags, you can use all the models and datasets implemented in our repo.
For example, to run on the FMOW dataset, you can use the following commang:
python3 run_expt.py -s confounder -d FMOW --model imagenet_resnet50_pretrained \
-t target -c confounder \ # the values of these flags are irrelevant
--root_dir <DATA_DIR> --robust --save_best --save_last --save_step 200 \
--batch_size 100 --n_epochs 20 --gamma 0.1 --augment_data --lr 0.001 \
--weight_decay 0.001 --generalization_adjustment 0 --seed 1 \
--log_dir logs/fmow/gdro_seed1 --dfr_data --dfr_model
This command should be run from the group_DRO
folder.
You can then evaluate the saved model weights with DFR analogously to the commands for ERM training above.
Code References
-
We used the DFR codebase as the basis for our code. The DFR codebase in turn is based on the group-DRO codebase.
-
We also include a modified version of the Group-DRO code in the
group_DRO
folder; this folder is ported from the group-DRO codebase with minor modifications. -
Our model implementations are based on the
torchvision
,timm
andtransformers
packages. -
Our implementation of AugMix is ported from the AugMix repo.