Segmentation models is python library with Neural Networks for Image Segmentation based on PyTorch.
The main features of this library are:
- High level API (just two lines to create neural network)
- 4 models architectures for binary and multi class segmentation (including legendary Unet)
- 30 available encoders for each architecture
- All encoders have pre-trained weights for faster and better convergence
Since the library is built on the PyTorch framework, created segmentation model is just a PyTorch nn.Module, which can be created as easy as:
import segmentation_models_pytorch as smp
model = smp.Unet()
Depending on the task, you can change the network architecture by choosing backbones with fewer or more parameters and use pretrainded weights to initialize it:
model = smp.Unet('resnet34', encoder_weights='imagenet')
Change number of output classes in the model:
model = smp.Unet('resnet34', classes=3, activation='softmax')
All models have pretrained encoders, so you have to prepare your data the same way as during weights pretraining:
from segmentation_models_pytorch.encoders import get_preprocessing_fn
preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
- Training model for cars segmentation on CamVid dataset here.
- Training model with Catalyst (high-level framework for PyTorch) - here.
Type | Encoder names |
---|---|
VGG | vgg11, vgg13, vgg16, vgg19, vgg11bn, vgg13bn, vgg16bn, vgg19bn |
DenseNet | densenet121, densenet169, densenet201, densenet161 |
DPN | dpn68, dpn68b, dpn92, dpn98, dpn107, dpn131 |
Inception | inceptionresnetv2 |
ResNet | resnet18, resnet34, resnet50, resnet101, resnet152 |
ResNeXt | resnext50_32x4d, resnext101_32x8d, resnext101_32x16d, resnext101_32x32d, resnext101_32x48d |
SE-ResNet | se_resnet50, se_resnet101, se_resnet152 |
SE-ResNeXt | se_resnext50_32x4d, se_resnext101_32x4d |
SENet | senet154 |
Weights name | Encoder names |
---|---|
imagenet+5k | dpn68b, dpn92, dpn107 |
imagenet | vgg11, vgg13, vgg16, vgg19, vgg11bn, vgg13bn, vgg16bn, vgg19bn, densenet121, densenet169, densenet201, densenet161, dpn68, dpn98, dpn131, inceptionresnetv2, resnet18, resnet34, resnet50, resnet101, resnet152, resnext50_32x4d, resnext101_32x8d, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d, se_resnext101_32x4d, senet154 |
resnext101_32x8d, resnext101_32x16d, resnext101_32x32d, resnext101_32x48d |
model.encoder
- pretrained backbone to extract features of different spatial resolutionmodel.decoder
- segmentation head, depends on models architecture (Unet
/Linknet
/PSPNet
/FPN
)model.activation
- output activation function, one ofsigmoid
,softmax
model.forward(x)
- sequentially passx
through model`s encoder and decoder (return logits!)model.predict(x)
- inference method, switch model to.eval()
mode, call.forward(x)
and apply activation function withtorch.no_grad()
PyPI version:
$ pip install segmentation-models-pytorch
Latest version from source:
$ pip install git+https://github.com/qubvel/segmentation_models.pytorch
Project is distributed under MIT License
$ docker build -f docker/Dockerfile.dev -t smp:dev .
$ docker run --rm smp:dev pytest -p no:cacheprovider