Home

Awesome

<p align="center"> <a href="https://pypi.org/project/pythae/"> <img src='https://badge.fury.io/py/pythae.svg' alt='Python' /> </a> <a> <img src='https://img.shields.io/badge/python-3.7%7C3.8%7C3.9%2B-blueviolet' alt='Python' /> </a> <a href='https://pythae.readthedocs.io/en/latest/?badge=latest'> <img src='https://readthedocs.org/projects/pythae/badge/?version=latest' alt='Documentation Status' /> </a> <a href='https://opensource.org/licenses/Apache-2.0'> <img src='https://img.shields.io/github/license/clementchadebec/benchmark_VAE?color=blue' /> </a><br> <a> <img src='https://img.shields.io/badge/code%20style-black-black' /> </a> <a href="https://codecov.io/gh/clementchadebec/benchmark_VAE"> <img src="https://codecov.io/gh/clementchadebec/benchmark_VAE/branch/main/graph/badge.svg?token=KEM7KKISXJ"/> </a> <a href="https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/overview_notebook.ipynb"> <img src="https://colab.research.google.com/assets/colab-badge.svg"/> </a> </a> </p> </p> <p align="center"> <a href="https://pythae.readthedocs.io/en/latest/">Documentation</a> </p>

pythae

This library implements some of the most common (Variational) Autoencoder models under a unified implementation. In particular, it provides the possibility to perform benchmark experiments and comparisons by training the models with the same autoencoding neural network architecture. The feature make your own autoencoder allows you to train any of these models with your own data and own Encoder and Decoder neural networks. It integrates experiment monitoring tools such wandb, mlflow or comet-ml 🧪 and allows model sharing and loading from the HuggingFace Hub 🤗 in a few lines of code.

News 📢

As of v0.1.0, Pythae now supports distributed training using PyTorch's DDP. You can now train your favorite VAE faster and on larger datasets, still with a few lines of code. See our speed-up benchmark.

Quick access:

Installation

To install the latest stable release of this library run the following using pip

$ pip install pythae

To install the latest github version of this library run the following using pip

$ pip install git+https://github.com/clementchadebec/benchmark_VAE.git

or alternatively you can clone the github repo to access to tests, tutorials and scripts.

$ git clone https://github.com/clementchadebec/benchmark_VAE.git

and install the library

$ cd benchmark_VAE
$ pip install -e .

Available Models

Below is the list of the models currently implemented in the library.

ModelsTraining examplePaperOfficial Implementation
Autoencoder (AE)Open In Colab
Variational Autoencoder (VAE)Open In Colablink
Beta Variational Autoencoder (BetaVAE)Open In Colablink
VAE with Linear Normalizing Flows (VAE_LinNF)Open In Colablink
VAE with Inverse Autoregressive Flows (VAE_IAF)Open In Colablinklink
Disentangled Beta Variational Autoencoder (DisentangledBetaVAE)Open In Colablink
Disentangling by Factorising (FactorVAE)Open In Colablink
Beta-TC-VAE (BetaTCVAE)Open In Colablinklink
Importance Weighted Autoencoder (IWAE)Open In Colablinklink
Multiply Importance Weighted Autoencoder (MIWAE)Open In Colablink
Partially Importance Weighted Autoencoder (PIWAE)Open In Colablink
Combination Importance Weighted Autoencoder (CIWAE)Open In Colablink
VAE with perceptual metric similarity (MSSSIM_VAE)Open In Colablink
Wasserstein Autoencoder (WAE)Open In Colablinklink
Info Variational Autoencoder (INFOVAE_MMD)Open In Colablink
VAMP Autoencoder (VAMP)Open In Colablinklink
Hyperspherical VAE (SVAE)Open In Colablinklink
Poincaré Disk VAE (PoincareVAE)Open In Colablinklink
Adversarial Autoencoder (Adversarial_AE)Open In Colablink
Variational Autoencoder GAN (VAEGAN) 🥗Open In Colablinklink
Vector Quantized VAE (VQVAE)Open In Colablinklink
Hamiltonian VAE (HVAE)Open In Colablinklink
Regularized AE with L2 decoder param (RAE_L2)Open In Colablinklink
Regularized AE with gradient penalty (RAE_GP)Open In Colablinklink
Riemannian Hamiltonian VAE (RHVAE)Open In Colablinklink
Hierarchical Residual Quantization (HRQVAE)Open In Colablinklink

See reconstruction and generation results for all aforementionned models

Available Samplers

Below is the list of the models currently implemented in the library.

SamplersModelsPaperOfficial Implementation
Normal prior (NormalSampler)all modelslink
Gaussian mixture (GaussianMixtureSampler)all modelslinklink
Two stage VAE sampler (TwoStageVAESampler)all VAE based modelslinklink
Unit sphere uniform sampler (HypersphereUniformSampler)SVAElinklink
Poincaré Disk sampler (PoincareDiskSampler)PoincareVAElinklink
VAMP prior sampler (VAMPSampler)VAMPlinklink
Manifold sampler (RHVAESampler)RHVAElinklink
Masked Autoregressive Flow Sampler (MAFSampler)all modelslinklink
Inverse Autoregressive Flow Sampler (IAFSampler)all modelslinklink
PixelCNN (PixelCNNSampler)VQVAElink

Reproducibility

We validate the implementations by reproducing some results presented in the original publications when the official code has been released or when enough details about the experimental section of the papers were available. See reproducibility for more details.

Launching a model training

To launch a model training, you only need to call a TrainingPipeline instance.

>>> from pythae.pipelines import TrainingPipeline
>>> from pythae.models import VAE, VAEConfig
>>> from pythae.trainers import BaseTrainerConfig

>>> # Set up the training configuration
>>> my_training_config = BaseTrainerConfig(
...	output_dir='my_model',
...	num_epochs=50,
...	learning_rate=1e-3,
...	per_device_train_batch_size=200,
...	per_device_eval_batch_size=200,
...	train_dataloader_num_workers=2,
...	eval_dataloader_num_workers=2,
...	steps_saving=20,
...	optimizer_cls="AdamW",
...	optimizer_params={"weight_decay": 0.05, "betas": (0.91, 0.995)},
...	scheduler_cls="ReduceLROnPlateau",
...	scheduler_params={"patience": 5, "factor": 0.5}
... )
>>> # Set up the model configuration 
>>> my_vae_config = model_config = VAEConfig(
...	input_dim=(1, 28, 28),
...	latent_dim=10
... )
>>> # Build the model
>>> my_vae_model = VAE(
...	model_config=my_vae_config
... )
>>> # Build the Pipeline
>>> pipeline = TrainingPipeline(
... 	training_config=my_training_config,
... 	model=my_vae_model
... )
>>> # Launch the Pipeline
>>> pipeline(
...	train_data=your_train_data, # must be torch.Tensor, np.array or torch datasets
...	eval_data=your_eval_data # must be torch.Tensor, np.array or torch datasets
... )

At the end of training, the best model weights, model configuration and training configuration are stored in a final_model folder available in my_model/MODEL_NAME_training_YYYY-MM-DD_hh-mm-ss (with my_model being the output_dir argument of the BaseTrainerConfig). If you further set the steps_saving argument to a certain value, folders named checkpoint_epoch_k containing the best model weights, optimizer, scheduler, configuration and training configuration at epoch k will also appear in my_model/MODEL_NAME_training_YYYY-MM-DD_hh-mm-ss.

Launching a training on benchmark datasets

We also provide a training script example here that can be used to train the models on benchmarks datasets (mnist, cifar10, celeba ...). The script can be launched with the following commandline

python training.py --dataset mnist --model_name ae --model_config 'configs/ae_config.json' --training_config 'configs/base_training_config.json'

See README.md for further details on this script

Launching data generation

Using the GenerationPipeline

The easiest way to launch a data generation from a trained model consists in using the built-in GenerationPipeline provided in Pythae. Say you want to generate 100 samples using a MAFSampler all you have to do is 1) relaod the trained model, 2) define the sampler's configuration and 3) create and launch the GenerationPipeline as follows

