Home

Awesome

All in Tokens: Unifying Output Space of Visual Tasks via Soft Token

PWC

By Jia Ning*, Chen Li*, Zheng Zhang*, Zigang Geng, Qi Dai, Kun He, Han Hu

Introduction

AiT is initially described in arxiv, which is a framework to unify the output space of visual tasks. We demonstrate a single unified model that simultaneously handles two typical visual tasks of instance segmentation and depth estimation, which have discrete/fixed-length and continuous/varied-length outputs, respectively. We propose several new techniques that take into account the particularity of visual tasks: 1) Soft tokens. We employ soft tokens to represent the task output. Unlike hard tokens in the common VQ-VAE which are assigned one-hot to discrete codebooks/vocabularies, the soft tokens are assigned softly to the codebook embeddings. Soft tokens can improve the accuracy of both the next token inference and decoding the task output; 2) Mask augmentation. Many visual tasks have corruption, undefined or invalid values in label annotations, i.e., occluded area of depth maps. We show that a mask augmentation technique can greatly benefit these tasks. With these new techniques and other designs, we show that the proposed general-purpose task solver can perform both instance segmentation and depth estimation well. Particularly, we achieve 0.275 RMSE on the specific task of NYUv2 depth estimation, setting a new record on this benchmark.

teaser

Results and Models

Results on COCO instance segmentation

<div style="width: 100pt"> ModelBox APMask APVQ-VAE ModelTask-Solver Model
AiT(SwinV2-B)43.334.2vqvae_insseg.ptmodel
AiT(SwinV2-B) w/o soft token43.631.1(-3.1)vqvae_insseg.ptmodel

Results on NYUv2 depth estimation

<div style="width: 100pt"> Model</div>D1D2D3Abs RelRMSELog10VQ-VAE <br> ModelTask-Solver <br> Model
AiT(SwinV2-B)0.9340.9910.9980.0870.3050.037vqvae_depth.ptmodel
AiT-P(SwinV2-B)0.9400.9920.9980.0850.3010.036vqvae_depth.ptmodel
AiT(SwinV2-B) w/o soft token0.9320.9910.9980.0890.3180.038vqvae_depth.ptmodel
AiT(SwinV2-L)0.9490.9930.9990.0790.2840.034vqvae_depth.ptmodel
AiT-P(SwinV2-L)0.9540.9940.9990.0760.2750.033vqvae_depth.ptmodel

Joint training results on COCO and NYUv2

<div style="width: 100pt"> Model</div>Box APMask APRMSEVQ-VAE ModelTask-Solver <br> Model
AiT(SwinV2-B)42.234.10.310vqvae_depth.pt/vqvae_insseg.ptmodel

Usage

Installation

We recommend using pytorch>=1.10, other packages can be found in requirements.txt. To install boundary-iou-api, please using the following command:

git clone https://github.com/bowenc0221/boundary-iou-api && cd boundary-iou-api && pip install -e .

Data/Pre-training model Preparation

  1. Download the NYU Depth V2 dataset, COCO datasets, our preprocess box-cropped binary instance masks, named maskcoco, and organize the data according to the following directory structure:
AiT
├── ait
├── vae
├── data
│   ├── coco
│   │   ├── annotations
│   │   ├── train2017
│   │   ├── val2017
│   │   ├── test2017
│   ├── maskcoco
│   ├── nyu_depth_v2
  1. Create the data links using following commands:
ln -s data ait/data
ln -s data vae/data
  1. Download pre-trained backbone models swin_v2_base_densesimmim.pth and swin_v2_large_densesimmim.pth.

Training

Training VQ-VAE on depth estimation:

cd vae
python -m torch.distributed.launch --nproc_per_node=${N_GPUS} train_depth_vqvae_dist.py  configs/depth/ait_depth_vqvae.py --cfg-options <custom-configs>

Training VQ-VAE on instance segmentation:

cd vae
python -m torch.distributed.launch --nproc_per_node=${N_GPUS} train_insseg_vqvae_dist.py  configs/insseg/ait_insseg_vqvae.py --cfg-options <custom-configs>

Training task-solver on depth estimation:

cd ait

# Train auto-regressive model
python -m torch.distributed.launch --nproc_per_node=8 code/train.py configs/swinv2b_480reso_depthonly.py --cfg-options model.backbone.init_cfg.checkpoint=swin_v2_base_densesimmim.pth model.task_heads.depth.vae_cfg.pretrained=vqvae_depth.pt # for AR training

# Train parallel model
python -m torch.distributed.launch --nproc_per_node=8 code/train.py configs/swinv2b_480reso_parallel_depthonly.py --cfg-options model.backbone.init_cfg.checkpoint=swin_v2_base_densesimmim.pth model.task_heads.depth.vae_cfg.pretrained=vqvae_depth.pt # for parallel training

Training task-solver on object detection

cd ait
python -m torch.distributed.launch --nproc_per_node=16 --nnodes=2 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} code/train.py configs/swinv2b_640reso_detonly.py --cfg-options model.backbone.init_cfg.checkpoint=swin_v2_base_densesimmim.pth

Note: We use the pre-trainined object detection model to initialize the instance segmentation models and joint-training models to save training cost, please download the pre-trained model (ait_det_swinv2b_wodec.pth) before training on instance segmentation and joint training setting.

Training task-solver on instance segmentation

python -m torch.distributed.launch --nproc_per_node=16 code/train.py configs/swinv2b_640reso_inssegonly.py --cfg-options model.backbone.init_cfg.checkpoint=swin_v2_base_densesimmim.pth model.task_heads.insseg.vae_cfg.pretrained=vqvae_insseg.pt load_from=ait_det_swinv2b_wodec.pth

Joint training on instance segmentation and depth estimation

python -m torch.distributed.launch --nproc_per_node=16 --nnodes=4 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} code/train.py configs/swinv2b_640reso_joint.py --cfg-options model.backbone.init_cfg.checkpoint=swin_v2_base_densesimmim.pth model.task_heads.insseg.vae_cfg.pretrained=vqvae_insseg.pt model.task_heads.depth.vae_cfg.pretrained=vqvae_depth.pt load_from=ait_det_swinv2b_wodec.pth

Inference

Evaluate on depth estimation

cd ait

# Evaluating auto-regressive model
python -m torch.distributed.launch --nproc_per_node=8 code/train.py configs/swinv2b_480reso_depthonly.py  --cfg-options model.task_heads.depth.vae_cfg.pretrained=vqvae_depth.pt --eval <model_checkpiont>

# Evaluating parallele model
python -m torch.distributed.launch --nproc_per_node=8 code/train.py configs/swinv2b_480reso_parallel_depthonly.py  --cfg-options model.task_heads.depth.vae_cfg.pretrained=vqvae_depth.pt --eval <model_checkpiont>

Evaluate on instance segmentation

cd ait

python -m torch.distributed.launch --nproc_per_node=8 code/train.py configs/swinv2b_640reso_inssegonly.py --cfg-options model.task_heads.insseg.vae_cfg.pretrained=vqvae_insseg.pt --eval <model_checkpiont>

Evaluate on both depth estimation and instance segmentation

cd ait

python -m torch.distributed.launch --nproc_per_node=8 code/train.py configs/swinv2b_640reso_joint.py --cfg-options model.task_heads.insseg.vae_cfg.pretrained=vqvae_insseg.pt model.task_heads.depth.vae_cfg.pretrained=vqvae_depth.pt --eval <model_checkpiont>

Citation

@article{ning2023all,
  title={All in Tokens: Unifying Output Space of Visual Tasks via Soft Token},
  author={Ning, Jia and Li, Chen and Zhang, Zheng and Geng, Zigang and Dai, Qi and He, Kun and Hu, Han},
  journal={arXiv preprint arXiv:2301.02229},
  year={2023}
}