Home

Awesome

MIDMs: Matching Interleaved Diffusion Models for Exemplar-based Image Translation

<a href="https://arxiv.org/abs/2209.11047"><img src="https://img.shields.io/badge/arXiv-2209.11047-b31b1b.svg"></a>

<!-- ## [[Project Page]](https://3dgan-inversion.github.io./) -->

Official PyTorch implementation of the AAAI 2023 paper

Junyoung Seo*, Gyuseong Lee*, Seokju Cho, Jiyoung Lee, Seungryong Kim,

*equal contribution

1

For more information, check out the paper on Arxiv or Project page

Preparation

Environmental Settings

Clone the Synchronized-BatchNorm-PyTorch repository.

cd models/
git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm .
cd ..

And, Download the weight of VQ-autoencoder(f=4, VQ, Z=8192, d=3) here and move model.ckpt and config.yaml to models/vq-f4.

After that, please install dependencies.

conda env create -f environment.yml
conda activate midms

Also, if you already have an LDM or Stable Diffusion Models environment, you can use it as well.

Pretrained Models

We provide finetuned model on CelebA-HQ(edge-to-face). Download the weight here.

Put the weights as followings:

└── weights

    └── celeba
    
        └── midms_celebA_finetuned.pth
    
        └── pretrained
    
            └── config.yaml
        
            └── model.ckpt

Datasets

For the datasets, we used the train and validation set provided by CoCosNet, which can be downloaded from here.

Inference

Prepare the validation dataset as speicified above, and run inference.py, e.g.,

python inference.py --benchmark celebahqedge --inference_mode target_fixed --pick 11

where pick is index of condition image (e.g., sketch). If you want to evaluate the model using the validation set, change the value of inference_mode from target_fixed to evaluation.

Training

Before starting fine-tuning for MIDMs, we first pretrain LDM on the desired dataset following here, or alternatively, the pretrained weights can be obtained from the model zoo.

Additionally, pretrained VGG model is required. Please download from the Training section of CoCosNet repository, and move it to models/. We used 8 NVIDIA RTX 3090s for finetuning, and it took an average of 5-12 hours per dataset.

Run train.py like:

torchrun --standalone --nproc_per_node=<NUM_GPU> train.py \
    --benchmark celebahqedge \
    --diffusion_config_path "weights/celeba/pretrained/config.yaml" \
    --diffusion_model_path "weights/celeba/pretrained/model.ckpt" \
    --phase e2e_recurrent --dataroot "/downloaded/dataset/folder" \
    --batch-size <BATCH_SIZE> \
    --snapshots "/path/to/save/results" --warmup_iter 10000

TIP

We discovered that the number of warm-up iterations and the number of training epochs are important when fine-tuning. If training for too long, collapse can occur. In addition, by adjusting the scaling factor of perceptual loss and style loss, the trade-off can be reduced. Finally, the training code is not yet well organized. It is currently being organized and if you encounter any errors or difficulties in implementation, please feel free to contact us.

Acknowledgement

This code implementation is heavily borrowed from the official implementation of LDM and CoCosNet. We are deeply grateful for all of the projects.

Bibtex

@article{seo2022midms,
  title={MIDMs: Matching Interleaved Diffusion Models for Exemplar-based Image Translation},
  author={Seo, Junyoung and Lee, Gyuseong and Cho, Seokju and Lee, Jiyoung and Kim, Seungryong},
  journal={arXiv preprint arXiv:2209.11047},
  year={2022}
}