>>> from pythae.models import AutoModel
>>> from pythae.samplers import MAFSamplerConfig
>>> from pythae.pipelines import GenerationPipeline
>>> # Retrieve the trained model
>>> my_trained_vae = AutoModel.load_from_folder(
...	'path/to/your/trained/model'
... )
>>> my_sampler_config = MAFSamplerConfig(
...	n_made_blocks=2,
...	n_hidden_in_made=3,
...	hidden_size=128
... )
>>> # Build the pipeline
>>> pipe = GenerationPipeline(
...	model=my_trained_vae,
...	sampler_config=my_sampler_config
... )
>>> # Launch data generation
>>> generated_samples = pipe(
...	num_samples=args.num_samples,
...	return_gen=True, # If false returns nothing
...	train_data=train_data, # Needed to fit the sampler
...	eval_data=eval_data, # Needed to fit the sampler
...	training_config=BaseTrainerConfig(num_epochs=200) # TrainingConfig to use to fit the sampler
... )

Using the Samplers

Alternatively, you can launch the data generation process from a trained model directly with the sampler. For instance, to generate new data with your sampler, run the following.

>>> from pythae.models import AutoModel
>>> from pythae.samplers import NormalSampler
>>> # Retrieve the trained model
>>> my_trained_vae = AutoModel.load_from_folder(
...	'path/to/your/trained/model'
... )
>>> # Define your sampler
>>> my_samper = NormalSampler(
...	model=my_trained_vae
... )
>>> # Generate samples
>>> gen_data = my_samper.sample(
...	num_samples=50,
...	batch_size=10,
...	output_dir=None,
...	return_gen=True
... )

If you set output_dir to a specific path, the generated images will be saved as .png files named 00000000.png, 00000001.png ... The samplers can be used with any model as long as it is suited. For instance, a GaussianMixtureSampler instance can be used to generate from any model but a VAMPSampler will only be usable with a VAMP model. Check here to see which ones apply to your model. Be carefull that some samplers such as the GaussianMixtureSampler for instance may need to be fitted by calling the fit method before using. Below is an example for the GaussianMixtureSampler.

>>> from pythae.models import AutoModel
>>> from pythae.samplers import GaussianMixtureSampler, GaussianMixtureSamplerConfig
>>> # Retrieve the trained model
>>> my_trained_vae = AutoModel.load_from_folder(
...	'path/to/your/trained/model'
... )
>>> # Define your sampler
... gmm_sampler_config = GaussianMixtureSamplerConfig(
...	n_components=10
... )
>>> my_samper = GaussianMixtureSampler(
...	sampler_config=gmm_sampler_config,
...	model=my_trained_vae
... )
>>> # fit the sampler
>>> gmm_sampler.fit(train_dataset)
>>> # Generate samples
>>> gen_data = my_samper.sample(
...	num_samples=50,
...	batch_size=10,
...	output_dir=None,
...	return_gen=True
... )

Define you own Autoencoder architecture

Pythae provides you the possibility to define your own neural networks within the VAE models. For instance, say you want to train a Wassertstein AE with a specific encoder and decoder, you can do the following:

>>> from pythae.models.nn import BaseEncoder, BaseDecoder
>>> from pythae.models.base.base_utils import ModelOutput
>>> class My_Encoder(BaseEncoder):
...	def __init__(self, args=None): # Args is a ModelConfig instance
...		BaseEncoder.__init__(self)
...		self.layers = my_nn_layers()
...		
...	def forward(self, x:torch.Tensor) -> ModelOutput:
...		out = self.layers(x)
...		output = ModelOutput(
...			embedding=out # Set the output from the encoder in a ModelOutput instance 
...		)
...		return output
...
... class My_Decoder(BaseDecoder):
...	def __init__(self, args=None):
...		BaseDecoder.__init__(self)
...		self.layers = my_nn_layers()
...		
...	def forward(self, x:torch.Tensor) -> ModelOutput:
...		out = self.layers(x)
...		output = ModelOutput(
...			reconstruction=out # Set the output from the decoder in a ModelOutput instance
...		)
...		return output
...
>>> my_encoder = My_Encoder()
>>> my_decoder = My_Decoder()

And now build the model

>>> from pythae.models import WAE_MMD, WAE_MMD_Config
>>> # Set up the model configuration 
>>> my_wae_config = model_config = WAE_MMD_Config(
...	input_dim=(1, 28, 28),
...	latent_dim=10
... )
...
>>> # Build the model
>>> my_wae_model = WAE_MMD(
...	model_config=my_wae_config,
...	encoder=my_encoder, # pass your encoder as argument when building the model
...	decoder=my_decoder # pass your decoder as argument when building the model
... )

