Home

Awesome

Learning Energy-Based Prior Model with Diffusion-Amortized MCMC

<img src="teaser.png" alt="teaser" width="50%" /> <img src="toy_example.png" alt="toy_example" width="50%" />

[Paper] [Code]

The official code repository for NeurIPS 2023 paper "Learning Energy-Based Prior Model with Diffusion-Amortized MCMC".

Installation

The implementation depends on the following commonly used packages, all of which can be installed via conda.

PackageVersion
PyTorch1.10.0
pytorch-fid0.2.1
pytorch-fid-wrapper0.0.4
numpy1.21.0

Please refer to this repo if you're having trouble installing pytorch-fid-wrapper.

Datasets and Pre-trained Weights

Pretrained models are available at: https://drive.google.com/drive/folders/18UT4u4vco5TaEJx3HqksXyKP5l_jovUU?usp=sharing.

Training

Image Reconstruction and Generation

# Under the root folder
CUDA_VISIBLE_DEVICES=<GPU_ID> python train_gen_recon.py --dataset <DATASET_ALIAS> --seed <RANDOM_SEED> --log_path <PATH_FOR_TRAINED_WEIGHTS_AND_VIS> --data_path <PATH_TO_DATASETS>

One may want to specify the log_path argument for saving the trained weights and visualization results. Available dataset aliases include (svhn, cifar10, celeba64, celebaHQ). data_path indicates the dataset location. L48-107 of train_gen_recon.py provide more details about how to set-up the data_path argument. Please find other available arguments at L352-405 in the train_gen_recon.py file.

Anomaly Detection

# Under the root folder
CUDA_VISIBLE_DEVICES=<GPU_ID> python train_anomaly_det.py --seed <RANDOM_SEED> --label <HELDOUT_DIGIT> --log_path <PATH_FOR_TRAINED_WEIGHTS_AND_VIS> --data_path <PATH_TO_DATASETS>

The label argument indicates the held-out digit in the MNIST dataset used for anomaly detection. Available options include (1, 4, 5, 7, 9). data_path indicates the dataset location. L58-62 of train_anomaly_det.py provide more details about how to set-up the data_path argument.

Running these training scripts will automatically create the folders for the trained weights and other intermediate results in the log_path.

Evaluation

To evaluate the pre-trained weights, one may consider using the following scripts

Image Reconstruction and Generation

# Under the root folder
CUDA_VISIBLE_DEVICES=<GPU_ID> python eval_gen_recon.py --dataset svhn --resume_path <PATH_TO_TRAINED_WEIGHTS> --data_path <PATH_TO_DATASETS> --e_l_step_size 0.4 --g_llhd_sigma 0.1

CUDA_VISIBLE_DEVICES=<GPU_ID> python eval_gen_recon.py --dataset cifar10 --resume_path <PATH_TO_TRAINED_WEIGHTS> --data_path <PATH_TO_DATASETS> --e_l_step_size 1.6 --g_llhd_sigma 0.1

CUDA_VISIBLE_DEVICES=<GPU_ID> python eval_gen_recon.py --dataset celeba64 --resume_path <PATH_TO_TRAINED_WEIGHTS> --data_path <PATH_TO_DATASETS> --e_l_step_size 0.4 --g_llhd_sigma 0.1

CUDA_VISIBLE_DEVICES=<GPU_ID> python eval_gen_recon.py --dataset celebaHQ --resume_path <PATH_TO_TRAINED_WEIGHTS> --data_path <PATH_TO_DATASETS> --e_l_step_size 0.4 --g_llhd_sigma 1.0

Anomaly Detection

# Under the root folder
CUDA_VISIBLE_DEVICES=<GPU_ID> python eval_anomaly_det.py --label 1 --resume_path <PATH_TO_TRAINED_WEIGHTS> --data_path <PATH_TO_DATASETS> --g_llhd_sigma .1

CUDA_VISIBLE_DEVICES=<GPU_ID> python eval_anomaly_det.py --label 4 --resume_path <PATH_TO_TRAINED_WEIGHTS> --data_path <PATH_TO_DATASETS> --g_llhd_sigma 1.

CUDA_VISIBLE_DEVICES=<GPU_ID> python eval_anomaly_det.py --label 5 --resume_path <PATH_TO_TRAINED_WEIGHTS> --data_path <PATH_TO_DATASETS> --g_llhd_sigma 1.

CUDA_VISIBLE_DEVICES=<GPU_ID> python eval_anomaly_det.py --label 7 --resume_path <PATH_TO_TRAINED_WEIGHTS> --data_path <PATH_TO_DATASETS> --g_llhd_sigma 1.

CUDA_VISIBLE_DEVICES=<GPU_ID> python eval_anomaly_det.py --label 9 --resume_path <PATH_TO_TRAINED_WEIGHTS> --data_path <PATH_TO_DATASETS> --g_llhd_sigma 1.

StyleGAN Inversion

# Under the root folder
CUDA_VISIBLE_DEVICES=<GPU_ID> python eval_stylegan_inv.py --dataset <DATASET_ALIAS> --resume_path <PATH_TO_TRAINED_WEIGHTS> --data_path <PATH_TO_DATASETS> --pretrained_G_path <TO_SPECIFY> --pretrained_E_path <TO_SPECIFY> --pretrained_F_path <TO_SPECIFY>

For styleGAN inversion, the pretrained_G_path is the path to the pre-trained generator weights, and the pretrained_E_path is the path to the encoder weights. pretrained_F_path specifies the path to the vgg model for perceptual loss. Available dataset aliases include (ffhq, lsun_tower).

Toy Example

Run the Code

To train the model on the toy example, one can run the following command in the toy_example folder.

CUDA_VISIBLE_DEVICES=<DEVICE_ID> python toy_example.py --seed <RANDOM_SEED_TO_SPECIFY> 

Here --seed argument specifies the random seed, which basically decides the ground-truth posterior distribution. The script will automatically generate a logs/toy/<TIMESTAMP> folder in the toy_example folder, where <TIMESTAMP> indicates the time you started this training process.

Important Tips about Training

For most random seeds, we observed that our learned sampler could achieve decent approximation of the ground-truth posterior distributions obtained by long-run langevin dynamics within 300-3000 training iterations. This would take from several minutes to an hour or so on a NVIDIA RTX A6000 GPU. The training process takes ~2GB GPU memory. It is possible that there are some extreme cases where longer training iterations are needed to produce decent results. For some random seeds, the default 1000-step langevin dynamics for sampling ground-truth posterior distribution might not converge. One may consider using 2000 or more steps by modifying the g_l_steps argument in the sample_langevin_post_z function at L277 in the toy_example/toy_example.py. One possible sign is that the g_loss (avg) Q (reconstruction error obtained by learned posterior samples) is significantly lower than g_loss (avg) L (reconstruction error obtained by langevin dynamics samples).

Citation

@article{yu2023learning,
  title={Learning Energy-Based Prior Model with Diffusion-Amortized MCMC},
  author={Yu, Peiyu and Zhu, Yaxuan and Xie, Sirui and Ma, Xiaojian and Gao, Ruiqi and Zhu, Song-Chun and Wu, Ying Nian},
  journal={arXiv preprint arXiv:2310.03218},
  year={2023}
}