Pytorch segmentation models.

Segmenation Models Pytorch Integration

From the website:

  • High level API (just two lines to create a neural network)
  • 9 models architectures for binary and multi class segmentation (including legendary Unet)
  • 104 available encoders
  • All encoders have pre-trained weights for faster and better convergence

See https://github.com/qubvel/segmentation_models.pytorch for API details.

get_pretrained_options[source]

get_pretrained_options(encoder_name)

Return available options for pretrained weights for a given encoder

create_smp_model[source]

create_smp_model(arch, **kwargs)

Create segmentation_models_pytorch model

bs = 2
tile_shapes = [512] #1024
in_channels = [1] #1,3,4
classes = [2] # 2,5
encoders = ENCODERS[1:2]#+ENCODERS[-1:]

for ts in tile_shapes:
    for in_c in in_channels:
        for c in classes:
            inp = torch.randn(bs, in_c, ts, ts)
            out_shape = [bs, c, ts, ts]
            for arch in ARCHITECTURES:
                for encoder_name in encoders:
                    model = create_smp_model(arch=arch, 
                                             encoder_name=encoder_name,
                                             encoder_weights=None,
                                             in_channels=in_c, 
                                             classes=c)
                    out = model(inp)
                    test_eq(out.shape, out_shape)
/media/data/anaconda3/envs/fastai2/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /opt/conda/conda-bld/pytorch_1623448278899/work/c10/core/TensorImpl.h:1156.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)

save_smp_model[source]

save_smp_model(model, arch, file, stats=None, pickle_protocol=2)

Save smp model, optionally including stats

arch = 'Unet'
file = 'tst.pth'
stats = (1,1)
kwargs = {'encoder_name': 'resnet34'}
tst = create_smp_model(arch, **kwargs)
save_smp_model(tst, arch, file, stats=stats)

load_smp_model[source]

load_smp_model(file, device=None, strict=True, **kwargs)

Loads smp model from file

tst2, stats2 = load_smp_model(file)
for p1, p2 in zip(tst.parameters(), tst2.parameters()):
    test_eq(p1.detach(), p2.detach())
test_eq(stats, stats2)