Home

Awesome

Wonder3D -> Wonder3D++

Single Image to 3D using Cross-Domain Diffusion (CVPR 2024 Highlight).

Now extent to Wonder3D++!

Paper | Project page | Hugging Face Demo | Colab from @camenduru

Wonder3D++ reconstructs highly-detailed textured meshes from a single-view image in only 3 minutes. Wonder3D++ first generates consistent multi-view normal maps with corresponding color images via a cross-domain diffusion model and then leverages a cascaded 3D mesh extraction method to achieve fast and high-quality reconstruction.

Collaborations

Our overarching mission is to enhance the speed, affordability, and quality of 3D AIGC, making the creation of 3D content accessible to all. While significant progress has been achieved in the recent years, we acknowledge there is still a substantial journey ahead. We enthusiastically invite you to engage in discussions and explore potential collaborations in any capacity. <span style="color:red">If you're interested in connecting or partnering with us, please don't hesitate to reach out via email (xxlong@connect.hku.hk)</span> .

News

Preparation for inference

Linux System Setup.

conda create -n wonder3d_plus python=3.10
conda activate wonder3d_plus
pip install -r requirements.txt

Prepare the training data

see render_codes/README.md.

Training

Here we provide two training scripts train_mvdiffusion_mixed.py and train_mvdiffusion_joint.py.

Our Multi-Stage Training Scheme training has three stages:

  1. We initially remove the domain switcher and cross-domain attention layers, modifying only the self-attention layers into a multi-view design;
  2. We introduce the domain switcher, fine-tuning the model to generate either multi-view colors or normals from a single-view color image, guided by the domain switcher;
  3. We add cross-domain attention modules into the SD model, and only optimize the newly added parameters.

You need to modify root_dir that contain the data of the config files configs/train/stage1-mix-6views-lvis.yaml and configs/train/stage2-joint-6views-lvis.yaml accordingly.

# stage 1:
accelerate launch --config_file 8gpu.yaml train_mvdiffusion_mixed.py --config configs/train/stage1-mixed-wo-switcher.yaml

# stage 2:
accelerate launch --config_file 8gpu.yaml train_mvdiffusion_mixed.py --config configs/train/stage1-mixed-6views-image-normal.yaml

# stage 3:
accelerate launch --config_file 8gpu.yaml train_mvdiffusion_joint_stage3.py --config configs/train/stage3-joint-6views-image-normal.yaml

To train our multi-view enhancement module:

accelerate launch --config_file 8gpu.yaml train_controlnet.py  --config configs/train-controlnet/mv_controlnet_train_joint.yaml

Inference

  1. Make sure you have downloaded the following models.
Wonder3D_plus
|-- ckpts
    |-- mvdiffusion
    |-- mv_controlnett
    |-- scheduler
    |-- vae
    ...

You also can download the model in python script:

from huggingface_hub import snapshot_download
snapshot_download(repo_id='flamehaze1115/Wonder3D_plus', local_dir="./ckpts")
  1. [Optional] Download the SAM model. Put it to the sam_pt folder. Wonder3D_plus
|-- sam_pt
    |-- sam_vit_h_4b8939.pth
  1. Predict foreground mask as the alpha channel. We use Clipdrop to segment the foreground object interactively. You may also use rembg to remove the backgrounds.
# !pip install rembg
import rembg
result = rembg.remove(result)
result.show()
  1. Run Wonder3d++ in a end2end manner. Then you can check the results in the folder ./outputs. (we use rembg to remove backgrounds of the results, but the segmentations are not always perfect. May consider using Clipdrop to get masks for the generated normal maps and color images, since the quality of masks will significantly influence the reconstructed mesh quality.)
python run.py  --input_path {Path to input image or directory}\
            --output_path {Your output_path} \
            --crop_size {Default to 192. Crop size of the input image, this is a relative num that assume the resolution of input image is 256.} \
            --camera_type {The projection_type of input image, choose from 'ortho' and 'persp'.}
            --num_refine {Number of iterative refinement, default to 2.}

see example:

python run.py --input_path example_images/owl.png \
--camera_type ortho \
--output_path outputs 

Acknowledgement

We have intensively borrow codes from the following repositories. Many thanks to the authors for sharing their codes.

License

Wonder3D++ and Wonder3D is under AGPL-3.0, so any downstream solution and products (including cloud services) that include wonder3d code or a trained model (both pretrained or custom trained) inside it should be open-sourced to comply with the AGPL conditions. If you have any questions about the usage of Wonder3D++, please contact us first.