Notebook to train deep learning models or ensembles for segmentation of fluorescent labels in microscopy images.
#@markdown Please run this cell to get started.
%load_ext autoreload
%autoreload 2
try:
    from google.colab import files, drive
except ImportError:
    pass
try:
    import deepflash2
except ImportError:
    !pip install -q deepflash2==0.0.14
import zipfile
import shutil
import imageio
from sklearn.model_selection import KFold, train_test_split
from fastai.vision.all import *
from deepflash2.all import *
from deepflash2.data import _read_msk
from scipy.stats import entropy

Provide Training Data

Required data structure

  • One folder for training images
  • One folder for segmentation masks

_Examplary structure: see naming conventions_

  • [folder] images
    • [file] 0001.tif
    • [file] 0002.tif
  • [folder] masks
    • [file] 0001_mask.png
    • [file] 0002_mask.png

Option A: Upload via Google Drive (recommended, Colab only)

  • The folder in your drive must contain all files and correct folder structure.
  • See here how to organize your files in Google Drive.
  • See this stackoverflow post for browsing files with the file browser
try:
    drive.mount('/content/drive')
    path = "/content/drive/My Drive/data" #@param {type:"string"}
    path = Path(path)
    print('Path contains the following files and folders: \n', L(os.listdir(path)))
    #@markdown Follow the instructions and press Enter after copying and pasting the key.
except:
    print("Warning: Connecting to Google Drive only works on Google Colab.")
    pass

Option B: Upload via zip file (Colab only)

  • The zip file must contain all images and segmentations and correct folder structure.
  • See here how to zip files on Windows or Mac.
path = Path('data')
try:
    u_dict = files.upload()
    for key in u_dict.keys():
        unzip(path, key)
    print('Path contains the following files and folders: \n', L(os.listdir(path)))
except:
    print("Warning: File upload only works on Google Colab.")
    pass

Option C: Provide path (Local installation)

If you're working on your local machine or server, provide a path to the correct folder.

path = "" #@param {type:"string"}
path = Path(path)
print('Path contains the following files and folders: \n', L(os.listdir(path)))

Option D: Try with sample data (Testing only)

If you don't have any data available yet, try our sample data

path = Path('sample_data_cFOS')
url = "https://github.com/matjesg/deepflash2/releases/download/model_library/wue1_cFOS_small.zip"
urllib.request.urlretrieve(url, 'sample_data_cFOS.zip')
unzip(path, 'sample_data_cFOS.zip')

Check and load data

image_folder = "images" #@param {type:"string"}
mask_folder = "masks" #@param {type:"string"}
mask_suffix = "_mask.png" #@param {type:"string"}
#@markdown Number of classes: e.g., 2 for binary segmentation (foreground and background class)
n_classes = 2 #@param {type:"integer"}
#@markdown Check if you are providing instance labels (class-aware and instance-aware)
instance_labels = False #@param {type:"boolean"}

f_names = get_image_files(path/image_folder)
label_fn = lambda o: path/mask_folder/f'{o.stem}{mask_suffix}'
#Check if corresponding masks exist
mask_check = [os.path.isfile(label_fn(x)) for x in f_names]
if len(f_names)==sum(mask_check) and len(f_names)>0:
    print(f'Found {len(f_names)} images and {sum(mask_check)} masks in "{path}".')
else:
    print(f'IMAGE/MASK MISMATCH! Found {len(f_names)} images and {sum(mask_check)} masks in "{path}".')
    print('Please check the steps above.')

Customize mask weights (optional)

  • Default values should work for most of the data.
  • However, this choice can significantly change the model performance later on.
#@markdown Run to set weight parameters
border_weight_sigma=10 #@param {type:"slider", min:1, max:20, step:1}
foreground_dist_sigma=10 #@param {type:"slider", min:1, max:20, step:1}
border_weight_factor=10 #@param {type:"slider", min:1, max:50, step:1}
foreground_background_ratio= 0.1 #@param {type:"slider", min:0.1, max:1, step:0.1}

#@markdown Check if want to plot the resulting weights of one mask
plot_weights = False #@param {type:"boolean"}
#@markdown Check `reset_to_defaults` to reset your parameters.
reset_to_defaults = False #@param {type:"boolean"}

mw_dict = {'bws': 10 if reset_to_defaults else border_weight_sigma ,
           'fds': 10 if reset_to_defaults else foreground_dist_sigma, 
           'bwf': 10 if reset_to_defaults else border_weight_factor,
           'fbr' : 0.1 if reset_to_defaults else foreground_background_ratio}

