Home

Awesome

A Data-Based Perspective on Transfer Learning.

This repository contains the code of our paper:

A Data-Based Perspective on Transfer Learning </br> Saachi Jain*, Hadi Salman*, Alaa Khaddaj*, Eric Wong, Sung Min Park, Aleksander Madry <br> Paper - Blog post

@article{jain2022data,
  title={A Data-Based Perspective on Transfer Learning},
  author={Jain, Saachi and Salman, Hadi and Khaddaj, Alaa and Wong, Eric and Park, Sung Min and Madry, Aleksander},
  journal={arXiv preprint arXiv:2207.05739},
  year={2022}
}

The major content of our repo are:

Getting started

Our code relies on the FFCV Library. To install this library along with other dependencies including PyTorch, follow the instructions below.

conda create -n ffcv python=3.9 cupy pkg-config compilers libjpeg-turbo opencv pytorch torchvision cudatoolkit=11.3 numba -c pytorch -c conda-forge 
conda activate ffcv
pip install ffcv

Full pipeline: Train source model and transfer to various downstream tasks

To train an ImageNet model and transfer it to all the datasets we consider in the paper, simply run:

python src/train_imagenet_class_subset.py \
                        --config-file configs/base_config.yaml \
                        --training.data_root $PATH_TO_DATASETS \
                        --out.output_pkl_dir $OUTDIR

where $OUTDIR is the output directory of your choice, and $PATH_TO_DATASETS is the path where the datasets exists (see below).

The config file configs/base_config.yaml contains all the hyperparameters needed for this experiment. For example, you can specify which downstream tasks you want to transfer to, or how many Imagenet class to train on the source model.

Calculating influences

Use analysis/data_compressors/2_20_compressor.py to compress model results into a summary file. Then use analysis/compute_influences.py to compute the influences. In a notebook, simply run the following code:

sf = <SUMMARY FILE FOLDER>
ds = compute_influences.SummaryFileDataSet(sf, dataset, INFLUENCE_KEY, keyword)
dl = torch.utils.data.DataLoader(ds, batch_size=1024, shuffle=False, drop_last=False)
infl = compute_influences.batch_calculate_influence(dl, len(val_labels), 1000, div=True)

Running counterfactual experiment

Once influences have been computed, we can now run counterfactual experiments by removing top or bottom influencing classes from the source dataset (ImageNet), and then applying transfer learning again. This can be done by running:

python src/counterfactuals_main.py\
            --config-file configs/base_config.yaml\
            --training.transfer_task ${TASK}\
            --out.output_pkl_dir ${OUT_DIR}\
            --counterfactual.cf_target_dataset ${DATASET}\
            --counterfactual.cf_infl_order_file ${INFL_ORDER_FILE} \
            --data.num_classes -1 \
            --counterfactual.cf_order TOP \
            --counterfactual.cf_num_classes_min ${MIN_STEPS} \
            --counterfactual.cf_num_classes_max ${MAX_STEPS} \
            --counterfactual.cf_num_classes_step ${STEP_SIZE} \
            --counterfactual.cf_type CLASS

Datasets that we use (see our paper for citations)

We have created an FFCV version of each of these datasets to enable super fast training. We will make these datasets available soon!

Download our data

Coming soon!

Download our pretrained models

Coming soon!

A detailed demo

Coming soon!