Home

Awesome

3DD-TTA

Official PyTorch Implementation of the Paper "Test-Time Adaptation of 3D Point Clouds via Denoising Diffusion Models"

Test-time adaptation (TTA) of 3D point clouds is essential for addressing discrepancies between training and testing samples, particularly in corrupted point clouds like those from LiDAR data. Adapting models online to distribution shifts is crucial, as training for every variation is impractical. Existing methods often fine-tune models using self-supervised learning or pseudo-labeling, which can result in forgetting source domain knowledge.

We propose a training-free, online 3D TTA method called 3DD-TTA (3D Denoising Diffusion Test-Time Adaptation), which adapts input point clouds using a diffusion strategy while keeping the source model unchanged. A Variational Autoencoder (VAE) encodes corrupted point clouds into latent spaces, followed by a denoising diffusion process.

<p align="center"> <img src="images/before-after.gif" alt="TTA of Pointcloud perturbed by Impulse Noise using 3DD-TTA"> </p>

3DD-TTA Process

Key Features:

Install:

# Create a new Conda environment named "3dd_tta_env" with Python 3.8
conda create --name 3dd_tta_env python=3.8
conda activate 3dd_tta_env

# Install PyTorch, torchvision, and torchaudio with CUDA 12.1 support
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

# Install all dependencies from requirements.txt
pip install -r requirements.txt

# Compile and install the Earth Mover's Distance (EMD) extension (used for point cloud comparison)
cd ./extensions/emd
python setup.py install --user
cd ../..

# Install the PointNet++ operations library (required for point cloud processing)
cd Pointnet2_PyTorch/pointnet2_ops_lib
pip install .
cd ../..

# Install KNN_CUDA (GPU-accelerated k-nearest neighbor functionality)
pip install --upgrade https://github.com/unlimblue/KNN_CUDA/releases/download/0.2/KNN_CUDA-0.2-py3-none-any.whl

# Install OpenAI's CLIP model (used for vision-language tasks)
pip install git+https://github.com/openai/CLIP.git 

# Build and package the project
python build_pkg.py

Data Preparation

Our code supports three datasets: ModelNet40, ShapeNetCore, and ScanObjectNN. To prepare the data, please follow the instructions in the preparation guide.

Obtaining Pre-trained Source Models

You can download the source model (PointMAE) pretrained on Modelnet40, ShapeNet, and ScanObjectNN from here. Put the PointMAE checkpoints inside pointnet_ckpts folder. To download the pretrained diffusion model you can use the following link. Put the diffusion checkpoints inside lion_ckpts folder.

Qualitative Evaluation of 3dd_tta Model

To evaluate the model qualitatively, use the following command:

python demo_3dd_tta.py --diff_ckpt=./lion_ckpts/epoch_10999_iters_2100999.pt --denoising_step=35 --dataset_root=./data/modelnet40_c --corruption=background --sample_id=11

This code will:

  1. Read the sample determined by the sample_id and corrupted by the specified corruption type.
  2. Adapt it back to the source domain.

The output will automatically be saved in the ./outputs/qualitative folder.

Quantitative Evaluation of 3dd_tta Model

To evaluate the model quantitatively, use the following command:

python main_3dd_tta.py --batch_size=32 --pointmae_ckpt=./pointnet_ckpts/modelnet_jt.pth --diff_ckpt=./lion_ckpts/epoch_10999_iters_2100999.pt --dataset_name=modelnet-c --dataset_root=./data/modelnet40_c

Quantitative Evaluation Instructions

Results:

Our method demonstrates superior generalization across multiple datasets, including ShapeNet, ModelNet40 and ScanObjectNN.

Classification accuracies on ShapeNet-c

Methodsunigaussbackimpuupsrbfrbf-iden-dden-ishearrotcutdistocclidarMean
Point-MAE (src)72.566.415.060.672.872.673.485.285.874.142.884.371.78.44.359.3
DUA76.170.114.360.976.271.672.980.083.877.157.575.072.111.912.160.8
TTT-Rot74.672.423.159.974.973.875.081.482.069.249.179.972.714.012.060.9
SHOT44.842.512.137.645.043.744.248.449.445.032.646.339.16.25.936.2
T3A70.060.56.540.767.867.268.579.579.972.742.979.166.87.75.654.4
TENT44.542.912.438.044.643.344.348.749.445.734.848.643.010.010.937.4
MATE-S77.874.74.366.278.676.375.386.186.679.256.184.176.112.313.163.1
3DD-TTA (ours)81.680.777.677.285.476.575.386.588.276.350.485.476.514.914.269.8

Classification accuracies on ModelNet40-c

Methodsunigaussbackimpuupsrbfrbf-iden-dden-ishearrotcutdistocclidarMean
Point-MAE (src)62.457.032.058.872.161.464.275.180.867.631.370.464.836.229.157.6
DUA65.058.514.748.568.862.863.262.166.268.846.253.864.741.236.554.7
TTT-Rot61.358.334.548.966.763.663.959.868.655.227.354.664.040.029.153.0
SHOT29.628.29.825.432.730.330.130.931.232.122.827.329.420.818.626.6
T3A64.162.333.465.075.463.266.757.463.072.732.854.467.739.118.355.7
TENT29.228.710.125.133.130.329.130.431.531.822.727.028.620.719.026.5
MATE-S75.071.127.567.578.769.572.079.184.575.444.473.672.939.734.264.3
3DD-TTA (ours)77.579.149.980.381.863.866.978.584.763.733.474.765.239.942.266.1