#@markdown Select image number
image_number = 0 #@param {type:"slider", min:0, max:100, step:1}
if plot_weights:
    idx = np.minimum(len(f_names), image_number)
    print('Plotting mask for image', f_names[idx].name, '- Please wait.')
    msk = _read_msk(label_fn(f_names[idx]))
    _, w, _ = calculate_weights(msk, n_dims=n_classes, **mw_dict)
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12,12))
    axes[0].imshow(msk)
    axes[0].set_axis_off()
    axes[0].set_title('Mask')
    axes[1].imshow(w)
    axes[1].set_axis_off()
    axes[1].set_title('Weights')

Create mask weights

try: 
    mw_dict=mw_dict
except:
    mw_dict = {'bws': 10,'fds': 10, 'bwf': 10,'fbr' : 0.1}

ds = RandomTileDataset(f_names, label_fn, n_classes=n_classes, instance_labels=instance_labels, **mw_dict)
#@markdown Run to show data.
#@markdown Use the slider to control the number of displayed images
first_n = 3 #@param {type:"slider", min:1, max:100, step:1}
ds.show_data(max_n = first_n, figsize=(15,15), overlay=False)

Model Defintion

Select one of the available model architectures.

model_arch = 'unet_deepflash2' #@param ["unet_deepflash2",  "unet_falk2019", "unet_ronnberger2015"]

Pretrained weights

  • Select 'new' to use an untrained model (no pretrained weights)
  • Or select pretraind model weights from dropdown menu
pretrained_weights = "wue_cFOS" #@param ["new", "wue_cFOS", "wue_Parv", "wue_GFAP", "wue_GFP", "wue_OPN3"]
pre = False if pretrained_weights=="new" else True
n_channels = ds.get_data(max_n=1)[0].shape[-1]
model = torch.hub.load('matjesg/deepflash2', model_arch, pretrained=pre, dataset=pretrained_weights, n_classes=ds.c, in_channels=n_channels)
if pretrained_weights=="new": apply_init(model)

Setting model hyperparameters (optional)

  • mixed_precision_training: enables Mixed precision training
    • decreases memory usage and speed-up training
    • may effect model accuracy
  • batch_size: the number of samples that will be propagated through the network during one iteration
mixed_precision_training = False #@param {type:"boolean"}
batch_size = 4 #@param {type:"slider", min:2, max:8, step:2}
loss_fn = WeightedSoftmaxCrossEntropy(axis=1)
cbs = [ElasticDeformCallback]
dls = DataLoaders.from_dsets(ds,ds, bs=batch_size)
if torch.cuda.is_available(): dls.cuda(), model.cuda()
learn = Learner(dls, model, wd=0.001, loss_func=loss_fn, cbs=cbs)
if mixed_precision_training: learn.to_fp16()
  • max_lr: The learning rate controls how quickly or slowly a neural network model learns.
    • We found that a maximum learning rate of 5e-4 (i.e., 0.0005) yielded the best results across experiments.
    • learning_rate_finder: Check only if you want use the Learning Rate Finder on your dataset.
learning_rate_finder = False #@param {type:"boolean"}
if learning_rate_finder:
    lr_min,lr_steep = learn.lr_find()
    print(f"Minimum/10: {lr_min:.2e}, steepest point: {lr_steep:.2e}")
max_lr = 5e-4 #@param {type:"number"}

Model Training

Setting training parameters

  • n_models: Number of models to train.
    • If you're experimenting with parameters, try only one model first.
    • Depending on the data, ensembles should comprise 3-5 models.
    • _Note: Number of model affects the Train-validation-split._
try:
    batch_size=batch_size
except:
    batch_size=4
    mixed_precision_training = False
    loss_fn = WeightedSoftmaxCrossEntropy(axis=1)
try:
    max_lr=max_lr
except:
    max_lr = 5e-4 

metrics = [Dice_f1(), Iou()]
n_models = 1 #@param {type:"slider", min:1, max:5, step:1}
print("Suggested epochs for 1000 iterations:", calc_iterations(len(ds), batch_size, n_models))
  • epochs: One epoch is when an entire (augemented) dataset is passed through the model for training.
    • Epochs need to be adusted depending on the size and number of images
    • We found that choosing the number of epochs such that the network parameters are update about 1000 times (iterations) leads to satiesfying results in most cases.
epochs = 30 #@param {type:"slider", min:1, max:200, step:1}

Train models