important note 1: For all AE-based models (AE, WAE, RAE_L2, RAE_GP), both the encoder and decoder must return a ModelOutput instance. For the encoder, the ModelOutput instance must contain the embbeddings under the key embedding. For the decoder, the ModelOutput instance must contain the reconstructions under the key reconstruction.

important note 2: For all VAE-based models (VAE, BetaVAE, IWAE, HVAE, VAMP, RHVAE), both the encoder and decoder must return a ModelOutput instance. For the encoder, the ModelOutput instance must contain the embbeddings and log-covariance matrices (of shape batch_size x latent_space_dim) respectively under the key embedding and log_covariance key. For the decoder, the ModelOutput instance must contain the reconstructions under the key reconstruction.

Using benchmark neural nets

You can also find predefined neural network architectures for the most common data sets (i.e. MNIST, CIFAR, CELEBA ...) that can be loaded as follows

>>> from pythae.models.nn.benchmark.mnist import (
...	Encoder_Conv_AE_MNIST, # For AE based model (only return embeddings)
...	Encoder_Conv_VAE_MNIST, # For VAE based model (return embeddings and log_covariances)
...	Decoder_Conv_AE_MNIST
... )

Replace mnist by cifar or celeba to access to other neural nets.

Distributed Training with Pythae

As of v0.1.0, Pythae now supports distributed training using PyTorch's DDP. It allows you to train your favorite VAE faster and on larger dataset using multi-gpu and/or multi-node training.

To do so, you can build a python script that will then be launched by a launcher (such as srun on a cluster). The only thing that is needed in the script is to specify some elements relative to the distributed environment (such as the number of nodes/gpus) directly in the training configuration as follows

>>> training_config = BaseTrainerConfig(
...     num_epochs=10,
...     learning_rate=1e-3,
...     per_device_train_batch_size=64,
...     per_device_eval_batch_size=64,
...     train_dataloader_num_workers=8,
...     eval_dataloader_num_workers=8,
...     dist_backend="nccl", # distributed backend
...     world_size=8 # number of gpus to use (n_nodes x n_gpus_per_node),
...     rank=5 # global gpu id,
...     local_rank=1 # gpu id within a node,
...     master_addr="localhost" # master address,
...     master_port="12345" # master port,
... )

See this example script that defines a multi-gpu VQVAE training on ImageNet dataset. Please note that the way the distributed environnement variables (world_size, rank ...) are recovered may be specific to the cluster and launcher you use.

Benchmark

Below are indicated the training times for a Vector Quantized VAE (VQ-VAE) with Pythae for 100 epochs on MNIST on V100 16GB GPU(s), for 50 epochs on FFHQ (1024x1024 images) and for 20 epochs on ImageNet-1k on V100 32GB GPU(s).

Train Data1 GPU4 GPUs2x4 GPUs
MNIST (VQ-VAE)28x28 images (50k)235.18 s62.00 s35.86 s
FFHQ 1024x1024 (VQVAE)1024x1024 RGB images (60k)19h 1min5h 6min2h 37min
ImageNet-1k 128x128 (VQVAE)128x128 RGB images (~ 1.2M)6h 25min1h 41min51min 26s

For each dataset, we provide the benchmarking scripts here

Sharing your models with the HuggingFace Hub 🤗

Pythae also allows you to share your models on the HuggingFace Hub. To do so you need:

$ python -m pip install huggingface_hub
$ huggingface-cli login

Uploading a model to the Hub

Any pythae model can be easily uploaded using the method push_to_hf_hub

>>> my_vae_model.push_to_hf_hub(hf_hub_path="your_hf_username/your_hf_hub_repo")

Note: If your_hf_hub_repo already exists and is not empty, files will be overridden. In case, the repo your_hf_hub_repo does not exist, a folder having the same name will be created.

Downloading models from the Hub

Equivalently, you can download or reload any Pythae's model directly from the Hub using the method load_from_hf_hub

>>> from pythae.models import AutoModel
>>> my_downloaded_vae = AutoModel.load_from_hf_hub(hf_hub_path="path_to_hf_repo")

