Pytorch segmentation models.
Segmenation Models Pytorch Integration
From the website:
- High level API (just two lines to create a neural network)
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
- 104 available encoders
- All encoders have pre-trained weights for faster and better convergence
See https://github.com/qubvel/segmentation_models.pytorch for API details.
bs = 2
tile_shapes = [512] #1024
in_channels = [1] #1,3,4
classes = [2] # 2,5
encoders = ENCODERS[1:2]#+ENCODERS[-1:]
for ts in tile_shapes:
for in_c in in_channels:
for c in classes:
inp = torch.randn(bs, in_c, ts, ts)
out_shape = [bs, c, ts, ts]
for arch in ARCHITECTURES:
for encoder_name in encoders:
model = create_smp_model(arch=arch,
encoder_name=encoder_name,
encoder_weights=None,
in_channels=in_c,
classes=c)
out = model(inp)
test_eq(out.shape, out_shape)
arch = 'Unet'
file = 'tst.pth'
stats = (1,1)
kwargs = {'encoder_name': 'resnet34'}
tst = create_smp_model(arch, **kwargs)
save_smp_model(tst, arch, file, stats=stats)
tst2, stats2 = load_smp_model(file)
for p1, p2 in zip(tst.parameters(), tst2.parameters()):
test_eq(p1.detach(), p2.detach())
test_eq(stats, stats2)