Home

Awesome

Label-Efficient Semantic Segmentation with Diffusion Models

ICLR'2022 [Project page]

Official implementation of the paper Label-Efficient Semantic Segmentation with Diffusion Models

This code is based on datasetGAN and guided-diffusion.

Note: use --recurse-submodules when clone.

 

Overview

The paper investigates the representations learned by the state-of-the-art DDPMs and shows that they capture high-level semantic information valuable for downstream vision tasks. We design a simple semantic segmentation approach that exploits these representations and outperforms the alternatives in the few-shot operating point.

<div align="center"> <img width="100%" alt="DDPM-based Segmentation" src="https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/figs/new_ddpm_seg_scheme.png"> </div>

 

Updates

3/9/2022:

  1. Improved performance of DDPM-based segmentation by changing:
      Diffusion steps: [50,150,250,350] --> [50,150,250];
      UNet blocks: [6,7,8,9] --> [5,6,7,8,12];
  2. Trained a bit better DDPM on FFHQ-256;
  3. Added MAE for comparison.

 

Datasets

The evaluation is performed on 6 collected datasets with a few annotated images in the training set: Bedroom-18, FFHQ-34, Cat-15, Horse-21, CelebA-19 and ADE-Bedroom-30. The number corresponds to the number of semantic classes.

datasets.tar.gz (~47Mb)

 

DDPM

Pretrained DDPMs

The models trained on LSUN are adopted from guided-diffusion. FFHQ-256 is trained by ourselves using the same model parameters as for the LSUN models.

LSUN-Bedroom: lsun_bedroom.pt
FFHQ-256: ffhq.pt (Updated 3/8/2022)
LSUN-Cat: lsun_cat.pt
LSUN-Horse: lsun_horse.pt

Run

  1. Download the datasets:
      bash datasets/download_datasets.sh
  2. Download the DDPM checkpoint:
       bash checkpoints/ddpm/download_checkpoint.sh <checkpoint_name>
  3. Check paths in experiments/<dataset_name>/ddpm.json
  4. Run: bash scripts/ddpm/train_interpreter.sh <dataset_name>

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30

Note: train_interpreter.sh is RAM consuming since it keeps all training pixel representations in memory. For ex, it requires ~210Gb for 50 training images of 256x256. (See issue)

Pretrained pixel classifiers and test predictions are here.

How to improve the performance

 

DatasetDDPM

Synthetic datasets

To download DDPM-produced synthetic datasets (50000 samples, ~7Gb) (updated 3/8/2022):
bash synthetic-datasets/ddpm/download_synthetic_dataset.sh <dataset_name>

Run | Option #1

  1. Download the synthetic dataset:
       bash synthetic-datasets/ddpm/download_synthetic_dataset.sh <dataset_name>
  2. Check paths in experiments/<dataset_name>/datasetDDPM.json
  3. Run: bash scripts/datasetDDPM/train_deeplab.sh <dataset_name>

Run | Option #2

  1. Download the datasets:
       bash datasets/download_datasets.sh

  2. Download the DDPM checkpoint:
       bash checkpoints/ddpm/download_checkpoint.sh <checkpoint_name>

  3. Check paths in experiments/<dataset_name>/datasetDDPM.json

  4. Train an interpreter on a few DDPM-produced annotated samples:
       bash scripts/datasetDDPM/train_interpreter.sh <dataset_name>

  5. Generate a synthetic dataset:
       bash scripts/datasetDDPM/generate_dataset.sh <dataset_name>
        Please specify the hyperparameters in this script for the available resources.
        On 8xA100 80Gb, it takes about 12 hours to generate 10000 samples.

  6. Run: bash scripts/datasetDDPM/train_deeplab.sh <dataset_name>
       One needs to specify the path to the generated data. See comments in the script.

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21

 

MAE

Pretrained MAEs

We pretrain MAE models using the official implementation on the LSUN and FFHQ-256 datasets:

LSUN-Bedroom: lsun_bedroom.pth
FFHQ-256: ffhq.pth
LSUN-Cat: lsun_cat.pth
LSUN-Horse: lsun_horse.pth

Training setups:

DatasetBackboneepochsbatch-sizemask-ratio
LSUN BedroomViT-L-815010240.75
LSUN CatViT-L-820010240.75
LSUN HorseViT-L-820010240.75
FFHQ-256ViT-L-840010240.75