Monitoring your experiments with wandb 🧪

Pythae also integrates the experiment tracking tool wandb allowing users to store their configs, monitor their trainings and compare runs through a graphic interface. To be able use this feature you will need:

$ pip install wandb
$ wandb login

Creating a WandbCallback

Launching an experiment monitoring with wandb in pythae is pretty simple. The only thing a user needs to do is create a WandbCallback instance...

>>> # Create you callback
>>> from pythae.trainers.training_callbacks import WandbCallback
>>> callbacks = [] # the TrainingPipeline expects a list of callbacks
>>> wandb_cb = WandbCallback() # Build the callback 
>>> # SetUp the callback 
>>> wandb_cb.setup(
...	training_config=your_training_config, # training config
...	model_config=your_model_config, # model config
...	project_name="your_wandb_project", # specify your wandb project
...	entity_name="your_wandb_entity", # specify your wandb entity
... )
>>> callbacks.append(wandb_cb) # Add it to the callbacks list

...and then pass it to the TrainingPipeline.

>>> pipeline = TrainingPipeline(
...	training_config=config,
...	model=model
... )
>>> pipeline(
...	train_data=train_dataset,
...	eval_data=eval_dataset,
...	callbacks=callbacks # pass the callbacks to the TrainingPipeline and you are done!
... )
>>> # You can log to https://wandb.ai/your_wandb_entity/your_wandb_project to monitor your training

See the detailed tutorial

Monitoring your experiments with mlflow 🧪

Pythae also integrates the experiment tracking tool mlflow allowing users to store their configs, monitor their trainings and compare runs through a graphic interface. To be able use this feature you will need:

$ pip install mlflow

Creating a MLFlowCallback

Launching an experiment monitoring with mlfow in pythae is pretty simple. The only thing a user needs to do is create a MLFlowCallback instance...

>>> # Create you callback
>>> from pythae.trainers.training_callbacks import MLFlowCallback
>>> callbacks = [] # the TrainingPipeline expects a list of callbacks
>>> mlflow_cb = MLFlowCallback() # Build the callback 
>>> # SetUp the callback 
>>> mlflow_cb.setup(
...	training_config=your_training_config, # training config
...	model_config=your_model_config, # model config
...	run_name="mlflow_cb_example", # specify your mlflow run
... )
>>> callbacks.append(mlflow_cb) # Add it to the callbacks list

...and then pass it to the TrainingPipeline.

>>> pipeline = TrainingPipeline(
...	training_config=config,
...	model=model
... )
>>> pipeline(
...	train_data=train_dataset,
...	eval_data=eval_dataset,
...	callbacks=callbacks # pass the callbacks to the TrainingPipeline and you are done!
... )

you can visualize your metric by running the following in the directory where the ./mlruns

$ mlflow ui 

See the detailed tutorial

Monitoring your experiments with comet_ml 🧪

Pythae also integrates the experiment tracking tool comet_ml allowing users to store their configs, monitor their trainings and compare runs through a graphic interface. To be able use this feature you will need:

$ pip install comet_ml

Creating a CometCallback

Launching an experiment monitoring with comet_ml in pythae is pretty simple. The only thing a user needs to do is create a CometCallback instance...

>>> # Create you callback
>>> from pythae.trainers.training_callbacks import CometCallback
>>> callbacks = [] # the TrainingPipeline expects a list of callbacks
>>> comet_cb = CometCallback() # Build the callback 
>>> # SetUp the callback 
>>> comet_cb.setup(
...	training_config=training_config, # training config
...	model_config=model_config, # model config
...	api_key="your_comet_api_key", # specify your comet api-key
...	project_name="your_comet_project", # specify your wandb project
...	#offline_run=True, # run in offline mode
...	#offline_directory='my_offline_runs' # set the directory to store the offline runs
... )
>>> callbacks.append(comet_cb) # Add it to the callbacks list

...and then pass it to the TrainingPipeline.

>>> pipeline = TrainingPipeline(
...	training_config=config,
...	model=model
... )
>>> pipeline(
...	train_data=train_dataset,
...	eval_data=eval_dataset,
...	callbacks=callbacks # pass the callbacks to the TrainingPipeline and you are done!
... )
>>> # You can log to https://comet.com/your_comet_username/your_comet_project to monitor your training

