Home

Awesome

<img src="website/tesla.gif" width="40" height="40" style="vertical-align: bottom"/> <b>TeSLA: Test-Time Self-Learning With Automatic Adversarial Augmentation</b>

This repository contains official PyTorch implementation for CVPR 2023 paper TeSLA: Test-Time Self-Learning With Automatic Adversarial Augmentation by Devavrat Tomar, Guillaume Vray, Behzad Bozorgtabar, and Jean-Philippe Thiran.

[arxiv] [Project]

Abstract

Most recent test-time adaptation methods focus on only classification tasks, use specialized network architectures, destroy model calibration or rely on lightweight information from the source domain. To tackle these issues, this paper proposes a novel Test-time Self-Learning method with automatic Adversarial augmentation dubbed TeSLA for adapting a pre-trained source model to the unlabeled streaming test data. In contrast to conventional self-learning methods based on cross-entropy, we introduce a new test-time loss function through an implicitly tight connection with the mutual information and online knowledge distillation. Furthermore, we propose a learnable efficient adversarial augmentation module that further enhances online knowledge distillation by simulating high entropy augmented images. Our method achieves state-of-the-art classification and segmentation results on several benchmarks and types of domain shifts, particularly on challenging measurement shifts of medical images. TeSLA also benefits from several desirable properties compared to competing methods in terms of calibration, uncertainty metrics, insensitivity to model architectures, and source training strategies, all supported by extensive ablations.

Overview of TeSLA Framework

<img src="website/tesla_overview.svg">

(a) The student model <img src="website/svgs/deb18c89b908abf80bef809cbdcbae2d.svg#gh-light-mode-only" align=middle width=14.252356799999989pt height=22.831056599999986pt/><img src="website/svgs_dark/deb18c89b908abf80bef809cbdcbae2d.svg#gh-dark-mode-only" align=middle width=14.252356799999989pt height=22.831056599999986pt/> is adapted on the test images by minimizing the proposed test-time objective <img src="website/svgs/a8c95121d37068acdbc35e9975f50c86.svg#gh-light-mode-only" align=middle width=22.31974139999999pt height=22.465723500000017pt/><img src="website/svgs_dark/a8c95121d37068acdbc35e9975f50c86.svg#gh-dark-mode-only" align=middle width=22.31974139999999pt height=22.465723500000017pt/> . The high-quality soft-pseudo labels required by <img src="website/svgs/a8c95121d37068acdbc35e9975f50c86.svg#gh-light-mode-only" align=middle width=22.31974139999999pt height=22.465723500000017pt/><img src="website/svgs_dark/a8c95121d37068acdbc35e9975f50c86.svg#gh-dark-mode-only" align=middle width=22.31974139999999pt height=22.465723500000017pt/> are obtained from the exponentially weighted averaged teacher model <img src="website/svgs/5c7704963fa9ece758ae7def4b308098.svg#gh-light-mode-only" align=middle width=13.01377934999999pt height=22.831056599999986pt/><img src="website/svgs_dark/5c7704963fa9ece758ae7def4b308098.svg#gh-dark-mode-only" align=middle width=13.01377934999999pt height=22.831056599999986pt/> and refined using the proposed Soft-Pseudo Label Refinement (PLR) on the corresponding test images. The soft-pseudo labels are further utilized for teacher-student knowledge distillation via <img src="website/svgs/9ca5d7ed36b5da46a0cde6b76ae0a92a.svg#gh-light-mode-only" align=middle width=25.50469679999999pt height=22.465723500000017pt/><img src="website/svgs_dark/9ca5d7ed36b5da46a0cde6b76ae0a92a.svg#gh-dark-mode-only" align=middle width=25.50469679999999pt height=22.465723500000017pt/> on the adversarially augmented views of the test images. (b) The adversarial augmentations are obtained by applying learned sub-policies sampled i.i.d from <img src="website/svgs/865a2c771b7419b8742c1a4a04cc5584.svg#gh-light-mode-only" align=middle width=10.045686749999991pt height=22.648391699999998pt/> <img src="website/svgs_dark/865a2c771b7419b8742c1a4a04cc5584.svg#gh-dark-mode-only" align=middle width=10.045686749999991pt height=22.648391699999998pt/> using the probability distribution <img src="website/svgs/df5a289587a2f0247a5b97c1e8ac58ca.svg#gh-light-mode-only" align=middle width=12.83677559999999pt height=22.465723500000017pt/><img src="website/svgs_dark/df5a289587a2f0247a5b97c1e8ac58ca.svg#gh-dark-mode-only" align=middle width=12.83677559999999pt height=22.465723500000017pt/> with their corresponding magnitudes selected from <img src="website/svgs/fb97d38bcc19230b0acd442e17db879c.svg#gh-light-mode-only" align=middle width=17.73973739999999pt height=22.465723500000017pt/><img src="website/svgs_dark/fb97d38bcc19230b0acd442e17db879c.svg#gh-dark-mode-only" align=middle width=17.73973739999999pt height=22.465723500000017pt/>. The parameters <img src="website/svgs/fb97d38bcc19230b0acd442e17db879c.svg#gh-light-mode-only" align=middle width=17.73973739999999pt height=22.465723500000017pt/><img src="website/svgs_dark/fb97d38bcc19230b0acd442e17db879c.svg#gh-dark-mode-only" align=middle width=17.73973739999999pt height=22.465723500000017pt/> and <img src="website/svgs/df5a289587a2f0247a5b97c1e8ac58ca.svg#gh-light-mode-only" align=middle width=12.83677559999999pt height=22.465723500000017pt/><img src="website/svgs_dark/df5a289587a2f0247a5b97c1e8ac58ca.svg#gh-dark-mode-only" align=middle width=12.83677559999999pt height=22.465723500000017pt/> of the augmentation module are updated by the unbiased gradient estimator of the loss <img src="website/svgs/10b6ebc26c060d3fcbcc764955f8476f.svg#gh-light-mode-only" align=middle width=35.03099654999999pt height=22.465723500000017pt/><img src="website/svgs_dark/10b6ebc26c060d3fcbcc764955f8476f.svg#gh-dark-mode-only" align=middle width=35.03099654999999pt height=22.465723500000017pt/> computed on the augmented test images.

