Home

Awesome

<p align="center">Differentially Private Diffusion Models <br><br> TMLR 2023</p>

<div align="center"> <a href="https://timudk.github.io/" target="_blank">Tim&nbsp;Dockhorn</a> &emsp; <b>&middot;</b> &emsp; <a href="https://orcid.org/0000-0001-6579-6044" target="_blank">Tianshi&nbsp;Cao</a> &emsp; <b>&middot;</b> &emsp; <a href="http://latentspace.cc/" target="_blank">Arash&nbsp;Vahdat</a> &emsp; <b>&middot;</b> &emsp; <a href="https://karstenkreis.github.io/" target="_blank">Karsten&nbsp;Kreis</a> <br> <br> <a href="https://arxiv.org/abs/2210.09929" target="_blank">Paper</a> &emsp; <a href="https://nv-tlabs.github.io/DPDM/" target="_blank">Project&nbsp;Page</a> </div> <br><br>

Requirements

DPDM is built using PyTorch 1.11.0 and CUDA 11.3. Please use the following command to install the requirements:

pip install -r requirements.txt 

Pretrained checkpoints

We provide pre-trained checkpoints for all models presented in the paper here. You can sample from the checkpoint using the following command:

python main.py --mode eval --workdir <new_directory> --config <config_file> -- model.ckpt=<checkpoint_path>
DatasetResolutionEpsilonConfigCheckpointClass-conditional
MNIST280.2configs/mnist_28/sample_eps_0.2.yamlmnist_v_0.2.pthYes
MNIST281.0configs/mnist_28/sample_eps_1.0.yamlmnist_edm_1.0.pthYes
MNIST2810.0configs/mnist_28/sample_eps_10.0.yamlmnist_edm_1.0.pthYes
Fashion-MNIST280.2configs/fmnist_28/sample_eps_0.2.yamlmnist_v_0.2.pthYes
Fashion-MNIST281.0configs/fmnist_28/sample_eps_1.0.yamlmnist_edm_1.0.pthYes
Fashion-MNIST2810.0configs/fmnist_28/sample_eps_10.0.yamlmnist_edm_10.0.pthYes
CelebA321.0configs/celeba_32/sample_eps_1.0.yamlceleba_edm_1.0.pthNo
CelebA3210.0configs/celeba_32/sample_eps_10.0.yamlceleba_edm_10.0.pthNo
CelebA6410.0configs/celeba_64/sample_eps_10.0.yamlceleba64_edm_10.0.pthNo
CIFAR-103210.0configs/cifar10_32/sample_eps_10.0.yamlcifar10_edm_10.0.pthYes
ImageNet3210.0configs/imagenet_32/sample_eps_10.0.yamlimagenet_edm_10.0.pthYes

 

By default, the above command generates 16 samples using DDIM with 50 steps on a single GPU. You can modify any entry of the config file directly through the command line, e.g., you may append sampler.num_steps=20 to the above command to change the number of DDIM steps from 50 (default value) to 20.

For class-conditional sampling, you can set sampler.labels to a single integer between 0 and num_classes - 1 (inclusive) to generate all samples from the same class. By default, the configs sets sampler.labels=num_classes which performs random sampling.

All class-conditional models are trained for classifier-free guidance by dropping out the label 10% of the time. By default, the guidance strength is set to 0 but you may change it by setting sampler.guid_scale.

To reproduce the results in the paper, you may adjust the sampler settings accordingly. For example, to obtain the best FID value for MNIST, eps=1.0, the sampler setting is (Table 13 & Table 14):

python main.py --mode eval --workdir <new_directory> --config configs/mnist_28/sample_eps_1.0.yaml -- model.ckpt=<checkpoint_path> sampler.type=edm sampler.s_churn=100. sampler.s_min=0.05 sampler.s_max=50. sampler.num_steps=1000

Training your own models

Data preparations

First, create the following two folders:

mkdir -p data/raw/
mkdir -p data/processed/

Afterwards, run the following commands to download and prepare the data used for training. MNIST and FashionMNIST are downloaded and processed automatically.

<details><summary>CIFAR-10</summary>
wget -P data/raw/ https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
python dataset_tool.py --source data/raw/cifar-10-python.tar.gz --dest data/processed/cifar10.zip
</details> <details><summary>ImageNet</summary>

First download the ImageNet Object Localization Challenge, then run the following

python dataset_tool.py --source==data/raw/imagenet/ILSVRC/Data/CLS-LOC/train --dest=data/processed/imagenet.zip --resolution=32x32 --transform=center-crop
</details>

FID evaluation

Before training, you should compute FID statistics of the data:

python precompute_data_mnist_fid_statistics.py
python precompute_data_mnist_fid_statistics.py --test
python precompute_data_mnist_fid_statistics.py --is_fmnist
python precompute_data_mnist_fid_statistics.py --is_fmnist --test
python compute_fid_statistics.py --path data/processed/cifar10.zip --file cifar10.npz 
python compute_fid_statistics.py --path data/processed/imagenet.zip --file imagenet.npz
python compute_fid_statistics.py --path data/processed/celeba.zip --file celeba.npz

DPDM training

We provide configurations to reproduce our models here. You may use the following command for training:

python main.py --mode train --workdir <new_directory> --config <dataset>

Our models are trained on eight NVIDIA A100 (80Gb) GPUs; you can set the number of GPUs per node via setup.n_gpus_per_node. We use large batch sizes which are split into several iterations; to reduce the required GPU memory, you may increase the flag train.n_splits (by a multiple of 2).

Citation

If you find the provided code or checkpoints useful for your research, please consider citing our paper:

@article{dockhorn2022differentially,
    title={{Differentially Private Diffusion Models}},
    author={Tim Dockhorn and Tianshi Cao and Arash Vahdat and Karsten Kreis},
    journal={Transactions on Machine Learning Research},
    issn={2835-8856},
    year={2023},
    url={https://openreview.net/forum?id=ZPpQk7FJXF}
}

License

Copyright © 2023, NVIDIA Corporation. All rights reserved.

The code of this work is made available under the NVIDIA Source Code License. Please see our main LICENSE file.

Pre-trained checkpooints are licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.

License Dependencies

For any code dependencies related to StyleGAN3 (stylegan3/, torch_utils/, and dnnlib/), the license is the Nvidia Source Code License by NVIDIA Corporation, see StyleGAN3 LICENSE.

Our DPDMs are based on the DDPM++ architecture which is realeased under Apache License 2.0.