Home

Awesome

Gaussian Splatting PyTorch Lightning Implementation

Known issues

Features

1. Installation

1.1. Clone repository

# clone repository
git clone https://github.com/yzslab/gaussian-splatting-lightning.git
cd gaussian-splatting-lightning

1.2. Create virtual environment

# create virtual environment
conda create -yn gspl python=3.9 pip
conda activate gspl

1.3. Install PyTorch

1.4. Install requirements

pip install -r requirements.txt

1.5. Install optional packages

2. Training

2.1. Basic command

python main.py fit \
    --data.path DATASET_PATH \
    -n EXPERIMENT_NAME

It can detect some dataset type automatically. You can also specify type with option --data.parser. Possible values are: Colmap, Blender, NSVF, Nerfies, MatrixCity, PhotoTourism, SegAnyColmap, Feature3DGSColmap.

<b>[NOTE]</b> By default, only checkpoint files will be produced on training end. If you need ply file in vanilla 3DGS's format (can be loaded by SIBR_viewer or some WebGL/GPU based viewer):

2.2. Some useful options

python main.py fit \
    --viewer \
    ...
python main.py fit \
    --config configs/blender.yaml \
    ...
# the requirements of mask
#   * must be single channel
#   * zero(black) represent the masked pixel (won't be used to supervise learning)
#   * the filename of the mask file must be image filename + '.png', 
#     e.g.: the mask of '001.jpg' is '001.jpg.png'
... fit \
  --data.parser Colmap \
  --data.parser.mask_dir MASK_DIR_PATH \
  ...

You can use utils/image_downsample.py to downsample your images, e.g. 4x downsample: python utils/image_downsample.py PATH_TO_DIRECTORY_THAT_STORE_IMAGES --factor 4

# it will load images from `images_4` directory
... fit \
  --data.parser Colmap \
  --data.parser.down_sample_factor 4 \
  ...

Rounding mode is specified by --data.parser.down_sample_rounding_mode. Available values are floor, round, round_half_up, ceil. Default is round.

2.3. Use <a href="https://github.com/nerfstudio-project/gsplat">nerfstudio-project/gsplat</a>

python main.py fit \
    --config configs/gsplat.yaml \
    ...

2.4. Multi-GPU training (DDP)

<b>[NOTE]</b> Try <a href="#216-new-multiple-gpu-training-strategy">New Multiple GPU training strategy</a>, which can be enabled during densification.

<b>[NOTE]</b> Multi-GPU training with DDP strategy can only be enabled after densification. You can start a single GPU training at the beginning, and save a checkpoint after densification finishing. Then resume from this checkpoint and enable multi-GPU training.

You will get improved PSNR and SSIM with more GPUs: image

# Single GPU at the beginning
python main.py fit \
    --config ... \
    --data.path DATASET_PATH \
    --model.density.densify_until_iter 15000 \
    --max_steps 15000
# Then resume, and enable multi-GPU
python main.py fit \
    --config ... \
    --trainer configs/ddp.yaml \
    --data.path DATASET_PATH \
    --max_steps 30000 \
    --ckpt_path last  # find latest checkpoint automatically, or provide a path to checkpoint file

2.5. <a href="https://ingra14m.github.io/Deformable-Gaussians/">Deformable 3D Gaussians</a>

<video src="https://github.com/yzslab/gaussian-splatting-lightning/assets/564361/177b3fbf-fdd2-490f-b446-433a4d929502"></video>

python main.py fit \
    --config configs/deformable_blender.yaml \
    --data.path ...

2.6. <a href="https://niujinshuchong.github.io/mip-splatting/">Mip-Splatting</a>

Training:

python main.py fit \
    --config configs/mip_splatting_gsplat_v2.yaml \
    --data.path ...

Fuse the 3D smoothing filter to the Gaussian parameters:

python utils/fuse_mip_filter.py \
    TRAINED_MODEL_DIR

2.7. <a href="https://lightgaussian.github.io/">LightGaussian</a>

2.8. <a href="https://ty424.github.io/AbsGS.github.io/">AbsGS</a> / EfficientGS