kf = KFold(n_splits=max(n_models,2))
model_path = path/'models'
model_path.mkdir(parents=True, exist_ok=True)
res, res_mc = {}, {}
fold = 0
for train_idx, val_idx in kf.split(f_names):
    fold += 1
    name = f'model{fold}'
    print('Train', name)
    if n_models==1:
        files_train, files_val = train_test_split(f_names)
    else:
        files_train, files_val = f_names[train_idx], f_names[val_idx]
    print(f'Validation Images: {files_val}')    
    train_ds = RandomTileDataset(files_train, label_fn, **mw_dict)
    valid_ds = TileDataset(files_val, label_fn, **mw_dict)
    
    dls = DataLoaders.from_dsets(train_ds, valid_ds, bs=batch_size)
    dls_valid = DataLoaders.from_dsets(valid_ds, batch_size=batch_size ,shuffle=False, drop_last=False)
    model = torch.hub.load('matjesg/deepflash2', model_arch, pretrained=pre, 
                           dataset=pretrained_weights, n_classes=ds.c, in_channels=n_channels)
    if pretrained_weights=="new": apply_init(model)
    if torch.cuda.is_available(): dls.cuda(), model.cuda(), dls_valid.cuda()
    
    cbs = [SaveModelCallback(monitor='iou'), ElasticDeformCallback]
    metrics = [Dice_f1(), Iou()]
    learn = Learner(dls, model, metrics = metrics, wd=0.001, loss_func=loss_fn, cbs=cbs)
    if mixed_precision_training: learn.to_fp16()
    learn.fit_one_cycle(epochs, max_lr)
    # save_model(model_path/f'{name}.pth', learn.model, opt=None)
    torch.save(learn.model.state_dict(), model_path/f'{name}.pth', _use_new_zipfile_serialization=False)
    
    smxs, segs, _ = learn.predict_tiles(dl=dls_valid.train)    
    smxs_mc, segs_mc, std = learn.predict_tiles(dl=dls_valid.train, mc_dropout=True, n_times=10)
    
    for i, file in enumerate(files_val):
        res[(name, file)] = smxs[i], segs[i]
        res_mc[(name, file)] = smxs_mc[i], segs_mc[i], std[i]
    
    if n_models==1:
        break

Validate models

Here you can validate your models. To avoid information leakage, only predictions on the respective models' validation set are made.

pred_dir = 'val_preds' #@param {type:"string"}
pred_path = path/pred_dir/'ensemble'
pred_path.mkdir(parents=True, exist_ok=True)
uncertainty_dir = 'val_uncertainties' #@param {type:"string"}
uncertainty_path = path/uncertainty_dir/'ensemble'
uncertainty_path.mkdir(parents=True, exist_ok=True)
result_path = path/'results'
result_path.mkdir(exist_ok=True)

#@markdown Define `filetype` to save the predictions and uncertainties. All common [file formats](https://imageio.readthedocs.io/en/stable/formats.html) are supported.
filetype = 'png' #@param {type:"string"}
res_list = []
for model_number in range(1,n_models+1):
    model_name = f'model{model_number}'
    val_files = [f for mod , f in res.keys() if mod == model_name]
    print(f'Validating {model_name}')
    pred_path = path/pred_dir/model_name
    pred_path.mkdir(parents=True, exist_ok=True)
    uncertainty_path = path/uncertainty_dir/model_name
    uncertainty_path.mkdir(parents=True, exist_ok=True)
    for file in val_files:
        img = ds.get_data(file)[0]
        msk = ds.get_data(file, mask=True)[0]
        pred = res[(model_name,file)][1]
        pred_std = res_mc[(model_name,file)][2][...,0]
        df_tmp = pd.Series({'file' : file.name,
                            'model' : model_name,
                            'iou': iou(msk, pred),
                            'entropy': entropy(pred_std, axis=None)})
        plot_results(img, msk, pred, pred_std, df=df_tmp)
        res_list.append(df_tmp)
        imageio.imsave(pred_path/f'{file.stem}_pred.{filetype}', pred.astype(np.uint8) if np.max(pred)>1 else pred.astype(np.uint8)*255)
        imageio.imsave(uncertainty_path/f'{file.stem}_uncertainty.{filetype}', pred_std.astype(np.uint8)*255)
df_res = pd.DataFrame(res_list)
df_res.to_csv(result_path/f'val_results.csv', index=False)

Download Section

  • The models will always be the last version trained in section Model Training
  • To download validation predictions and uncertainties, you first need to execute section Validate models.

Note: If you're connected to Google Drive, the models are automatically saved to your drive.

model_number = "1" #@param ["1", "2", "3", "4", "5"]
model_path = path/'models'/f'model{model_number}.pth'
try:
    files.download(model_path)
except:
    print("Warning: File download only works on Google Colab.")
    print(f"Models are saved at {model_path.parent}")
    pass
out_name = 'val_predictions'
shutil.make_archive(path/out_name, 'zip', path/pred_dir)
try:
    files.download(path/f'{out_name}.zip')
except:
    print("Warning: File download only works on Google Colab.")
    pass
out_name = 'val_uncertainties'
shutil.make_archive(path/out_name, 'zip', path/uncertainty_dir)
try:
    files.download(path/f'{out_name}.zip')
except:
    print("Warning: File download only works on Google Colab.")
    pass
try:
    files.download(result_path/f'val_results.csv')
except:
    print("Warning: File download only works on Google Colab.")
    pass