

Diffusion Model as Representation Learner

This repository contains the official implementation of the ICCV 2023 paper

Diffusion Model as Representation Learner Xingyi Yang, Xinchao Wang

[arxiv] [code]


In this paper, we conduct an in-depth investigation of the representation power of DPMs, and propose a novel knowledge transfer method that leverages the knowledge acquired by generative DPMs for recognition tasks. We introduce a novel knowledge transfer paradigm named RepFusion. Our paradigm extracts representations at different time steps from off-the-shelf DPMs and dynamically employs them as supervision for student networks, in which the optimal time is determined through reinforcement learning.

File Orgnizations

Basicly, we contain the code for distillation, the 3 downstream tasks including classification, segmentation, landmark

├── classification_distill/ 
    # code for image classification 
    # and knowledge distillation
    ├── configs/
        ├── <DATASET>-<DISTILL_LOSS>/
            # config file for Repfussion on <DATASET> 
            # with <DISTILL_LOSS> as loss function 
            # and <BACKBONE> as architecture
        ├── baseline/
    ├── mmcls/
        ├── models/
            ├── guided_diffusion/
                # code taken from the guided diffusion repo
            ├── classifiers/
                ├── kd.py
                    # distillation baselines
                ├── repfusion.py
                    # core code for distillation from diffusion model

├── landmark/
    # code for facial landmark detection 
    ├── configs/face/2d_kpt_sview_rgb_img/topdown_heatmap/wflw

├── segmentation/
    # code for face parsing
    ├── configs/
        ├── celebahq_mask/


We mainly depend on 4 packages, namely

  1. mmclassification. Please install the enviroment using INSTALL
  2. mmsegmentation. Please install the enviroment using INSTALL
  3. mmpose. Please install the enviroment using INSTALL
  4. diffusers. Install via pip install --upgrade diffusers[torch], or go to the official repo for help.

Data Preparation

We use 4 datasets in our paper. Please put them all under the data/<DATASET>

  1. CelabAMask-HQ, and please follow the guideline on official repo.
  2. WFLW. For WFLW data, please download images from WFLW Dataset. Please download the annotation files from wflw_annotations.
  3. TinyImageNet, please download dataset using this script.
  4. CIAFR10, mmcls will automatically download it for you.

Teacher Checkpoints


# <CONFIG_NAME>: config path for distillation 
# <GPU_NUMS>: num of gpus for training
cd classification_distill
bash tools/dist_train.sh <CONFIG_NAME> <GPU_NUMS>
model = dict(
                checkpoint=<CHECKPOINT_PATH> , 
                # Put the disilled checkpoint hear
# <CONFIG_NAME>: config path for distillation 
# <GPU_NUMS>: num of gpus for training
# <TASK_NAME>: either 'classification_distill', 'segmentation' or 'landmark'
bash tools/dist_train.sh <CONFIG_NAME> <GPU_NUMS>


    author    = {Xingyi Yang, Xinchao Wang},
    title     = {Diffusion Model as Representation Learner},
    journal   = {International Conference on Computer Vision (ICCV)},
    year      = {2023},