Home

Awesome

<div align="center">

logo
Python library with Neural Networks for Image
Segmentation based on PyTorch.

Generic badge GitHub Workflow Status (branch) Read the Docs <br> PyPI PyPI - Downloads <br> PyTorch - Version Python - Version

</div>

The main features of this library are:

📚 Project Documentation 📚

Visit Read The Docs Project Page or read the following README to know more about Segmentation Models Pytorch (SMP for short) library

📋 Table of content

  1. Quick start
  2. Examples
  3. Models
    1. Architectures
    2. Encoders
    3. Timm Encoders
  4. Models API
    1. Input channels
    2. Auxiliary classification output
    3. Depth
  5. Installation
  6. Competitions won with the library
  7. Contributing
  8. Citing
  9. License

⏳ Quick start <a name="start"></a>

1. Create your first Segmentation model with SMP

The segmentation model is just a PyTorch torch.nn.Module, which can be created as easy as:

import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=3,                      # model output channels (number of classes in your dataset)
)

2. Configure data preprocessing

All encoders have pretrained weights. Preparing your data the same way as during weights pre-training may give you better results (higher metric score and faster convergence). It is not necessary in case you train the whole model, not only the decoder.

from segmentation_models_pytorch.encoders import get_preprocessing_fn

preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')

Congratulations! You are done! Now you can train your model with your favorite framework!

💡 Examples <a name="examples"></a>

📦 Models <a name="models"></a>

Architectures <a name="architectures"></a>

Encoders <a name="encoders"></a>

The following is a list of supported encoders in the SMP. Select the appropriate family of encoders and click to expand the table and select a specific encoder and its pre-trained weights (encoder_name and encoder_weights parameters).