... fit \
    --config configs/gsplat-absgrad.yaml \
    --data.path ...

2.9. <a href="https://surfsplatting.github.io/">2D Gaussian Splatting</a>

2.10. <a href="https://jumpat.github.io/SAGA/">Segment Any 3D Gaussians</a>

2.11. Large-scale scene reconstruction with partitioning and LoD

BaselinePartitioning
imageimage
imageimage

The implementation here references <a href="https://waymo.com/research/block-nerf/">Block-NeRF</a>, <a href="https://vastgaussian.github.io/">VastGaussian</a> and <a href="https://dekuliutesla.github.io/citygs/">CityGaussians</a>.

There is no single script to finish the whole pipeline. Please refer to below contents about how to reconstruct a large scale scene.

(1) An example pipeline for the <a href="https://storage.cmusatyalab.org/mega-nerf-data/rubble-pixsfm.tgz">Rubble</a> dataset from <a href="https://meganerf.cmusatyalab.org/">MegaNeRF</a>

(2) Utilize multiple GPUs

2.12. Appearance Model

With appearance model, the reconstruction quality can be improved when your images have various appearance, such as different exposure, white balance, contrast and even day and night.

This model assign an extra feature vector $\boldsymbol{\ell}^{(g)}$ to each 3D Gaussian and an appearance embedding vector $\boldsymbol{\ell}^{(a)}$ to each appearance group. Both of them will be used as the input of a lightweight MLP to calculate the color.

$$ \mathbf{C} = f \left ( \boldsymbol{\ell}^{(g)}, \boldsymbol{\ell}^{(a)} \right ) $$

Please refer to <a href="https://github.com/yzslab/gaussian-splatting-lightning/blob/main/internal/renderers/gsplat_appearance_embedding_renderer.py">internal/renderers/gsplat_appearance_embedding_renderer.py</a> for more details.

BaselineNew Model
<video src="https://github.com/yzslab/gaussian-splatting-lightning/assets/564361/3a990247-b57b-4ba8-8e9d-7346a3bd41e3"></video><video src="https://github.com/yzslab/gaussian-splatting-lightning/assets/564361/afeea69f-ed74-4c50-843a-e5d480eb66ef"></video>
<video src="https://github.com/yzslab/gaussian-splatting-lightning/assets/564361/ab89e4cf-80c0-4e99-88bc-3ec5ca047e19"></video>

2.13. <a href="https://ubc-vision.github.io/3dgs-mcmc/">3DGS-MCMC</a>

... fit \
    --config configs/gsplat-mcmc.yaml \
    --model.density.cap_max MAX_NUM_GAUSSIANS \
    ...

MAX_NUM_GAUSSIANS is the maximum number of Gaussians that will be used.

Refer to <a href="https://github.com/ubc-vision/3dgs-mcmc">ubc-vision/3dgs-mcmc</a>, <a href="https://github.com/yzslab/gaussian-splatting-lightning/tree/main/internal/density_controllers/mcmc_density_controller.py">internal/density_controllers/mcmc_density_controller.py</a> and <a href="https://github.com/yzslab/gaussian-splatting-lightning/tree/main/internal/metrics/mcmc_metrics.py">internal/metrics/mcmc_metrics.py</a> for more details.

2.14. Feature distillation

<details> <summary> Click me </summary>

This comes from <a href="https://feature-3dgs.github.io/">Feature 3DGS</a>. But two stage optimization is adapted here, rather than jointly.

</details>

2.15. In the wild

imageimageimageimage

Introduction

Based on the Appearance Model (2.12.) above, this model can produce a visibility map for every training view indicating whether a pixel belongs to transient objects or not.

The idea of the visibility map is a bit like <a href="https://rover-xingyu.github.io/Ha-NeRF/">Ha-NeRF</a>, but rather than uses positional encoding for pixel coordinates, 2D dense grid encoding is used here in order to accelerate training.

Please refer to <a href="https://rover-xingyu.github.io/Ha-NeRF/">Ha-NeRF</a>, internal/renderers/gsplat_appearance_embedding_visibility_map_renderer.py and internal/metrics/visibility_map_metrics.py for more details.

