Awesome
tensorf-jax
JAX implementation of Tensorial Radiance Fields, written as an exercise.
@misc{TensoRF,
title={TensoRF: Tensorial Radiance Fields},
author={Anpei Chen and Zexiang Xu and Andreas Geiger and and Jingyi Yu and Hao Su},
year={2022},
eprint={2203.09517},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
We don't attempt to reproduce the original paper exactly, but can achieve decent results after 5~10 minutes of training:
As proposed, TensoRF only supports scenes that fit in a fixed-size bounding box. We've also added basic support for unbounded "real" scenes via mip-NeRF 360-inspired scene contraction1. From nerfstudio's "dozer" dataset:
Instructions
-
Download
nerf_synthetic
dataset: Google Drive. With the default training script arguments, we expect this to be extracted to./data
, eg./data/nerf_synthetic/lego
. -
Install dependencies. Probably you want the GPU version of JAX; see the official instructions. Then:
pip install -r requirements.txt
-
To print training options:
python ./train_lego.py --help
-
To monitor training, we use Tensorboard:
tensorboard --logdir=./runs/
-
To render:
python ./render_360.py --help
Differences from the PyTorch implementation
Things aren't totally matched to the official implementation:
- The official implementation relies heavily on masking operations to improve runtime (for example, by using a weight threshold for sampled points). These require dynamic shapes and are currently difficult to implement in JAX, so we replace them with workarounds like weighted sampling.
- Several training details that would likely improve performance are not yet implemented: bounding box refinement, ray filtering, regularization, etc.
- We include mixed-precision training, which can speed training throughput up by a significant factor. (is this actually faster in terms of wall-clock time? unclear)
References
Implementation details are based loosely on the original PyTorch implementation apchsenstu/TensoRF.
unixpickle/learn-nerf and google-research/jaxnerf were also really helpful for understanding core NeRF concepts + connecting them to JAX!
To-do
- Main implementation
- Point sampling
- Feature MLP
- Rendering
- VM decomposition
- Basic implementation
- Vectorized
- Dataloading
- Blender
- nerfstudio
- Basics
- Fisheye support
- Compute samples without undistorting images (throws away a lot of pixels)
- Tricks for real data
- Scene contraction (~mip-NeRF 360)
- Camera embeddings
- Training
- Learning rate scheduler
- ADAM + grouped LR
- Exponential decay
- Reset decay after upsampling
- Running
- Checkpointing
- Logging
- Loss
- PSNR
- Test metrics
- Test images
- Render previews
- Ray filtering
- Bounding box refinement
- Incremental upsampling
- Regularization terms
- Learning rate scheduler
- Performance
- Weight thresholding for computing appearance features
- per ray top-k
- global top-k (bad & deleted)
- Mixed-precision
- implemented
- stable
- Multi-GPU (should be quick)
- Weight thresholding for computing appearance features
- Rendering
- RGB
- Depth (median)
- Depth (mean)
- Batching
- Generate some GIFs
- Misc engineering
- Actions
- Understand vmap performance differences (details)
Footnotes
-
Same as the original, but with an $L-\infty$ norm instead of $L-2$ norm. ↩