<details> <summary style="margin-left: 25px;">ResNet</summary> <div style="margin-left: 25px;">
EncoderWeightsParams, M
resnet18imagenet / ssl / swsl11M
resnet34imagenet21M
resnet50imagenet / ssl / swsl23M
resnet101imagenet42M
resnet152imagenet58M
</div> </details> <details> <summary style="margin-left: 25px;">ResNeXt</summary> <div style="margin-left: 25px;">
EncoderWeightsParams, M
resnext50_32x4dimagenet / ssl / swsl22M
resnext101_32x4dssl / swsl42M
resnext101_32x8dimagenet / instagram / ssl / swsl86M
resnext101_32x16dinstagram / ssl / swsl191M
resnext101_32x32dinstagram466M
resnext101_32x48dinstagram826M
</div> </details> <details> <summary style="margin-left: 25px;">ResNeSt</summary> <div style="margin-left: 25px;">
EncoderWeightsParams, M
timm-resnest14dimagenet8M
timm-resnest26dimagenet15M
timm-resnest50dimagenet25M
timm-resnest101eimagenet46M
timm-resnest200eimagenet68M
timm-resnest269eimagenet108M
timm-resnest50d_4s2x40dimagenet28M
timm-resnest50d_1s4x24dimagenet23M
</div> </details> <details> <summary style="margin-left: 25px;">Res2Ne(X)t</summary> <div style="margin-left: 25px;">
EncoderWeightsParams, M
timm-res2net50_26w_4simagenet23M
timm-res2net101_26w_4simagenet43M
timm-res2net50_26w_6simagenet35M
timm-res2net50_26w_8simagenet46M
timm-res2net50_48w_2simagenet23M
timm-res2net50_14w_8simagenet23M
timm-res2next50imagenet22M
</div> </details> <details> <summary style="margin-left: 25px;">RegNet(x/y)</summary> <div style="margin-left: 25px;">
EncoderWeightsParams, M
timm-regnetx_002imagenet2M
timm-regnetx_004imagenet4M
timm-regnetx_006imagenet5M
timm-regnetx_008imagenet6M
timm-regnetx_016imagenet8M
timm-regnetx_032imagenet14M
timm-regnetx_040imagenet20M
timm-regnetx_064imagenet24M
timm-regnetx_080imagenet37M
timm-regnetx_120imagenet43M
timm-regnetx_160imagenet52M
timm-regnetx_320imagenet105M
timm-regnety_002imagenet2M
timm-regnety_004imagenet3M
timm-regnety_006imagenet5M
timm-regnety_008imagenet5M
timm-regnety_016imagenet10M
timm-regnety_032imagenet17M
timm-regnety_040imagenet19M
timm-regnety_064imagenet29M
timm-regnety_080imagenet37M
timm-regnety_120imagenet49M
timm-regnety_160imagenet80M
timm-regnety_320imagenet141M
</div> </details> <details> <summary style="margin-left: 25px;">GERNet</summary> <div style="margin-left: 25px;">
EncoderWeightsParams, M
timm-gernet_simagenet6M
timm-gernet_mimagenet18M
timm-gernet_limagenet28M
</div> </details> <details> <summary style="margin-left: 25px;">SE-Net</summary> <div style="margin-left: 25px;">
EncoderWeightsParams, M
senet154imagenet113M
se_resnet50imagenet26M
se_resnet101imagenet47M
se_resnet152imagenet64M
se_resnext50_32x4dimagenet25M
se_resnext101_32x4dimagenet46M
</div> </details> <details> <summary style="margin-left: 25px;">SK-ResNe(X)t</summary> <div style="margin-left: 25px;">
EncoderWeightsParams, M
timm-skresnet18imagenet11M
timm-skresnet34imagenet21M
timm-skresnext50_32x4dimagenet25M
</div> </details> <details> <summary style="margin-left: 25px;">DenseNet</summary> <div style="margin-left: 25px;">
EncoderWeightsParams, M
densenet121imagenet6M
densenet169imagenet12M
densenet201imagenet18M
densenet161imagenet26M
</div> </details> <details> <summary style="margin-left: 25px;">Inception</summary> <div style="margin-left: 25px;">
EncoderWeightsParams, M
inceptionresnetv2imagenet / imagenet+background54M
inceptionv4imagenet / imagenet+background41M
xceptionimagenet22M
</div> </details> <details> <summary style="margin-left: 25px;">EfficientNet</summary> <div style="margin-left: 25px;">
EncoderWeightsParams, M
efficientnet-b0imagenet4M
efficientnet-b1imagenet6M
efficientnet-b2imagenet7M
efficientnet-b3imagenet10M
efficientnet-b4imagenet17M
efficientnet-b5imagenet28M
efficientnet-b6imagenet40M
efficientnet-b7imagenet63M
timm-efficientnet-b0imagenet / advprop / noisy-student4M
timm-efficientnet-b1imagenet / advprop / noisy-student6M
timm-efficientnet-b2imagenet / advprop / noisy-student7M
timm-efficientnet-b3imagenet / advprop / noisy-student10M
timm-efficientnet-b4imagenet / advprop / noisy-student17M
timm-efficientnet-b5imagenet / advprop / noisy-student28M
timm-efficientnet-b6imagenet / advprop / noisy-student40M
timm-efficientnet-b7imagenet / advprop / noisy-student63M
timm-efficientnet-b8imagenet / advprop84M
timm-efficientnet-l2noisy-student474M
timm-efficientnet-lite0imagenet4M
timm-efficientnet-lite1imagenet5M
timm-efficientnet-lite2imagenet6M
timm-efficientnet-lite3imagenet8M
timm-efficientnet-lite4imagenet13M
</div> </details> <details> <summary style="margin-left: 25px;">MobileNet</summary> <div style="margin-left: 25px;">
EncoderWeightsParams, M
mobilenet_v2imagenet2M
timm-mobilenetv3_large_075imagenet1.78M
timm-mobilenetv3_large_100imagenet2.97M
timm-mobilenetv3_large_minimal_100imagenet1.41M
timm-mobilenetv3_small_075imagenet0.57M
timm-mobilenetv3_small_100imagenet0.93M
timm-mobilenetv3_small_minimal_100imagenet0.43M
</div> </details> <details> <summary style="margin-left: 25px;">DPN</summary> <div style="margin-left: 25px;">
EncoderWeightsParams, M
dpn68imagenet11M
dpn68bimagenet+5k11M
dpn92imagenet+5k34M
dpn98imagenet58M
dpn107imagenet+5k84M
dpn131imagenet76M
</div> </details> <details> <summary style="margin-left: 25px;">VGG</summary> <div style="margin-left: 25px;">
EncoderWeightsParams, M
vgg11imagenet9M
vgg11_bnimagenet9M
vgg13imagenet9M
vgg13_bnimagenet9M
vgg16imagenet14M
vgg16_bnimagenet14M
vgg19imagenet20M
vgg19_bnimagenet20M
</div> </details> <details> <summary style="margin-left: 25px;">Mix Vision Transformer</summary> <div style="margin-left: 25px;">