<b>[NOTE]</b> Though it shows the capability to distinguish the pixels of transient objects, may not be able to remove some artifats/floaters belong to transients. And may also treat under-reconstructed regions as transients.

Usage

pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch

Download PhotoTourism dataset from <a href="https://www.cs.ubc.ca/~kmyi/imw2020/data.html">here</a> and split file from the "Additional links" <a href="https://nerf-w.github.io/">here</a>. The split file should be placed at the same path as the dense directory of the PhotoTourism dataset, e.g.:

├──brandenburg_gate
  ├── dense  # colmap database
      ├── images
          ├── ...
      ├── sparse
      ...
  ├── brandenburg.tsv  # split file

[Optional] 2x downsize the images: python utils/image_downsample.py data/brandenburg_gate/dense/images --factor 2

python main.py fit \
    --config configs/appearance_embedding_visibility_map_renderer/view_independent-2x_ds.yaml \
    --data.path data/brandenburg_gate \
    -n brandenburg_gate

If you have not downsized images, remember to add a --data.parser.down_sample_factor 1 to the command above.

python main.py validate \
   --save_val \
   --val_train \
   --config outputs/brandenburg_gate/lightning_logs/version_0/config.yaml  # you may need to change this path

Then you can find the rendered masks and images in outputs/brandenburg_gate/val.

2.16. New Multiple GPU training strategy

Introduction

This is a bit like a simplified version of <a href="https://daohanlu.github.io/scaling-up-3dgs/">Scaling Up 3DGS</a>.

In the implementation here, Gaussians are stored, projected and their colors are calculated in a distributed manner, and each GPU rasterizes a whole image for a different camera. No Pixel-wise Distribution currently.

This strategy works with densification enabled.

<b>[NOTE]</b>

<details> <summary>Metrics of MipNeRF360 dataset</summary> One batch per GPU, 30K iterations, no other hyperparameters changed. </details>

Usage

python main.py fit \
    --config configs/distributed.yaml \
    ...

By default, all processes will hold a (redundant) replica of the dataset in memory, which may cause CPU OOM. You can avoid this by adding the option --data.distributed true, so that each process loads a different subset of the dataset.

python utils/merge_distributed_ckpts.py outputs/TRAINED_MODEL_DIR
python viewer.py outputs/TRAINED_MODEL_DIR/checkpoints/MERGED_CHECKPOINT_FILE

2.17. <a href="https://spotlesssplats.github.io/">SpotLessSplats</a>

<b>[NOTE]</b> No utilization-based pruning (4.2.3 of the paper) and appearance modeling (4.2.4 of the paper)

2.18. Depth Regularization with <a href="https://depth-anything-v2.github.io/">Depth Anything V2</a>

This is implemented with reference to <a href="https://repo-sam.inria.fr/fungraph/hierarchical-3d-gaussians/">Hierarchical 3DGS</a>.

BaselineDepthRegDepthReg + AppearanceModel
<video src="https://github.com/user-attachments/assets/138290ca-6c19-4dc0-81c0-f5b1fd7dbb04"></video><video src="https://github.com/user-attachments/assets/4f6b04f7-c889-4d80-b32d-32339fe5ddb7"></video><video src="https://github.com/user-attachments/assets/68c57124-87c0-4eb6-8e2e-4457103beee2"></video>

2.19. <a href="https://r4dl.github.io/StopThePop/">StopThePop</a>

2.20. Scale Regularization

The scales of Gaussians will grow to some unreasonable values after densification. For example, some linear shape Gaussians are almost longer than your scene, and appear as artifacts at many viewpoints. This regularization, containing max scale and scale ratio losses, can avoid it. Take a look <a href="https://github.com/yzslab/gaussian-splatting-lightning/blob/main/internal/metrics/scale_regularization_metrics.py">internal/metrics/scale_regularization_metrics.py</a> for more details.

Usage:

python main.py fit \
    --config configs/scale_reg.yaml \
    --model.metric.max_scale 1. \
    ...

