Home

Awesome

Segmentation Models Pytroch 3D

Python library with Neural Networks for Volume (3D) Segmentation based on PyTorch.

This library is based on famous Segmentation Models Pytorch library for images. Most of the documentation can be used directly from there.

Installation

Quick start

Segmentation model is just a PyTorch nn.Module, which can be created as easy as:

import segmentation_models_pytorch_3d as smp
import torch

model = smp.Unet(
    encoder_name="efficientnet-b0", # choose encoder, e.g. resnet34
    in_channels=1,                  # model input channels (1 for gray-scale volumes, 3 for RGB, etc.)
    classes=3,                      # model output channels (number of classes in your dataset)
)

# Shape of input (B, C, H, W, D). B - batch size, C - channels, H - height, W - width, D - depth
res = model(torch.randn(4, 1, 64, 64, 64)) 

Models

Architectures

Encoders

The following is a list of supported encoders in the SMP. Select the appropriate family of encoders and click to expand the table and select a specific encoder and its pre-trained weights (encoder_name and encoder_weights parameters).

<details> <summary style="margin-left: 25px;">ResNet</summary> <div style="margin-left: 25px;">
EncoderWeightsParams, M
resnet18imagenet / ssl / swsl11M
resnet34imagenet21M
resnet50imagenet / ssl / swsl23M
resnet101imagenet42M
resnet152imagenet58M
</div> </details> <details> <summary style="margin-left: 25px;">ResNeXt</summary> <div style="margin-left: 25px;">
EncoderWeightsParams, M
resnext50_32x4dimagenet / ssl / swsl22M
resnext101_32x4dssl / swsl42M
resnext101_32x8dimagenet / instagram / ssl / swsl86M
resnext101_32x16dinstagram / ssl / swsl191M
resnext101_32x32dinstagram466M
resnext101_32x48dinstagram826M
</div> </details> <details> <summary style="margin-left: 25px;">SE-Net</summary> <div style="margin-left: 25px;">
EncoderWeightsParams, M
senet154imagenet113M
se_resnet50imagenet26M
se_resnet101imagenet47M
se_resnet152imagenet64M
se_resnext50_32x4dimagenet25M
se_resnext101_32x4dimagenet46M
</div> </details> <details> <summary style="margin-left: 25px;">DenseNet</summary> <div style="margin-left: 25px;">
EncoderWeightsParams, M
densenet121imagenet6M
densenet169imagenet12M
densenet201imagenet18M
densenet161imagenet26M
</div> </details> <details> <summary style="margin-left: 25px;">EfficientNet</summary> <div style="margin-left: 25px;">
EncoderWeightsParams, M
efficientnet-b0imagenet4M
efficientnet-b1imagenet6M
efficientnet-b2imagenet7M
efficientnet-b3imagenet10M
efficientnet-b4imagenet17M
efficientnet-b5imagenet28M
efficientnet-b6imagenet40M
efficientnet-b7imagenet63M
</div> </details> <details> <summary style="margin-left: 25px;">DPN</summary> <div style="margin-left: 25px;">
EncoderWeightsParams, M
dpn68imagenet11M
dpn68bimagenet+5k11M
dpn92imagenet+5k34M
dpn98imagenet58M
dpn107imagenet+5k84M
dpn131imagenet76M
</div> </details> <details> <summary style="margin-left: 25px;">VGG</summary> <div style="margin-left: 25px;">
EncoderWeightsParams, M
vgg11imagenet9M
vgg11_bnimagenet9M
vgg13imagenet9M
vgg13_bnimagenet9M
vgg16imagenet14M
vgg16_bnimagenet14M
vgg19imagenet20M
vgg19_bnimagenet20M
</div> </details> <details> <summary style="margin-left: 25px;">Mix Vision Transformer</summary> <div style="margin-left: 25px;">

Backbone from SegFormer pretrained on Imagenet! Can be used with other decoders from package, you can combine Mix Vision Transformer with Unet, FPN and others!

Limitations:

EncoderWeightsParams, M
mit_b0imagenet3M
mit_b1imagenet13M
mit_b2imagenet24M
mit_b3imagenet44M
mit_b4imagenet60M
mit_b5imagenet81M
</div> </details> <details> <summary style="margin-left: 25px;">MobileOne</summary> <div style="margin-left: 25px;">

Apple's "sub-one-ms" Backbone pretrained on Imagenet! Can be used with all decoders.

Note: In the official github repo the s0 variant has additional num_conv_branches, leading to more params than s1.

EncoderWeightsParams, M
mobileone_s0imagenet4.6M
mobileone_s1imagenet4.0M
mobileone_s2imagenet6.5M
mobileone_s3imagenet8.8M
mobileone_s4imagenet13.6M
</div> </details>

Timm 3D encoders

We now support encoders from timm_3d library. Full list available here. To use them add tu- before encoder name. Example:

encoder_name = 'tu-maxvit_base_tf_224.in21k'
model = smp.Unet(
    encoder_name=encoder_name,
    encoder_weights=None,
    in_channels=3,
    classes=1,
)

Notes for 3D version

Input size

Recommended input size for backbones can be calculated as: K = pow(N, 2/3). Where N - is size for input image for the same model in 2D variant.

For example for N = 224, K = 32. For N = 512, K = 64.

Strides

Typical strides for 2D case is 2 for H and W. It applied depth times (in almost all cases 5 times). So input image reduced from (224, 224) to (7, 7) on final layers. For 3D case because of very massive input, it's sometimes useful to control strides for every dimension independently. For this you can use input variable strides, which default values is: strides=((2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2)). Example:

Let's say you have input data of size: (224, 128, 12). You can use strides like that: ((2, 2, 2), (4, 2, 1), (2, 2, 2), (2, 2, 1), (1, 2, 3)). Output shape for these strides will be: (7, 4, 1)

import segmentation_models_pytorch_3d as smp
import torch

model = smp.Unet(
    encoder_name="resnet50",        
    in_channels=1,                  
    strides=((2, 2, 2), (4, 2, 1), (2, 2, 2), (2, 2, 1), (1, 2, 3)),
    classes=3, 
)

res = model(torch.randn(4, 1, 224, 128, 12)) 

Note: Strides currently supported by resnet-family and densenet models with Unet decoder only.

Related repositories

Citation

If you find this code useful, please cite it as:

@article{solovyev20223d,
  title={3D convolutional neural networks for stalled brain capillary detection},
  author={Solovyev, Roman and Kalinin, Alexandr A and Gabruseva, Tatiana},
  journal={Computers in Biology and Medicine},
  volume={141},
  pages={105089},
  year={2022},
  publisher={Elsevier},
  doi={10.1016/j.compbiomed.2021.105089}
}

To Do List