Home

Awesome

nurd-code-public

Companion code for the paper Out-of-distribution Generalization in the Presence of Nuisance-Induced Spurious Correlations, ICLR 2022 (ArXiv).

Email aahlad@nyu.edu to get the datasets used in the paper. You can also generate them using save_processed_dataset_as_pt.py which takes data and stores them as pytorch tensors to make training very fast. If you want to use the raw data, please make a dataset with the functionality of the class XYZ_DatasetWithIndices in the dataloaders.py file.

example run for the chest X-ray experiment with default distillation

time python  nurd_reweighting.py --prefix=DUMMY --dataset=joint --workers=0 --dist_epochs=10 --nr_epochs=15 --nr_batch_size=1000 --dist_batch_size=1000 --seed=500 --rho=0.9 --img_side=32 --label_balance_method=downsample --rho_test=0.9 --border=6 --pred_model_type=small --weight_model_type=small --num_folds=2 --theta_lr=0.001 --gamma_lr=0.001 --phi_lr=0.001 --nr_lr=0.001 --dist_decay=0.0 --phi_decay=0.0 --nr_decay=0.01 --nr_strategy=weight --debug=2

example run for waterbirds with specified distillation

python nurd_reweighting.py --prefix=DUMMY --dataset=waterbirds --workers=0 --dist_epochs=1 --nr_epochs=2 --nr_batch_size=300 --dist_batch_size=300 --seed=500 --rho=0.9 --img_side=224 --label_balance_method=downsample --rho_test=0.9 --border=7 --pred_model_type=resnet_color --weight_model_type=resnet_color --num_folds=2 --theta_lr=0.001 --gamma_lr=0.0005 --phi_lr=0.001 --nr_lr=0.001 --dist_decay=0.01 --phi_decay=0.01 --nr_decay=0.01 --nr_strategy=weight --debug=2 --add_pred_suffix=_LAM1_FRAC2_RR1 --load_weights --lambda_=1 --max_lambda_=1 --frac_phi_steps=2 --randomrestart=1

Use --nr_only to learn only the weight model. Weights are saved by default and can be loaded using --load_weights.

required folders

Create the following folders in the directory for running the scripts.

LOGS/
SAVED_DATA/
SAVED_MODELS/
cub/

cub/ is the directory to put the waterbirds data.