Run

  1. Download the datasets:
       bash datasets/download_datasets.sh
  2. Download the MAE checkpoint:
       bash checkpoints/mae/download_checkpoint.sh <checkpoint_name>
  3. Check paths in experiments/<dataset_name>/mae.json
  4. Run: bash scripts/mae/train_interpreter.sh <dataset_name>

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30

 

SwAV

Pretrained SwAVs

We pretrain SwAV models using the official implementation on the LSUN and FFHQ-256 datasets:

LSUN-BedroomFFHQ-256LSUN-CatLSUN-Horse
SwAVSwAVSwAVSwAV
SwAVw2SwAVw2SwAVw2SwAVw2

Training setups:

DatasetBackboneepochsbatch-sizemulti-cropnum-prototypes
LSUNRN5020017922x256 + 6x1081000
FFHQ-256RN5040020482x224 + 6x96200
LSUNRN50w220019202x256 + 4x1081000
FFHQ-256RN50w240020482x224 + 4x96200

Run

  1. Download the datasets:
       bash datasets/download_datasets.sh
  2. Download the SwAV checkpoint:
       bash checkpoints/{swav|swav_w2}/download_checkpoint.sh <checkpoint_name>
  3. Check paths in experiments/<dataset_name>/{swav|swav_w2}.json
  4. Run: bash scripts/{swav|swav_w2}/train_interpreter.sh <dataset_name>

Available checkpoint names: lsun_bedroom, ffhq, lsun_cat, lsun_horse
Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30

 

DatasetGAN

Opposed to the official implementation, more recent StyleGAN2(-ADA) models are used.

Synthetic datasets

To download GAN-produced synthetic datasets (50000 samples):

bash synthetic-datasets/gan/download_synthetic_dataset.sh <dataset_name>

Run

Since we almost fully adopt the official implementation, we don't provide our reimplementation here. However, one can still reproduce our results:

  1. Download the synthetic dataset:
      bash synthetic-datasets/gan/download_synthetic_dataset.sh <dataset_name>
  2. Change paths in experiments/<dataset_name>/datasetDDPM.json
  3. Change paths and run: bash scripts/datasetDDPM/train_deeplab.sh <dataset_name>

Available dataset names: bedroom_28, ffhq_34, cat_15, horse_21

 

Results

MethodBedroom-28FFHQ-34Cat-15Horse-21CelebA-19ADE-Bedroom-30
ALAE20.0 ± 1.048.1 ± 1.3----49.7 ± 0.715.0 ± 0.5
VDVAE--57.3 ± 1.1----54.1 ± 1.0--
GAN Inversion13.9 ± 0.651.7 ± 0.821.4 ± 1.717.7 ± 0.451.5 ± 2.311.1 ± 0.2
GAN Encoder22.4 ± 1.653.9 ± 1.332.0 ± 1.826.7 ± 0.753.9 ± 0.815.7 ± 0.3
SwAV41.0 ± 2.354.7 ± 1.444.1 ± 2.151.7 ± 0.553.2 ± 1.030.3 ± 1.5
SwAVw242.4 ± 1.756.9 ± 1.345.1 ± 2.154.0 ± 0.952.4 ± 1.330.6 ± 1.0
MAE45.0 ± 2.058.8 ± 1.152.4 ± 2.363.4 ± 1.457.8 ± 0.431.7 ± 1.8
DatasetGAN31.3 ± 2.757.0 ± 1.036.5 ± 2.345.4 ± 1.4----
DatasetDDPM47.9 ± 2.956.0 ± 0.947.6 ± 1.560.8 ± 1.0----
DDPM49.4 ± 1.959.1 ± 1.453.7 ± 3.365.0 ± 0.859.9 ± 1.034.6 ± 1.7

 

<div> <img width="100%" alt="DDPM-based Segmentation" src="https://storage.yandexcloud.net/yandex-research/ddpm-segmentation/figs/examples.png"> </div>

 

Cite

@misc{baranchuk2021labelefficient,
      title={Label-Efficient Semantic Segmentation with Diffusion Models}, 
      author={Dmitry Baranchuk and Ivan Rubachev and Andrey Voynov and Valentin Khrulkov and Artem Babenko},
      year={2021},
      eprint={2112.03126},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}