Backbone from SegFormer pretrained on Imagenet! Can be used with other decoders from package, you can combine Mix Vision Transformer with Unet, FPN and others!

Limitations:

EncoderWeightsParams, M
mit_b0imagenet3M
mit_b1imagenet13M
mit_b2imagenet24M
mit_b3imagenet44M
mit_b4imagenet60M
mit_b5imagenet81M
</div> </details> <details> <summary style="margin-left: 25px;">MobileOne</summary> <div style="margin-left: 25px;">

Apple's "sub-one-ms" Backbone pretrained on Imagenet! Can be used with all decoders.

Note: In the official github repo the s0 variant has additional num_conv_branches, leading to more params than s1.

EncoderWeightsParams, M
mobileone_s0imagenet4.6M
mobileone_s1imagenet4.0M
mobileone_s2imagenet6.5M
mobileone_s3imagenet8.8M
mobileone_s4imagenet13.6M
</div> </details>

* ssl, swsl - semi-supervised and weakly-supervised learning on ImageNet (repo).

Timm Encoders <a name="timm"></a>

docs

Pytorch Image Models (a.k.a. timm) has a lot of pretrained models and interface which allows using these models as encoders in smp, however, not all models are supported

Total number of supported encoders: 549

🔁 Models API <a name="api"></a>

Input channels

Input channels parameter allows you to create models, which process tensors with arbitrary number of channels. If you use pretrained weights from imagenet - weights of first convolution will be reused. For 1-channel case it would be a sum of weights of first convolution layer, otherwise channels would be populated with weights like new_weight[:, i] = pretrained_weight[:, i % 3] and than scaled with new_weight * 3 / new_in_channels.

model = smp.FPN('resnet34', in_channels=1)
mask = model(torch.ones([1, 1, 64, 64]))
Auxiliary classification output

All models support aux_params parameters, which is default set to None. If aux_params = None then classification auxiliary output is not created, else model produce not only mask, but also label output with shape NC. Classification head consists of GlobalPooling->Dropout(optional)->Linear->Activation(optional) layers, which can be configured by aux_params as follows:

aux_params=dict(
    pooling='avg',             # one of 'avg', 'max'
    dropout=0.5,               # dropout ratio, default is None
    activation='sigmoid',      # activation function, default is None
    classes=4,                 # define number of output labels
)
model = smp.Unet('resnet34', classes=4, aux_params=aux_params)
mask, label = model(x)
Depth

Depth parameter specify a number of downsampling operations in encoder, so you can make your model lighter if specify smaller depth.

model = smp.Unet('resnet34', encoder_depth=4)

🛠 Installation <a name="installation"></a>

PyPI version:

$ pip install segmentation-models-pytorch

Latest version from source:

$ pip install git+https://github.com/qubvel/segmentation_models.pytorch

🏆 Competitions won with the library

Segmentation Models package is widely used in the image segmentation competitions. Here you can find competitions, names of the winners and links to their solutions.

🤝 Contributing

Install SMP

make install_dev  # create .venv, install SMP in dev mode

Run tests and code checks

make fixup         # Ruff for formatting and lint checks

Update table with encoders

make table        # generate a table with encoders and print to stdout

📝 Citing

@misc{Iakubovskii:2019,
  Author = {Pavel Iakubovskii},
  Title = {Segmentation Models Pytorch},
  Year = {2019},
  Publisher = {GitHub},
  Journal = {GitHub repository},
  Howpublished = {\url{https://github.com/qubvel/segmentation_models.pytorch}}
}

🛡️ License <a name="license"></a>

The project is primarily distributed under MIT License, while some files are subject to other licenses. Please refer to LICENSES and license statements in each file for careful check, especially for commercial use.