Awesome
Reproducing Neural Discrete Representation Learning
Course Project for IFT 6135 - Representation Learning
Project Report link: final_project.pdf
Instructions
- To train the VQVAE with default arguments as discussed in the report, execute:
python vqvae.py --data-folder /tmp/miniimagenet --output-folder models/vqvae
- To train the PixelCNN prior on the latents, execute:
python pixelcnn_prior.py --data-folder /tmp/miniimagenet --model models/vqvae --output-folder models/pixelcnn_prior
Datasets Tested
Image
- MNIST
- FashionMNIST
- CIFAR10
- Mini-ImageNet
Video
- Atari 2600 - Boxing (OpenAI Gym) code
Reconstructions from VQ-VAE
Top 4 rows are Original Images. Bottom 4 rows are Reconstructions.
MNIST
Fashion MNIST
Class-conditional samples from VQVAE with PixelCNN prior on the latents
MNIST
Fashion MNIST
Comments
- We noticed that implementing our own VectorQuantization PyTorch function speeded-up training of VQ-VAE by nearly 3x. The slower, but simpler code is in this commit.
- We added some basic tests for the vector quantization functions (based on
pytest
). To run these tests
py.test . -vv
Authors
- Rithesh Kumar
- Tristan Deleu
- Evan Racah