Awesome
CVAE and VQ-VAE
This is an implementation of the VQ-VAE (Vector Quantized Variational Autoencoder) and Convolutional Varational Autoencoder. from Neural Discrete representation learning for compressing MNIST and Cifar10. The code is based upon pytorch/examples/vae.
pip install -r requirements.txt
python main.py
requirements
- Python 3.6 (maybe 3.5 will work as well)
- PyTorch 0.4
- Additional requirements in requirements.txt
Usage
# For example
python3 main.py --dataset=cifar10 --model=vqvae --data-dir=~/.datasets --epochs=3
Results
All images are taken from the test set. Top row is the original image. Bottom row is the reconstruction.
k - number of elements in the dictionary. d - dimension of elements in the dictionary (number of channels in bottleneck).
- MNIST (k=10, d=64)
- CIFAR10 (k=128, d=256)
- Imagenet (k=512, d=128)
TODO:
-
Implement Continuous Relaxation Training of Discrete Latent Variable Image Models
-
Sample using PixelCNN prior
-
Improve results on cifar - nearest neighbor should be performed to 10 dictionaries rather than 1
-
Improve results on cifar - replace MSE with NLL
-
Improve results on cifar - measure bits/dim
-
Compare architecture with the offical one
-
Merge VAE and VQ-VAE for MNIST and Cifar to one script
Acknowledgement
tf-vaevae for a good reference.