Code adapted from https://github.com/qubvel/ttach.
imgs = TensorImage(torch.randn(4, 1, 356, 356))
for t in ['mean', 'max', 'std']:
m = Merger()
for _ in range(10): m.append(imgs)
test_eq(imgs.shape, m.result(t).shape)
t = HorizontalFlip()
aug = t.apply_aug_image(imgs)
deaug = t.apply_deaug_mask(aug)
test_eq(imgs, deaug)
t = VerticalFlip()
aug = t.apply_aug_image(imgs)
deaug = t.apply_deaug_mask(aug)
test_eq(imgs, deaug)
t = Rotate90([180])
aug = t.apply_aug_image(imgs)
deaug = t.apply_deaug_mask(aug)
test_eq(imgs, deaug)
Pipeline Test
tfms=[HorizontalFlip(),VerticalFlip(), Rotate90(angles=[90,180,270])]
c = Compose(tfms)
m = Merger()
for t in c:
aug = t.augment_image(imgs)
deaug = t.deaugment_mask(aug)
test_eq(imgs, deaug)
m.append(deaug)
test_close(imgs, m.result())