Requirements

Fist install Anaconda (Python >= 3.8) using this link. Create the following CONDA environment by running the following command:

conda create --name TeSLA python=3.8
conda activate TeSLA
conda install pip
pip install -r requirements.txt

Activate the TeSLA environment as:

conda activate TeSLA

Datasets Download Links

Dataset NameDownload LinkExtract to Relative Path
CIFAR-10Cclick here../Datasets/cifar_dataset/CIFAR-10-C/
CIFAR-100Cclick here../Datasets/cifar_dataset/CIFAR-100-C/
ImageNet-Cclick here../Datasets/imagenet_dataset/
VisDA-Cclick here../Datasets/visda_dataset
Katherclick here../Datasets/Kather/kather2016
VisDA-Sclick here../Datasets/visda_segmentation_dataset
(MRI) Spinal Cordclick here../Datasets/MRI/SpinalCord
(MRI) Prostateclick here../Datasets/MRI/Prostate

Pre-trained Source Models Links

Classification Task

Dataset NameDownload LinkExtract to Relative Path
CIFAR-10click here../Source_classifiers/cifar10
CIFAR-100click here../Source_classifiers/cifar100
ImageNetPyTorch Default
VisDA-Cclick here../Source_classifier/VisDA
Katherclick here../Source_classifier/Kather

Segmentation Task

Dataset NameDownload LinkExtract to Relative Path
VisDA-Sclick here../Source_Segmentation/VisDA/
MRI (Spinal Cord and Prostate)click here../Source_Segmentation/MRI/
<!-- ## **Code for training source models from scratch** The above pre-trained source models can be obtained using the code available at: https://github.com/devavratTomar/tesla_appendix -->

Examples of adapting source models using TeSLA

Classification task on CIFAR, ImageNet, VisDA, and Kather datasets for online and offline adaptation:

(1) Common Image Corruptions: CIFAR-10C

bash scripts_classification/online/cifar10.sh
bash scripts_classification/offline/cifar10.sh

(2) Common Image Corruptions: CIFAR-100C

bash scripts_classification/online/cifar100.sh
bash scripts_classification/offline/cifar100.sh

(3) Common Image Corruptions: ImageNet-C

bash scripts_classification/online/imagenet.sh
bash scripts_classification/offline/imagenet.sh

(4) Synthetic to Real Adaptation: VisDA-C

bash scripts_classification/online/visdac.sh
bash scripts_classification/offline/visdac.sh

(5) Medical Measurement Shifts: Kather

bash scripts_classification/online/kather.sh
bash scripts_classification/offline/kather.sh

Segmentation task on VisDA-S and MRI datasets for online and offline adaptation:

(1) GTA5 to CityScapes

bash scripts_segmentation/online/cityscapes.sh
bash scripts_segmentation/offline/cityscapes.sh

(2) Domain shifts of MRI

bash scripts_segmentation/online/spinalcord.sh
bash scripts_segmentation/offline/prostate.sh

Citation

If you find our work useful, please consider citing:

@inproceedings{tomar2023TeSLA,
  title={TeSLA: Test-Time Self-Learning With Automatic Adversarial Augmentation},
  author={Tomar, Devavrat and Vray, Guillaume and Bozorgtabar, Behzad and Thiran, Jean-Philippe},
  booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (CVPR)},
  year={2023}
}