The --model.metric.max_scale is a scene-specific hyperparameter. The regularization will be applied to the Gaussians with scales exceeding it. It should be greater than percent_dense * camera_extent. The percent_dense is 0.01 by default. The camera_extent will be printed as spatial_lr_scale=... at the beginning of the training. Set it to a very large value, e.g. 2048, to disable the max scale loss if you are not sure what value should be used.

2.21. <a href="https://humansensinglab.github.io/taming-3dgs/">Taming 3DGS</a>

There are two implementations: one is the gsplat v1 based, and the other is the vanilla one. The gsplat v1 based implementation currently does not have "Backpropagation with Per-Splat Parallelization."

3. Evaluation

Per-image metrics will be saved to TRAINING_OUTPUT/metrics as a csv file.

Evaluate on validation set

python main.py validate \
    --config outputs/lego/config.yaml

On test set

python main.py test \
    --config outputs/lego/config.yaml

On train set

python main.py validate \
    --config outputs/lego/config.yaml \
    --val_train

Save images that rendered during evaluation/test

python main.py <validate or test> \
    --config outputs/lego/config.yaml \
    --save_val

Then you can find the images in outputs/lego/<val or test>.

4. Web Viewer

TransformCamera PathEdit
<video src="https://github.com/yzslab/gaussian-splatting-lightning/assets/564361/de1ff3c3-a27a-4600-8c76-ab6551df6fca"></video><video src="https://github.com/yzslab/gaussian-splatting-lightning/assets/564361/3f87243d-d9a1-41e2-9d51-225735925db4"></video><video src="https://github.com/yzslab/gaussian-splatting-lightning/assets/564361/7cf0ccf2-44e9-4fc9-87cc-740b7bbda488"></video>

4.1 Basic usage

python viewer.py TRAINING_OUTPUT_PATH
# e.g.: 
#   python viewer.py outputs/lego/
#   python viewer.py outputs/lego/checkpoints/epoch=300-step=30000.ckpt
#   python viewer.py outputs/lego/baseline/point_cloud/iteration_30000/point_cloud.ply  # only works with VanillaRenderer

4.2 Load multiple models and enable transform options

python viewer.py \
    outputs/garden \
    outputs/lego \
    outputs/Synthetic_NSVF/Palace/point_cloud/iteration_30000/point_cloud.ply \
    --enable_transform

4.3 Load model trained by other implementations

<b>[NOTE]</b> The commands in this section only design for third-party outputs

python viewer.py \
    Deformable-3D-Gaussians/outputs/lego \
    --vanilla_deformable \
    --reorient disable  # change to enable when loading real world scene
python viewer.py \
    4DGaussians/outputs/lego \
    --vanilla_gs4d
# Install `diff-surfel-rasterization` first
pip install git+https://github.com/hbb1/diff-surfel-rasterization.git@e0ed0207b3e0669960cfad70852200a4a5847f61
# Then start viewer
python viewer.py \
    2d-gaussian-splatting/outputs/Truck \
    --vanilla_gs2d
python viewer.py \
    SegAnyGAussians/outputs/Truck \
    --vanilla_seganygs
python viewer.py \
    mip-splatting/outputs/bicycle \
    --vanilla_mip

5. F.A.Q.

<b>Q: </b> The viewer shows my scene in unexpected orientation, how to rotate the camera, like the U and O key in the SIBR_viewer?

<b>A: </b> Check the Orientation Control on the right panel, rotate the camera frustum in the scene to the orientation you want, then click Apply Up Direction. <video src="https://github.com/yzslab/gaussian-splatting-lightning/assets/564361/7e9198b5-d853-4800-aac2-1774640a8874"></video>

<br/>

Besides: You can also click the 'Reset up direction' button. Then the viewer will use your current orientation as the reference.

<b>Q: </b> The web viewer is slow (or low fps, far from real-time).

<b>A: </b> This is expected because of the overhead of the image transfer over network. You can get around 10fps in 1080P resolution, which is enough for you to view the reconstruction quality.