Home

Awesome

Convolutional State Space Models for Long-Range Spatiotemporal Modeling

This repository provides the official JAX implementation for the paper:

Convolutional State Space Models for Long-Range Spatiotemporal Modeling [arXiv]

Jimmy T.H. Smith, Shalini De Mello, Jan Kautz, Scott Linderman, Wonmin Byeon, NeurIPS 2023.

<!-- [[paper]](https://drive.google.com/file/d/1tSZ9Qd4eqiaQdiu_KkhdMgDvYn6UG1Tc/view?usp=sharing) -->

For business inquiries, please visit the NVIDIA website and submit the form: NVIDIA Research Licensing.


We introduce an efficient long-range spatiotemporal sequence modeling method, ConvSSM. It is parallelizable and overcomes major limitations of the traditional ConvRNN (e.g., vanishing/exploding gradient problems) while providing an unbounded context and fast autoregressive generation compared to Transformers. It performs similarly or better than Transformers/ConvLSTM on long-horizon video prediction tasks, trains up to 3× faster than ConvLSTM, and generates samples up to 400× faster than Transformers. We provide the results for the long horizon Moving-MNIST generation task and long-range 3D environment benchmarks (DMLab, Minecraft, and Habitat).

teaser

The repository builds on the training pipeline from TECO.


Installation

You will need to install JAX following the instructions here. We used JAX version 0.3.21.

pip install --upgrade jax[cuda]==0.3.21 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Then install the rest of the dependencies with:

sudo apt-get update && sudo apt-get install -y ffmpeg
pip install -r requirements.txt
pip install -e .
<!-- **TODO**: Need to make sure requirements.txt is up to date, need to check specific Jax version and if newer jax versions work, need to check ffmpeg, etc. -->

Datasets

For Moving-Mnist:

  1. Download the MNIST binary file.
wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz -O data/moving-mnist-pytorch/train-images-idx3-ubyte.gz
  1. Use the script in data/moving-mnist-pytorch to generate the Moving MNIST data.

For 3D Environment tasks:

We used the scripts from the TECO repository to download the datasets; DMLab and Habitat. Check the TECO repository for the details of the datasets.

<!-- **TODO**: Not sure if we provide more specific instructions ourselves, or just point them to TECO. -->

The data should be split into 'train' and 'test' folders.


Pretrained VQ-GANs:

Pretrained VQ-GAN checkpoints for each dataset can be found here. Note these are also from TECO.


Pretrained ConvS5 checkpoints:

Pretrained ConvS5 checkpoints for each dataset can be found here. Download the checkpoints to the checkpoint_directories. Default checkpoint_directory: logs/<output_dir>/checkpoints/

datasetcheckpointconfig
Moving-Mnist 300linkMoving-MNIST/300_train_len/mnist_convS5_novq.yaml
Moving-Mnist 600linkMoving-MNIST/600_train_len/mnist_convS5_novq.yaml
DMLablink3D_ENV_BENCHMARK/dmlab/dmlab_convs5.yaml
Habitatlink3D_ENV_BENCHMARK/habitat/habitat_teco_convS5.yaml
Minecraftlink3D_ENV_BENCHMARK/minecraft/minecraft_teco_convS5.yaml

Training

Before training, you will need to update the paths to the corresponding configs files to point to your dataset and VQ-GAN directories.

To train, run: python scripts/train.py -d <dataset_dir> -o <output_dir> -c <path_to_config_yaml>

Example for training ConvS5 on DMLAB:

python scripts/train.py -d datasets/dmlab -o dmlab_convs5 -c configs/3D_ENV_BENCHMARK/dmlab/dmlab_convs5.yaml

Note: we only used data parallel training for our experiments. Model parallel training will require implementing JAX xmap or pjit/jit. See this folder in the TECO repo for an example using xmap.

Our runs were performed in a multinode NVIDIA V100 32GB GPU environment.


Evaluation

To evaluate run: python scripts/eval.py -d <dataset_dir> -o <output_dir> -c <path_to_eval_config_yaml>

Example for evaluating ConvS5 on DMLAB:

python scripts/eval.py -d datasets/dmlab -o dmlab_convs5 -c configs/3D_ENV_BENCHMARK/dmlab/dmlab_convs5_eval.yaml

This will perform the sampling required for computing the different evaluation metrics. The videos will be saved into npz files.

For FVD evaluations run: python scripts/compute_fvd.py <path_to_npz>

Example for ConvS5 on DMLAB:

python scripts/compute_fvd.py logs/dmlab_convs5/samples_36

For PSNR, SSIM, and LPIPS run: python scripts/compute_metrics.py <path_to_npz>

Example for ConvS5 on DMLAB:

python scripts/compute_metrics.py logs/dmlab_convs5/samples_action_144

Citation

Please use the following when citing our work:

@inproceedings{
      smith2023convolutional,
      title={Convolutional State Space Models for Long-Range Spatiotemporal Modeling},
      author={Jimmy T.H. Smith and Shalini De Mello and Jan Kautz and Scott Linderman and Wonmin Byeon},
      booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
      year={2023},
      url={https://openreview.net/forum?id=1ZvEtnrHS1}
}

License

Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE file for details.

Please reach out if you have any questions.

-- The ConvS5 authors.