Home

Awesome

<p align="center"> <h1 align="center">Diffusion Prior-Based Amortized Variational Inference for Noisy Inverse Problems</h1> <p align="center">Sojin Lee*, Dogyun Park*, Inho Kong, Hyunwoo J. Kim†. </p> <h2 align="center">ECCV 2024 Oral</h2> <h3 align="center"> <a href="https://mlvlab.github.io/DAVI-project/" target='_blank'><img src="https://img.shields.io/badge/🐳-Project%20Page-blue"></a> <a href="https://www.arxiv.org/pdf/2407.16125" target='_blank'><img src="https://img.shields.io/badge/arXiv-2407.16125-b31b1b.svg"></a> </h3> </p>

This repository contains the official PyTorch implementation of DAVI: Diffusion Prior-Based Amortized Variational Inference for Noisy Inverse Problems accepted at ECCV 2024 as an oral presentation.

Our framework allows efficient posterior sampling with a single evaluation of a neural network, and enables generalization to both seen and unseen measurements without the need for test-time optimization. We provide five image restoration tasks (Gaussian deblur, 4x Super-resolution, Box inpainting, Denoising, and Colorization) with two benchmark datasets (FFHQ and ImageNet).

<div align="center"> <img src="asset/main.png" width="700px" /> </div>

Setting

Please follow these steps to set up the repository.

1. Clone the Repository

git clone https://github.com/mlvlab/DAVI.git
cd DAVI

2. Install Environment

conda create -n DAVI python==3.8
conda activate DAVI
conda install pytorch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 pytorch-cuda=12.1 -c pytorch -c nvidia
pip install accelerate ema_pytorch matplotlib piq scikit-image pytorch-fid wandb

3. Download Pre-trained models and Official Checkpoints

We utilize pre-trained models from FFHQ (ffhq_10m.pt) and ImageNet (256x256_diffusion_uncond.pt) obtained from DPS and guided_diffusion, respectively.

4. Prepare Data

For amortized optimization, we use the FFHQ 49K dataset and the ImageNet 130K dataset, which are subsets of the training datasets used for the pre-trained models. These subsets are distinct from the validation datasets (ffhq_1K and imagenet_val_1K) used for evaluation.

Overall directory

├── results
│
├── models
│ ├── ffhq_10m.pt # FFHQ for training
│ ├── 256x256_diffusion_uncond.pt # ImageNet for training
│ └── official_ckpt # For Evaluation
│     ├── ffhq
│     │   ├── gaussian_ema.pt
│     │   ├── sr_averagepooling_ema.pt
│     │   ├── ...
│     │   ├── ...
│     ├── imagenet
│     │   ├── gaussian_ema.pt
│     │   ├── sr_averagepooling_ema.pt
│     │   ├── ...
│     └── └── ...
│
├── data # including training set and evaluation set
│ ├── ffhq_1K # FFHQ evluation
│ ├── imagenet_val_1K # ImageNet evluation
│ ├── ffhq_49K # FFHQ training
│ ├── imagenet_130K # ImageNet training
│ └── y_npy
│         ├── ffhq_1k_npy
│         │   ├── gaussian
│         │   ├── sr_averagepooling
│         │   ├── ...
│         │   └── ...
│         ├── imagenet_val_1k_npy
│         │   ├── gaussian
│         │   ├── sr_averagepooling
│         │   ├── ...
└─────────└── └── ...

Evaluation

1. Restore degraded images

accelerate launch --num_processes=1 eval.py --eval_dir data/ffhq_1K --deg gaussian --perturb_h 0.1 --ckpt model/official_ckpt/ffhq/gaussian_ema.pt

2. Evaluate PSNR,LPIPS and FID

Train with MultiGPU

accelerate launch --multi_gpu --num_processes=4 train.py --data_dir data/ffhq_49K/ --model_path model/ffhq_10m.pt --deg gaussian --t_ikl 400 --weight_con 0.5 --reg_coeff 0.25 --perturb_h 0.1