See the detailed tutorial

Getting your hands on the code

To help you to understand the way pythae works and how you can train your models with this library we also provide tutorials:

Dealing with issues 🛠️

If you are experiencing any issues while running the code or request new features/models to be implemented please open an issue on github.

Contributing 🚀

You want to contribute to this library by adding a model, a sampler or simply fix a bug ? That's awesome! Thank you! Please see CONTRIBUTING.md to follow the main contributing guidelines.

Results

Reconstruction

First let's have a look at the reconstructed samples taken from the evaluation set.

ModelsMNISTCELEBA
Eval dataEvalAE
AEAEAE
VAEVAEVAE
Beta-VAEBetaBeta Normal
VAE Lin NFVAE_LinNFVAE_IAF Normal
VAE IAFVAE_IAFVAE_IAF Normal
Disentangled Beta-VAEDisentangled BetaDisentangled Beta
FactorVAEFactorVAEFactorVAE
BetaTCVAEBetaTCVAEBetaTCVAE
IWAEIWAEIWAE
MSSSIM_VAEMSSSIM VAEMSSSIM VAE
WAEWAEWAE
INFO VAEINFOINFO
VAMPVAMPVAMP
SVAESVAESVAE
Adversarial_AEAAEAAE
VAE_GANVAEGANVAEGAN
VQVAEVQVAEVQVAE
HVAEHVAEHVAE
RAE_L2RAE L2RAE L2
RAE_GPRAE GMMRAE GMM
Riemannian Hamiltonian VAE (RHVAE)RHVAERHVAE RHVAE

Generation

Here, we show the generated samples using each model implemented in the library and different samplers.

ModelsMNISTCELEBA
AE + GaussianMixtureSamplerAE GMMAE GMM
VAE + NormalSamplerVAE NormalVAE Normal
VAE + GaussianMixtureSamplerVAE GMMVAE GMM
VAE + TwoStageVAESamplerVAE 2 stageVAE 2 stage
VAE + MAFSamplerVAE MAFVAE MAF
Beta-VAE + NormalSamplerBeta NormalBeta Normal
VAE Lin NF + NormalSamplerVAE_LinNF NormalVAE_LinNF Normal
VAE IAF + NormalSamplerVAE_IAF NormalVAE IAF Normal
Disentangled Beta-VAE + NormalSamplerDisentangled Beta NormalDisentangled Beta Normal
FactorVAE + NormalSamplerFactorVAE NormalFactorVAE Normal
BetaTCVAE + NormalSamplerBetaTCVAE NormalBetaTCVAE Normal
IWAE + Normal samplerIWAE NormalIWAE Normal
MSSSIM_VAE + NormalSamplerMSSSIM_VAE NormalMSSSIM_VAE Normal
WAE + NormalSamplerWAE NormalWAE Normal
INFO VAE + NormalSamplerINFO NormalINFO Normal
SVAE + HypershereUniformSamplerSVAE SphereSVAE Sphere
VAMP + VAMPSamplerVAMP VampVAMP Vamp
Adversarial_AE + NormalSamplerAAE_NormalAAE_Normal
VAEGAN + NormalSamplerVAEGAN_NormalVAEGAN_Normal
VQVAE + MAFSamplerVQVAE_MAFVQVAE_MAF
HVAE + NormalSamplerHVAE NormalHVAE GMM
RAE_L2 + GaussianMixtureSamplerRAE L2 GMMRAE L2 GMM
RAE_GP + GaussianMixtureSamplerRAE GMMRAE GMM
Riemannian Hamiltonian VAE (RHVAE) + RHVAE SamplerRHVAE RHVAERHVAE RHVAE

Citation

If you find this work useful or use it in your research, please consider citing us

@inproceedings{chadebec2022pythae,
 author = {Chadebec, Cl\'{e}ment and Vincent, Louis and Allassonniere, Stephanie},
 booktitle = {Advances in Neural Information Processing Systems},
 editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh},
 pages = {21575--21589},
 publisher = {Curran Associates, Inc.},
 title = {Pythae: Unifying Generative Autoencoders in Python - A Benchmarking Use Case},
 volume = {35},
 year = {2022}
}