Home

Awesome

Benchmarking Test-Time Adaptation against Distribution Shifts in Image Classification

Prerequisites

To use the repository, we provide a conda environment.

conda update conda
conda env create -f environment.yaml
conda activate Benchmark_TTA 

Structure of Project

This project contains several directories. Their roles are listed as follows:

Run

This repository allows to study a wide range of different datasets, models, settings, and methods. A quick overview is given below:

Get Started

To run one of the following benchmarks, the corresponding datasets need to be downloaded.

Next, specify the root folder for all datasets _C.DATA_DIR = "./data" in the file conf.py.

The best parameters for each method and dataset are save in ./best_cfgs

download the ckpt of pretrained models and data load sequences from here and put it in ./ckpt

How to reproduce

The entry file for SHOT, NRC, PLUE to run is SFDA-eva.sh

To evaluate this methods, modify the DATASET and METHOD in SFDA-eva.sh

and then

bash SFDA-eva.sh

The entry file for other algorithms is test-time-eva.sh

To evaluate this methods, modify the DATASET and METHOD in test-time-eva.sh

and then

bash test-time-eva.sh

Add your own algorithm, dataset and model

We decouple the loading of datasets, models, and methods. So you can add them to our benchmarks completely independently.

To add a algorithm

  1. You can add a python files Algorithm_XX.py for your algorithm in ./src/methods/

  2. Add the setup process function of your algorithm setup_XX(model, cfg) in function ./src/methods/setup.py.

  3. Add two line of your setup code in line 22 on ./test-time.py like

        elif cfg.MODEL.ADAPTATION == "XX":
            model, param_names = setup_XX(base_model, cfg)
    

To add a dataset

  1. Write a function load_dataset_name() to load your dataset Dataset_new in ./src/data/data.py

  2. Define the transforms used to load your dataset on function get_transform() in ./src/data/data.py

  3. Add two line to load your dataset in function load_dataset() in ./src/data/data.py like

        elif dataset == 'dataset_name':
            return load_dataset_name(root=root, batch_size=batch_size, workers=workers, split=split, transforms=transforms,
                                 ckpt=ckpt)
    

To add a model

  1. Just add the code for loading your model in load_model() function in ./src/model/load_model.py like

        elif model_name == 'model_new':
            model =# the code for loading your model
    

You can cite our work by

@article{yu2023benchmarking,
  title={Benchmarking test-time adaptation against distribution shifts in image classification},
  author={Yu, Yongcan and Sheng, Lijun and He, Ran and Liang, Jian},
  journal={arXiv preprint arXiv:2307.03133},
  year={2023}
}

Acknowledgements