Home

Awesome

transformer-domain-generalization

This repository provides the official PyTorch implementation of the following paper:

A Re-Parameterized Vision Transformer (ReVT) for Domain-Generalized Semantic Segmentation<br>

Abstract: The task of semantic segmentation requires a model to assign semantic labels to each pixel of an image. However, the performance of such models degrades when deployed in an unseen domain with different data distributions compared to the training domain. We present a new augmentation-driven approach to domain generalization for semantic segmentation using a re-parameterized vision transformer (ReVT) with weight averaging of multiple models after training. We evaluate our approach on several benchmark datasets and achieve state-of-the-art mIoU performance of 47.3% (prior art: 46.3%) for small models and of 50.1% (prior art: 47.8%) for midsized models on commonly used benchmark datasets. At the same time, our method requires fewer parameters and reaches a higher frame rate than the best prior art. It is also easy to implement and, unlike network ensembles, does not add any computational complexity during inference.<br>

Getting Started

Our code is heavily derived from SegFormer (NeurIPS 2021). If you use this code in your research, please also cite their work.

Requirements

We ran our experiments on a GPU cluster driven by slurm. Therefore, all our experiment scripts are slurm scripts. Nevertheless, it should be possible to reproduce our results without slurm by only small changes.

Other requirements are:

Installation

Clone this repository into <User_Home_dir>/work/.

cd ~/
mkdir ./work
cd work
git clone revt-domain-generalization
cd revt-domain-generalization

Install all packages in the requirements.txt by calling the following command:

conda create --name transformer-domain-generalization --file ./requirements.txt

Datasets

We used the dataset structure from SegFormer, which is based on MMSegmentation v0.13.0. All datasets are implemented as a Custom Datasets, which have the following structure:

├── data
│   ├── my_dataset
│   │   ├── img_dir
│   │   │   ├── train
│   │   │   │   ├── xxx{img_suffix}
│   │   │   │   ├── yyy{img_suffix}
│   │   │   │   ├── zzz{img_suffix}
│   │   │   ├── val
│   │   ├── ann_dir
│   │   │   ├── train
│   │   │   │   ├── xxx{seg_map_suffix}
│   │   │   │   ├── yyy{seg_map_suffix}
│   │   │   │   ├── zzz{seg_map_suffix}
│   │   │   ├── val

For more detailed information please refer to MMSegmentation v0.13.0 on how to structure the datasets.

We use the following datasets:

Download Pretrained Weights on Imagenet

We used the pretrained weights offered by the SegFormer repository. Please download the weights and put them into the folder: ./pretrained

Usage

Reproduce Experiments with Provided Slurm Scripts:

All included slurm-scripts work out of the box by calling:

cd ~/work/revt-domain-generalization/<location-of-slurm-script>
sbatch ./<name-of-slurm-script>

The slurm scripts are all located in ./slurm_scripts_runs/.

The structure is as follows:

├── slurm_scripts_runs
│   ├── templates
│   │   ├── test.sh
│   │   ├── train.sh
│   │   ├── test_ensemble.sh
│   │   ├── test_reparam.sh
│   ├── <Model_Type1>
│   │   ├── <Train-DatasetA>
│   │   │   ├── test
│   │   │   │   ├── standard
│   │   │   │   │   ├── <Test_Method_I>
│   │   │   │   │   ├── ...
│   │   │   │   ├── ensemble
│   │   │   │   │   ├── Ensemble(1,1,1)
│   │   │   │   │   ├── Ensemble(1,4,6)
│   │   │   │   │   ├── ...
│   │   │   │   ├── re_param
│   │   │   │   │   ├── ReVT(1,1,1)
│   │   │   │   │   ├── ReVT(1,4,6)
│   │   │   │   │   ├── ...
│   │   │   ├── train
│   │   │   │   ├── <Method_I>
│   │   │   │   ├── <Method_II>
│   │   │   │   ├── ...
│   │   ├── <Train-DatasetB>
│   │   │   ├── ...
│   │   ├── ...
│   ├── <Model_Type2>
│   │   ├── <Train-DatasetA>
│   │   │   ├── ...
│   │   ├── ...
│   ├── ...

For example:

Use Reparameterization Script

The Script for applying ReVT on checkpoints is implemented in a generic way and should therefore work on most torch models. The only requirement is, that the checkpoints are at the top level dict objects, which have the key "state_dict". All weights in the "state_dict" that match the specified regex are averaged.

You can merge an arbitrary number of checkpoints by:

python ./tools/model_reparameterization <Destination_File> --checkpoints <list of checkpoint files> --cpu-only --weights-filter <Regex>

Example: The following example applies ReVT to three baseline models (A, B, C), to create a combined checkpoint file. All weights in the state_dict that match the regex "backbone.*" are averaged. All others weights are taken over by the first given checkpointfile (A)

python ./tools/model_reparameterization.py ./work_dir/ReVT_B5/BaselineABC.pth\
         --checkpoints\
            ./work_dir/SegFormerB5/gta_dev/BaselineA.pth\
            ./work_dir/SegFormerB5/gta_dev/BaselineB.pth\
            ./work_dir/SegFormerB5/gta_dev/BaselineC.pth\
         --cpu-only --weights-filter "backbone.*"

Trained weights

The weights can be obtained from the following links:

Model TypeTrain DatasetMethodmIoU CityscapesmIoU BDDLink
Segformer B2GTA train splitBaseline41.73%38.77%will be added soon
ReVT {1,4,6}46.27%43.29%will be added soon
ReVT {4,5,6}45.55%43.43%will be added soon
SYNTHIA train splitBaseline39.71%29.76%will be added soon
ReVT {1,4,6}40.91%34.53%will be added soon
ReVT {4,5,6}41.09%35.18%will be added soon
Segformer B3GTA train splitBaseline43.92%42.96%will be added soon
ReVT {1,4,6}48.33%48.17%will be added soon
ReVT {4,5,6}47.95%48.26%will be added soon
SYNTHIA train splitBaseline42.43%33.33%will be added soon
ReVT {1,4,6}44.97%38.65%will be added soon
ReVT {4,5,6}45.26%38.73%will be added soon
Segformer B5GTA train splitBaseline45.31%43.32%https://drive.google.com/drive/folders/1clnJptm58PrLEGooB2cyUpbdqU0_d8Oq?usp=sharing
ReVT {1,4,6}49.96%48.01%https://drive.google.com/drive/folders/1MuvBVyNveIxs5v011AdIp_i4dp5HU4sd?usp=sharing
ReVT {4,5,6}49.55%48.11%https://drive.google.com/drive/folders/1hhmNDwGRAd_F9Rb198L2pPaDn0rlf3Ec?usp=sharing
SYNTHIA train splitBaseline45.07%35.19%will be added soon
ReVT {1,4,6}46.28%40.30%will be added soon
ReVT {4,5,6}45.08%39.62%will be added soon