Awesome
DeepLabv3Plus-Pytorch
DeepLabv3, DeepLabv3+ and pretrained weights on VOC & Cityscapes.
Available Architectures
Specify the model architecture with '--model ARCH_NAME' and set the output stride with '--output_stride OUTPUT_STRIDE'.
DeepLabV3 | DeepLabV3+ |
---|---|
deeplabv3_resnet50 | deeplabv3plus_resnet50 |
deeplabv3_resnet101 | deeplabv3plus_resnet101 |
deeplabv3_mobilenet | deeplabv3plus_mobilenet |
Available models: Dropbox, Tencent Weiyun
Load the pretrained model:
model.load_state_dict( torch.load( CKPT_PATH )['model_state'] )
Visualize segmentation outputs:
outputs = model(images)
preds = outputs.max(1)[1].detach().cpu().numpy()
colorized_preds = val_dst.decode_target(preds) # To RGB images, (N, H, W, 3), ranged 0~255, numpy array
# Do whatever you like here with the colorized segmentation maps
colorized_preds = colorized_preds.transpose(0, 2, 3, 1).astype('uint8')
colorized_preds = Image.fromarray(colorized_preds[0]) # to PIL Image
Atrous Separable Convolution
Atrous Separable Convolution is supported in this repo. We provide a simple tool network.convert_to_separable_conv
to convert nn.Conv2d
to AtrousSeparableConvolution
. Please run main.py with '--separable_conv' if it is required. See 'main.py' and 'network/_deeplab.py' for more details.
Datasets
Results
Performance on Pascal VOC2012 Aug (21 classes, 513 x 513)
Training: 513x513 random crop
validation: 513x513 center crop
Model | Batch Size | FLOPs | train/val OS | mIoU | Dropbox | Tencent Weiyun |
---|---|---|---|---|---|---|
DeepLabV3-MobileNet | 16 | 6.0G | 16/16 | 0.701 | Download | Download |
DeepLabV3-ResNet50 | 16 | 51.4G | 16/16 | 0.769 | Download | Download |
DeepLabV3-ResNet101 | 16 | 72.1G | 16/16 | 0.773 | Download | Download |
DeepLabV3Plus-MobileNet | 16 | 17.0G | 16/16 | 0.711 | Download | Download |
DeepLabV3Plus-ResNet50 | 16 | 62.7G | 16/16 | 0.772 | Download | Download |
DeepLabV3Plus-ResNet101 | 16 | 83.4G | 16/16 | 0.783 | Download | Download |
Performance on Cityscapes (19 classes, 1024 x 2048)
Training: 768x768 random crop
validation: 1024x2048
Model | Batch Size | FLOPs | train/val OS | mIoU | Dropbox | Tencent Weiyun |
---|---|---|---|---|---|---|
DeepLabV3Plus-MobileNet | 16 | 135G | 16/16 | 0.721 | Download | Download |
DeepLabV3Plus-ResNet101 | 16 | N/A | 16/16 | 0.762 | Download | Comming Soon |
Segmentation Results on Pascal VOC2012 (DeepLabv3Plus-MobileNet)
<div> <img src="samples/1_image.png" width="20%"> <img src="samples/1_target.png" width="20%"> <img src="samples/1_pred.png" width="20%"> <img src="samples/1_overlay.png" width="20%"> </div> <div> <img src="samples/23_image.png" width="20%"> <img src="samples/23_target.png" width="20%"> <img src="samples/23_pred.png" width="20%"> <img src="samples/23_overlay.png" width="20%"> </div> <div> <img src="samples/114_image.png" width="20%"> <img src="samples/114_target.png" width="20%"> <img src="samples/114_pred.png" width="20%"> <img src="samples/114_overlay.png" width="20%"> </div>Segmentation Results on Cityscapes (DeepLabv3Plus-MobileNet)
<div> <img src="samples/city_1_target.png" width="45%"> <img src="samples/city_1_overlay.png" width="45%"> </div> <div> <img src="samples/city_6_target.png" width="45%"> <img src="samples/city_6_overlay.png" width="45%"> </div>Visualization of training
Quick Start
1. Requirements
pip install -r requirements.txt
2. Prepare Datasets
Pascal VOC
You can run train.py with "--download" option to download and extract pascal voc dataset. The defaut path is './datasets/data':
/datasets
/data
/VOCdevkit
/VOC2012
/SegmentationClass
/JPEGImages
...
...
/VOCtrainval_11-May-2012.tar
...
Pascal VOC trainaug (Recommended!!)
See chapter 4 of [2]
The original dataset contains 1464 (train), 1449 (val), and 1456 (test) pixel-level annotated images. We augment the dataset by the extra annotations provided by [76], resulting in 10582 (trainaug) training images. The performance is measured in terms of pixel intersection-over-union averaged across the 21 classes (mIOU).
./datasets/data/train_aug.txt includes names of 10582 trainaug images (val images are excluded). You need to download additional labels from Dropbox or Tencent Weiyun. Those labels come from DrSleep's repo.
Extract trainaug labels (SegmentationClassAug) to the VOC2012 directory.
/datasets
/data
/VOCdevkit
/VOC2012
/SegmentationClass
/SegmentationClassAug # <= the trainaug labels
/JPEGImages
...
...
/VOCtrainval_11-May-2012.tar
...
3. Train on Pascal VOC2012 Aug
Visualize training (Optional)
Start visdom sever for visualization. Please remove '--enable_vis' if visualization is not needed.
# Run visdom server on port 28333
visdom -port 28333
Train with OS=16
Run main.py with "--year 2012_aug" to train your model on Pascal VOC2012 Aug. You can also parallel your training on 4 GPUs with '--gpu_id 0,1,2,3'
Note: There is no SyncBN in this repo, so training with multple GPUs may degrades the performance. See PyTorch-Encoding for more details about SyncBN
python main.py --model deeplabv3plus_mobilenet --enable_vis --vis_port 28333 --gpu_id 0 --year 2012_aug --crop_val --lr 0.01 --crop_size 513 --batch_size 16 --output_stride 16
Continue training
Run main.py with '--continue_training' to restore the state_dict of optimizer and scheduler from YOUR_CKPT.
python main.py ... --ckpt YOUR_CKPT --continue_training
4. Test
Results will be saved at ./results.
python main.py --model deeplabv3plus_mobilenet --enable_vis --vis_port 28333 --gpu_id 0 --year 2012_aug --crop_val --lr 0.01 --crop_size 513 --batch_size 16 --output_stride 16 --ckpt checkpoints/best_deeplabv3plus_mobilenet_voc_os16.pth --test_only --save_val_results
Cityscapes
1. Download cityscapes and extract it to 'datasets/data/cityscapes'
/datasets
/data
/cityscapes
/gtFine
/leftImg8bit
2. Train your model on Cityscapes
python main.py --model deeplabv3plus_mobilenet --dataset cityscapes --enable_vis --vis_port 28333 --gpu_id 0 --lr 0.1 --crop_size 768 --batch_size 16 --output_stride 16 --data_root ./datasets/data/cityscapes
Reference
[1] Rethinking Atrous Convolution for Semantic Image Segmentation
[2] Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation