From 770e74cfa1a8bbe24366a3dd2523f757d3c8da48 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Jun 2024 14:31:56 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nobrainer/ext/lab2im/__init__.py | 7 +- nobrainer/ext/lab2im/edit_tensors.py | 202 ++- nobrainer/ext/lab2im/edit_volumes.py | 1844 ++++++++++++++++------- nobrainer/ext/lab2im/image_generator.py | 157 +- nobrainer/ext/lab2im/lab2im_model.py | 124 +- nobrainer/ext/lab2im/layers.py | 1249 +++++++++++---- nobrainer/ext/lab2im/utils.py | 868 +++++++---- nobrainer/models/__init__.py | 2 +- nobrainer/models/lab2im_model.py | 124 +- nobrainer/processing/image_generator.py | 157 +- nobrainer/processing/segmentation.py | 4 +- nobrainer/tfrecord.py | 8 +- 12 files changed, 3346 insertions(+), 1400 deletions(-) diff --git a/nobrainer/ext/lab2im/__init__.py b/nobrainer/ext/lab2im/__init__.py index d3fd52d0..f26d7db9 100644 --- a/nobrainer/ext/lab2im/__init__.py +++ b/nobrainer/ext/lab2im/__init__.py @@ -1,6 +1 @@ -from . import edit_tensors -from . import edit_volumes -from . import image_generator -from . import lab2im_model -from . import layers -from . import utils +from . import edit_tensors, edit_volumes, image_generator, lab2im_model, layers, utils diff --git a/nobrainer/ext/lab2im/edit_tensors.py b/nobrainer/ext/lab2im/edit_tensors.py index 035a104b..23139e85 100644 --- a/nobrainer/ext/lab2im/edit_tensors.py +++ b/nobrainer/ext/lab2im/edit_tensors.py @@ -22,12 +22,6 @@ """ - -# python imports -import numpy as np -import tensorflow as tf -import keras.layers as KL -import keras.backend as K from itertools import combinations # project imports @@ -36,9 +30,17 @@ # third-party imports import ext.neuron.layers as nrn_layers from ext.neuron.utils import volshape_to_meshgrid +import keras.backend as K +import keras.layers as KL + +# python imports +import numpy as np +import tensorflow as tf -def blurring_sigma_for_downsampling(current_res, downsample_res, mult_coef=None, thickness=None): +def blurring_sigma_for_downsampling( + current_res, downsample_res, mult_coef=None, thickness=None +): """Compute standard deviations of 1d gaussian masks for image blurring before downsampling. :param downsample_res: resolution to downsample to. Can be a 1d numpy array or list, or a tensor. :param current_res: resolution of the volume before downsampling. @@ -68,17 +70,32 @@ def blurring_sigma_for_downsampling(current_res, downsample_res, mult_coef=None, # reformat data resolution at which we blur if thickness is not None: - down_res = KL.Lambda(lambda x: tf.math.minimum(x[0], x[1]))([downsample_res, thickness]) + down_res = KL.Lambda(lambda x: tf.math.minimum(x[0], x[1]))( + [downsample_res, thickness] + ) else: down_res = downsample_res # get std deviation for blurring kernels if mult_coef is None: - sigma = KL.Lambda(lambda x: tf.where(tf.math.equal(x, tf.convert_to_tensor(current_res, dtype='float32')), - 0.5, 0.75 * x / tf.convert_to_tensor(current_res, dtype='float32')))(down_res) + sigma = KL.Lambda( + lambda x: tf.where( + tf.math.equal( + x, tf.convert_to_tensor(current_res, dtype="float32") + ), + 0.5, + 0.75 * x / tf.convert_to_tensor(current_res, dtype="float32"), + ) + )(down_res) else: - sigma = KL.Lambda(lambda x: mult_coef * x / tf.convert_to_tensor(current_res, dtype='float32'))(down_res) - sigma = KL.Lambda(lambda x: tf.where(tf.math.equal(x[0], 0.), 0., x[1]))([down_res, sigma]) + sigma = KL.Lambda( + lambda x: mult_coef + * x + / tf.convert_to_tensor(current_res, dtype="float32") + )(down_res) + sigma = KL.Lambda(lambda x: tf.where(tf.math.equal(x[0], 0.0), 0.0, x[1]))( + [down_res, sigma] + ) return sigma @@ -95,9 +112,13 @@ def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True): """ # convert sigma into a tensor if not tf.is_tensor(sigma): - sigma_tens = tf.convert_to_tensor(utils.reformat_to_list(sigma), dtype='float32') + sigma_tens = tf.convert_to_tensor( + utils.reformat_to_list(sigma), dtype="float32" + ) else: - assert max_sigma is not None, 'max_sigma must be provided when sigma is given as a tensor' + assert ( + max_sigma is not None + ), "max_sigma must be provided when sigma is given as a tensor" sigma_tens = sigma shape = sigma_tens.get_shape().as_list() @@ -118,7 +139,9 @@ def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True): # randomise the burring std dev and/or split it between dimensions if blur_range is not None: if blur_range != 1: - sigma_tens = sigma_tens * tf.random.uniform(tf.shape(sigma_tens), minval=1 / blur_range, maxval=blur_range) + sigma_tens = sigma_tens * tf.random.uniform( + tf.shape(sigma_tens), minval=1 / blur_range, maxval=blur_range + ) # get size of blurring kernels windowsize = np.int32(np.ceil(2.5 * max_sigma) / 2) * 2 + 1 @@ -129,16 +152,23 @@ def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True): kernels = list() comb = np.array(list(combinations(list(range(n_dims)), n_dims - 1))[::-1]) - for (i, wsize) in enumerate(windowsize): + for i, wsize in enumerate(windowsize): if wsize > 1: # build meshgrid and replicate it along batch dim if dynamic blurring - locations = tf.cast(tf.range(0, wsize), 'float32') - (wsize - 1) / 2 + locations = tf.cast(tf.range(0, wsize), "float32") - (wsize - 1) / 2 if batchsize is not None: - locations = tf.tile(tf.expand_dims(locations, axis=0), - tf.concat([batchsize, tf.ones(tf.shape(tf.shape(locations)), dtype='int32')], - axis=0)) + locations = tf.tile( + tf.expand_dims(locations, axis=0), + tf.concat( + [ + batchsize, + tf.ones(tf.shape(tf.shape(locations)), dtype="int32"), + ], + axis=0, + ), + ) comb[i] += 1 # compute gaussians @@ -156,13 +186,23 @@ def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True): else: # build meshgrid - mesh = [tf.cast(f, 'float32') for f in volshape_to_meshgrid(windowsize, indexing='ij')] - diff = tf.stack([mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1) + mesh = [ + tf.cast(f, "float32") + for f in volshape_to_meshgrid(windowsize, indexing="ij") + ] + diff = tf.stack( + [mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1 + ) # replicate meshgrid to batch size and reshape sigma_tens if batchsize is not None: - diff = tf.tile(tf.expand_dims(diff, axis=0), - tf.concat([batchsize, tf.ones(tf.shape(tf.shape(diff)), dtype='int32')], axis=0)) + diff = tf.tile( + tf.expand_dims(diff, axis=0), + tf.concat( + [batchsize, tf.ones(tf.shape(tf.shape(diff)), dtype="int32")], + axis=0, + ), + ) for i in range(n_dims): sigma_tens = tf.expand_dims(sigma_tens, axis=1) else: @@ -171,8 +211,14 @@ def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True): # compute gaussians sigma_is_0 = tf.equal(sigma_tens, 0) - exp_term = -K.square(diff) / (2 * tf.where(sigma_is_0, tf.ones_like(sigma_tens), sigma_tens)**2) - norms = exp_term - tf.math.log(tf.where(sigma_is_0, tf.ones_like(sigma_tens), np.sqrt(2 * np.pi) * sigma_tens)) + exp_term = -K.square(diff) / ( + 2 * tf.where(sigma_is_0, tf.ones_like(sigma_tens), sigma_tens) ** 2 + ) + norms = exp_term - tf.math.log( + tf.where( + sigma_is_0, tf.ones_like(sigma_tens), np.sqrt(2 * np.pi) * sigma_tens + ) + ) kernels = K.sum(norms, -1) kernels = tf.exp(kernels) kernels /= tf.reduce_sum(kernels) @@ -184,8 +230,8 @@ def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True): def sobel_kernels(n_dims): """Returns sobel kernels to compute spatial derivative on image of n dimensions.""" - in_dir = tf.convert_to_tensor([1, 0, -1], dtype='float32') - orthogonal_dir = tf.convert_to_tensor([1, 2, 1], dtype='float32') + in_dir = tf.convert_to_tensor([1, 0, -1], dtype="float32") + orthogonal_dir = tf.convert_to_tensor([1, 2, 1], dtype="float32") comb = np.array(list(combinations(list(range(n_dims)), n_dims - 1))[::-1]) list_kernels = list() @@ -216,31 +262,49 @@ def unit_kernel(dist_threshold, n_dims, max_dist_threshold=None): # convert dist_threshold into a tensor if not tf.is_tensor(dist_threshold): - dist_threshold_tens = tf.convert_to_tensor(utils.reformat_to_list(dist_threshold), dtype='float32') + dist_threshold_tens = tf.convert_to_tensor( + utils.reformat_to_list(dist_threshold), dtype="float32" + ) else: - assert max_dist_threshold is not None, 'max_sigma must be provided when dist_threshold is given as a tensor' - dist_threshold_tens = tf.cast(dist_threshold, 'float32') + assert ( + max_dist_threshold is not None + ), "max_sigma must be provided when dist_threshold is given as a tensor" + dist_threshold_tens = tf.cast(dist_threshold, "float32") shape = dist_threshold_tens.get_shape().as_list() # get batchsize - batchsize = None if shape[0] is not None else tf.split(tf.shape(dist_threshold_tens), [1, -1])[0] + batchsize = ( + None + if shape[0] is not None + else tf.split(tf.shape(dist_threshold_tens), [1, -1])[0] + ) # set max_dist_threshold into an array - if max_dist_threshold is None: # dist_threshold is fixed (i.e. dist_threshold will not change at each mini-batch) + if ( + max_dist_threshold is None + ): # dist_threshold is fixed (i.e. dist_threshold will not change at each mini-batch) max_dist_threshold = dist_threshold # get size of blurring kernels - windowsize = np.array([max_dist_threshold * 2 + 1]*n_dims, dtype='int32') + windowsize = np.array([max_dist_threshold * 2 + 1] * n_dims, dtype="int32") # build tensor representing the distance from the centre - mesh = [tf.cast(f, 'float32') for f in volshape_to_meshgrid(windowsize, indexing='ij')] - dist = tf.stack([mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1) + mesh = [ + tf.cast(f, "float32") for f in volshape_to_meshgrid(windowsize, indexing="ij") + ] + dist = tf.stack( + [mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1 + ) dist = tf.sqrt(tf.reduce_sum(tf.square(dist), axis=-1)) # replicate distance to batch size and reshape sigma_tens if batchsize is not None: - dist = tf.tile(tf.expand_dims(dist, axis=0), - tf.concat([batchsize, tf.ones(tf.shape(tf.shape(dist)), dtype='int32')], axis=0)) + dist = tf.tile( + tf.expand_dims(dist, axis=0), + tf.concat( + [batchsize, tf.ones(tf.shape(tf.shape(dist)), dtype="int32")], axis=0 + ), + ) for i in range(n_dims - 1): dist_threshold_tens = tf.expand_dims(dist_threshold_tens, axis=1) else: @@ -248,18 +312,24 @@ def unit_kernel(dist_threshold, n_dims, max_dist_threshold=None): dist_threshold_tens = tf.expand_dims(dist_threshold_tens, axis=0) # build final kernel by thresholding distance tensor - kernel = tf.where(tf.less_equal(dist, dist_threshold_tens), tf.ones_like(dist), tf.zeros_like(dist)) + kernel = tf.where( + tf.less_equal(dist, dist_threshold_tens), + tf.ones_like(dist), + tf.zeros_like(dist), + ) kernel = tf.expand_dims(tf.expand_dims(kernel, -1), -1) return kernel -def resample_tensor(tensor, - resample_shape, - interp_method='linear', - subsample_res=None, - volume_res=None, - build_reliability_map=False): +def resample_tensor( + tensor, + resample_shape, + interp_method="linear", + subsample_res=None, + volume_res=None, + build_reliability_map=False, +): """This function resamples a volume to resample_shape. It does not apply any pre-filtering. A prior downsampling step can be added if subsample_res is specified. In this case, volume_res should also be specified, in order to calculate the downsampling ratio. A reliability map can also be returned to indicate which @@ -286,22 +356,35 @@ def resample_tensor(tensor, downsample_shape = tensor_shape # will be modified if we actually downsample if subsample_res is not None: - assert volume_res is not None, 'volume_res must be given when providing a subsampling resolution.' - assert len(subsample_res) == len(volume_res), 'subsample_res and volume_res must have the same length, ' \ - 'had {0}, and {1}'.format(len(subsample_res), len(volume_res)) + assert ( + volume_res is not None + ), "volume_res must be given when providing a subsampling resolution." + assert len(subsample_res) == len(volume_res), ( + "subsample_res and volume_res must have the same length, " + "had {0}, and {1}".format(len(subsample_res), len(volume_res)) + ) if subsample_res != volume_res: # get shape at which we downsample - downsample_shape = [int(tensor_shape[i] * volume_res[i] / subsample_res[i]) for i in range(n_dims)] + downsample_shape = [ + int(tensor_shape[i] * volume_res[i] / subsample_res[i]) + for i in range(n_dims) + ] # downsample volume tensor._keras_shape = tuple(tensor.get_shape().as_list()) - tensor = nrn_layers.Resize(size=downsample_shape, interp_method='nearest')(tensor) + tensor = nrn_layers.Resize(size=downsample_shape, interp_method="nearest")( + tensor + ) # resample image at target resolution - if resample_shape != downsample_shape: # if we didn't downsample downsample_shape = tensor_shape + if ( + resample_shape != downsample_shape + ): # if we didn't downsample downsample_shape = tensor_shape tensor._keras_shape = tuple(tensor.get_shape().as_list()) - tensor = nrn_layers.Resize(size=resample_shape, interp_method=interp_method)(tensor) + tensor = nrn_layers.Resize(size=resample_shape, interp_method=interp_method)( + tensor + ) # compute reliability maps if necessary and return results if build_reliability_map: @@ -320,13 +403,20 @@ def resample_tensor(tensor, loc_ceil = np.int32(np.clip(loc_floor + 1, 0, resample_shape[i] - 1)) tmp_reliability_map = np.zeros(resample_shape[i]) tmp_reliability_map[loc_floor] = 1 - (loc_float - loc_floor) - tmp_reliability_map[loc_ceil] = tmp_reliability_map[loc_ceil] + (loc_float - loc_floor) + tmp_reliability_map[loc_ceil] = tmp_reliability_map[loc_ceil] + ( + loc_float - loc_floor + ) shape = [1, 1, 1] shape[i] = resample_shape[i] - reliability_map = reliability_map * np.reshape(tmp_reliability_map, shape) + reliability_map = reliability_map * np.reshape( + tmp_reliability_map, shape + ) shape = KL.Lambda(lambda x: tf.shape(x))(tensor) - mask = KL.Lambda(lambda x: tf.reshape(tf.convert_to_tensor(reliability_map, dtype='float32'), - shape=x))(shape) + mask = KL.Lambda( + lambda x: tf.reshape( + tf.convert_to_tensor(reliability_map, dtype="float32"), shape=x + ) + )(shape) # otherwise just return an all-one tensor else: diff --git a/nobrainer/ext/lab2im/edit_volumes.py b/nobrainer/ext/lab2im/edit_volumes.py index 1afeb34f..3afdc389 100644 --- a/nobrainer/ext/lab2im/edit_volumes.py +++ b/nobrainer/ext/lab2im/edit_volumes.py @@ -69,31 +69,40 @@ License. """ +import csv # python imports import os -import csv import shutil -import numpy as np -import tensorflow as tf -import keras.layers as KL -from keras.models import Model -from scipy.ndimage.filters import convolve -from scipy.ndimage import label as scipy_label -from scipy.interpolate import RegularGridInterpolator -from scipy.ndimage.morphology import distance_transform_edt, binary_fill_holes -from scipy.ndimage import binary_dilation, binary_erosion, gaussian_filter # project imports from ext.lab2im import utils -from ext.lab2im.layers import GaussianBlur, ConvertLabels from ext.lab2im.edit_tensors import blurring_sigma_for_downsampling - +from ext.lab2im.layers import ConvertLabels, GaussianBlur +import keras.layers as KL +from keras.models import Model +import numpy as np +from scipy.interpolate import RegularGridInterpolator +from scipy.ndimage import binary_dilation, binary_erosion, gaussian_filter +from scipy.ndimage import label as scipy_label +from scipy.ndimage.filters import convolve +from scipy.ndimage.morphology import binary_fill_holes, distance_transform_edt +import tensorflow as tf # ---------------------------------------------------- edit volume ----------------------------------------------------- -def mask_volume(volume, mask=None, threshold=0.1, dilate=0, erode=0, fill_holes=False, masking_value=0, - return_mask=False, return_copy=True): + +def mask_volume( + volume, + mask=None, + threshold=0.1, + dilate=0, + erode=0, + fill_holes=False, + masking_value=0, + return_mask=False, + return_copy=True, +): """Mask a volume, either with a given mask, or by keeping only the values above a threshold. :param volume: a numpy array, possibly with several channels :param mask: (optional) a numpy array to mask volume with. @@ -119,8 +128,11 @@ def mask_volume(volume, mask=None, threshold=0.1, dilate=0, erode=0, fill_holes= if mask is None: mask = new_volume >= threshold else: - assert list(mask.shape[:n_dims]) == vol_shape[:n_dims], 'mask should have shape {0}, or {1}, had {2}'.format( - vol_shape[:n_dims], vol_shape[:n_dims] + [n_channels], list(mask.shape)) + assert ( + list(mask.shape[:n_dims]) == vol_shape[:n_dims] + ), "mask should have shape {0}, or {1}, had {2}".format( + vol_shape[:n_dims], vol_shape[:n_dims] + [n_channels], list(mask.shape) + ) mask = mask > 0 if dilate > 0: dilate_struct = utils.build_binary_structure(dilate, n_dims) @@ -137,7 +149,9 @@ def mask_volume(volume, mask=None, threshold=0.1, dilate=0, erode=0, fill_holes= if mask_to_apply.shape == new_volume.shape: new_volume[np.logical_not(mask_to_apply)] = masking_value else: - new_volume[np.stack([np.logical_not(mask_to_apply)] * n_channels, axis=-1)] = masking_value + new_volume[np.stack([np.logical_not(mask_to_apply)] * n_channels, axis=-1)] = ( + masking_value + ) if return_mask: return new_volume, mask_to_apply @@ -145,7 +159,14 @@ def mask_volume(volume, mask=None, threshold=0.1, dilate=0, erode=0, fill_holes= return new_volume -def rescale_volume(volume, new_min=0, new_max=255, min_percentile=2, max_percentile=98, use_positive_only=False): +def rescale_volume( + volume, + new_min=0, + new_max=255, + min_percentile=2, + max_percentile=98, + use_positive_only=False, +): """This function linearly rescales a volume between new_min and new_max. :param volume: a numpy array :param new_min: (optional) minimum value for the rescaled image. @@ -160,23 +181,42 @@ def rescale_volume(volume, new_min=0, new_max=255, min_percentile=2, max_percent # select only positive intensities new_volume = volume.copy() - intensities = new_volume[new_volume > 0] if use_positive_only else new_volume.flatten() + intensities = ( + new_volume[new_volume > 0] if use_positive_only else new_volume.flatten() + ) # define min and max intensities in original image for normalisation - robust_min = np.min(intensities) if min_percentile == 0 else np.percentile(intensities, min_percentile) - robust_max = np.max(intensities) if max_percentile == 100 else np.percentile(intensities, max_percentile) + robust_min = ( + np.min(intensities) + if min_percentile == 0 + else np.percentile(intensities, min_percentile) + ) + robust_max = ( + np.max(intensities) + if max_percentile == 100 + else np.percentile(intensities, max_percentile) + ) # trim values outside range new_volume = np.clip(new_volume, robust_min, robust_max) # rescale image if robust_min != robust_max: - return new_min + (new_volume - robust_min) / (robust_max - robust_min) * (new_max - new_min) + return new_min + (new_volume - robust_min) / (robust_max - robust_min) * ( + new_max - new_min + ) else: # avoid dividing by zero return np.zeros_like(new_volume) -def crop_volume(volume, cropping_margin=None, cropping_shape=None, aff=None, return_crop_idx=False, mode='center'): +def crop_volume( + volume, + cropping_margin=None, + cropping_shape=None, + aff=None, + return_crop_idx=False, + mode="center", +): """Crop volume by a given margin, or to a given shape. :param volume: 2d or 3d numpy array (possibly with multiple channels) :param cropping_margin: (optional) margin by which to crop the volume. The cropping margin is applied on both sides. @@ -192,10 +232,12 @@ def crop_volume(volume, cropping_margin=None, cropping_shape=None, aff=None, ret True (in that order). """ - assert (cropping_margin is not None) | (cropping_shape is not None), \ - 'cropping_margin or cropping_shape should be provided' - assert not ((cropping_margin is not None) & (cropping_shape is not None)), \ - 'only one of cropping_margin or cropping_shape should be provided' + assert (cropping_margin is not None) | ( + cropping_shape is not None + ), "cropping_margin or cropping_shape should be provided" + assert not ( + (cropping_margin is not None) & (cropping_shape is not None) + ), "only one of cropping_margin or cropping_shape should be provided" # get info new_volume = volume.copy() @@ -206,27 +248,49 @@ def crop_volume(volume, cropping_margin=None, cropping_shape=None, aff=None, ret if cropping_margin is not None: cropping_margin = utils.reformat_to_list(cropping_margin, length=n_dims) do_cropping = np.array(vol_shape[:n_dims]) > 2 * np.array(cropping_margin) - min_crop_idx = [cropping_margin[i] if do_cropping[i] else 0 for i in range(n_dims)] - max_crop_idx = [vol_shape[i] - cropping_margin[i] if do_cropping[i] else vol_shape[i] for i in range(n_dims)] + min_crop_idx = [ + cropping_margin[i] if do_cropping[i] else 0 for i in range(n_dims) + ] + max_crop_idx = [ + vol_shape[i] - cropping_margin[i] if do_cropping[i] else vol_shape[i] + for i in range(n_dims) + ] else: cropping_shape = utils.reformat_to_list(cropping_shape, length=n_dims) - if mode == 'center': - min_crop_idx = np.maximum([int((vol_shape[i] - cropping_shape[i]) / 2) for i in range(n_dims)], 0) - max_crop_idx = np.minimum([min_crop_idx[i] + cropping_shape[i] for i in range(n_dims)], - np.array(vol_shape)[:n_dims]) - elif mode == 'random': - crop_max_val = np.maximum(np.array([vol_shape[i] - cropping_shape[i] for i in range(n_dims)]), 0) + if mode == "center": + min_crop_idx = np.maximum( + [int((vol_shape[i] - cropping_shape[i]) / 2) for i in range(n_dims)], 0 + ) + max_crop_idx = np.minimum( + [min_crop_idx[i] + cropping_shape[i] for i in range(n_dims)], + np.array(vol_shape)[:n_dims], + ) + elif mode == "random": + crop_max_val = np.maximum( + np.array([vol_shape[i] - cropping_shape[i] for i in range(n_dims)]), 0 + ) min_crop_idx = np.random.randint(0, high=crop_max_val + 1) - max_crop_idx = np.minimum(min_crop_idx + np.array(cropping_shape), np.array(vol_shape)[:n_dims]) + max_crop_idx = np.minimum( + min_crop_idx + np.array(cropping_shape), np.array(vol_shape)[:n_dims] + ) else: - raise ValueError('mode should be either "center" or "random", had %s' % mode) + raise ValueError( + 'mode should be either "center" or "random", had %s' % mode + ) crop_idx = np.concatenate([np.array(min_crop_idx), np.array(max_crop_idx)]) # crop volume if n_dims == 2: - new_volume = new_volume[crop_idx[0]: crop_idx[2], crop_idx[1]: crop_idx[3], ...] + new_volume = new_volume[ + crop_idx[0] : crop_idx[2], crop_idx[1] : crop_idx[3], ... + ] elif n_dims == 3: - new_volume = new_volume[crop_idx[0]: crop_idx[3], crop_idx[1]: crop_idx[4], crop_idx[2]: crop_idx[5], ...] + new_volume = new_volume[ + crop_idx[0] : crop_idx[3], + crop_idx[1] : crop_idx[4], + crop_idx[2] : crop_idx[5], + ..., + ] # sort outputs output = [new_volume] @@ -238,15 +302,17 @@ def crop_volume(volume, cropping_margin=None, cropping_shape=None, aff=None, ret return output[0] if len(output) == 1 else tuple(output) -def crop_volume_around_region(volume, - mask=None, - masking_labels=None, - threshold=0.1, - margin=0, - cropping_shape=None, - cropping_shape_div_by=None, - aff=None, - overflow='strict'): +def crop_volume_around_region( + volume, + mask=None, + masking_labels=None, + threshold=0.1, + margin=0, + cropping_shape=None, + cropping_shape_div_by=None, + aff=None, + overflow="strict", +): """Crop a volume around a specific region. This region is defined by a mask obtained by either: 1) directly specifying it as input (see mask) @@ -285,11 +351,15 @@ def crop_volume_around_region(volume, and the updated affine matrix if aff is not None. """ - assert not ((margin > 0) & (cropping_shape is not None)), "margin and cropping_shape can't be given together." - assert not ((margin > 0) & (cropping_shape_div_by is not None)), \ - "margin and cropping_shape_div_by can't be given together." - assert not ((cropping_shape_div_by is not None) & (cropping_shape is not None)), \ - "cropping_shape_div_by and cropping_shape can't be given together." + assert not ( + (margin > 0) & (cropping_shape is not None) + ), "margin and cropping_shape can't be given together." + assert not ( + (margin > 0) & (cropping_shape_div_by is not None) + ), "margin and cropping_shape_div_by can't be given together." + assert not ( + (cropping_shape_div_by is not None) & (cropping_shape is not None) + ), "cropping_shape_div_by and cropping_shape can't be given together." new_vol = volume.copy() n_dims, n_channels = utils.get_dims(new_vol.shape) @@ -298,7 +368,9 @@ def crop_volume_around_region(volume, # mask ROIs for cropping if mask is None: if masking_labels is not None: - _, mask = mask_label_map(new_vol, masking_values=masking_labels, return_mask=True) + _, mask = mask_label_map( + new_vol, masking_values=masking_labels, return_mask=True + ) else: mask = new_vol > threshold @@ -315,25 +387,35 @@ def crop_volume_around_region(volume, if margin: cropping_shape = intermediate_vol_shape + 2 * margin elif cropping_shape is not None: - cropping_shape = np.array(utils.reformat_to_list(cropping_shape, length=n_dims)) + cropping_shape = np.array( + utils.reformat_to_list(cropping_shape, length=n_dims) + ) elif cropping_shape_div_by is not None: - cropping_shape = [utils.find_closest_number_divisible_by_m(s, cropping_shape_div_by, answer_type='higher') - for s in intermediate_vol_shape] - - min_idx = min_idx - np.int32(np.ceil((cropping_shape - intermediate_vol_shape) / 2)) - max_idx = max_idx + np.int32(np.floor((cropping_shape - intermediate_vol_shape) / 2)) + cropping_shape = [ + utils.find_closest_number_divisible_by_m( + s, cropping_shape_div_by, answer_type="higher" + ) + for s in intermediate_vol_shape + ] + + min_idx = min_idx - np.int32( + np.ceil((cropping_shape - intermediate_vol_shape) / 2) + ) + max_idx = max_idx + np.int32( + np.floor((cropping_shape - intermediate_vol_shape) / 2) + ) min_overflow = np.abs(np.minimum(min_idx, 0)) max_overflow = np.maximum(max_idx - vol_shape, 0) - if 'strict' in overflow: + if "strict" in overflow: min_overflow = np.zeros_like(min_overflow) max_overflow = np.zeros_like(min_overflow) - if overflow == 'shift-strict': + if overflow == "shift-strict": min_idx -= max_overflow max_idx += min_overflow - if overflow == 'shift-padding': + if overflow == "shift-padding": for ii in range(n_dims): # no need to do anything if both min/max_overflow are 0 (no padding/shifting required at all) # or if both are positive, because in this case we don't shift at all and we pad directly @@ -343,7 +425,9 @@ def crop_volume_around_region(volume, max_idx[ii] = max_idx_new min_overflow[ii] = 0 else: - min_overflow[ii] = min_overflow[ii] - (vol_shape[ii] - max_idx[ii]) + min_overflow[ii] = min_overflow[ii] - ( + vol_shape[ii] - max_idx[ii] + ) max_idx[ii] = vol_shape[ii] elif (min_overflow[ii] == 0) & (max_overflow[ii] > 0): min_idx_new = min_idx[ii] - max_overflow[ii] @@ -360,17 +444,28 @@ def crop_volume_around_region(volume, cropping = np.concatenate([min_idx, max_idx]) if np.any(cropping[:3] > 0) or np.any(cropping[3:] != vol_shape): if n_dims == 3: - new_vol = new_vol[cropping[0]:cropping[3], cropping[1]:cropping[4], cropping[2]:cropping[5], ...] + new_vol = new_vol[ + cropping[0] : cropping[3], + cropping[1] : cropping[4], + cropping[2] : cropping[5], + ..., + ] elif n_dims == 2: - new_vol = new_vol[cropping[0]:cropping[2], cropping[1]:cropping[3], ...] + new_vol = new_vol[ + cropping[0] : cropping[2], cropping[1] : cropping[3], ... + ] else: - raise ValueError('cannot crop volumes with more than 3 dimensions') + raise ValueError("cannot crop volumes with more than 3 dimensions") # pad volume if necessary if np.any(min_overflow > 0) | np.any(max_overflow > 0): - pad_margins = tuple([(min_overflow[i], max_overflow[i]) for i in range(n_dims)]) - pad_margins = tuple(list(pad_margins) + [(0, 0)]) if n_channels > 1 else pad_margins - new_vol = np.pad(new_vol, pad_margins, mode='constant', constant_values=0) + pad_margins = tuple( + [(min_overflow[i], max_overflow[i]) for i in range(n_dims)] + ) + pad_margins = ( + tuple(list(pad_margins) + [(0, 0)]) if n_channels > 1 else pad_margins + ) + new_vol = np.pad(new_vol, pad_margins, mode="constant", constant_values=0) # if there's nothing to crop around, we return the input as is else: @@ -408,11 +503,18 @@ def crop_volume_with_idx(volume, crop_idx, aff=None, n_dims=None, return_copy=Tr # crop image if n_dims == 2: - new_volume = new_volume[crop_idx[0]:crop_idx[2], crop_idx[1]:crop_idx[3], ...] + new_volume = new_volume[ + crop_idx[0] : crop_idx[2], crop_idx[1] : crop_idx[3], ... + ] elif n_dims == 3: - new_volume = new_volume[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5], ...] + new_volume = new_volume[ + crop_idx[0] : crop_idx[3], + crop_idx[1] : crop_idx[4], + crop_idx[2] : crop_idx[5], + ..., + ] else: - raise Exception('cannot crop volumes with more than 3 dimensions') + raise Exception("cannot crop volumes with more than 3 dimensions") if aff is not None: aff[0:3, -1] = aff[0:3, -1] + aff[:3, :3] @ crop_idx[:3] @@ -436,21 +538,38 @@ def pad_volume(volume, padding_shape, padding_value=0, aff=None, return_pad_idx= new_volume = volume.copy() vol_shape = new_volume.shape n_dims, n_channels = utils.get_dims(vol_shape) - padding_shape = utils.reformat_to_list(padding_shape, length=n_dims, dtype='int') + padding_shape = utils.reformat_to_list(padding_shape, length=n_dims, dtype="int") # check if need to pad - if np.any(np.array(padding_shape, dtype='int32') > np.array(vol_shape[:n_dims], dtype='int32')): + if np.any( + np.array(padding_shape, dtype="int32") + > np.array(vol_shape[:n_dims], dtype="int32") + ): # get padding margins - min_margins = np.maximum(np.int32(np.floor((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2)), 0) - max_margins = np.maximum(np.int32(np.ceil((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2)), 0) - pad_idx = np.concatenate([min_margins, min_margins + np.array(vol_shape[:n_dims])]) + min_margins = np.maximum( + np.int32( + np.floor((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2) + ), + 0, + ) + max_margins = np.maximum( + np.int32( + np.ceil((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2) + ), + 0, + ) + pad_idx = np.concatenate( + [min_margins, min_margins + np.array(vol_shape[:n_dims])] + ) pad_margins = tuple([(min_margins[i], max_margins[i]) for i in range(n_dims)]) if n_channels > 1: pad_margins = tuple(list(pad_margins) + [(0, 0)]) # pad volume - new_volume = np.pad(new_volume, pad_margins, mode='constant', constant_values=padding_value) + new_volume = np.pad( + new_volume, pad_margins, mode="constant", constant_values=padding_value + ) if aff is not None: if n_dims == 2: @@ -482,26 +601,29 @@ def flip_volume(volume, axis=None, direction=None, aff=None, return_copy=True): """ new_volume = volume.copy() if return_copy else volume - assert (axis is not None) | ((aff is not None) & (direction is not None)), \ - 'please provide either axis, or an affine matrix with a direction' + assert (axis is not None) | ( + (aff is not None) & (direction is not None) + ), "please provide either axis, or an affine matrix with a direction" # get flipping axis from aff if axis not provided if (axis is None) & (aff is not None): volume_axes = get_ras_axes(aff) - if direction == 'rl': + if direction == "rl": axis = volume_axes[0] - elif direction == 'ap': + elif direction == "ap": axis = volume_axes[1] - elif direction == 'si': + elif direction == "si": axis = volume_axes[2] else: - raise ValueError("direction should be 'rl', 'ap', or 'si', had %s" % direction) + raise ValueError( + "direction should be 'rl', 'ap', or 'si', had %s" % direction + ) # flip volume return np.flip(new_volume, axis=axis) -def resample_volume(volume, aff, new_vox_size, interpolation='linear', blur=True): +def resample_volume(volume, aff, new_vox_size, interpolation="linear", blur=True): """This function resizes the voxels of a volume to a new provided size, while adjusting the header to keep the RAS :param volume: a numpy array :param aff: affine matrix of the volume @@ -525,9 +647,11 @@ def resample_volume(volume, aff, new_vox_size, interpolation='linear', blur=True y = np.arange(0, volume_filt.shape[1]) z = np.arange(0, volume_filt.shape[2]) - my_interpolating_function = RegularGridInterpolator((x, y, z), volume_filt, method=interpolation) + my_interpolating_function = RegularGridInterpolator( + (x, y, z), volume_filt, method=interpolation + ) - start = - (factor - 1) / (2 * factor) + start = -(factor - 1) / (2 * factor) step = 1.0 / factor stop = start + step * np.ceil(volume_filt.shape * factor) @@ -541,7 +665,7 @@ def resample_volume(volume, aff, new_vox_size, interpolation='linear', blur=True yi[yi > (volume_filt.shape[1] - 1)] = volume_filt.shape[1] - 1 zi[zi > (volume_filt.shape[2] - 1)] = volume_filt.shape[2] - 1 - xig, yig, zig = np.meshgrid(xi, yi, zi, indexing='ij', sparse=True) + xig, yig, zig = np.meshgrid(xi, yi, zi, indexing="ij", sparse=True) volume2 = my_interpolating_function((xig, yig, zig)) aff2 = aff.copy() @@ -552,7 +676,7 @@ def resample_volume(volume, aff, new_vox_size, interpolation='linear', blur=True return volume2, aff2 -def resample_volume_like(vol_ref, aff_ref, vol_flo, aff_flo, interpolation='linear'): +def resample_volume_like(vol_ref, aff_ref, vol_flo, aff_flo, interpolation="linear"): """This function reslices a floating image to the space of a reference image :param vol_ref: a numpy array with the reference volume :param aff_ref: affine matrix of the reference volume @@ -568,14 +692,15 @@ def resample_volume_like(vol_ref, aff_ref, vol_flo, aff_flo, interpolation='line yf = np.arange(0, vol_flo.shape[1]) zf = np.arange(0, vol_flo.shape[2]) - my_interpolating_function = RegularGridInterpolator((xf, yf, zf), vol_flo, bounds_error=False, fill_value=0.0, - method=interpolation) + my_interpolating_function = RegularGridInterpolator( + (xf, yf, zf), vol_flo, bounds_error=False, fill_value=0.0, method=interpolation + ) xr = np.arange(0, vol_ref.shape[0]) yr = np.arange(0, vol_ref.shape[1]) zr = np.arange(0, vol_ref.shape[2]) - xrg, yrg, zrg = np.meshgrid(xr, yr, zr, indexing='ij', sparse=False) + xrg, yrg, zrg = np.meshgrid(xr, yr, zr, indexing="ij", sparse=False) n = xrg.size xrg = xrg.reshape([n]) yrg = yrg.reshape([n]) @@ -583,7 +708,9 @@ def resample_volume_like(vol_ref, aff_ref, vol_flo, aff_flo, interpolation='line bottom = np.ones_like(xrg) coords = np.stack([xrg, yrg, zrg, bottom]) coords_new = np.matmul(T, coords)[:-1, :] - result = my_interpolating_function((coords_new[0, :], coords_new[1, :], coords_new[2, :])) + result = my_interpolating_function( + (coords_new[0, :], coords_new[1, :], coords_new[2, :]) + ) return result.reshape(vol_ref.shape) @@ -606,7 +733,9 @@ def get_ras_axes(aff, n_dims=3): return img_ras_axes -def align_volume_to_ref(volume, aff, aff_ref=None, return_aff=False, n_dims=None, return_copy=True): +def align_volume_to_ref( + volume, aff, aff_ref=None, return_aff=False, n_dims=None, return_copy=True +): """This function aligns a volume to a reference orientation (axis and direction) specified by an affine matrix. :param volume: a numpy array :param aff: affine matrix of the floating volume @@ -638,14 +767,17 @@ def align_volume_to_ref(volume, aff, aff_ref=None, return_aff=False, n_dims=None if ras_axes_flo[i] != ras_axes_ref[i]: new_volume = np.swapaxes(new_volume, ras_axes_flo[i], ras_axes_ref[i]) swapped_axis_idx = np.where(ras_axes_flo == ras_axes_ref[i]) - ras_axes_flo[swapped_axis_idx], ras_axes_flo[i] = ras_axes_flo[i], ras_axes_flo[swapped_axis_idx] + ras_axes_flo[swapped_axis_idx], ras_axes_flo[i] = ( + ras_axes_flo[i], + ras_axes_flo[swapped_axis_idx], + ) # align directions dot_products = np.sum(aff_flo[:3, :3] * aff_ref[:3, :3], axis=0) for i in range(n_dims): if dot_products[i] < 0: new_volume = np.flip(new_volume, axis=i) - aff_flo[:, i] = - aff_flo[:, i] + aff_flo[:, i] = -aff_flo[:, i] aff_flo[:3, 3] = aff_flo[:3, 3] - aff_flo[:3, i] * (new_volume.shape[i] - 1) if return_aff: @@ -666,17 +798,21 @@ def blur_volume(volume, sigma, mask=None): # initialisation new_volume = volume.copy() n_dims, _ = utils.get_dims(new_volume.shape) - sigma = utils.reformat_to_list(sigma, length=n_dims, dtype='float') + sigma = utils.reformat_to_list(sigma, length=n_dims, dtype="float") # blur image - new_volume = gaussian_filter(new_volume, sigma=sigma, mode='nearest') # nearest refers to edge padding + new_volume = gaussian_filter( + new_volume, sigma=sigma, mode="nearest" + ) # nearest refers to edge padding # correct edge effect if mask is not None if mask is not None: - assert new_volume.shape == mask.shape, 'volume and mask should have the same dimensions: ' \ - 'got {0} and {1}'.format(new_volume.shape, mask.shape) + assert new_volume.shape == mask.shape, ( + "volume and mask should have the same dimensions: " + "got {0} and {1}".format(new_volume.shape, mask.shape) + ) mask = (mask > 0) * 1.0 - blurred_mask = gaussian_filter(mask, sigma=sigma, mode='nearest') + blurred_mask = gaussian_filter(mask, sigma=sigma, mode="nearest") new_volume = new_volume / (blurred_mask + 1e-6) new_volume[mask == 0] = 0 @@ -685,8 +821,15 @@ def blur_volume(volume, sigma, mask=None): # --------------------------------------------------- edit label map --------------------------------------------------- -def correct_label_map(labels, list_incorrect_labels, list_correct_labels=None, use_nearest_label=False, - remove_zero=False, smooth=False): + +def correct_label_map( + labels, + list_incorrect_labels, + list_correct_labels=None, + use_nearest_label=False, + remove_zero=False, + smooth=False, +): """This function corrects specified label values in a label map by either a list of given values, or by the nearest label. :param labels: a 2d or 3d label map @@ -703,27 +846,39 @@ def correct_label_map(labels, list_incorrect_labels, list_correct_labels=None, u :return: corrected label map """ - assert (list_correct_labels is not None) | use_nearest_label, \ - 'please provide a list of correct labels, or set use_nearest_label to True.' - assert (list_correct_labels is None) | (not use_nearest_label), \ - 'cannot provide a list of correct values and set use_nearest_label to True' + assert ( + list_correct_labels is not None + ) | use_nearest_label, ( + "please provide a list of correct labels, or set use_nearest_label to True." + ) + assert (list_correct_labels is None) | ( + not use_nearest_label + ), "cannot provide a list of correct values and set use_nearest_label to True" # initialisation new_labels = labels.copy() - list_incorrect_labels = utils.reformat_to_list(utils.load_array_if_path(list_incorrect_labels)) + list_incorrect_labels = utils.reformat_to_list( + utils.load_array_if_path(list_incorrect_labels) + ) volume_labels = np.unique(labels) n_dims, _ = utils.get_dims(labels.shape) # use list of correct values if list_correct_labels is not None: - list_correct_labels = utils.reformat_to_list(utils.load_array_if_path(list_correct_labels)) + list_correct_labels = utils.reformat_to_list( + utils.load_array_if_path(list_correct_labels) + ) # loop over label values - for incorrect_label, correct_label in zip(list_incorrect_labels, list_correct_labels): + for incorrect_label, correct_label in zip( + list_incorrect_labels, list_correct_labels + ): if incorrect_label in volume_labels: # only one possible value to replace with - if isinstance(correct_label, (int, float, np.int64, np.int32, np.int16, np.int8)): + if isinstance( + correct_label, (int, float, np.int64, np.int32, np.int16, np.int8) + ): incorrect_voxels = np.where(labels == incorrect_label) new_labels[incorrect_voxels] = correct_label @@ -732,8 +887,12 @@ def correct_label_map(labels, list_incorrect_labels, list_correct_labels=None, u # make sure at least one correct label is present if not any([lab in volume_labels for lab in correct_label]): - print('no correct values found in volume, please adjust: ' - 'incorrect: {}, correct: {}'.format(incorrect_label, correct_label)) + print( + "no correct values found in volume, please adjust: " + "incorrect: {}, correct: {}".format( + incorrect_label, correct_label + ) + ) # crop around incorrect label until we find incorrect labels correct_label_not_found = True @@ -741,21 +900,34 @@ def correct_label_map(labels, list_incorrect_labels, list_correct_labels=None, u tmp_labels = None crop = None while correct_label_not_found: - tmp_labels, crop = crop_volume_around_region(labels, - masking_labels=incorrect_label, - margin=10 * margin_mult) - correct_label_not_found = not any([lab in np.unique(tmp_labels) for lab in correct_label]) + tmp_labels, crop = crop_volume_around_region( + labels, + masking_labels=incorrect_label, + margin=10 * margin_mult, + ) + correct_label_not_found = not any( + [lab in np.unique(tmp_labels) for lab in correct_label] + ) margin_mult += 1 # calculate distance maps for all new label candidates incorrect_voxels = np.where(tmp_labels == incorrect_label) - distance_map_list = [distance_transform_edt(tmp_labels != lab) for lab in correct_label] - distances_correct = np.stack([dist[incorrect_voxels] for dist in distance_map_list]) + distance_map_list = [ + distance_transform_edt(tmp_labels != lab) + for lab in correct_label + ] + distances_correct = np.stack( + [dist[incorrect_voxels] for dist in distance_map_list] + ) # select nearest values and use them to correct label map idx_correct_lab = np.argmin(distances_correct, axis=0) - incorrect_voxels = tuple([incorrect_voxels[i] + crop[i] for i in range(n_dims)]) - new_labels[incorrect_voxels] = np.array(correct_label)[idx_correct_lab] + incorrect_voxels = tuple( + [incorrect_voxels[i] + crop[i] for i in range(n_dims)] + ) + new_labels[incorrect_voxels] = np.array(correct_label)[ + idx_correct_lab + ] # use nearest label else: @@ -766,21 +938,27 @@ def correct_label_map(labels, list_incorrect_labels, list_correct_labels=None, u # loop around regions components, n_components = scipy_label(labels == incorrect_label) - loop_info = utils.LoopInfo(n_components + 1, 100, 'correcting') + loop_info = utils.LoopInfo(n_components + 1, 100, "correcting") for i in range(1, n_components + 1): loop_info.update(i) # crop each region - _, crop = crop_volume_around_region(components, masking_labels=i, margin=1) + _, crop = crop_volume_around_region( + components, masking_labels=i, margin=1 + ) tmp_labels = crop_volume_with_idx(labels, crop) tmp_new_labels = crop_volume_with_idx(new_labels, crop) # list all possible correct labels correct_labels = np.unique(tmp_labels) for il in list_incorrect_labels: - correct_labels = np.delete(correct_labels, np.where(correct_labels == il)) + correct_labels = np.delete( + correct_labels, np.where(correct_labels == il) + ) if remove_zero: - correct_labels = np.delete(correct_labels, np.where(correct_labels == 0)) + correct_labels = np.delete( + correct_labels, np.where(correct_labels == 0) + ) # replace incorrect voxels by new value incorrect_voxels = np.where(tmp_labels == incorrect_label) @@ -788,18 +966,31 @@ def correct_label_map(labels, list_incorrect_labels, list_correct_labels=None, u tmp_new_labels[incorrect_voxels] = -1 else: if len(correct_labels) == 1: - idx_correct_lab = np.zeros(len(incorrect_voxels[0]), dtype='int32') + idx_correct_lab = np.zeros( + len(incorrect_voxels[0]), dtype="int32" + ) else: - distance_map_list = [distance_transform_edt(tmp_labels != lab) for lab in correct_labels] - distances_correct = np.stack([dist[incorrect_voxels] for dist in distance_map_list]) + distance_map_list = [ + distance_transform_edt(tmp_labels != lab) + for lab in correct_labels + ] + distances_correct = np.stack( + [dist[incorrect_voxels] for dist in distance_map_list] + ) idx_correct_lab = np.argmin(distances_correct, axis=0) - tmp_new_labels[incorrect_voxels] = np.array(correct_labels)[idx_correct_lab] + tmp_new_labels[incorrect_voxels] = np.array(correct_labels)[ + idx_correct_lab + ] # paste back if n_dims == 2: - new_labels[crop[0]:crop[2], crop[1]:crop[3], ...] = tmp_new_labels + new_labels[crop[0] : crop[2], crop[1] : crop[3], ...] = ( + tmp_new_labels + ) else: - new_labels[crop[0]:crop[3], crop[1]:crop[4], crop[2]:crop[5], ...] = tmp_new_labels + new_labels[ + crop[0] : crop[3], crop[1] : crop[4], crop[2] : crop[5], ... + ] = tmp_new_labels # smoothing if smooth: @@ -843,18 +1034,20 @@ def smooth_label_map(labels, kernel, labels_list=None, print_progress=0): """ # get info labels_shape = labels.shape - unique_labels = np.unique(labels).astype('int32') + unique_labels = np.unique(labels).astype("int32") if labels_list is None: labels_list = unique_labels new_labels = mask_new_labels = None else: labels_to_keep = [lab for lab in unique_labels if lab not in labels_list] - new_labels, mask_new_labels = mask_label_map(labels, labels_to_keep, return_mask=True) + new_labels, mask_new_labels = mask_label_map( + labels, labels_to_keep, return_mask=True + ) # loop through label values count = np.zeros(labels_shape) - labels_smoothed = np.zeros(labels_shape, dtype='int') - loop_info = utils.LoopInfo(len(labels_list), print_progress, 'smoothing') + labels_smoothed = np.zeros(labels_shape, dtype="int") + loop_info = utils.LoopInfo(len(labels_list), print_progress, "smoothing") for la, label in enumerate(labels_list): if print_progress: loop_info.update(la) @@ -867,7 +1060,7 @@ def smooth_label_map(labels, kernel, labels_list=None, print_progress=0): idx = n_neighbours > count count[idx] = n_neighbours[idx] labels_smoothed[idx] = label - labels_smoothed = labels_smoothed.astype('int32') + labels_smoothed = labels_smoothed.astype("int32") if new_labels is None: new_labels = labels_smoothed @@ -877,7 +1070,14 @@ def smooth_label_map(labels, kernel, labels_list=None, print_progress=0): return new_labels -def erode_label_map(labels, labels_to_erode, erosion_factors=1., gpu=False, model=None, return_model=False): +def erode_label_map( + labels, + labels_to_erode, + erosion_factors=1.0, + gpu=False, + model=None, + return_model=False, +): """Erode a given set of label values within a label map. :param labels: a 2d or 3d label map :param labels_to_erode: list of label values to erode @@ -893,17 +1093,23 @@ def erode_label_map(labels, labels_to_erode, erosion_factors=1., gpu=False, mode # reformat labels_to_erode and erode new_labels = labels.copy() labels_to_erode = utils.reformat_to_list(labels_to_erode) - erosion_factors = utils.reformat_to_list(erosion_factors, length=len(labels_to_erode)) + erosion_factors = utils.reformat_to_list( + erosion_factors, length=len(labels_to_erode) + ) labels_shape = list(new_labels.shape) n_dims, _ = utils.get_dims(labels_shape) # loop over labels to erode for label_to_erode, erosion_factor in zip(labels_to_erode, erosion_factors): - assert erosion_factor > 0, 'all erosion factors should be strictly positive, had {}'.format(erosion_factor) + assert ( + erosion_factor > 0 + ), "all erosion factors should be strictly positive, had {}".format( + erosion_factor + ) # get mask of current label value - mask = (new_labels == label_to_erode) + mask = new_labels == label_to_erode # erode as usual if erosion factor is int if int(erosion_factor) == erosion_factor: @@ -914,12 +1120,14 @@ def erode_label_map(labels, labels_to_erode, erosion_factors=1., gpu=False, mode else: if gpu: if model is None: - mask_in = KL.Input(shape=labels_shape + [1], dtype='float32') + mask_in = KL.Input(shape=labels_shape + [1], dtype="float32") blurred_mask = GaussianBlur([1] * 3)(mask_in) model = Model(inputs=mask_in, outputs=blurred_mask) - eroded_mask = model.predict(utils.add_axis(np.float32(mask), axis=[0, -1])) + eroded_mask = model.predict( + utils.add_axis(np.float32(mask), axis=[0, -1]) + ) else: - eroded_mask = blur_volume(np.array(mask, dtype='float32'), 1) + eroded_mask = blur_volume(np.array(mask, dtype="float32"), 1) eroded_mask = np.squeeze(eroded_mask) > erosion_factor # crop label map and mask around values to change @@ -930,16 +1138,28 @@ def erode_label_map(labels, labels_to_erode, erosion_factors=1., gpu=False, mode # calculate distance maps for all labels in cropped_labels labels_list = np.unique(cropped_labels) labels_list = labels_list[labels_list != label_to_erode] - list_dist_maps = [distance_transform_edt(np.logical_not(cropped_labels == la)) for la in labels_list] - candidate_distances = np.stack([dist[cropped_lab_mask] for dist in list_dist_maps]) + list_dist_maps = [ + distance_transform_edt(np.logical_not(cropped_labels == la)) + for la in labels_list + ] + candidate_distances = np.stack( + [dist[cropped_lab_mask] for dist in list_dist_maps] + ) # select nearest value and put cropped labels back to full label map idx_correct_lab = np.argmin(candidate_distances, axis=0) cropped_labels[cropped_lab_mask] = np.array(labels_list)[idx_correct_lab] if n_dims == 2: - new_labels[cropping[0]:cropping[2], cropping[1]:cropping[3], ...] = cropped_labels + new_labels[cropping[0] : cropping[2], cropping[1] : cropping[3], ...] = ( + cropped_labels + ) elif n_dims == 3: - new_labels[cropping[0]:cropping[3], cropping[1]:cropping[4], cropping[2]:cropping[5], ...] = cropped_labels + new_labels[ + cropping[0] : cropping[3], + cropping[1] : cropping[4], + cropping[2] : cropping[5], + ..., + ] = cropped_labels if return_model: return new_labels, model @@ -953,10 +1173,16 @@ def get_largest_connected_component(mask, structure=None): :param structure: numpy array defining the connectivity. """ components, n_components = scipy_label(mask, structure) - return components == np.argmax(np.bincount(components.flat)[1:]) + 1 if n_components > 0 else mask.copy() + return ( + components == np.argmax(np.bincount(components.flat)[1:]) + 1 + if n_components > 0 + else mask.copy() + ) -def compute_hard_volumes(labels, voxel_volume=1., label_list=None, skip_background=True): +def compute_hard_volumes( + labels, voxel_volume=1.0, label_list=None, skip_background=True +): """Compute hard volumes in a label map. :param labels: a label map :param voxel_volume: (optional) volume of voxel. Default is 1 (i.e. returned volumes are voxel counts). @@ -969,7 +1195,7 @@ def compute_hard_volumes(labels, voxel_volume=1., label_list=None, skip_backgrou """ # initialisation - subject_label_list = utils.reformat_to_list(np.unique(labels), dtype='int') + subject_label_list = utils.reformat_to_list(np.unique(labels), dtype="int") if label_list is None: label_list = subject_label_list else: @@ -996,7 +1222,8 @@ def compute_distance_map(labels, masking_labels=None, crop_margin=None): for these labels only. Default is None, where all positive values are considered. :param crop_margin: (optional) margin with which to crop the input label maps around the labels for which we want to compute the distance maps. - :return: a distance map with positive values inside the considered regions, and negative values outside.""" + :return: a distance map with positive values inside the considered regions, and negative values outside. + """ n_dims, _ = utils.get_dims(labels.shape) @@ -1010,7 +1237,7 @@ def compute_distance_map(labels, masking_labels=None, crop_margin=None): # mask label map around specify values if masking_labels is not None: masking_labels = utils.reformat_to_list(masking_labels) - mask = np.zeros(tmp_labels.shape, dtype='bool') + mask = np.zeros(tmp_labels.shape, dtype="bool") for masking_label in masking_labels: mask = mask | tmp_labels == masking_label else: @@ -1020,17 +1247,22 @@ def compute_distance_map(labels, masking_labels=None, crop_margin=None): # compute distances dist_in = distance_transform_edt(mask) dist_in = np.where(mask, dist_in - 0.5, dist_in) - dist_out = - distance_transform_edt(not_mask) + dist_out = -distance_transform_edt(not_mask) dist_out = np.where(not_mask, dist_out + 0.5, dist_out) tmp_dist = dist_in + dist_out # put back in original matrix if we cropped if crop_idx is not None: - dist = np.min(tmp_dist) * np.ones(labels.shape, dtype='float32') + dist = np.min(tmp_dist) * np.ones(labels.shape, dtype="float32") if n_dims == 3: - dist[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5], ...] = tmp_dist + dist[ + crop_idx[0] : crop_idx[3], + crop_idx[1] : crop_idx[4], + crop_idx[2] : crop_idx[5], + ..., + ] = tmp_dist elif n_dims == 2: - dist[crop_idx[0]:crop_idx[2], crop_idx[1]:crop_idx[3], ...] = tmp_dist + dist[crop_idx[0] : crop_idx[2], crop_idx[1] : crop_idx[3], ...] = tmp_dist else: dist = tmp_dist @@ -1039,8 +1271,20 @@ def compute_distance_map(labels, masking_labels=None, crop_margin=None): # ------------------------------------------------- edit volumes in dir ------------------------------------------------ -def mask_images_in_dir(image_dir, result_dir, mask_dir=None, threshold=0.1, dilate=0, erode=0, fill_holes=False, - masking_value=0, write_mask=False, mask_result_dir=None, recompute=True): + +def mask_images_in_dir( + image_dir, + result_dir, + mask_dir=None, + threshold=0.1, + dilate=0, + erode=0, + fill_holes=False, + masking_value=0, + write_mask=False, + mask_result_dir=None, + recompute=True, +): """Mask all volumes in a folder, either with masks in a specified folder, or by keeping only the intensity values above a specified threshold. :param image_dir: path of directory with images to mask @@ -1072,7 +1316,7 @@ def mask_images_in_dir(image_dir, result_dir, mask_dir=None, threshold=0.1, dila path_masks = [None] * len(path_images) # loop over images - loop_info = utils.LoopInfo(len(path_images), 10, 'masking', True) + loop_info = utils.LoopInfo(len(path_images), 10, "masking", True) for idx, (path_image, path_mask) in enumerate(zip(path_images, path_masks)): loop_info.update(idx) @@ -1084,22 +1328,41 @@ def mask_images_in_dir(image_dir, result_dir, mask_dir=None, threshold=0.1, dila mask = utils.load_volume(path_mask) else: mask = None - im = mask_volume(im, mask, threshold, dilate, erode, fill_holes, masking_value, write_mask) + im = mask_volume( + im, + mask, + threshold, + dilate, + erode, + fill_holes, + masking_value, + write_mask, + ) # write mask if necessary if write_mask: - assert mask_result_dir is not None, 'if write_mask is True, mask_result_dir has to be specified as well' - mask_result_path = os.path.join(mask_result_dir, os.path.basename(path_image)) + assert ( + mask_result_dir is not None + ), "if write_mask is True, mask_result_dir has to be specified as well" + mask_result_path = os.path.join( + mask_result_dir, os.path.basename(path_image) + ) utils.save_volume(im[1], aff, h, mask_result_path) utils.save_volume(im[0], aff, h, path_result) else: utils.save_volume(im, aff, h, path_result) -def rescale_images_in_dir(image_dir, result_dir, - new_min=0, new_max=255, - min_percentile=2, max_percentile=98, use_positive_only=True, - recompute=True): +def rescale_images_in_dir( + image_dir, + result_dir, + new_min=0, + new_max=255, + min_percentile=2, + max_percentile=98, + use_positive_only=True, + recompute=True, +): """This function linearly rescales all volumes in image_dir between new_min and new_max. :param image_dir: path of directory with images to rescale :param result_dir: path of directory where rescaled images will be writen @@ -1118,18 +1381,22 @@ def rescale_images_in_dir(image_dir, result_dir, # loop over images path_images = utils.list_images_in_folder(image_dir) - loop_info = utils.LoopInfo(len(path_images), 10, 'rescaling', True) + loop_info = utils.LoopInfo(len(path_images), 10, "rescaling", True) for idx, path_image in enumerate(path_images): loop_info.update(idx) path_result = os.path.join(result_dir, os.path.basename(path_image)) if (not os.path.isfile(path_result)) | recompute: im, aff, h = utils.load_volume(path_image, im_only=False) - im = rescale_volume(im, new_min, new_max, min_percentile, max_percentile, use_positive_only) + im = rescale_volume( + im, new_min, new_max, min_percentile, max_percentile, use_positive_only + ) utils.save_volume(im, aff, h, path_result) -def crop_images_in_dir(image_dir, result_dir, cropping_margin=None, cropping_shape=None, recompute=True): +def crop_images_in_dir( + image_dir, result_dir, cropping_margin=None, cropping_shape=None, recompute=True +): """Crop all volumes in a folder by a given margin, or to a given shape. :param image_dir: path of directory with images to rescale :param result_dir: path of directory where cropped images will be writen @@ -1145,7 +1412,7 @@ def crop_images_in_dir(image_dir, result_dir, cropping_margin=None, cropping_sha # loop over images and masks path_images = utils.list_images_in_folder(image_dir) - loop_info = utils.LoopInfo(len(path_images), 10, 'cropping', True) + loop_info = utils.LoopInfo(len(path_images), 10, "cropping", True) for idx, path_image in enumerate(path_images): loop_info.update(idx) @@ -1157,13 +1424,15 @@ def crop_images_in_dir(image_dir, result_dir, cropping_margin=None, cropping_sha utils.save_volume(volume, aff, h, path_result) -def crop_images_around_region_in_dir(image_dir, - result_dir, - mask_dir=None, - threshold=0.1, - masking_labels=None, - crop_margin=5, - recompute=True): +def crop_images_around_region_in_dir( + image_dir, + result_dir, + mask_dir=None, + threshold=0.1, + masking_labels=None, + crop_margin=5, + recompute=True, +): """Crop all volumes in a folder around a region, which is defined for each volume by a mask obtained by either 1) directly providing it as input 2) thresholding the input volume @@ -1189,7 +1458,7 @@ def crop_images_around_region_in_dir(image_dir, path_masks = [None] * len(path_images) # loop over images and masks - loop_info = utils.LoopInfo(len(path_images), 10, 'cropping', True) + loop_info = utils.LoopInfo(len(path_images), 10, "cropping", True) for idx, (path_image, path_mask) in enumerate(zip(path_images, path_masks)): loop_info.update(idx) @@ -1201,11 +1470,15 @@ def crop_images_around_region_in_dir(image_dir, mask = utils.load_volume(path_mask) else: mask = None - volume, cropping, aff = crop_volume_around_region(volume, mask, threshold, masking_labels, crop_margin, aff) + volume, cropping, aff = crop_volume_around_region( + volume, mask, threshold, masking_labels, crop_margin, aff + ) utils.save_volume(volume, aff, h, path_result) -def pad_images_in_dir(image_dir, result_dir, max_shape=None, padding_value=0, recompute=True): +def pad_images_in_dir( + image_dir, result_dir, max_shape=None, padding_value=0, recompute=True +): """Pads all the volumes in a folder to the same shape (either provided or computed). :param image_dir: path of directory with images to pad :param result_dir: path of directory where padded images will be writen @@ -1227,11 +1500,13 @@ def pad_images_in_dir(image_dir, result_dir, max_shape=None, padding_value=0, re max_shape, aff, _, _, h, _ = utils.get_volume_info(path_images[0]) for path_image in path_images[1:]: image_shape, aff, _, _, h, _ = utils.get_volume_info(path_image) - max_shape = tuple(np.maximum(np.asarray(max_shape), np.asarray(image_shape))) + max_shape = tuple( + np.maximum(np.asarray(max_shape), np.asarray(image_shape)) + ) max_shape = np.array(max_shape) # loop over label maps - loop_info = utils.LoopInfo(len(path_images), 10, 'padding', True) + loop_info = utils.LoopInfo(len(path_images), 10, "padding", True) for idx, path_image in enumerate(path_images): loop_info.update(idx) @@ -1245,7 +1520,9 @@ def pad_images_in_dir(image_dir, result_dir, max_shape=None, padding_value=0, re return max_shape -def flip_images_in_dir(image_dir, result_dir, axis=None, direction=None, recompute=True): +def flip_images_in_dir( + image_dir, result_dir, axis=None, direction=None, recompute=True +): """Flip all images in a directory along a specified axis. If unknown, this axis can be replaced by an anatomical direction. :param image_dir: path of directory with images to flip @@ -1260,7 +1537,7 @@ def flip_images_in_dir(image_dir, result_dir, axis=None, direction=None, recompu # loop over images path_images = utils.list_images_in_folder(image_dir) - loop_info = utils.LoopInfo(len(path_images), 10, 'flipping', True) + loop_info = utils.LoopInfo(len(path_images), 10, "flipping", True) for idx, path_image in enumerate(path_images): loop_info.update(idx) @@ -1272,7 +1549,9 @@ def flip_images_in_dir(image_dir, result_dir, axis=None, direction=None, recompu utils.save_volume(im, aff, h, path_result) -def align_images_in_dir(image_dir, result_dir, aff_ref=None, path_ref=None, recompute=True): +def align_images_in_dir( + image_dir, result_dir, aff_ref=None, path_ref=None, recompute=True +): """This function aligns all images in image_dir to a reference orientation (axes and directions). This reference orientation can be directly provided as an affine matrix, or can be specified by a reference volume. If neither are provided, the reference orientation is assumed to be an identity matrix. @@ -1291,9 +1570,14 @@ def align_images_in_dir(image_dir, result_dir, aff_ref=None, path_ref=None, reco # read reference affine matrix if path_ref is not None: - assert aff_ref is None, 'cannot provide aff_ref and path_ref together.' + assert aff_ref is None, "cannot provide aff_ref and path_ref together." basename = os.path.basename(path_ref) - if ('.nii.gz' in basename) | ('.nii' in basename) | ('.mgz' in basename) | ('.npz' in basename): + if ( + (".nii.gz" in basename) + | (".nii" in basename) + | (".mgz" in basename) + | (".npz" in basename) + ): _, aff_ref, _ = utils.load_volume(path_ref, im_only=False) path_refs = [None] * len(path_images) else: @@ -1306,7 +1590,7 @@ def align_images_in_dir(image_dir, result_dir, aff_ref=None, path_ref=None, reco path_refs = [None] * len(path_images) # loop over images - loop_info = utils.LoopInfo(len(path_images), 10, 'aligning', True) + loop_info = utils.LoopInfo(len(path_images), 10, "aligning", True) for idx, (path_image, path_ref) in enumerate(zip(path_images, path_refs)): loop_info.update(idx) @@ -1331,7 +1615,7 @@ def correct_nans_images_in_dir(image_dir, result_dir, recompute=True): # loop over images path_images = utils.list_images_in_folder(image_dir) - loop_info = utils.LoopInfo(len(path_images), 10, 'correcting', True) + loop_info = utils.LoopInfo(len(path_images), 10, "correcting", True) for idx, path_image in enumerate(path_images): loop_info.update(idx) @@ -1343,7 +1627,9 @@ def correct_nans_images_in_dir(image_dir, result_dir, recompute=True): utils.save_volume(im, aff, h, path_result) -def blur_images_in_dir(image_dir, result_dir, sigma, mask_dir=None, gpu=False, recompute=True): +def blur_images_in_dir( + image_dir, result_dir, sigma, mask_dir=None, gpu=False, recompute=True +): """This function blurs all the images in image_dir with kernels of the specified std deviations. :param image_dir: path of directory with images to blur :param result_dir: path of directory where blurred images will be writen @@ -1368,17 +1654,21 @@ def blur_images_in_dir(image_dir, result_dir, sigma, mask_dir=None, gpu=False, r # loop over images previous_model_input_shape = None model = None - loop_info = utils.LoopInfo(len(path_images), 10, 'blurring', True) + loop_info = utils.LoopInfo(len(path_images), 10, "blurring", True) for idx, (path_image, path_mask) in enumerate(zip(path_images, path_masks)): loop_info.update(idx) # load image path_result = os.path.join(result_dir, os.path.basename(path_image)) if (not os.path.isfile(path_result)) | recompute: - im, im_shape, aff, n_dims, _, h, _ = utils.get_volume_info(path_image, return_volume=True) + im, im_shape, aff, n_dims, _, h, _ = utils.get_volume_info( + path_image, return_volume=True + ) if path_mask is not None: mask = utils.load_volume(path_mask) - assert mask.shape == im.shape, 'mask and image should have the same shape' + assert ( + mask.shape == im.shape + ), "mask and image should have the same shape" else: mask = None @@ -1391,13 +1681,17 @@ def blur_images_in_dir(image_dir, result_dir, sigma, mask_dir=None, gpu=False, r if mask is None: image = GaussianBlur(sigma=sigma)(inputs[0]) else: - inputs.append(KL.Input(shape=im_shape + [1], dtype='float32')) + inputs.append(KL.Input(shape=im_shape + [1], dtype="float32")) image = GaussianBlur(sigma=sigma, use_mask=True)(inputs) model = Model(inputs=inputs, outputs=image) if mask is None: im = np.squeeze(model.predict(utils.add_axis(im, axis=[0, -1]))) else: - im = np.squeeze(model.predict([utils.add_axis(im, [0, -1]), utils.add_axis(mask, [0, -1])])) + im = np.squeeze( + model.predict( + [utils.add_axis(im, [0, -1]), utils.add_axis(mask, [0, -1])] + ) + ) else: im = blur_volume(im, sigma, mask=mask) utils.save_volume(im, aff, h, path_result) @@ -1414,7 +1708,9 @@ def create_mutlimodal_images(list_channel_dir, result_dir, recompute=True): # create result dir utils.mkdir(result_dir) - assert isinstance(list_channel_dir, (list, tuple)), 'list_channel_dir should be a list or a tuple' + assert isinstance( + list_channel_dir, (list, tuple) + ), "list_channel_dir should be a list or a tuple" # gather path of all images for all channels list_channel_paths = [utils.list_images_in_folder(d) for d in list_channel_dir] @@ -1422,27 +1718,33 @@ def create_mutlimodal_images(list_channel_dir, result_dir, recompute=True): n_channels = len(list_channel_dir) for channel_paths in list_channel_paths: if len(channel_paths) != n_images: - raise ValueError('all directories should have the same number of files') + raise ValueError("all directories should have the same number of files") # loop over images - loop_info = utils.LoopInfo(n_images, 10, 'processing', True) + loop_info = utils.LoopInfo(n_images, 10, "processing", True) for idx in range(n_images): loop_info.update(idx) # stack all channels and save multichannel image - path_result = os.path.join(result_dir, os.path.basename(list_channel_paths[0][idx])) + path_result = os.path.join( + result_dir, os.path.basename(list_channel_paths[0][idx]) + ) if (not os.path.isfile(path_result)) | recompute: list_channels = list() tmp_aff = None tmp_h = None for channel_idx in range(n_channels): - tmp_channel, tmp_aff, tmp_h = utils.load_volume(list_channel_paths[channel_idx][idx], im_only=False) + tmp_channel, tmp_aff, tmp_h = utils.load_volume( + list_channel_paths[channel_idx][idx], im_only=False + ) list_channels.append(tmp_channel) im = np.stack(list_channels, axis=-1) utils.save_volume(im, tmp_aff, tmp_h, path_result) -def convert_images_in_dir_to_nifty(image_dir, result_dir, aff=None, ref_aff_dir=None, recompute=True): +def convert_images_in_dir_to_nifty( + image_dir, result_dir, aff=None, ref_aff_dir=None, recompute=True +): """Converts all images in image_dir to nifty format. :param image_dir: path of directory with images to convert :param result_dir: path of directory where converted images will be writen @@ -1464,14 +1766,19 @@ def convert_images_in_dir_to_nifty(image_dir, result_dir, aff=None, ref_aff_dir= path_ref_images = [None] * len(path_images) # loop over images - loop_info = utils.LoopInfo(len(path_images), 10, 'converting', True) + loop_info = utils.LoopInfo(len(path_images), 10, "converting", True) for idx, (path_image, path_ref) in enumerate(zip(path_images, path_ref_images)): loop_info.update(idx) # convert images to nifty format - path_result = os.path.join(result_dir, os.path.basename(utils.strip_extension(path_image))) + '.nii.gz' + path_result = ( + os.path.join( + result_dir, os.path.basename(utils.strip_extension(path_image)) + ) + + ".nii.gz" + ) if (not os.path.isfile(path_result)) | recompute: - if utils.get_image_extension(path_image) == 'nii.gz': + if utils.get_image_extension(path_image) == "nii.gz": shutil.copy2(path_image, path_result) else: im, tmp_aff, h = utils.load_volume(path_image, im_only=False) @@ -1482,15 +1789,17 @@ def convert_images_in_dir_to_nifty(image_dir, result_dir, aff=None, ref_aff_dir= utils.save_volume(im, tmp_aff, h, path_result) -def mri_convert_images_in_dir(image_dir, - result_dir, - interpolation=None, - reference_dir=None, - same_reference=False, - voxsize=None, - path_freesurfer='/usr/local/freesurfer', - mri_convert_path='/usr/local/freesurfer/bin/mri_convert', - recompute=True): +def mri_convert_images_in_dir( + image_dir, + result_dir, + interpolation=None, + reference_dir=None, + same_reference=False, + voxsize=None, + path_freesurfer="/usr/local/freesurfer", + mri_convert_path="/usr/local/freesurfer/bin/mri_convert", + recompute=True, +): """This function launches mri_convert on all images contained in image_dir, and writes the results in result_dir. The interpolation type can be specified (i.e. 'nearest'), as well as a folder containing references for resampling. reference_dir can be the path of a single *image* if same_reference=True. @@ -1512,9 +1821,9 @@ def mri_convert_images_in_dir(image_dir, utils.mkdir(result_dir) # set up FreeSurfer - os.environ['FREESURFER_HOME'] = path_freesurfer - os.system(os.path.join(path_freesurfer, 'SetUpFreeSurfer.sh')) - mri_convert = mri_convert_path + ' ' + os.environ["FREESURFER_HOME"] = path_freesurfer + os.system(os.path.join(path_freesurfer, "SetUpFreeSurfer.sh")) + mri_convert = mri_convert_path + " " # list images path_images = utils.list_images_in_folder(image_dir) @@ -1523,36 +1832,42 @@ def mri_convert_images_in_dir(image_dir, path_references = [reference_dir] * len(path_images) else: path_references = utils.list_images_in_folder(reference_dir) - assert len(path_references) == len(path_images), 'different number of files in image_dir and reference_dir' + assert len(path_references) == len( + path_images + ), "different number of files in image_dir and reference_dir" else: path_references = [None] * len(path_images) # loop over images - loop_info = utils.LoopInfo(len(path_images), 10, 'converting', True) - for idx, (path_image, path_reference) in enumerate(zip(path_images, path_references)): + loop_info = utils.LoopInfo(len(path_images), 10, "converting", True) + for idx, (path_image, path_reference) in enumerate( + zip(path_images, path_references) + ): loop_info.update(idx) # convert image path_result = os.path.join(result_dir, os.path.basename(path_image)) if (not os.path.isfile(path_result)) | recompute: - cmd = mri_convert + path_image + ' ' + path_result + ' -odt float' + cmd = mri_convert + path_image + " " + path_result + " -odt float" if interpolation is not None: - cmd += ' -rt ' + interpolation + cmd += " -rt " + interpolation if reference_dir is not None: - cmd += ' -rl ' + path_reference + cmd += " -rl " + path_reference if voxsize is not None: - voxsize = utils.reformat_to_list(voxsize, dtype='float') - cmd += ' --voxsize ' + ' '.join([str(np.around(v, 3)) for v in voxsize]) + voxsize = utils.reformat_to_list(voxsize, dtype="float") + cmd += " --voxsize " + " ".join([str(np.around(v, 3)) for v in voxsize]) os.system(cmd) -def samseg_images_in_dir(image_dir, - result_dir, - atlas_dir=None, - threads=4, - path_freesurfer='/usr/local/freesurfer', - keep_segm_only=True, - recompute=True): +def samseg_images_in_dir( + image_dir, + result_dir, + atlas_dir=None, + threads=4, + path_freesurfer="/usr/local/freesurfer", + keep_segm_only=True, + recompute=True, +): """This function launches samseg for all images contained in image_dir and writes the results in result_dir. If keep_segm_only=True, the result segmentation is copied in result_dir and SAMSEG's intermediate result dir is deleted. @@ -1570,29 +1885,42 @@ def samseg_images_in_dir(image_dir, utils.mkdir(result_dir) # set up FreeSurfer - os.environ['FREESURFER_HOME'] = path_freesurfer - os.system(os.path.join(path_freesurfer, 'SetUpFreeSurfer.sh')) - path_samseg = os.path.join(path_freesurfer, 'bin', 'run_samseg') + os.environ["FREESURFER_HOME"] = path_freesurfer + os.system(os.path.join(path_freesurfer, "SetUpFreeSurfer.sh")) + path_samseg = os.path.join(path_freesurfer, "bin", "run_samseg") # loop over images path_images = utils.list_images_in_folder(image_dir) - loop_info = utils.LoopInfo(len(path_images), 10, 'processing', True) + loop_info = utils.LoopInfo(len(path_images), 10, "processing", True) for idx, path_image in enumerate(path_images): loop_info.update(idx) # build path_result - path_im_result_dir = os.path.join(result_dir, utils.strip_extension(os.path.basename(path_image))) - path_samseg_result = os.path.join(path_im_result_dir, 'seg.mgz') + path_im_result_dir = os.path.join( + result_dir, utils.strip_extension(os.path.basename(path_image)) + ) + path_samseg_result = os.path.join(path_im_result_dir, "seg.mgz") if keep_segm_only: - path_result = os.path.join(result_dir, utils.strip_extension(os.path.basename(path_image)) + '_seg.mgz') + path_result = os.path.join( + result_dir, + utils.strip_extension(os.path.basename(path_image)) + "_seg.mgz", + ) else: path_result = path_samseg_result # run samseg if (not os.path.isfile(path_result)) | recompute: - cmd = utils.mkcmd(path_samseg, '-i', path_image, '-o', path_im_result_dir, '--threads', threads) + cmd = utils.mkcmd( + path_samseg, + "-i", + path_image, + "-o", + path_im_result_dir, + "--threads", + threads, + ) if atlas_dir is not None: - cmd = utils.mkcmd(cmd, '-a', atlas_dir) + cmd = utils.mkcmd(cmd, "-a", atlas_dir) os.system(cmd) # move segmentation to result_dir if necessary @@ -1603,18 +1931,20 @@ def samseg_images_in_dir(image_dir, shutil.rmtree(path_im_result_dir) -def niftyreg_images_in_dir(image_dir, - reference_dir, - nifty_reg_function='reg_resample', - input_transformation_dir=None, - result_dir=None, - result_transformation_dir=None, - interpolation=None, - same_floating=False, - same_reference=False, - same_transformation=False, - path_nifty_reg='/home/benjamin/Softwares/niftyreg-gpu/build/reg-apps', - recompute=True): +def niftyreg_images_in_dir( + image_dir, + reference_dir, + nifty_reg_function="reg_resample", + input_transformation_dir=None, + result_dir=None, + result_transformation_dir=None, + interpolation=None, + same_floating=False, + same_reference=False, + same_transformation=False, + path_nifty_reg="/home/benjamin/Softwares/niftyreg-gpu/build/reg-apps", + recompute=True, +): """This function launches one of niftyreg functions (reg_aladin, reg_f3d, reg_resample) on all images contained in image_dir. :param image_dir: path of directory with images to register. Can also be a single image, in that case set @@ -1650,54 +1980,70 @@ def niftyreg_images_in_dir(image_dir, path_images = utils.list_images_in_folder(image_dir) path_references = utils.list_images_in_folder(reference_dir) if same_reference: - path_references = utils.reformat_to_list(path_references, length=len(path_images)) + path_references = utils.reformat_to_list( + path_references, length=len(path_images) + ) if same_floating: path_images = utils.reformat_to_list(path_images, length=len(path_references)) - assert len(path_references) == len(path_images), 'different number of files in image_dir and reference_dir' + assert len(path_references) == len( + path_images + ), "different number of files in image_dir and reference_dir" # list input transformations if input_transformation_dir is not None: if same_transformation: - path_input_transfs = utils.reformat_to_list(input_transformation_dir, length=len(path_images)) + path_input_transfs = utils.reformat_to_list( + input_transformation_dir, length=len(path_images) + ) else: path_input_transfs = utils.list_files(input_transformation_dir) - assert len(path_input_transfs) == len(path_images), 'different number of transformations and images' + assert len(path_input_transfs) == len( + path_images + ), "different number of transformations and images" else: path_input_transfs = [None] * len(path_images) # define flag input trans if input_transformation_dir is not None: - if nifty_reg_function == 'reg_aladin': - flag_input_trans = '-inaff' - elif nifty_reg_function == 'reg_f3d': - flag_input_trans = '-aff' - elif nifty_reg_function == 'reg_resample': - flag_input_trans = '-trans' + if nifty_reg_function == "reg_aladin": + flag_input_trans = "-inaff" + elif nifty_reg_function == "reg_f3d": + flag_input_trans = "-aff" + elif nifty_reg_function == "reg_resample": + flag_input_trans = "-trans" else: - raise Exception('nifty_reg_function can only be "reg_aladin", "reg_f3d", or "reg_resample"') + raise Exception( + 'nifty_reg_function can only be "reg_aladin", "reg_f3d", or "reg_resample"' + ) else: flag_input_trans = None # define flag result transformation if result_transformation_dir is not None: - if nifty_reg_function == 'reg_aladin': - flag_result_trans = '-aff' - elif nifty_reg_function == 'reg_f3d': - flag_result_trans = '-cpp' + if nifty_reg_function == "reg_aladin": + flag_result_trans = "-aff" + elif nifty_reg_function == "reg_f3d": + flag_result_trans = "-cpp" else: - raise Exception('result_transformation_dir can only be used with "reg_aladin" or "reg_f3d"') + raise Exception( + 'result_transformation_dir can only be used with "reg_aladin" or "reg_f3d"' + ) else: flag_result_trans = None # loop over images - loop_info = utils.LoopInfo(len(path_images), 10, 'processing', True) - for idx, (path_image, path_ref, path_input_trans) in enumerate(zip(path_images, - path_references, - path_input_transfs)): + loop_info = utils.LoopInfo(len(path_images), 10, "processing", True) + for idx, (path_image, path_ref, path_input_trans) in enumerate( + zip(path_images, path_references, path_input_transfs) + ): loop_info.update(idx) # define path registered image - name = os.path.basename(path_ref) if same_floating else os.path.basename(path_image) + name = ( + os.path.basename(path_ref) + if same_floating + else os.path.basename(path_image) + ) if result_dir is not None: path_result = os.path.join(result_dir, name) result_already_computed = os.path.isfile(path_result) @@ -1707,8 +2053,10 @@ def niftyreg_images_in_dir(image_dir, # define path resulting transformation if result_transformation_dir is not None: - if nifty_reg_function == 'reg_aladin': - path_result_trans = os.path.join(result_transformation_dir, utils.strip_extension(name) + '.txt') + if nifty_reg_function == "reg_aladin": + path_result_trans = os.path.join( + result_transformation_dir, utils.strip_extension(name) + ".txt" + ) result_trans_already_computed = os.path.isfile(path_result_trans) else: path_result_trans = os.path.join(result_transformation_dir, name) @@ -1717,30 +2065,36 @@ def niftyreg_images_in_dir(image_dir, path_result_trans = None result_trans_already_computed = True - if (not result_already_computed) | (not result_trans_already_computed) | recompute: + if ( + (not result_already_computed) + | (not result_trans_already_computed) + | recompute + ): # build main command - cmd = utils.mkcmd(nifty_reg, '-ref', path_ref, '-flo', path_image, '-pad 0') + cmd = utils.mkcmd(nifty_reg, "-ref", path_ref, "-flo", path_image, "-pad 0") # add options if path_result is not None: - cmd = utils.mkcmd(cmd, '-res', path_result) + cmd = utils.mkcmd(cmd, "-res", path_result) if flag_input_trans is not None: cmd = utils.mkcmd(cmd, flag_input_trans, path_input_trans) if flag_result_trans is not None: cmd = utils.mkcmd(cmd, flag_result_trans, path_result_trans) if interpolation is not None: - cmd = utils.mkcmd(cmd, '-inter', interpolation) + cmd = utils.mkcmd(cmd, "-inter", interpolation) # execute os.system(cmd) -def upsample_anisotropic_images(image_dir, - resample_image_result_dir, - resample_like_dir, - path_freesurfer='/usr/local/freesurfer/', - recompute=True): +def upsample_anisotropic_images( + image_dir, + resample_image_result_dir, + resample_like_dir, + path_freesurfer="/usr/local/freesurfer/", + recompute=True, +): """This function takes as input a set of LR images and resample them to HR with respect to reference images. :param image_dir: path of directory with input images (only uni-modal images supported) :param resample_image_result_dir: path of directory where resampled images will be writen @@ -1753,66 +2107,97 @@ def upsample_anisotropic_images(image_dir, utils.mkdir(resample_image_result_dir) # set up FreeSurfer - os.environ['FREESURFER_HOME'] = path_freesurfer - os.system(os.path.join(path_freesurfer, 'SetUpFreeSurfer.sh')) - mri_convert = os.path.join(path_freesurfer, 'bin/mri_convert') + os.environ["FREESURFER_HOME"] = path_freesurfer + os.system(os.path.join(path_freesurfer, "SetUpFreeSurfer.sh")) + mri_convert = os.path.join(path_freesurfer, "bin/mri_convert") # list images and labels path_images = utils.list_images_in_folder(image_dir) path_ref_images = utils.list_images_in_folder(resample_like_dir) - assert len(path_images) == len(path_ref_images), \ - 'the folders containing the images and their references are not the same size' + assert len(path_images) == len( + path_ref_images + ), "the folders containing the images and their references are not the same size" # loop over images - loop_info = utils.LoopInfo(len(path_images), 10, 'upsampling', True) + loop_info = utils.LoopInfo(len(path_images), 10, "upsampling", True) for idx, (path_image, path_ref) in enumerate(zip(path_images, path_ref_images)): loop_info.update(idx) # upsample image - _, _, n_dims, _, _, image_res = utils.get_volume_info(path_image, return_volume=False) - path_im_upsampled = os.path.join(resample_image_result_dir, os.path.basename(path_image)) + _, _, n_dims, _, _, image_res = utils.get_volume_info( + path_image, return_volume=False + ) + path_im_upsampled = os.path.join( + resample_image_result_dir, os.path.basename(path_image) + ) if (not os.path.isfile(path_im_upsampled)) | recompute: - cmd = utils.mkcmd(mri_convert, path_image, path_im_upsampled, '-rl', path_ref, '-odt float') + cmd = utils.mkcmd( + mri_convert, + path_image, + path_im_upsampled, + "-rl", + path_ref, + "-odt float", + ) os.system(cmd) - path_dist_map = os.path.join(resample_image_result_dir, 'dist_map_' + os.path.basename(path_image)) + path_dist_map = os.path.join( + resample_image_result_dir, "dist_map_" + os.path.basename(path_image) + ) if (not os.path.isfile(path_dist_map)) | recompute: im, aff, h = utils.load_volume(path_image, im_only=False) - dist_map = np.meshgrid(*[np.arange(s) for s in im.shape], indexing='ij') - tmp_dir = utils.strip_extension(path_im_upsampled) + '_meshes' + dist_map = np.meshgrid(*[np.arange(s) for s in im.shape], indexing="ij") + tmp_dir = utils.strip_extension(path_im_upsampled) + "_meshes" utils.mkdir(tmp_dir) path_meshes_up = list() - for (i, maps) in enumerate(dist_map): - path_mesh = os.path.join(tmp_dir, '%s_' % i + os.path.basename(path_image)) - path_mesh_up = os.path.join(tmp_dir, 'up_%s_' % i + os.path.basename(path_image)) + for i, maps in enumerate(dist_map): + path_mesh = os.path.join( + tmp_dir, "%s_" % i + os.path.basename(path_image) + ) + path_mesh_up = os.path.join( + tmp_dir, "up_%s_" % i + os.path.basename(path_image) + ) utils.save_volume(maps, aff, h, path_mesh) - cmd = utils.mkcmd(mri_convert, path_mesh, path_mesh_up, '-rl', path_im_upsampled, '-odt float') + cmd = utils.mkcmd( + mri_convert, + path_mesh, + path_mesh_up, + "-rl", + path_im_upsampled, + "-odt float", + ) os.system(cmd) path_meshes_up.append(path_mesh_up) mesh_up_0, aff, h = utils.load_volume(path_meshes_up[0], im_only=False) - mesh_up = np.stack([mesh_up_0] + [utils.load_volume(p) for p in path_meshes_up[1:]], -1) + mesh_up = np.stack( + [mesh_up_0] + [utils.load_volume(p) for p in path_meshes_up[1:]], -1 + ) shutil.rmtree(tmp_dir) floor = np.floor(mesh_up) ceil = np.ceil(mesh_up) f_dist = mesh_up - floor c_dist = ceil - mesh_up - dist = np.minimum(f_dist, c_dist) * utils.add_axis(image_res, axis=[0] * n_dims) - dist = np.sqrt(np.sum(dist ** 2, axis=-1)) + dist = np.minimum(f_dist, c_dist) * utils.add_axis( + image_res, axis=[0] * n_dims + ) + dist = np.sqrt(np.sum(dist**2, axis=-1)) utils.save_volume(dist, aff, h, path_dist_map) -def simulate_upsampled_anisotropic_images(image_dir, - downsample_image_result_dir, - resample_image_result_dir, - data_res, - labels_dir=None, - downsample_labels_result_dir=None, - slice_thickness=None, - build_dist_map=False, - path_freesurfer='/usr/local/freesurfer/', - gpu=True, - recompute=True): +def simulate_upsampled_anisotropic_images( + image_dir, + downsample_image_result_dir, + resample_image_result_dir, + data_res, + labels_dir=None, + downsample_labels_result_dir=None, + slice_thickness=None, + build_dist_map=False, + path_freesurfer="/usr/local/freesurfer/", + gpu=True, + recompute=True, +): """This function takes as input a set of HR images and creates two datasets with it: 1) a set of LR images obtained by downsampling the HR images with nearest neighbour interpolation, 2) a set of HR images obtained by resampling the LR images to native HR with linear interpolation. @@ -1835,39 +2220,58 @@ def simulate_upsampled_anisotropic_images(image_dir, utils.mkdir(resample_image_result_dir) utils.mkdir(downsample_image_result_dir) if labels_dir is not None: - assert downsample_labels_result_dir is not None, \ - 'downsample_labels_result_dir should not be None if labels_dir is specified' + assert ( + downsample_labels_result_dir is not None + ), "downsample_labels_result_dir should not be None if labels_dir is specified" utils.mkdir(downsample_labels_result_dir) # set up FreeSurfer - os.environ['FREESURFER_HOME'] = path_freesurfer - os.system(os.path.join(path_freesurfer, 'SetUpFreeSurfer.sh')) - mri_convert = os.path.join(path_freesurfer, 'bin/mri_convert') + os.environ["FREESURFER_HOME"] = path_freesurfer + os.system(os.path.join(path_freesurfer, "SetUpFreeSurfer.sh")) + mri_convert = os.path.join(path_freesurfer, "bin/mri_convert") # list images and labels path_images = utils.list_images_in_folder(image_dir) - path_labels = [None] * len(path_images) if labels_dir is None else utils.list_images_in_folder(labels_dir) + path_labels = ( + [None] * len(path_images) + if labels_dir is None + else utils.list_images_in_folder(labels_dir) + ) # initialisation - _, _, n_dims, _, _, image_res = utils.get_volume_info(path_images[0], return_volume=False, aff_ref=np.eye(4)) - data_res = np.squeeze(utils.reformat_to_n_channels_array(data_res, n_dims, n_channels=1)) + _, _, n_dims, _, _, image_res = utils.get_volume_info( + path_images[0], return_volume=False, aff_ref=np.eye(4) + ) + data_res = np.squeeze( + utils.reformat_to_n_channels_array(data_res, n_dims, n_channels=1) + ) slice_thickness = utils.reformat_to_list(slice_thickness, length=n_dims) # loop over images previous_model_input_shape = None model = None - loop_info = utils.LoopInfo(len(path_images), 10, 'processing', True) + loop_info = utils.LoopInfo(len(path_images), 10, "processing", True) for idx, (path_image, path_labels) in enumerate(zip(path_images, path_labels)): loop_info.update(idx) # downsample image - path_im_downsampled = os.path.join(downsample_image_result_dir, os.path.basename(path_image)) + path_im_downsampled = os.path.join( + downsample_image_result_dir, os.path.basename(path_image) + ) if (not os.path.isfile(path_im_downsampled)) | recompute: - im, _, aff, n_dims, _, h, image_res = utils.get_volume_info(path_image, return_volume=True) - im, aff_aligned = align_volume_to_ref(im, aff, aff_ref=np.eye(4), return_aff=True, n_dims=n_dims) + im, _, aff, n_dims, _, h, image_res = utils.get_volume_info( + path_image, return_volume=True + ) + im, aff_aligned = align_volume_to_ref( + im, aff, aff_ref=np.eye(4), return_aff=True, n_dims=n_dims + ) im_shape = list(im.shape[:n_dims]) - sigma = blurring_sigma_for_downsampling(image_res, data_res, thickness=slice_thickness) - sigma = [0 if data_res[i] == image_res[i] else sigma[i] for i in range(n_dims)] + sigma = blurring_sigma_for_downsampling( + image_res, data_res, thickness=slice_thickness + ) + sigma = [ + 0 if data_res[i] == image_res[i] else sigma[i] for i in range(n_dims) + ] # blur image if gpu: @@ -1882,57 +2286,100 @@ def simulate_upsampled_anisotropic_images(image_dir, utils.save_volume(im, aff_aligned, h, path_im_downsampled) # downsample blurred image - voxsize = ' '.join([str(r) for r in data_res]) - cmd = utils.mkcmd(mri_convert, path_im_downsampled, path_im_downsampled, '--voxsize', voxsize, - '-odt float -rt nearest') + voxsize = " ".join([str(r) for r in data_res]) + cmd = utils.mkcmd( + mri_convert, + path_im_downsampled, + path_im_downsampled, + "--voxsize", + voxsize, + "-odt float -rt nearest", + ) os.system(cmd) # downsample labels if necessary if path_labels is not None: - path_lab_downsampled = os.path.join(downsample_labels_result_dir, os.path.basename(path_labels)) + path_lab_downsampled = os.path.join( + downsample_labels_result_dir, os.path.basename(path_labels) + ) if (not os.path.isfile(path_lab_downsampled)) | recompute: - cmd = utils.mkcmd(mri_convert, path_labels, path_lab_downsampled, '-rl', path_im_downsampled, - '-odt float -rt nearest') + cmd = utils.mkcmd( + mri_convert, + path_labels, + path_lab_downsampled, + "-rl", + path_im_downsampled, + "-odt float -rt nearest", + ) os.system(cmd) # upsample image - path_im_upsampled = os.path.join(resample_image_result_dir, os.path.basename(path_image)) + path_im_upsampled = os.path.join( + resample_image_result_dir, os.path.basename(path_image) + ) if (not os.path.isfile(path_im_upsampled)) | recompute: - cmd = utils.mkcmd(mri_convert, path_im_downsampled, path_im_upsampled, '-rl', path_image, '-odt float') + cmd = utils.mkcmd( + mri_convert, + path_im_downsampled, + path_im_upsampled, + "-rl", + path_image, + "-odt float", + ) os.system(cmd) if build_dist_map: - path_dist_map = os.path.join(resample_image_result_dir, 'dist_map_' + os.path.basename(path_image)) + path_dist_map = os.path.join( + resample_image_result_dir, "dist_map_" + os.path.basename(path_image) + ) if (not os.path.isfile(path_dist_map)) | recompute: im, aff, h = utils.load_volume(path_im_downsampled, im_only=False) - dist_map = np.meshgrid(*[np.arange(s) for s in im.shape], indexing='ij') - tmp_dir = utils.strip_extension(path_im_downsampled) + '_meshes' + dist_map = np.meshgrid(*[np.arange(s) for s in im.shape], indexing="ij") + tmp_dir = utils.strip_extension(path_im_downsampled) + "_meshes" utils.mkdir(tmp_dir) path_meshes_up = list() - for (i, d_map) in enumerate(dist_map): - path_mesh = os.path.join(tmp_dir, '%s_' % i + os.path.basename(path_image)) - path_mesh_up = os.path.join(tmp_dir, 'up_%s_' % i + os.path.basename(path_image)) + for i, d_map in enumerate(dist_map): + path_mesh = os.path.join( + tmp_dir, "%s_" % i + os.path.basename(path_image) + ) + path_mesh_up = os.path.join( + tmp_dir, "up_%s_" % i + os.path.basename(path_image) + ) utils.save_volume(d_map, aff, h, path_mesh) - cmd = utils.mkcmd(mri_convert, path_mesh, path_mesh_up, '-rl', path_image, '-odt float') + cmd = utils.mkcmd( + mri_convert, + path_mesh, + path_mesh_up, + "-rl", + path_image, + "-odt float", + ) os.system(cmd) path_meshes_up.append(path_mesh_up) mesh_up_0, aff, h = utils.load_volume(path_meshes_up[0], im_only=False) - mesh_up = np.stack([mesh_up_0] + [utils.load_volume(p) for p in path_meshes_up[1:]], -1) + mesh_up = np.stack( + [mesh_up_0] + [utils.load_volume(p) for p in path_meshes_up[1:]], -1 + ) shutil.rmtree(tmp_dir) floor = np.floor(mesh_up) ceil = np.ceil(mesh_up) f_dist = mesh_up - floor c_dist = ceil - mesh_up - dist = np.minimum(f_dist, c_dist) * utils.add_axis(data_res, axis=[0] * n_dims) - dist = np.sqrt(np.sum(dist ** 2, axis=-1)) + dist = np.minimum(f_dist, c_dist) * utils.add_axis( + data_res, axis=[0] * n_dims + ) + dist = np.sqrt(np.sum(dist**2, axis=-1)) utils.save_volume(dist, aff, h, path_dist_map) -def check_images_in_dir(image_dir, check_values=False, keep_unique=True, max_channels=10, verbose=True): +def check_images_in_dir( + image_dir, check_values=False, keep_unique=True, max_channels=10, verbose=True +): """Check if all volumes within the same folder share the same characteristics: shape, affine matrix, resolution. Also have option to check if all volumes have the same intensity values (useful for label maps). - :return four lists, each containing the different values detected for a specific parameter among those to check.""" + :return four lists, each containing the different values detected for a specific parameter among those to check. + """ # define information to check list_shape = list() @@ -1946,13 +2393,17 @@ def check_images_in_dir(image_dir, check_values=False, keep_unique=True, max_cha # loop through files path_images = utils.list_images_in_folder(image_dir) - loop_info = utils.LoopInfo(len(path_images), 10, 'checking', verbose) if verbose else None + loop_info = ( + utils.LoopInfo(len(path_images), 10, "checking", verbose) if verbose else None + ) for idx, path_image in enumerate(path_images): if loop_info is not None: loop_info.update(idx) # get info - im, shape, aff, n_dims, _, h, res = utils.get_volume_info(path_image, True, np.eye(4), max_channels) + im, shape, aff, n_dims, _, h, res = utils.get_volume_info( + path_image, True, np.eye(4), max_channels + ) axes = get_ras_axes(aff, n_dims=n_dims).tolist() aff[:, np.arange(n_dims)] = aff[:, axes] aff = (np.int32(np.round(np.array(aff[:3, :3]), 2) * 100) / 100).tolist() @@ -1977,8 +2428,17 @@ def check_images_in_dir(image_dir, check_values=False, keep_unique=True, max_cha # ----------------------------------------------- edit label maps in dir ----------------------------------------------- -def correct_labels_in_dir(labels_dir, results_dir, incorrect_labels, correct_labels=None, - use_nearest_label=False, remove_zero=False, smooth=False, recompute=True): + +def correct_labels_in_dir( + labels_dir, + results_dir, + incorrect_labels, + correct_labels=None, + use_nearest_label=False, + remove_zero=False, + smooth=False, + recompute=True, +): """This function corrects label values for all label maps in a folder with either - a list a given values, - or with the nearest label value. @@ -2002,19 +2462,33 @@ def correct_labels_in_dir(labels_dir, results_dir, incorrect_labels, correct_lab # prepare data files path_labels = utils.list_images_in_folder(labels_dir) - loop_info = utils.LoopInfo(len(path_labels), 10, 'correcting', True) + loop_info = utils.LoopInfo(len(path_labels), 10, "correcting", True) for idx, path_label in enumerate(path_labels): loop_info.update(idx) # correct labels path_result = os.path.join(results_dir, os.path.basename(path_label)) if (not os.path.isfile(path_result)) | recompute: - im, aff, h = utils.load_volume(path_label, im_only=False, dtype='int32') - im = correct_label_map(im, incorrect_labels, correct_labels, use_nearest_label, remove_zero, smooth) + im, aff, h = utils.load_volume(path_label, im_only=False, dtype="int32") + im = correct_label_map( + im, + incorrect_labels, + correct_labels, + use_nearest_label, + remove_zero, + smooth, + ) utils.save_volume(im, aff, h, path_result) -def mask_labels_in_dir(labels_dir, result_dir, values_to_keep, masking_value=0, mask_result_dir=None, recompute=True): +def mask_labels_in_dir( + labels_dir, + result_dir, + values_to_keep, + masking_value=0, + mask_result_dir=None, + recompute=True, +): """This function masks all label maps in a folder by keeping a set of given label values. :param labels_dir: path of directory with input label maps :param result_dir: path of directory where corrected label maps will be writen @@ -2034,30 +2508,42 @@ def mask_labels_in_dir(labels_dir, result_dir, values_to_keep, masking_value=0, # loop over labels path_labels = utils.list_images_in_folder(labels_dir) - loop_info = utils.LoopInfo(len(path_labels), 10, 'masking', True) + loop_info = utils.LoopInfo(len(path_labels), 10, "masking", True) for idx, path_label in enumerate(path_labels): loop_info.update(idx) # mask labels path_result = os.path.join(result_dir, os.path.basename(path_label)) if mask_result_dir is not None: - path_result_mask = os.path.join(mask_result_dir, os.path.basename(path_label)) + path_result_mask = os.path.join( + mask_result_dir, os.path.basename(path_label) + ) else: - path_result_mask = '' - if (not os.path.isfile(path_result)) | \ - (mask_result_dir is not None) & (not os.path.isfile(path_result_mask)) | \ - recompute: + path_result_mask = "" + if ( + (not os.path.isfile(path_result)) + | (mask_result_dir is not None) & (not os.path.isfile(path_result_mask)) + | recompute + ): lab, aff, h = utils.load_volume(path_label, im_only=False) if mask_result_dir is not None: - labels, mask = mask_label_map(lab, values_to_keep, masking_value, return_mask=True) - path_result_mask = os.path.join(mask_result_dir, os.path.basename(path_label)) + labels, mask = mask_label_map( + lab, values_to_keep, masking_value, return_mask=True + ) + path_result_mask = os.path.join( + mask_result_dir, os.path.basename(path_label) + ) utils.save_volume(mask, aff, h, path_result_mask) else: - labels = mask_label_map(lab, values_to_keep, masking_value, return_mask=False) + labels = mask_label_map( + lab, values_to_keep, masking_value, return_mask=False + ) utils.save_volume(labels, aff, h, path_result) -def smooth_labels_in_dir(labels_dir, result_dir, gpu=False, labels_list=None, connectivity=1, recompute=True): +def smooth_labels_in_dir( + labels_dir, result_dir, gpu=False, labels_list=None, connectivity=1, recompute=True +): """Smooth all label maps in a folder by replacing each voxel by the value of its most numerous neighbours. :param labels_dir: path of directory with input label maps :param result_dir: path of directory where smoothed label maps will be writen @@ -2083,28 +2569,40 @@ def smooth_labels_in_dir(labels_dir, result_dir, gpu=False, labels_list=None, co smoothing_model = None # loop over label maps - loop_info = utils.LoopInfo(len(path_labels), 10, 'smoothing', True) + loop_info = utils.LoopInfo(len(path_labels), 10, "smoothing", True) for idx, path_label in enumerate(path_labels): loop_info.update(idx) # smooth label map path_result = os.path.join(result_dir, os.path.basename(path_label)) if (not os.path.isfile(path_result)) | recompute: - labels, label_shape, aff, n_dims, _, h, _ = utils.get_volume_info(path_label, return_volume=True) + labels, label_shape, aff, n_dims, _, h, _ = utils.get_volume_info( + path_label, return_volume=True + ) if label_shape != previous_model_input_shape: previous_model_input_shape = label_shape - smoothing_model = smoothing_gpu_model(label_shape, labels_list, connectivity) - unique_labels = np.unique(labels).astype('int32') + smoothing_model = smoothing_gpu_model( + label_shape, labels_list, connectivity + ) + unique_labels = np.unique(labels).astype("int32") if labels_list is None: smoothed_labels = smoothing_model.predict(utils.add_axis(labels)) else: - labels_to_keep = [lab for lab in unique_labels if lab not in labels_list] - new_labels, mask_new_labels = mask_label_map(labels, labels_to_keep, return_mask=True) - smoothed_labels = np.squeeze(smoothing_model.predict(utils.add_axis(labels))) - smoothed_labels = np.where(mask_new_labels, new_labels, smoothed_labels) + labels_to_keep = [ + lab for lab in unique_labels if lab not in labels_list + ] + new_labels, mask_new_labels = mask_label_map( + labels, labels_to_keep, return_mask=True + ) + smoothed_labels = np.squeeze( + smoothing_model.predict(utils.add_axis(labels)) + ) + smoothed_labels = np.where( + mask_new_labels, new_labels, smoothed_labels + ) mask_new_zeros = (labels > 0) & (smoothed_labels == 0) smoothed_labels[mask_new_zeros] = labels[mask_new_zeros] - utils.save_volume(smoothed_labels, aff, h, path_result, dtype='int32') + utils.save_volume(smoothed_labels, aff, h, path_result, dtype="int32") else: # build kernel @@ -2112,7 +2610,7 @@ def smooth_labels_in_dir(labels_dir, result_dir, gpu=False, labels_list=None, co kernel = utils.build_binary_structure(connectivity, n_dims, shape=n_dims) # loop over label maps - loop_info = utils.LoopInfo(len(path_labels), 10, 'smoothing', True) + loop_info = utils.LoopInfo(len(path_labels), 10, "smoothing", True) for idx, path in enumerate(path_labels): loop_info.update(idx) @@ -2121,7 +2619,7 @@ def smooth_labels_in_dir(labels_dir, result_dir, gpu=False, labels_list=None, co if (not os.path.isfile(path_result)) | recompute: volume, aff, h = utils.load_volume(path, im_only=False) new_volume = smooth_label_map(volume, kernel, labels_list) - utils.save_volume(new_volume, aff, h, path_result, dtype='int32') + utils.save_volume(new_volume, aff, h, path_result, dtype="int32") def smoothing_gpu_model(label_shape, label_list, connectivity=1): @@ -2135,18 +2633,26 @@ def smoothing_gpu_model(label_shape, label_list, connectivity=1): # convert labels so values are in [0, ..., N-1] and use one hot encoding n_labels = label_list.shape[0] - labels_in = KL.Input(shape=label_shape, name='lab_input', dtype='int32') + labels_in = KL.Input(shape=label_shape, name="lab_input", dtype="int32") labels = ConvertLabels(label_list)(labels_in) - labels = KL.Lambda(lambda x: tf.one_hot(tf.cast(x, dtype='int32'), depth=n_labels, axis=-1))(labels) + labels = KL.Lambda( + lambda x: tf.one_hot(tf.cast(x, dtype="int32"), depth=n_labels, axis=-1) + )(labels) # count neighbouring voxels n_dims, _ = utils.get_dims(label_shape) - k = utils.add_axis(utils.build_binary_structure(connectivity, n_dims, shape=n_dims), axis=[-1, -1]) - kernel = KL.Lambda(lambda x: tf.convert_to_tensor(k, dtype='float32'))([]) + k = utils.add_axis( + utils.build_binary_structure(connectivity, n_dims, shape=n_dims), axis=[-1, -1] + ) + kernel = KL.Lambda(lambda x: tf.convert_to_tensor(k, dtype="float32"))([]) split = KL.Lambda(lambda x: tf.split(x, [1] * n_labels, axis=-1))(labels) - labels = KL.Lambda(lambda x: tf.nn.convolution(x[0], x[1], padding='SAME'))([split[0], kernel]) + labels = KL.Lambda(lambda x: tf.nn.convolution(x[0], x[1], padding="SAME"))( + [split[0], kernel] + ) for i in range(1, n_labels): - tmp = KL.Lambda(lambda x: tf.nn.convolution(x[0], x[1], padding='SAME'))([split[i], kernel]) + tmp = KL.Lambda(lambda x: tf.nn.convolution(x[0], x[1], padding="SAME"))( + [split[i], kernel] + ) labels = KL.Lambda(lambda x: tf.concat([x[0], x[1]], -1))([labels, tmp]) # take the argmax and convert labels to original values @@ -2155,7 +2661,14 @@ def smoothing_gpu_model(label_shape, label_list, connectivity=1): return Model(inputs=labels_in, outputs=labels) -def erode_labels_in_dir(labels_dir, result_dir, labels_to_erode, erosion_factors=1., gpu=False, recompute=True): +def erode_labels_in_dir( + labels_dir, + result_dir, + labels_to_erode, + erosion_factors=1.0, + gpu=False, + recompute=True, +): """Erode a given set of label values for all label maps in a folder. :param labels_dir: path of directory with input label maps :param result_dir: path of directory where cropped label maps will be writen @@ -2173,7 +2686,7 @@ def erode_labels_in_dir(labels_dir, result_dir, labels_to_erode, erosion_factors # loop over label maps model = None path_labels = utils.list_images_in_folder(labels_dir) - loop_info = utils.LoopInfo(len(path_labels), 5, 'eroding', True) + loop_info = utils.LoopInfo(len(path_labels), 5, "eroding", True) for idx, path_label in enumerate(path_labels): loop_info.update(idx) @@ -2181,16 +2694,20 @@ def erode_labels_in_dir(labels_dir, result_dir, labels_to_erode, erosion_factors labels, aff, h = utils.load_volume(path_label, im_only=False) path_result = os.path.join(result_dir, os.path.basename(path_label)) if (not os.path.isfile(path_result)) | recompute: - labels, model = erode_label_map(labels, labels_to_erode, erosion_factors, gpu, model, return_model=True) + labels, model = erode_label_map( + labels, labels_to_erode, erosion_factors, gpu, model, return_model=True + ) utils.save_volume(labels, aff, h, path_result) -def upsample_labels_in_dir(labels_dir, - target_res, - result_dir, - path_label_list=None, - path_freesurfer='/usr/local/freesurfer/', - recompute=True): +def upsample_labels_in_dir( + labels_dir, + target_res, + result_dir, + path_label_list=None, + path_freesurfer="/usr/local/freesurfer/", + recompute=True, +): """This function upsamples all label maps within a folder. Importantly, each label map is converted into probability maps for all label values, and all these maps are upsampled separately. The upsampled label maps are recovered by taking the argmax of the label values probability maps. @@ -2207,25 +2724,29 @@ def upsample_labels_in_dir(labels_dir, utils.mkdir(result_dir) # set up FreeSurfer - os.environ['FREESURFER_HOME'] = path_freesurfer - os.system(os.path.join(path_freesurfer, 'SetUpFreeSurfer.sh')) - mri_convert = os.path.join(path_freesurfer, 'bin/mri_convert') + os.environ["FREESURFER_HOME"] = path_freesurfer + os.system(os.path.join(path_freesurfer, "SetUpFreeSurfer.sh")) + mri_convert = os.path.join(path_freesurfer, "bin/mri_convert") # list label maps path_labels = utils.list_images_in_folder(labels_dir) - labels_shape, aff, n_dims, _, h, _ = utils.get_volume_info(path_labels[0], max_channels=3) + labels_shape, aff, n_dims, _, h, _ = utils.get_volume_info( + path_labels[0], max_channels=3 + ) # build command target_res = utils.reformat_to_list(target_res, length=n_dims) - post_cmd = '-voxsize ' + ' '.join([str(r) for r in target_res]) + ' -odt float' + post_cmd = "-voxsize " + " ".join([str(r) for r in target_res]) + " -odt float" # load label list and corresponding LUT to make sure that labels go from 0 to N-1 - label_list, _ = utils.get_list_labels(path_label_list, labels_dir=path_labels, FS_sort=False) - new_label_list = np.arange(len(label_list), dtype='int32') + label_list, _ = utils.get_list_labels( + path_label_list, labels_dir=path_labels, FS_sort=False + ) + new_label_list = np.arange(len(label_list), dtype="int32") lut = utils.get_mapping_lut(label_list) # loop over label maps - loop_info = utils.LoopInfo(len(path_labels), 5, 'upsampling', True) + loop_info = utils.LoopInfo(len(path_labels), 5, "upsampling", True) for idx, path_label in enumerate(path_labels): loop_info.update(idx) path_result = os.path.join(result_dir, os.path.basename(path_label)) @@ -2233,44 +2754,56 @@ def upsample_labels_in_dir(labels_dir, # load volume labels, aff, h = utils.load_volume(path_label, im_only=False) - labels = lut[labels.astype('int')] + labels = lut[labels.astype("int")] # create individual folders for label map basefilename = utils.strip_extension(os.path.basename(path_label)) indiv_label_dir = os.path.join(result_dir, basefilename) - upsample_indiv_label_dir = os.path.join(result_dir, basefilename + '_upsampled') + upsample_indiv_label_dir = os.path.join( + result_dir, basefilename + "_upsampled" + ) utils.mkdir(indiv_label_dir) utils.mkdir(upsample_indiv_label_dir) # loop over label values for label in new_label_list: - path_mask = os.path.join(indiv_label_dir, str(label) + '.nii.gz') - path_mask_upsampled = os.path.join(upsample_indiv_label_dir, str(label) + '.nii.gz') + path_mask = os.path.join(indiv_label_dir, str(label) + ".nii.gz") + path_mask_upsampled = os.path.join( + upsample_indiv_label_dir, str(label) + ".nii.gz" + ) if not os.path.isfile(path_mask): mask = (labels == label) * 1.0 utils.save_volume(mask, aff, h, path_mask) if not os.path.isfile(path_mask_upsampled): - cmd = utils.mkcmd(mri_convert, path_mask, path_mask_upsampled, post_cmd) + cmd = utils.mkcmd( + mri_convert, path_mask, path_mask_upsampled, post_cmd + ) os.system(cmd) # compute argmax of upsampled probability maps (upload them one at a time) - probmax, aff, h = utils.load_volume(os.path.join(upsample_indiv_label_dir, '0.nii.gz'), im_only=False) - labels = np.zeros(probmax.shape, dtype='int') + probmax, aff, h = utils.load_volume( + os.path.join(upsample_indiv_label_dir, "0.nii.gz"), im_only=False + ) + labels = np.zeros(probmax.shape, dtype="int") for label in new_label_list: - prob = utils.load_volume(os.path.join(upsample_indiv_label_dir, str(label) + '.nii.gz')) + prob = utils.load_volume( + os.path.join(upsample_indiv_label_dir, str(label) + ".nii.gz") + ) idx = prob > probmax labels[idx] = label probmax[idx] = prob[idx] - utils.save_volume(label_list[labels], aff, h, path_result, dtype='int32') - - -def compute_hard_volumes_in_dir(labels_dir, - voxel_volume=None, - path_label_list=None, - skip_background=True, - path_numpy_result=None, - path_csv_result=None, - FS_sort=False): + utils.save_volume(label_list[labels], aff, h, path_result, dtype="int32") + + +def compute_hard_volumes_in_dir( + labels_dir, + voxel_volume=None, + path_label_list=None, + skip_background=True, + path_numpy_result=None, + path_csv_result=None, + FS_sort=False, +): """Compute hard volumes of structures for all label maps in a folder. :param labels_dir: path of directory with input label maps :param voxel_volume: (optional) volume of the voxels. If None, it will be directly inferred from the file header. @@ -2299,10 +2832,10 @@ def compute_hard_volumes_in_dir(labels_dir, # create csv volume file if necessary if path_csv_result is not None: if skip_background: - cvs_header = [['subject'] + [str(lab) for lab in label_list[1:]]] + cvs_header = [["subject"] + [str(lab) for lab in label_list[1:]]] else: - cvs_header = [['subject'] + [str(lab) for lab in label_list]] - with open(path_csv_result, 'w') as csvFile: + cvs_header = [["subject"] + [str(lab) for lab in label_list]] + with open(path_csv_result, "w") as csvFile: writer = csv.writer(csvFile) writer.writerows(cvs_header) csvFile.close() @@ -2313,22 +2846,28 @@ def compute_hard_volumes_in_dir(labels_dir, volumes = np.zeros((label_list.shape[0] - 1, len(path_labels))) else: volumes = np.zeros((label_list.shape[0], len(path_labels))) - loop_info = utils.LoopInfo(len(path_labels), 10, 'processing', True) + loop_info = utils.LoopInfo(len(path_labels), 10, "processing", True) for idx, path_label in enumerate(path_labels): loop_info.update(idx) # load segmentation, and compute unique labels - labels, _, _, _, _, _, subject_res = utils.get_volume_info(path_label, return_volume=True) + labels, _, _, _, _, _, subject_res = utils.get_volume_info( + path_label, return_volume=True + ) if voxel_volume is None: voxel_volume = float(np.prod(subject_res)) - subject_volumes = compute_hard_volumes(labels, voxel_volume, label_list, skip_background) + subject_volumes = compute_hard_volumes( + labels, voxel_volume, label_list, skip_background + ) volumes[:, idx] = subject_volumes # write volumes if path_csv_result is not None: subject_volumes = np.around(volumes[:, idx], 3) - row = [utils.strip_suffix(os.path.basename(path_label))] + [str(vol) for vol in subject_volumes] - with open(path_csv_result, 'a') as csvFile: + row = [utils.strip_suffix(os.path.basename(path_label))] + [ + str(vol) for vol in subject_volumes + ] + with open(path_csv_result, "a") as csvFile: writer = csv.writer(csvFile) writer.writerow(row) csvFile.close() @@ -2340,12 +2879,14 @@ def compute_hard_volumes_in_dir(labels_dir, return volumes -def build_atlas(labels_dir, - label_list, - align_centre_of_mass=False, - margin=15, - shape=None, - path_atlas=None): +def build_atlas( + labels_dir, + label_list, + align_centre_of_mass=False, + margin=15, + shape=None, + path_atlas=None, +): """This function builds a binary atlas (defined by label values > 0) from several label maps. :param labels_dir: path of directory with input label maps :param label_list: list of all labels in the label maps. If there is more than 1 value here, the different channels @@ -2364,28 +2905,36 @@ def build_atlas(labels_dir, utils.mkdir(os.path.dirname(path_atlas)) # read list labels and create lut - label_list = np.array(utils.reformat_to_list(label_list, load_as_numpy=True, dtype='int')) + label_list = np.array( + utils.reformat_to_list(label_list, load_as_numpy=True, dtype="int") + ) lut = utils.get_mapping_lut(label_list) n_labels = len(label_list) # create empty atlas - im_shape, aff, n_dims, _, h, _ = utils.get_volume_info(path_labels[0], aff_ref=np.eye(4)) + im_shape, aff, n_dims, _, h, _ = utils.get_volume_info( + path_labels[0], aff_ref=np.eye(4) + ) if align_centre_of_mass: shape = [margin * 2] * n_dims else: - shape = utils.reformat_to_list(shape, length=n_dims) if shape is not None else im_shape + shape = ( + utils.reformat_to_list(shape, length=n_dims) + if shape is not None + else im_shape + ) shape = shape + [n_labels] if n_labels > 1 else shape atlas = np.zeros(shape) # loop over label maps - loop_info = utils.LoopInfo(n_label_maps, 10, 'processing', True) + loop_info = utils.LoopInfo(n_label_maps, 10, "processing", True) for idx, path_label in enumerate(path_labels): loop_info.update(idx) # load label map and build mask - lab = utils.load_volume(path_label, dtype='int32', aff_ref=np.eye(4)) + lab = utils.load_volume(path_label, dtype="int32", aff_ref=np.eye(4)) lab = correct_label_map(lab, [31, 63, 72], [4, 43, 0]) - lab = lut[lab.astype('int')] + lab = lut[lab.astype("int")] lab = pad_volume(lab, shape[:n_dims]) lab = crop_volume(lab, cropping_shape=shape[:n_dims]) indices = np.where(lab > 0) @@ -2395,10 +2944,18 @@ def build_atlas(labels_dir, # crop label map around centre of mass if align_centre_of_mass: - centre_of_mass = np.array([np.mean(indices[0]), np.mean(indices[1]), np.mean(indices[2])], dtype='int32') + centre_of_mass = np.array( + [np.mean(indices[0]), np.mean(indices[1]), np.mean(indices[2])], + dtype="int32", + ) min_crop = centre_of_mass - margin max_crop = centre_of_mass + margin - atlas += lab[min_crop[0]:max_crop[0], min_crop[1]:max_crop[1], min_crop[2]:max_crop[2], ...] + atlas += lab[ + min_crop[0] : max_crop[0], + min_crop[1] : max_crop[1], + min_crop[2] : max_crop[2], + ..., + ] # otherwise just add the one-hot labels else: atlas += lab @@ -2414,6 +2971,7 @@ def build_atlas(labels_dir, # ---------------------------------------------------- edit dataset ---------------------------------------------------- + def check_images_and_labels(image_dir, labels_dir, verbose=True): """Check if corresponding images and labels have the same affine matrices and shapes. Labels are matched to images by sorting order. @@ -2425,10 +2983,14 @@ def check_images_and_labels(image_dir, labels_dir, verbose=True): # list images and labels path_images = utils.list_images_in_folder(image_dir) path_labels = utils.list_images_in_folder(labels_dir) - assert len(path_images) == len(path_labels), 'different number of files in image_dir and labels_dir' + assert len(path_images) == len( + path_labels + ), "different number of files in image_dir and labels_dir" # loop over images and labels - loop_info = utils.LoopInfo(len(path_images), 10, 'checking', verbose) if verbose else None + loop_info = ( + utils.LoopInfo(len(path_images), 10, "checking", verbose) if verbose else None + ) for idx, (path_image, path_label) in enumerate(zip(path_images, path_labels)): if loop_info is not None: loop_info.update(idx) @@ -2441,20 +3003,22 @@ def check_images_and_labels(image_dir, labels_dir, verbose=True): # check matching affine and shape if aff_lab_list != aff_im_list: - print('aff mismatch :\n' + path_image) + print("aff mismatch :\n" + path_image) print(aff_im_list) print(path_label) print(aff_lab_list) - print('') + print("") if lab.shape != im.shape: - print('shape mismatch :\n' + path_image) + print("shape mismatch :\n" + path_image) print(im.shape) - print('\n' + path_label) + print("\n" + path_label) print(lab.shape) - print('') + print("") -def crop_dataset_to_minimum_size(labels_dir, result_dir, image_dir=None, image_result_dir=None, margin=5): +def crop_dataset_to_minimum_size( + labels_dir, result_dir, image_dir=None, image_result_dir=None, margin=5 +): """Crop all label maps in a directory to the minimum possible common size, with a margin. This is achieved by cropping each label map individually to the minimum size, and by padding all the cropped maps to the same size (taken to be the maximum size of the cropped maps). @@ -2469,7 +3033,9 @@ def crop_dataset_to_minimum_size(labels_dir, result_dir, image_dir=None, image_r # create result dir utils.mkdir(result_dir) if image_dir is not None: - assert image_result_dir is not None, 'image_result_dir should not be None if image_dir is specified' + assert ( + image_result_dir is not None + ), "image_result_dir should not be None if image_dir is specified" utils.mkdir(image_result_dir) # list labels and images @@ -2481,27 +3047,36 @@ def crop_dataset_to_minimum_size(labels_dir, result_dir, image_dir=None, image_r _, _, n_dims, _, _, _ = utils.get_volume_info(path_labels[0]) # loop over label maps for cropping - print('\ncropping labels to individual minimum size') + print("\ncropping labels to individual minimum size") maximum_size = np.zeros(n_dims) - loop_info = utils.LoopInfo(len(path_labels), 10, 'cropping', True) + loop_info = utils.LoopInfo(len(path_labels), 10, "cropping", True) for idx, (path_label, path_image) in enumerate(zip(path_labels, path_images)): loop_info.update(idx) # crop label maps and update maximum size of cropped map label, aff, h = utils.load_volume(path_label, im_only=False) label, cropping, aff = crop_volume_around_region(label, aff=aff) - utils.save_volume(label, aff, h, os.path.join(result_dir, os.path.basename(path_label))) - maximum_size = np.maximum(maximum_size, np.array(label.shape) + margin * 2) # *2 to add margin on each side + utils.save_volume( + label, aff, h, os.path.join(result_dir, os.path.basename(path_label)) + ) + maximum_size = np.maximum( + maximum_size, np.array(label.shape) + margin * 2 + ) # *2 to add margin on each side # crop images if required if path_image is not None: image, aff_im, h_im = utils.load_volume(path_image, im_only=False) image, aff_im = crop_volume_with_idx(image, cropping, aff=aff_im) - utils.save_volume(image, aff_im, h_im, os.path.join(image_result_dir, os.path.basename(path_image))) + utils.save_volume( + image, + aff_im, + h_im, + os.path.join(image_result_dir, os.path.basename(path_image)), + ) # loop over label maps for padding - print('\npadding labels to same size') - loop_info = utils.LoopInfo(len(path_labels), 10, 'padding', True) + print("\npadding labels to same size") + loop_info = utils.LoopInfo(len(path_labels), 10, "padding", True) for idx, (path_label, path_image) in enumerate(zip(path_labels, path_images)): loop_info.update(idx) @@ -2519,29 +3094,47 @@ def crop_dataset_to_minimum_size(labels_dir, result_dir, image_dir=None, image_r utils.save_volume(image, aff, h, path_result) -def crop_dataset_around_region_of_same_size(labels_dir, - result_dir, - image_dir=None, - image_result_dir=None, - margin=0, - recompute=True): +def crop_dataset_around_region_of_same_size( + labels_dir, + result_dir, + image_dir=None, + image_result_dir=None, + margin=0, + recompute=True, +): # create result dir utils.mkdir(result_dir) if image_dir is not None: - assert image_result_dir is not None, 'image_result_dir should not be None if image_dir is specified' + assert ( + image_result_dir is not None + ), "image_result_dir should not be None if image_dir is specified" utils.mkdir(image_result_dir) # list labels and images path_labels = utils.list_images_in_folder(labels_dir) - path_images = utils.list_images_in_folder(image_dir) if image_dir is not None else [None] * len(path_labels) + path_images = ( + utils.list_images_in_folder(image_dir) + if image_dir is not None + else [None] * len(path_labels) + ) _, _, n_dims, _, _, _ = utils.get_volume_info(path_labels[0]) - recompute_labels = any([not os.path.isfile(os.path.join(result_dir, os.path.basename(path))) - for path in path_labels]) + recompute_labels = any( + [ + not os.path.isfile(os.path.join(result_dir, os.path.basename(path))) + for path in path_labels + ] + ) if (image_dir is not None) & (not recompute_labels): - recompute_labels = any([not os.path.isfile(os.path.join(image_result_dir, os.path.basename(path))) - for path in path_images]) + recompute_labels = any( + [ + not os.path.isfile( + os.path.join(image_result_dir, os.path.basename(path)) + ) + for path in path_images + ] + ) # get minimum patch shape so that no labels are left out when doing the cropping later on max_crop_shape = np.zeros(n_dims) @@ -2549,11 +3142,17 @@ def crop_dataset_around_region_of_same_size(labels_dir, for path_label in path_labels: label, aff, _ = utils.load_volume(path_label, im_only=False) label = align_volume_to_ref(label, aff, aff_ref=np.eye(4)) - label = get_largest_connected_component(label > 0, structure=np.ones((3, 3, 3))) + label = get_largest_connected_component( + label > 0, structure=np.ones((3, 3, 3)) + ) _, cropping = crop_volume_around_region(label) - max_crop_shape = np.maximum(cropping[n_dims:] - cropping[:n_dims], max_crop_shape) - max_crop_shape += np.array(utils.reformat_to_list(margin, length=n_dims, dtype='int')) - print('max_crop_shape: ', max_crop_shape) + max_crop_shape = np.maximum( + cropping[n_dims:] - cropping[:n_dims], max_crop_shape + ) + max_crop_shape += np.array( + utils.reformat_to_list(margin, length=n_dims, dtype="int") + ) + print("max_crop_shape: ", max_crop_shape) # crop shapes (possibly with padding if images are smaller than crop shape) for path_label, path_image in zip(path_labels, path_images): @@ -2561,10 +3160,18 @@ def crop_dataset_around_region_of_same_size(labels_dir, path_label_result = os.path.join(result_dir, os.path.basename(path_label)) path_image_result = os.path.join(image_result_dir, os.path.basename(path_image)) - if (not os.path.isfile(path_image_result)) | (not os.path.isfile(path_label_result)) | recompute: + if ( + (not os.path.isfile(path_image_result)) + | (not os.path.isfile(path_label_result)) + | recompute + ): # load labels - label, aff, h_la = utils.load_volume(path_label, im_only=False, dtype='int32') - label, aff_new = align_volume_to_ref(label, aff, aff_ref=np.eye(4), return_aff=True) + label, aff, h_la = utils.load_volume( + path_label, im_only=False, dtype="int32" + ) + label, aff_new = align_volume_to_ref( + label, aff, aff_ref=np.eye(4), return_aff=True + ) vol_shape = np.array(label.shape[:n_dims]) if path_image is not None: image, _, h_im = utils.load_volume(path_image, im_only=False) @@ -2573,27 +3180,39 @@ def crop_dataset_around_region_of_same_size(labels_dir, image = h_im = None # mask labels - mask = get_largest_connected_component(label > 0, structure=np.ones((3, 3, 3))) + mask = get_largest_connected_component( + label > 0, structure=np.ones((3, 3, 3)) + ) label[np.logical_not(mask)] = 0 # find cropping indices indices = np.nonzero(mask) min_idx = np.maximum(np.array([np.min(idx) for idx in indices]) - margin, 0) - max_idx = np.minimum(np.array([np.max(idx) for idx in indices]) + 1 + margin, vol_shape) + max_idx = np.minimum( + np.array([np.max(idx) for idx in indices]) + 1 + margin, vol_shape + ) # expand/retract (depending on the desired shape) the cropping region around the centre intermediate_vol_shape = max_idx - min_idx - min_idx = min_idx - np.int32(np.ceil((max_crop_shape - intermediate_vol_shape) / 2)) - max_idx = max_idx + np.int32(np.floor((max_crop_shape - intermediate_vol_shape) / 2)) + min_idx = min_idx - np.int32( + np.ceil((max_crop_shape - intermediate_vol_shape) / 2) + ) + max_idx = max_idx + np.int32( + np.floor((max_crop_shape - intermediate_vol_shape) / 2) + ) # check if we need to pad the output to the desired shape min_padding = np.abs(np.minimum(min_idx, 0)) max_padding = np.maximum(max_idx - vol_shape, 0) if np.any(min_padding > 0) | np.any(max_padding > 0): - pad_margins = tuple([(min_padding[i], max_padding[i]) for i in range(n_dims)]) + pad_margins = tuple( + [(min_padding[i], max_padding[i]) for i in range(n_dims)] + ) else: pad_margins = None - cropping = np.concatenate([np.maximum(min_idx, 0), np.minimum(max_idx, vol_shape)]) + cropping = np.concatenate( + [np.maximum(min_idx, 0), np.minimum(max_idx, vol_shape)] + ) # crop volume label = crop_volume_with_idx(label, cropping, n_dims=n_dims) @@ -2602,11 +3221,17 @@ def crop_dataset_around_region_of_same_size(labels_dir, # pad volume if necessary if pad_margins is not None: - label = np.pad(label, pad_margins, mode='constant', constant_values=0) + label = np.pad(label, pad_margins, mode="constant", constant_values=0) if path_image is not None: _, n_channels = utils.get_dims(image.shape) - pad_margins = tuple(list(pad_margins) + [(0, 0)]) if n_channels > 1 else pad_margins - image = np.pad(image, pad_margins, mode='constant', constant_values=0) + pad_margins = ( + tuple(list(pad_margins) + [(0, 0)]) + if n_channels > 1 + else pad_margins + ) + image = np.pad( + image, pad_margins, mode="constant", constant_values=0 + ) # update aff if n_dims == 2: @@ -2614,15 +3239,24 @@ def crop_dataset_around_region_of_same_size(labels_dir, aff_new[0:3, -1] = aff_new[0:3, -1] + aff_new[:3, :3] @ min_idx # write labels - label, aff_final = align_volume_to_ref(label, aff_new, aff_ref=aff, return_aff=True) - utils.save_volume(label, aff_final, h_la, path_label_result, dtype='int32') + label, aff_final = align_volume_to_ref( + label, aff_new, aff_ref=aff, return_aff=True + ) + utils.save_volume(label, aff_final, h_la, path_label_result, dtype="int32") if path_image is not None: image = align_volume_to_ref(image, aff_new, aff_ref=aff) utils.save_volume(image, aff_final, h_im, path_image_result) -def crop_dataset_around_region(image_dir, labels_dir, image_result_dir, labels_result_dir, margin=0, - cropping_shape_div_by=None, recompute=True): +def crop_dataset_around_region( + image_dir, + labels_dir, + image_result_dir, + labels_result_dir, + margin=0, + cropping_shape_div_by=None, + recompute=True, +): # create result dir utils.mkdir(image_result_dir) @@ -2634,42 +3268,65 @@ def crop_dataset_around_region(image_dir, labels_dir, image_result_dir, labels_r _, _, n_dims, n_channels, _, _ = utils.get_volume_info(path_labels[0]) # loop over images and labels - loop_info = utils.LoopInfo(len(path_images), 10, 'cropping', True) + loop_info = utils.LoopInfo(len(path_images), 10, "cropping", True) for idx, (path_image, path_label) in enumerate(zip(path_images, path_labels)): loop_info.update(idx) - path_label_result = os.path.join(labels_result_dir, os.path.basename(path_label)) + path_label_result = os.path.join( + labels_result_dir, os.path.basename(path_label) + ) path_image_result = os.path.join(image_result_dir, os.path.basename(path_image)) - if (not os.path.isfile(path_label_result)) | (not os.path.isfile(path_image_result)) | recompute: + if ( + (not os.path.isfile(path_label_result)) + | (not os.path.isfile(path_image_result)) + | recompute + ): image, aff, h_im = utils.load_volume(path_image, im_only=False) label, _, h_lab = utils.load_volume(path_label, im_only=False) - mask = get_largest_connected_component(label > 0, structure=np.ones((3, 3, 3))) + mask = get_largest_connected_component( + label > 0, structure=np.ones((3, 3, 3)) + ) label[np.logical_not(mask)] = 0 vol_shape = np.array(label.shape[:n_dims]) # find cropping indices indices = np.nonzero(mask) min_idx = np.maximum(np.array([np.min(idx) for idx in indices]) - margin, 0) - max_idx = np.minimum(np.array([np.max(idx) for idx in indices]) + 1 + margin, vol_shape) + max_idx = np.minimum( + np.array([np.max(idx) for idx in indices]) + 1 + margin, vol_shape + ) # expand/retract (depending on the desired shape) the cropping region around the centre intermediate_vol_shape = max_idx - min_idx - cropping_shape = np.array([utils.find_closest_number_divisible_by_m(s, cropping_shape_div_by, - answer_type='higher') - for s in intermediate_vol_shape]) - min_idx = min_idx - np.int32(np.ceil((cropping_shape - intermediate_vol_shape) / 2)) - max_idx = max_idx + np.int32(np.floor((cropping_shape - intermediate_vol_shape) / 2)) + cropping_shape = np.array( + [ + utils.find_closest_number_divisible_by_m( + s, cropping_shape_div_by, answer_type="higher" + ) + for s in intermediate_vol_shape + ] + ) + min_idx = min_idx - np.int32( + np.ceil((cropping_shape - intermediate_vol_shape) / 2) + ) + max_idx = max_idx + np.int32( + np.floor((cropping_shape - intermediate_vol_shape) / 2) + ) # check if we need to pad the output to the desired shape min_padding = np.abs(np.minimum(min_idx, 0)) max_padding = np.maximum(max_idx - vol_shape, 0) if np.any(min_padding > 0) | np.any(max_padding > 0): - pad_margins = tuple([(min_padding[i], max_padding[i]) for i in range(n_dims)]) + pad_margins = tuple( + [(min_padding[i], max_padding[i]) for i in range(n_dims)] + ) else: pad_margins = None - cropping = np.concatenate([np.maximum(min_idx, 0), np.minimum(max_idx, vol_shape)]) + cropping = np.concatenate( + [np.maximum(min_idx, 0), np.minimum(max_idx, vol_shape)] + ) # crop volume label = crop_volume_with_idx(label, cropping, n_dims=n_dims) @@ -2677,9 +3334,13 @@ def crop_dataset_around_region(image_dir, labels_dir, image_result_dir, labels_r # pad volume if necessary if pad_margins is not None: - label = np.pad(label, pad_margins, mode='constant', constant_values=0) - pad_margins = tuple(list(pad_margins) + [(0, 0)]) if n_channels > 1 else pad_margins - image = np.pad(image, pad_margins, mode='constant', constant_values=0) + label = np.pad(label, pad_margins, mode="constant", constant_values=0) + pad_margins = ( + tuple(list(pad_margins) + [(0, 0)]) + if n_channels > 1 + else pad_margins + ) + image = np.pad(image, pad_margins, mode="constant", constant_values=0) # update aff if n_dims == 2: @@ -2688,16 +3349,18 @@ def crop_dataset_around_region(image_dir, labels_dir, image_result_dir, labels_r # write results utils.save_volume(image, aff, h_im, path_image_result) - utils.save_volume(label, aff, h_lab, path_label_result, dtype='int32') - - -def subdivide_dataset_to_patches(patch_shape, - image_dir=None, - image_result_dir=None, - labels_dir=None, - labels_result_dir=None, - full_background=True, - remove_after_dividing=False): + utils.save_volume(label, aff, h_lab, path_label_result, dtype="int32") + + +def subdivide_dataset_to_patches( + patch_shape, + image_dir=None, + image_result_dir=None, + labels_dir=None, + labels_result_dir=None, + full_background=True, + remove_after_dividing=False, +): """This function subdivides images and/or label maps into several smaller patches of specified shape. :param patch_shape: shape of patches to create. Can either be an int, a sequence, or a 1d numpy array. :param image_dir: (optional) path of directory with input images @@ -2711,16 +3374,21 @@ def subdivide_dataset_to_patches(patch_shape, """ # create result dir and list images and label maps - assert (image_dir is not None) | (labels_dir is not None), \ - 'at least one of image_dir or labels_dir should not be None.' + assert (image_dir is not None) | ( + labels_dir is not None + ), "at least one of image_dir or labels_dir should not be None." if image_dir is not None: - assert image_result_dir is not None, 'image_result_dir should not be None if image_dir is specified' + assert ( + image_result_dir is not None + ), "image_result_dir should not be None if image_dir is specified" utils.mkdir(image_result_dir) path_images = utils.list_images_in_folder(image_dir) else: path_images = None if labels_dir is not None: - assert labels_result_dir is not None, 'labels_result_dir should not be None if labels_dir is specified' + assert ( + labels_result_dir is not None + ), "labels_result_dir should not be None if labels_dir is specified" utils.mkdir(labels_result_dir) path_labels = utils.list_images_in_folder(labels_dir) else: @@ -2735,17 +3403,21 @@ def subdivide_dataset_to_patches(patch_shape, n_dims, _ = utils.get_dims(patch_shape) # loop over images and labels - loop_info = utils.LoopInfo(len(path_images), 10, 'processing', True) + loop_info = utils.LoopInfo(len(path_images), 10, "processing", True) for idx, (path_image, path_label) in enumerate(zip(path_images, path_labels)): loop_info.update(idx) # load image and labels if path_image is not None: - im, aff_im, h_im = utils.load_volume(path_image, im_only=False, squeeze=False) + im, aff_im, h_im = utils.load_volume( + path_image, im_only=False, squeeze=False + ) else: im = aff_im = h_im = None if path_label is not None: - lab, aff_lab, h_lab = utils.load_volume(path_label, im_only=False, squeeze=True) + lab, aff_lab, h_lab = utils.load_volume( + path_label, im_only=False, squeeze=True + ) else: lab = aff_lab = h_lab = None @@ -2756,21 +3428,26 @@ def subdivide_dataset_to_patches(patch_shape, shape = lab.shape # crop image and label map to size divisible by patch_shape - new_size = np.array([utils.find_closest_number_divisible_by_m(shape[i], patch_shape[i]) for i in range(n_dims)]) - crop = np.round((np.array(shape[:n_dims]) - new_size) / 2).astype('int') + new_size = np.array( + [ + utils.find_closest_number_divisible_by_m(shape[i], patch_shape[i]) + for i in range(n_dims) + ] + ) + crop = np.round((np.array(shape[:n_dims]) - new_size) / 2).astype("int") crop = np.concatenate((crop, crop + new_size), axis=0) if (im is not None) & (n_dims == 2): - im = im[crop[0]:crop[2], crop[1]:crop[3], ...] + im = im[crop[0] : crop[2], crop[1] : crop[3], ...] elif (im is not None) & (n_dims == 3): - im = im[crop[0]:crop[3], crop[1]:crop[4], crop[2]:crop[5], ...] + im = im[crop[0] : crop[3], crop[1] : crop[4], crop[2] : crop[5], ...] if (lab is not None) & (n_dims == 2): - lab = lab[crop[0]:crop[2], crop[1]:crop[3], ...] + lab = lab[crop[0] : crop[2], crop[1] : crop[3], ...] elif (lab is not None) & (n_dims == 3): - lab = lab[crop[0]:crop[3], crop[1]:crop[4], crop[2]:crop[5], ...] + lab = lab[crop[0] : crop[3], crop[1] : crop[4], crop[2] : crop[5], ...] # loop over patches n_im = 0 - n_crop = (new_size / patch_shape).astype('int') + n_crop = (new_size / patch_shape).astype("int") for i in range(n_crop[0]): i *= patch_shape[0] for j in range(n_crop[1]): @@ -2780,11 +3457,15 @@ def subdivide_dataset_to_patches(patch_shape, # crop volumes if lab is not None: - temp_la = lab[i:i + patch_shape[0], j:j + patch_shape[1], ...] + temp_la = lab[ + i : i + patch_shape[0], j : j + patch_shape[1], ... + ] else: temp_la = None if im is not None: - temp_im = im[i:i + patch_shape[0], j:j + patch_shape[1], ...] + temp_im = im[ + i : i + patch_shape[0], j : j + patch_shape[1], ... + ] else: temp_im = None @@ -2792,14 +3473,45 @@ def subdivide_dataset_to_patches(patch_shape, if temp_la is not None: if full_background | (not (temp_la == 0).all()): n_im += 1 - utils.save_volume(temp_la, aff_lab, h_lab, os.path.join(labels_result_dir, - os.path.basename(path_label.replace('.nii.gz', '_%d.nii.gz' % n_im)))) + utils.save_volume( + temp_la, + aff_lab, + h_lab, + os.path.join( + labels_result_dir, + os.path.basename( + path_label.replace( + ".nii.gz", "_%d.nii.gz" % n_im + ) + ), + ), + ) if temp_im is not None: - utils.save_volume(temp_im, aff_im, h_im, os.path.join(image_result_dir, - os.path.basename(path_image.replace('.nii.gz', '_%d.nii.gz' % n_im)))) + utils.save_volume( + temp_im, + aff_im, + h_im, + os.path.join( + image_result_dir, + os.path.basename( + path_image.replace( + ".nii.gz", "_%d.nii.gz" % n_im + ) + ), + ), + ) else: - utils.save_volume(temp_im, aff_im, h_im, os.path.join(image_result_dir, - os.path.basename(path_image.replace('.nii.gz', '_%d.nii.gz' % n_im)))) + utils.save_volume( + temp_im, + aff_im, + h_im, + os.path.join( + image_result_dir, + os.path.basename( + path_image.replace(".nii.gz", "_%d.nii.gz" % n_im) + ), + ), + ) elif n_dims == 3: for k in range(n_crop[2]): @@ -2807,11 +3519,21 @@ def subdivide_dataset_to_patches(patch_shape, # crop volumes if lab is not None: - temp_la = lab[i:i + patch_shape[0], j:j + patch_shape[1], k:k + patch_shape[2], ...] + temp_la = lab[ + i : i + patch_shape[0], + j : j + patch_shape[1], + k : k + patch_shape[2], + ..., + ] else: temp_la = None if im is not None: - temp_im = im[i:i + patch_shape[0], j:j + patch_shape[1], k:k + patch_shape[2], ...] + temp_im = im[ + i : i + patch_shape[0], + j : j + patch_shape[1], + k : k + patch_shape[2], + ..., + ] else: temp_im = None @@ -2819,15 +3541,47 @@ def subdivide_dataset_to_patches(patch_shape, if temp_la is not None: if full_background | (not (temp_la == 0).all()): n_im += 1 - utils.save_volume(temp_la, aff_lab, h_lab, os.path.join(labels_result_dir, - os.path.basename(path_label.replace('.nii.gz', '_%d.nii.gz' % n_im)))) + utils.save_volume( + temp_la, + aff_lab, + h_lab, + os.path.join( + labels_result_dir, + os.path.basename( + path_label.replace( + ".nii.gz", "_%d.nii.gz" % n_im + ) + ), + ), + ) if temp_im is not None: - utils.save_volume(temp_im, aff_im, h_im, os.path.join(image_result_dir, - os.path.basename(path_image.replace('.nii.gz', - '_%d.nii.gz' % n_im)))) + utils.save_volume( + temp_im, + aff_im, + h_im, + os.path.join( + image_result_dir, + os.path.basename( + path_image.replace( + ".nii.gz", "_%d.nii.gz" % n_im + ) + ), + ), + ) else: - utils.save_volume(temp_im, aff_im, h_im, os.path.join(image_result_dir, - os.path.basename(path_image.replace('.nii.gz', '_%d.nii.gz' % n_im)))) + utils.save_volume( + temp_im, + aff_im, + h_im, + os.path.join( + image_result_dir, + os.path.basename( + path_image.replace( + ".nii.gz", "_%d.nii.gz" % n_im + ) + ), + ), + ) if remove_after_dividing: if path_image is not None: diff --git a/nobrainer/ext/lab2im/image_generator.py b/nobrainer/ext/lab2im/image_generator.py index d48886a7..1c015d58 100644 --- a/nobrainer/ext/lab2im/image_generator.py +++ b/nobrainer/ext/lab2im/image_generator.py @@ -13,34 +13,34 @@ License. """ +# project imports +from ext.lab2im import edit_volumes, utils +from ext.lab2im.lab2im_model import lab2im_model # python imports import numpy as np import numpy.random as npr -# project imports -from ext.lab2im import utils -from ext.lab2im import edit_volumes -from ext.lab2im.lab2im_model import lab2im_model - class ImageGenerator: - def __init__(self, - labels_dir, - generation_labels=None, - output_labels=None, - batchsize=1, - n_channels=1, - target_res=None, - output_shape=None, - output_div_by_n=None, - generation_classes=None, - prior_distributions='uniform', - prior_means=None, - prior_stds=None, - use_specific_stats_for_channel=False, - blur_range=1.15): + def __init__( + self, + labels_dir, + generation_labels=None, + output_labels=None, + batchsize=1, + n_channels=1, + target_res=None, + output_shape=None, + output_div_by_n=None, + generation_classes=None, + prior_distributions="uniform", + prior_means=None, + prior_stds=None, + use_specific_stats_for_channel=False, + blur_range=1.15, + ): """ This class is wrapper around the lab2im_model model. It contains the GPU model that generates images from labels maps, and a python generator that supplies the input data for this model. @@ -115,8 +115,9 @@ def __init__(self, self.labels_paths = utils.list_images_in_folder(labels_dir) # generation parameters - self.labels_shape, self.aff, self.n_dims, _, self.header, self.atlas_res = \ + self.labels_shape, self.aff, self.n_dims, _, self.header, self.atlas_res = ( utils.get_volume_info(self.labels_paths[0], aff_ref=np.eye(4)) + ) self.n_channels = n_channels if generation_labels is not None: self.generation_labels = utils.load_array_if_path(generation_labels) @@ -135,11 +136,13 @@ def __init__(self, self.prior_distributions = prior_distributions if generation_classes is not None: self.generation_classes = utils.load_array_if_path(generation_classes) - assert self.generation_classes.shape == self.generation_labels.shape, \ - 'if provided, generation labels should have the same shape as generation_labels' + assert ( + self.generation_classes.shape == self.generation_labels.shape + ), "if provided, generation labels should have the same shape as generation_labels" unique_classes = np.unique(self.generation_classes) - assert np.array_equal(unique_classes, np.arange(np.max(unique_classes)+1)), \ - 'generation_classes should a linear range between 0 and its maximum value.' + assert np.array_equal( + unique_classes, np.arange(np.max(unique_classes) + 1) + ), "generation_classes should a linear range between 0 and its maximum value." else: self.generation_classes = np.arange(self.generation_labels.shape[0]) self.prior_means = utils.load_array_if_path(prior_means) @@ -153,22 +156,26 @@ def __init__(self, self.labels_to_image_model, self.model_output_shape = self._build_lab2im_model() # build generator for model inputs - self.model_inputs_generator = self._build_model_inputs(len(self.generation_labels)) + self.model_inputs_generator = self._build_model_inputs( + len(self.generation_labels) + ) # build brain generator self.image_generator = self._build_image_generator() def _build_lab2im_model(self): # build_model - lab_to_im_model = lab2im_model(labels_shape=self.labels_shape, - n_channels=self.n_channels, - generation_labels=self.generation_labels, - output_labels=self.output_labels, - atlas_res=self.atlas_res, - target_res=self.target_res, - output_shape=self.output_shape, - output_div_by_n=self.output_div_by_n, - blur_range=self.blur_range) + lab_to_im_model = lab2im_model( + labels_shape=self.labels_shape, + n_channels=self.n_channels, + generation_labels=self.generation_labels, + output_labels=self.output_labels, + atlas_res=self.atlas_res, + target_res=self.target_res, + output_shape=self.output_shape, + output_div_by_n=self.output_div_by_n, + blur_range=self.blur_range, + ) out_shape = lab_to_im_model.output[0].get_shape().as_list()[1:] return lab_to_im_model, out_shape @@ -185,10 +192,16 @@ def generate_image(self): list_images = list() list_labels = list() for i in range(self.batchsize): - list_images.append(edit_volumes.align_volume_to_ref(image[i], np.eye(4), aff_ref=self.aff, - n_dims=self.n_dims)) - list_labels.append(edit_volumes.align_volume_to_ref(labels[i], np.eye(4), aff_ref=self.aff, - n_dims=self.n_dims)) + list_images.append( + edit_volumes.align_volume_to_ref( + image[i], np.eye(4), aff_ref=self.aff, n_dims=self.n_dims + ) + ) + list_labels.append( + edit_volumes.align_volume_to_ref( + labels[i], np.eye(4), aff_ref=self.aff, n_dims=self.n_dims + ) + ) image = np.stack(list_images, axis=0) labels = np.stack(list_labels, axis=0) return np.squeeze(image), np.squeeze(labels) @@ -212,7 +225,9 @@ def _build_model_inputs(self, n_labels): for idx in indices: # load label in identity space, and add them to inputs - y = utils.load_volume(self.labels_paths[idx], dtype='int', aff_ref=np.eye(4)) + y = utils.load_volume( + self.labels_paths[idx], dtype="int", aff_ref=np.eye(4) + ) list_label_maps.append(utils.add_axis(y, axis=[0, -1])) # add means and standard deviations to inputs @@ -222,35 +237,61 @@ def _build_model_inputs(self, n_labels): # retrieve channel specific stats if necessary if isinstance(self.prior_means, np.ndarray): - if (self.prior_means.shape[0] > 2) & self.use_specific_stats_for_channel: + if ( + self.prior_means.shape[0] > 2 + ) & self.use_specific_stats_for_channel: if self.prior_means.shape[0] / 2 != self.n_channels: - raise ValueError("the number of blocks in prior_means does not match n_channels. This " - "message is printed because use_specific_stats_for_channel is True.") - tmp_prior_means = self.prior_means[2 * channel:2 * channel + 2, :] + raise ValueError( + "the number of blocks in prior_means does not match n_channels. This " + "message is printed because use_specific_stats_for_channel is True." + ) + tmp_prior_means = self.prior_means[ + 2 * channel : 2 * channel + 2, : + ] else: tmp_prior_means = self.prior_means else: tmp_prior_means = self.prior_means if isinstance(self.prior_stds, np.ndarray): - if (self.prior_stds.shape[0] > 2) & self.use_specific_stats_for_channel: + if ( + self.prior_stds.shape[0] > 2 + ) & self.use_specific_stats_for_channel: if self.prior_stds.shape[0] / 2 != self.n_channels: - raise ValueError("the number of blocks in prior_stds does not match n_channels. This " - "message is printed because use_specific_stats_for_channel is True.") - tmp_prior_stds = self.prior_stds[2 * channel:2 * channel + 2, :] + raise ValueError( + "the number of blocks in prior_stds does not match n_channels. This " + "message is printed because use_specific_stats_for_channel is True." + ) + tmp_prior_stds = self.prior_stds[ + 2 * channel : 2 * channel + 2, : + ] else: tmp_prior_stds = self.prior_stds else: tmp_prior_stds = self.prior_stds # draw means and std devs from priors - tmp_classes_means = utils.draw_value_from_distribution(tmp_prior_means, n_labels, - self.prior_distributions, 125., 100., - positive_only=True) - tmp_classes_stds = utils.draw_value_from_distribution(tmp_prior_stds, n_labels, - self.prior_distributions, 15., 10., - positive_only=True) - tmp_means = utils.add_axis(tmp_classes_means[self.generation_classes], axis=[0, -1]) - tmp_stds = utils.add_axis(tmp_classes_stds[self.generation_classes], axis=[0, -1]) + tmp_classes_means = utils.draw_value_from_distribution( + tmp_prior_means, + n_labels, + self.prior_distributions, + 125.0, + 100.0, + positive_only=True, + ) + tmp_classes_stds = utils.draw_value_from_distribution( + tmp_prior_stds, + n_labels, + self.prior_distributions, + 15.0, + 10.0, + positive_only=True, + ) + tmp_means = utils.add_axis( + tmp_classes_means[self.generation_classes], axis=[0, -1] + ) + tmp_stds = utils.add_axis( + tmp_classes_stds[self.generation_classes], axis=[0, -1] + ) means = np.concatenate([means, tmp_means], axis=-1) stds = np.concatenate([stds, tmp_stds], axis=-1) list_means.append(means) @@ -258,7 +299,9 @@ def _build_model_inputs(self, n_labels): # build list of inputs of augmentation model list_inputs = [list_label_maps, list_means, list_stds] - if self.batchsize > 1: # concatenate individual input types if batchsize > 1 + if ( + self.batchsize > 1 + ): # concatenate individual input types if batchsize > 1 list_inputs = [np.concatenate(item, 0) for item in list_inputs] else: list_inputs = [item[0] for item in list_inputs] diff --git a/nobrainer/ext/lab2im/lab2im_model.py b/nobrainer/ext/lab2im/lab2im_model.py index 743626cf..b20e5274 100644 --- a/nobrainer/ext/lab2im/lab2im_model.py +++ b/nobrainer/ext/lab2im/lab2im_model.py @@ -13,27 +13,27 @@ License. """ +# project imports +from ext.lab2im import layers, utils +from ext.lab2im.edit_tensors import blurring_sigma_for_downsampling, resample_tensor +import keras.layers as KL +from keras.models import Model # python imports import numpy as np -import keras.layers as KL -from keras.models import Model -# project imports -from ext.lab2im import utils -from ext.lab2im import layers -from ext.lab2im.edit_tensors import resample_tensor, blurring_sigma_for_downsampling - - -def lab2im_model(labels_shape, - n_channels, - generation_labels, - output_labels, - atlas_res, - target_res, - output_shape=None, - output_div_by_n=None, - blur_range=1.15): + +def lab2im_model( + labels_shape, + n_channels, + generation_labels, + output_labels, + atlas_res, + target_res, + output_shape=None, + output_div_by_n=None, + blur_range=1.15, +): """ This function builds a keras/tensorflow model to generate images from provided label maps. The images are generated by sampling a Gaussian Mixture Model (of given parameters), conditioned on the label map. @@ -74,18 +74,30 @@ def lab2im_model(labels_shape, labels_shape = utils.reformat_to_list(labels_shape) n_dims, _ = utils.get_dims(labels_shape) atlas_res = utils.reformat_to_n_channels_array(atlas_res, n_dims=n_dims)[0] - target_res = atlas_res if (target_res is None) else utils.reformat_to_n_channels_array(target_res, n_dims)[0] + target_res = ( + atlas_res + if (target_res is None) + else utils.reformat_to_n_channels_array(target_res, n_dims)[0] + ) # get shapes - crop_shape, output_shape = get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_n) + crop_shape, output_shape = get_shapes( + labels_shape, output_shape, atlas_res, target_res, output_div_by_n + ) # define model inputs - labels_input = KL.Input(shape=labels_shape+[1], name='labels_input', dtype='int32') - means_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='means_input') - stds_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='stds_input') + labels_input = KL.Input( + shape=labels_shape + [1], name="labels_input", dtype="int32" + ) + means_input = KL.Input( + shape=list(generation_labels.shape) + [n_channels], name="means_input" + ) + stds_input = KL.Input( + shape=list(generation_labels.shape) + [n_channels], name="stds_input" + ) # deform labels - labels = layers.RandomSpatialDeformation(inter_method='nearest')(labels_input) + labels = layers.RandomSpatialDeformation(inter_method="nearest")(labels_input) # cropping if crop_shape != labels_shape: @@ -94,15 +106,19 @@ def lab2im_model(labels_shape, # build synthetic image labels._keras_shape = tuple(labels.get_shape().as_list()) - image = layers.SampleConditionalGMM(generation_labels)([labels, means_input, stds_input]) + image = layers.SampleConditionalGMM(generation_labels)( + [labels, means_input, stds_input] + ) # apply bias field image._keras_shape = tuple(image.get_shape().as_list()) - image = layers.BiasFieldCorruption(.3, .025, same_bias_for_all_channels=False)(image) + image = layers.BiasFieldCorruption(0.3, 0.025, same_bias_for_all_channels=False)( + image + ) # intensity augmentation image._keras_shape = tuple(image.get_shape().as_list()) - image = layers.IntensityAugmentation(clip=300, normalise=True, gamma_std=.2)(image) + image = layers.IntensityAugmentation(clip=300, normalise=True, gamma_std=0.2)(image) # blur image sigma = blurring_sigma_for_downsampling(atlas_res, target_res) @@ -111,15 +127,19 @@ def lab2im_model(labels_shape, # resample to target res if crop_shape != output_shape: - image = resample_tensor(image, output_shape, interp_method='linear') - labels = resample_tensor(labels, output_shape, interp_method='nearest') + image = resample_tensor(image, output_shape, interp_method="linear") + labels = resample_tensor(labels, output_shape, interp_method="nearest") # reset unwanted labels to zero - labels = layers.ConvertLabels(generation_labels, dest_values=output_labels, name='labels_out')(labels) + labels = layers.ConvertLabels( + generation_labels, dest_values=output_labels, name="labels_out" + )(labels) # build model (dummy layer enables to keep the labels when plugging this model to other models) - image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels]) - brain_model = Model(inputs=[labels_input, means_input, stds_input], outputs=[image, labels]) + image = KL.Lambda(lambda x: x[0], name="image_out")([image, labels]) + brain_model = Model( + inputs=[labels_input, means_input, stds_input], outputs=[image, labels] + ) return brain_model @@ -136,26 +156,39 @@ def get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_ # output shape specified, need to get cropping shape, and resample shape if necessary if output_shape is not None: - output_shape = utils.reformat_to_list(output_shape, length=n_dims, dtype='int') + output_shape = utils.reformat_to_list(output_shape, length=n_dims, dtype="int") # make sure that output shape is smaller or equal to label shape if resample_factor is not None: - output_shape = [min(int(labels_shape[i] * resample_factor[i]), output_shape[i]) for i in range(n_dims)] + output_shape = [ + min(int(labels_shape[i] * resample_factor[i]), output_shape[i]) + for i in range(n_dims) + ] else: - output_shape = [min(labels_shape[i], output_shape[i]) for i in range(n_dims)] + output_shape = [ + min(labels_shape[i], output_shape[i]) for i in range(n_dims) + ] # make sure output shape is divisible by output_div_by_n if output_div_by_n is not None: - tmp_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n) - for s in output_shape] + tmp_shape = [ + utils.find_closest_number_divisible_by_m(s, output_div_by_n) + for s in output_shape + ] if output_shape != tmp_shape: - print('output shape {0} not divisible by {1}, changed to {2}'.format(output_shape, output_div_by_n, - tmp_shape)) + print( + "output shape {0} not divisible by {1}, changed to {2}".format( + output_shape, output_div_by_n, tmp_shape + ) + ) output_shape = tmp_shape # get cropping and resample shape if resample_factor is not None: - cropping_shape = [int(np.around(output_shape[i]/resample_factor[i], 0)) for i in range(n_dims)] + cropping_shape = [ + int(np.around(output_shape[i] / resample_factor[i], 0)) + for i in range(n_dims) + ] else: cropping_shape = output_shape @@ -163,12 +196,19 @@ def get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_ else: cropping_shape = labels_shape if resample_factor is not None: - output_shape = [int(np.around(cropping_shape[i]*resample_factor[i], 0)) for i in range(n_dims)] + output_shape = [ + int(np.around(cropping_shape[i] * resample_factor[i], 0)) + for i in range(n_dims) + ] else: output_shape = cropping_shape # make sure output shape is divisible by output_div_by_n if output_div_by_n is not None: - output_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n, answer_type='closer') - for s in output_shape] + output_shape = [ + utils.find_closest_number_divisible_by_m( + s, output_div_by_n, answer_type="closer" + ) + for s in output_shape + ] return cropping_shape, output_shape diff --git a/nobrainer/ext/lab2im/layers.py b/nobrainer/ext/lab2im/layers.py index 96cbda30..914b7b69 100644 --- a/nobrainer/ext/lab2im/layers.py +++ b/nobrainer/ext/lab2im/layers.py @@ -34,22 +34,21 @@ License. """ - -# python imports -import keras -import numpy as np -import tensorflow as tf -import keras.backend as K -from keras.layers import Layer - # project imports -from ext.lab2im import utils from ext.lab2im import edit_tensors as l2i_et +from ext.lab2im import utils # third-party imports from ext.neuron import utils as nrn_utils import ext.neuron.layers as nrn_layers +# python imports +import keras +import keras.backend as K +from keras.layers import Layer +import numpy as np +import tensorflow as tf + class RandomSpatialDeformation(Layer): """This layer spatially deforms one or several tensors with a combination of affine and elastic transformations. @@ -85,17 +84,19 @@ class RandomSpatialDeformation(Layer): :param prob_deform: (optional) probability to apply spatial deformation """ - def __init__(self, - scaling_bounds=0.15, - rotation_bounds=10, - shearing_bounds=0.02, - translation_bounds=False, - enable_90_rotations=False, - nonlin_std=4., - nonlin_scale=.0625, - inter_method='linear', - prob_deform=1, - **kwargs): + def __init__( + self, + scaling_bounds=0.15, + rotation_bounds=10, + shearing_bounds=0.02, + translation_bounds=False, + enable_90_rotations=False, + nonlin_std=4.0, + nonlin_scale=0.0625, + inter_method="linear", + prob_deform=1, + **kwargs + ): # shape attributes self.n_inputs = 1 @@ -113,9 +114,13 @@ def __init__(self, self.nonlin_scale = nonlin_scale # boolean attributes - self.apply_affine_trans = (self.scaling_bounds is not False) | (self.rotation_bounds is not False) | \ - (self.shearing_bounds is not False) | (self.translation_bounds is not False) | \ - self.enable_90_rotations + self.apply_affine_trans = ( + (self.scaling_bounds is not False) + | (self.rotation_bounds is not False) + | (self.shearing_bounds is not False) + | (self.translation_bounds is not False) + | self.enable_90_rotations + ) self.apply_elastic_trans = self.nonlin_std > 0 self.prob_deform = prob_deform @@ -148,12 +153,15 @@ def build(self, input_shape): self.n_dims = len(self.inshape) - 1 if self.apply_elastic_trans: - self.small_shape = utils.get_resample_shape(self.inshape[:self.n_dims], - self.nonlin_scale, self.n_dims) + self.small_shape = utils.get_resample_shape( + self.inshape[: self.n_dims], self.nonlin_scale, self.n_dims + ) else: self.small_shape = None - self.inter_method = utils.reformat_to_list(self.inter_method, length=self.n_inputs, dtype='str') + self.inter_method = utils.reformat_to_list( + self.inter_method, length=self.n_inputs, dtype="str" + ) self.built = True super(RandomSpatialDeformation, self).build(input_shape) @@ -164,7 +172,7 @@ def call(self, inputs, **kwargs): if self.n_inputs < 2: inputs = [inputs] types = [v.dtype for v in inputs] - inputs = [tf.cast(v, dtype='float32') for v in inputs] + inputs = [tf.cast(v, dtype="float32") for v in inputs] batchsize = tf.split(tf.shape(inputs[0]), [1, self.n_dims + 1])[0] # initialise list of transforms to operate @@ -172,39 +180,61 @@ def call(self, inputs, **kwargs): # add affine deformation to inputs list if self.apply_affine_trans: - affine_trans = utils.sample_affine_transform(batchsize, - self.n_dims, - self.rotation_bounds, - self.scaling_bounds, - self.shearing_bounds, - self.translation_bounds, - self.enable_90_rotations) + affine_trans = utils.sample_affine_transform( + batchsize, + self.n_dims, + self.rotation_bounds, + self.scaling_bounds, + self.shearing_bounds, + self.translation_bounds, + self.enable_90_rotations, + ) list_trans.append(affine_trans) # prepare non-linear deformation field and add it to inputs list if self.apply_elastic_trans: # sample small field from normal distribution of specified std dev - trans_shape = tf.concat([batchsize, tf.convert_to_tensor(self.small_shape, dtype='int32')], axis=0) + trans_shape = tf.concat( + [batchsize, tf.convert_to_tensor(self.small_shape, dtype="int32")], + axis=0, + ) trans_std = tf.random.uniform((1, 1), maxval=self.nonlin_std) elastic_trans = tf.random.normal(trans_shape, stddev=trans_std) # reshape this field to half size (for smoother SVF), integrate it, and reshape to full image size - resize_shape = [max(int(self.inshape[i] / 2), self.small_shape[i]) for i in range(self.n_dims)] - elastic_trans = nrn_layers.Resize(size=resize_shape, interp_method='linear')(elastic_trans) + resize_shape = [ + max(int(self.inshape[i] / 2), self.small_shape[i]) + for i in range(self.n_dims) + ] + elastic_trans = nrn_layers.Resize( + size=resize_shape, interp_method="linear" + )(elastic_trans) elastic_trans = nrn_layers.VecInt()(elastic_trans) - elastic_trans = nrn_layers.Resize(size=self.inshape[:self.n_dims], interp_method='linear')(elastic_trans) + elastic_trans = nrn_layers.Resize( + size=self.inshape[: self.n_dims], interp_method="linear" + )(elastic_trans) list_trans.append(elastic_trans) # apply deformations and return tensors with correct dtype if self.apply_affine_trans | self.apply_elastic_trans: if self.prob_deform == 1: - inputs = [nrn_layers.SpatialTransformer(m)([v] + list_trans) for (m, v) in - zip(self.inter_method, inputs)] + inputs = [ + nrn_layers.SpatialTransformer(m)([v] + list_trans) + for (m, v) in zip(self.inter_method, inputs) + ] else: - rand_trans = tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_deform)) - inputs = [K.switch(rand_trans, nrn_layers.SpatialTransformer(m)([v] + list_trans), v) - for (m, v) in zip(self.inter_method, inputs)] + rand_trans = tf.squeeze( + K.less(tf.random.uniform([1], 0, 1), self.prob_deform) + ) + inputs = [ + K.switch( + rand_trans, + nrn_layers.SpatialTransformer(m)([v] + list_trans), + v, + ) + for (m, v) in zip(self.inter_method, inputs) + ] if self.n_inputs < 2: return tf.cast(inputs[0], types[0]) else: @@ -244,7 +274,9 @@ def build(self, input_shape): inputshape = [input_shape] else: inputshape = input_shape - self.crop_max_val = np.array(np.array(inputshape[0][1:self.n_dims + 1])) - np.array(self.crop_shape) + self.crop_max_val = np.array( + np.array(inputshape[0][1 : self.n_dims + 1]) + ) - np.array(self.crop_shape) self.list_n_channels = [i[-1] for i in inputshape] self.built = True super(RandomCrop, self).build(input_shape) @@ -258,19 +290,24 @@ def call(self, inputs, **kwargs): # otherwise we concatenate all inputs before cropping, so that they are all cropped at the same location else: types = [v.dtype for v in inputs] - inputs = tf.concat([tf.cast(v, 'float32') for v in inputs], axis=-1) + inputs = tf.concat([tf.cast(v, "float32") for v in inputs], axis=-1) inputs = tf.map_fn(self._single_slice, inputs, dtype=tf.float32) inputs = tf.split(inputs, self.list_n_channels, axis=-1) return [tf.cast(v, t) for (t, v) in zip(types, inputs)] def _single_slice(self, vol): - crop_idx = tf.cast(tf.random.uniform([self.n_dims], 0, np.array(self.crop_max_val), 'float32'), dtype='int32') - crop_idx = tf.concat([crop_idx, tf.zeros([1], dtype='int32')], axis=0) - crop_size = tf.convert_to_tensor(self.crop_shape + [-1], dtype='int32') + crop_idx = tf.cast( + tf.random.uniform([self.n_dims], 0, np.array(self.crop_max_val), "float32"), + dtype="int32", + ) + crop_idx = tf.concat([crop_idx, tf.zeros([1], dtype="int32")], axis=0) + crop_size = tf.convert_to_tensor(self.crop_shape + [-1], dtype="int32") return tf.slice(vol, begin=crop_idx, size=crop_size) def compute_output_shape(self, input_shape): - output_shape = [tuple([None] + self.crop_shape + [v]) for v in self.list_n_channels] + output_shape = [ + tuple([None] + self.crop_shape + [v]) for v in self.list_n_channels + ] return output_shape if self.several_inputs else output_shape[0] @@ -329,7 +366,15 @@ class RandomFlip(Layer): This doesn't concern the image input, as its values are not swapped. """ - def __init__(self, axis=None, swap_labels=False, label_list=None, n_neutral_labels=None, prob=0.5, **kwargs): + def __init__( + self, + axis=None, + swap_labels=False, + label_list=None, + n_neutral_labels=None, + prob=0.5, + **kwargs + ): # shape attributes self.several_inputs = True @@ -368,22 +413,35 @@ def build(self, input_shape): inputshape = input_shape self.n_dims = len(inputshape[0][1:-1]) self.list_n_channels = [i[-1] for i in inputshape] - self.swap_labels = utils.reformat_to_list(self.swap_labels, length=len(inputshape)) - self.flip_axes = np.arange(self.n_dims).tolist() if self.axis is None else self.axis + self.swap_labels = utils.reformat_to_list( + self.swap_labels, length=len(inputshape) + ) + self.flip_axes = ( + np.arange(self.n_dims).tolist() if self.axis is None else self.axis + ) # create label list with swapped labels if any(self.swap_labels): - assert (self.label_list is not None) & (self.n_neutral_labels is not None), \ - 'please provide a label_list, and n_neutral_labels when swapping the values of at least one input' + assert (self.label_list is not None) & ( + self.n_neutral_labels is not None + ), "please provide a label_list, and n_neutral_labels when swapping the values of at least one input" n_labels = len(self.label_list) if self.n_neutral_labels == n_labels: self.swap_labels = [False] * len(self.swap_labels) else: - rl_split = np.split(self.label_list, [self.n_neutral_labels, - self.n_neutral_labels + int((n_labels-self.n_neutral_labels)/2)]) - label_list_swap = np.concatenate((rl_split[0], rl_split[2], rl_split[1])) + rl_split = np.split( + self.label_list, + [ + self.n_neutral_labels, + self.n_neutral_labels + + int((n_labels - self.n_neutral_labels) / 2), + ], + ) + label_list_swap = np.concatenate( + (rl_split[0], rl_split[2], rl_split[1]) + ) swap_lut = utils.get_mapping_lut(self.label_list, label_list_swap) - self.swap_lut = tf.convert_to_tensor(swap_lut, dtype='int32') + self.swap_lut = tf.convert_to_tensor(swap_lut, dtype="int32") self.built = True super(RandomFlip, self).build(input_shape) @@ -396,20 +454,29 @@ def call(self, inputs, **kwargs): # store whether to flip along each specified dimension batchsize = tf.split(tf.shape(inputs[0]), [1, self.n_dims + 1])[0] - size = tf.concat([batchsize, len(self.flip_axes) * tf.ones(1, dtype='int32')], axis=0) + size = tf.concat( + [batchsize, len(self.flip_axes) * tf.ones(1, dtype="int32")], axis=0 + ) rand_flip = K.less(tf.random.uniform(size, 0, 1), self.prob) # swap right/left labels if we apply an odd number of flips - odd = tf.math.floormod(tf.reduce_sum(tf.cast(rand_flip, 'int32'), -1, keepdims=True), 2) != 0 + odd = ( + tf.math.floormod( + tf.reduce_sum(tf.cast(rand_flip, "int32"), -1, keepdims=True), 2 + ) + != 0 + ) swapped_inputs = list() for i in range(len(inputs)): if self.swap_labels[i]: - swapped_inputs.append(tf.map_fn(self._single_swap, [inputs[i], odd], dtype=types[i])) + swapped_inputs.append( + tf.map_fn(self._single_swap, [inputs[i], odd], dtype=types[i]) + ) else: swapped_inputs.append(inputs[i]) # flip inputs and convert them back to their original type - inputs = tf.concat([tf.cast(v, 'float32') for v in swapped_inputs], axis=-1) + inputs = tf.concat([tf.cast(v, "float32") for v in swapped_inputs], axis=-1) inputs = tf.map_fn(self._single_flip, [inputs, rand_flip], dtype=tf.float32) inputs = tf.split(inputs, self.list_n_channels, axis=-1) @@ -424,7 +491,11 @@ def _single_swap(self, inputs): @staticmethod def _single_flip(inputs): flip_axis = tf.where(inputs[1]) - return K.switch(tf.equal(tf.size(flip_axis), 0), inputs[0], tf.reverse(inputs[0], axis=flip_axis[..., 0])) + return K.switch( + tf.equal(tf.size(flip_axis), 0), + inputs[0], + tf.reverse(inputs[0], axis=flip_axis[..., 0]), + ) class SampleConditionalGMM(Layer): @@ -462,17 +533,31 @@ def get_config(self): def build(self, input_shape): # check n_labels and n_channels - assert len(input_shape) == 3, 'should have three inputs: labels, means, std devs (in that order).' + assert ( + len(input_shape) == 3 + ), "should have three inputs: labels, means, std devs (in that order)." self.n_channels = input_shape[1][-1] self.n_labels = len(self.generation_labels) - assert self.n_labels == input_shape[1][1], 'means should have the same number of values as generation_labels' - assert self.n_labels == input_shape[2][1], 'stds should have the same number of values as generation_labels' + assert ( + self.n_labels == input_shape[1][1] + ), "means should have the same number of values as generation_labels" + assert ( + self.n_labels == input_shape[2][1] + ), "stds should have the same number of values as generation_labels" # scatter parameters (to build mean/std lut) self.max_label = np.max(self.generation_labels) + 1 - indices = np.concatenate([self.generation_labels + self.max_label * i for i in range(self.n_channels)], axis=-1) - self.shape = tf.convert_to_tensor([np.max(indices) + 1], dtype='int32') - self.indices = tf.convert_to_tensor(utils.add_axis(indices, axis=[0, -1]), dtype='int32') + indices = np.concatenate( + [ + self.generation_labels + self.max_label * i + for i in range(self.n_channels) + ], + axis=-1, + ) + self.shape = tf.convert_to_tensor([np.max(indices) + 1], dtype="int32") + self.indices = tf.convert_to_tensor( + utils.add_axis(indices, axis=[0, -1]), dtype="int32" + ) self.built = True super(SampleConditionalGMM, self).build(input_shape) @@ -481,24 +566,56 @@ def call(self, inputs, **kwargs): # reformat labels and scatter indices batch = tf.split(tf.shape(inputs[0]), [1, -1])[0] - tmp_indices = tf.tile(self.indices, tf.concat([batch, tf.convert_to_tensor([1, 1], dtype='int32')], axis=0)) - labels = tf.concat([tf.cast(inputs[0], dtype='int32') + self.max_label * i for i in range(self.n_channels)], -1) + tmp_indices = tf.tile( + self.indices, + tf.concat([batch, tf.convert_to_tensor([1, 1], dtype="int32")], axis=0), + ) + labels = tf.concat( + [ + tf.cast(inputs[0], dtype="int32") + self.max_label * i + for i in range(self.n_channels) + ], + -1, + ) # build mean map means = tf.concat([inputs[1][..., i] for i in range(self.n_channels)], 1) - tile_shape = tf.concat([batch, tf.convert_to_tensor([1, ], dtype='int32')], axis=0) - means = tf.tile(tf.expand_dims(tf.scatter_nd(tmp_indices, means, self.shape), 0), tile_shape) - means_map = tf.map_fn(lambda x: tf.gather(x[0], x[1]), [means, labels], dtype=tf.float32) + tile_shape = tf.concat( + [ + batch, + tf.convert_to_tensor( + [ + 1, + ], + dtype="int32", + ), + ], + axis=0, + ) + means = tf.tile( + tf.expand_dims(tf.scatter_nd(tmp_indices, means, self.shape), 0), tile_shape + ) + means_map = tf.map_fn( + lambda x: tf.gather(x[0], x[1]), [means, labels], dtype=tf.float32 + ) # same for stds stds = tf.concat([inputs[2][..., i] for i in range(self.n_channels)], 1) - stds = tf.tile(tf.expand_dims(tf.scatter_nd(tmp_indices, stds, self.shape), 0), tile_shape) - stds_map = tf.map_fn(lambda x: tf.gather(x[0], x[1]), [stds, labels], dtype=tf.float32) + stds = tf.tile( + tf.expand_dims(tf.scatter_nd(tmp_indices, stds, self.shape), 0), tile_shape + ) + stds_map = tf.map_fn( + lambda x: tf.gather(x[0], x[1]), [stds, labels], dtype=tf.float32 + ) return stds_map * tf.random.normal(tf.shape(labels)) + means_map def compute_output_shape(self, input_shape): - return input_shape[0] if (self.n_channels == 1) else tuple(list(input_shape[0][:-1]) + [self.n_channels]) + return ( + input_shape[0] + if (self.n_channels == 1) + else tuple(list(input_shape[0][:-1]) + [self.n_channels]) + ) class SampleResolution(Layer): @@ -528,14 +645,16 @@ class SampleResolution(Layer): """ - def __init__(self, - min_resolution, - max_res_iso=None, - max_res_aniso=None, - prob_iso=0.1, - prob_min=0.05, - return_thickness=True, - **kwargs): + def __init__( + self, + min_resolution, + max_res_iso=None, + max_res_aniso=None, + prob_iso=0.1, + prob_min=0.05, + return_thickness=True, + **kwargs + ): self.min_res = min_resolution self.max_res_iso_input = max_res_iso @@ -563,34 +682,43 @@ def get_config(self): def build(self, input_shape): # check maximum resolutions - assert ((self.max_res_iso_input is not None) | (self.max_res_aniso_input is not None)), \ - 'at least one of maximum isotropic or anisotropic resolutions must be provided, received none' + assert (self.max_res_iso_input is not None) | ( + self.max_res_aniso_input is not None + ), "at least one of maximum isotropic or anisotropic resolutions must be provided, received none" # reformat resolutions as numpy arrays self.min_res = np.array(self.min_res) if self.max_res_iso_input is not None: self.max_res_iso = np.array(self.max_res_iso_input) - assert len(self.min_res) == len(self.max_res_iso), \ - 'min and isotropic max resolution must have the same length, ' \ - 'had {0} and {1}'.format(self.min_res, self.max_res_iso) + assert len(self.min_res) == len(self.max_res_iso), ( + "min and isotropic max resolution must have the same length, " + "had {0} and {1}".format(self.min_res, self.max_res_iso) + ) if np.array_equal(self.min_res, self.max_res_iso): self.max_res_iso = None if self.max_res_aniso_input is not None: self.max_res_aniso = np.array(self.max_res_aniso_input) - assert len(self.min_res) == len(self.max_res_aniso), \ - 'min and anisotropic max resolution must have the same length, ' \ - 'had {} and {}'.format(self.min_res, self.max_res_aniso) + assert len(self.min_res) == len(self.max_res_aniso), ( + "min and anisotropic max resolution must have the same length, " + "had {} and {}".format(self.min_res, self.max_res_aniso) + ) if np.array_equal(self.min_res, self.max_res_aniso): self.max_res_aniso = None # check prob iso - if (self.max_res_iso is not None) & (self.max_res_aniso is not None) & (self.prob_iso == 0): - raise Exception('prob iso is 0 while sampling either isotropic and anisotropic resolutions is enabled') + if ( + (self.max_res_iso is not None) + & (self.max_res_aniso is not None) + & (self.prob_iso == 0) + ): + raise Exception( + "prob iso is 0 while sampling either isotropic and anisotropic resolutions is enabled" + ) if input_shape: self.add_batchsize = True - self.min_res_tens = tf.convert_to_tensor(self.min_res, dtype='float32') + self.min_res_tens = tf.convert_to_tensor(self.min_res, dtype="float32") self.built = True super(SampleResolution, self).build(input_shape) @@ -599,17 +727,36 @@ def call(self, inputs, **kwargs): if not self.add_batchsize: shape = [self.n_dims] - dim = tf.random.uniform(shape=(1, 1), minval=0, maxval=self.n_dims, dtype='int32') - mask = tf.tensor_scatter_nd_update(tf.zeros([self.n_dims], dtype='bool'), dim, - tf.convert_to_tensor([True], dtype='bool')) + dim = tf.random.uniform( + shape=(1, 1), minval=0, maxval=self.n_dims, dtype="int32" + ) + mask = tf.tensor_scatter_nd_update( + tf.zeros([self.n_dims], dtype="bool"), + dim, + tf.convert_to_tensor([True], dtype="bool"), + ) else: batch = tf.split(tf.shape(inputs), [1, -1])[0] - tile_shape = tf.concat([batch, tf.convert_to_tensor([1], dtype='int32')], axis=0) - self.min_res_tens = tf.tile(tf.expand_dims(self.min_res_tens, 0), tile_shape) - - shape = tf.concat([batch, tf.convert_to_tensor([self.n_dims], dtype='int32')], axis=0) - indices = tf.stack([tf.range(0, batch[0]), tf.random.uniform(batch, 0, self.n_dims, dtype='int32')], 1) - mask = tf.tensor_scatter_nd_update(tf.zeros(shape, dtype='bool'), indices, tf.ones(batch, dtype='bool')) + tile_shape = tf.concat( + [batch, tf.convert_to_tensor([1], dtype="int32")], axis=0 + ) + self.min_res_tens = tf.tile( + tf.expand_dims(self.min_res_tens, 0), tile_shape + ) + + shape = tf.concat( + [batch, tf.convert_to_tensor([self.n_dims], dtype="int32")], axis=0 + ) + indices = tf.stack( + [ + tf.range(0, batch[0]), + tf.random.uniform(batch, 0, self.n_dims, dtype="int32"), + ], + 1, + ) + mask = tf.tensor_scatter_nd_update( + tf.zeros(shape, dtype="bool"), indices, tf.ones(batch, dtype="bool") + ) # return min resolution as tensor if min=max if (self.max_res_iso is None) & (self.max_res_aniso is None): @@ -617,37 +764,60 @@ def call(self, inputs, **kwargs): # sample isotropic resolution only elif (self.max_res_iso is not None) & (self.max_res_aniso is None): - new_resolution_iso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_iso) - new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), - self.min_res_tens, - new_resolution_iso) + new_resolution_iso = tf.random.uniform( + shape, minval=self.min_res, maxval=self.max_res_iso + ) + new_resolution = K.switch( + tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), + self.min_res_tens, + new_resolution_iso, + ) # sample anisotropic resolution only elif (self.max_res_iso is None) & (self.max_res_aniso is not None): - new_resolution_aniso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_aniso) - new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), - self.min_res_tens, - tf.where(mask, new_resolution_aniso, self.min_res_tens)) + new_resolution_aniso = tf.random.uniform( + shape, minval=self.min_res, maxval=self.max_res_aniso + ) + new_resolution = K.switch( + tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), + self.min_res_tens, + tf.where(mask, new_resolution_aniso, self.min_res_tens), + ) # sample either anisotropic or isotropic resolution else: - new_resolution_iso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_iso) - new_resolution_aniso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_aniso) - new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_iso)), - new_resolution_iso, - tf.where(mask, new_resolution_aniso, self.min_res_tens)) - new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), - self.min_res_tens, - new_resolution) + new_resolution_iso = tf.random.uniform( + shape, minval=self.min_res, maxval=self.max_res_iso + ) + new_resolution_aniso = tf.random.uniform( + shape, minval=self.min_res, maxval=self.max_res_aniso + ) + new_resolution = K.switch( + tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_iso)), + new_resolution_iso, + tf.where(mask, new_resolution_aniso, self.min_res_tens), + ) + new_resolution = K.switch( + tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), + self.min_res_tens, + new_resolution, + ) if self.return_thickness: - return [new_resolution, tf.random.uniform(tf.shape(self.min_res_tens), self.min_res_tens, new_resolution)] + return [ + new_resolution, + tf.random.uniform( + tf.shape(self.min_res_tens), self.min_res_tens, new_resolution + ), + ] else: return new_resolution def compute_output_shape(self, input_shape): if self.return_thickness: - return [(None, self.n_dims)] * 2 if self.add_batchsize else [self.n_dims] * 2 + return ( + [(None, self.n_dims)] * 2 if self.add_batchsize else [self.n_dims] * 2 + ) else: return (None, self.n_dims) if self.add_batchsize else self.n_dims @@ -684,7 +854,9 @@ class GaussianBlur(Layer): def __init__(self, sigma, random_blur_range=None, use_mask=False, **kwargs): self.sigma = utils.reformat_to_list(sigma) - assert np.all(np.array(self.sigma) >= 0), 'sigma should be superior or equal to 0' + assert np.all( + np.array(self.sigma) >= 0 + ), "sigma should be superior or equal to 0" self.use_mask = use_mask self.n_dims = None @@ -707,7 +879,9 @@ def build(self, input_shape): # get shapes if self.use_mask: - assert len(input_shape) == 2, 'please provide a mask as second layer input when use_mask=True' + assert ( + len(input_shape) == 2 + ), "please provide a mask as second layer input when use_mask=True" self.n_dims = len(input_shape[0]) - 2 self.n_channels = input_shape[0][-1] else: @@ -724,7 +898,7 @@ def build(self, input_shape): self.kernels = None # prepare convolution - self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims) + self.convnd = getattr(tf.nn, "conv%dd" % self.n_dims) self.built = True super(GaussianBlur, self).build(input_shape) @@ -733,34 +907,76 @@ def call(self, inputs, **kwargs): if self.use_mask: image = inputs[0] - mask = tf.cast(inputs[1], 'bool') + mask = tf.cast(inputs[1], "bool") else: image = inputs mask = None # redefine the kernels at each new step when blur_range is activated if self.blur_range is not None: - self.kernels = l2i_et.gaussian_kernel(self.sigma, blur_range=self.blur_range, separable=self.separable) + self.kernels = l2i_et.gaussian_kernel( + self.sigma, blur_range=self.blur_range, separable=self.separable + ) if self.separable: for k in self.kernels: if k is not None: - image = tf.concat([self.convnd(tf.expand_dims(image[..., n], -1), k, self.stride, 'SAME') - for n in range(self.n_channels)], -1) + image = tf.concat( + [ + self.convnd( + tf.expand_dims(image[..., n], -1), + k, + self.stride, + "SAME", + ) + for n in range(self.n_channels) + ], + -1, + ) if self.use_mask: - maskb = tf.cast(mask, 'float32') - maskb = tf.concat([self.convnd(tf.expand_dims(maskb[..., n], -1), k, self.stride, 'SAME') - for n in range(self.n_channels)], -1) + maskb = tf.cast(mask, "float32") + maskb = tf.concat( + [ + self.convnd( + tf.expand_dims(maskb[..., n], -1), + k, + self.stride, + "SAME", + ) + for n in range(self.n_channels) + ], + -1, + ) image = image / (maskb + K.epsilon()) image = tf.where(mask, image, tf.zeros_like(image)) else: if any(self.sigma): - image = tf.concat([self.convnd(tf.expand_dims(image[..., n], -1), self.kernels, self.stride, 'SAME') - for n in range(self.n_channels)], -1) + image = tf.concat( + [ + self.convnd( + tf.expand_dims(image[..., n], -1), + self.kernels, + self.stride, + "SAME", + ) + for n in range(self.n_channels) + ], + -1, + ) if self.use_mask: - maskb = tf.cast(mask, 'float32') - maskb = tf.concat([self.convnd(tf.expand_dims(maskb[..., n], -1), self.kernels, self.stride, 'SAME') - for n in range(self.n_channels)], -1) + maskb = tf.cast(mask, "float32") + maskb = tf.concat( + [ + self.convnd( + tf.expand_dims(maskb[..., n], -1), + self.kernels, + self.stride, + "SAME", + ) + for n in range(self.n_channels) + ], + -1, + ) image = image / (maskb + K.epsilon()) image = tf.where(mask, image, tf.zeros_like(image)) @@ -798,10 +1014,12 @@ def get_config(self): return config def build(self, input_shape): - assert len(input_shape) == 2, 'sigma should be provided as an input tensor for dynamic blurring' + assert ( + len(input_shape) == 2 + ), "sigma should be provided as an input tensor for dynamic blurring" self.n_dims = len(input_shape[0]) - 2 self.n_channels = input_shape[0][-1] - self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims) + self.convnd = getattr(tf.nn, "conv%dd" % self.n_dims) self.max_sigma = utils.reformat_to_list(self.max_sigma, length=self.n_dims) self.separable = np.linalg.norm(np.array(self.max_sigma)) > 5 self.built = True @@ -810,7 +1028,9 @@ def build(self, input_shape): def call(self, inputs, **kwargs): image = inputs[0] sigma = inputs[-1] - kernels = l2i_et.gaussian_kernel(sigma, self.max_sigma, self.blur_range, self.separable) + kernels = l2i_et.gaussian_kernel( + sigma, self.max_sigma, self.blur_range, self.separable + ) if self.separable: for kernel in kernels: image = tf.map_fn(self._single_blur, [image, kernel], dtype=tf.float32) @@ -823,11 +1043,21 @@ def _single_blur(self, inputs): split_channels = tf.split(inputs[0], [1] * self.n_channels, axis=-1) blurred_channel = list() for channel in split_channels: - blurred = self.convnd(tf.expand_dims(channel, 0), inputs[1], [1] * (self.n_dims + 2), padding='SAME') + blurred = self.convnd( + tf.expand_dims(channel, 0), + inputs[1], + [1] * (self.n_dims + 2), + padding="SAME", + ) blurred_channel.append(tf.squeeze(blurred, axis=0)) output = tf.concat(blurred_channel, -1) else: - output = self.convnd(tf.expand_dims(inputs[0], 0), inputs[1], [1] * (self.n_dims + 2), padding='SAME') + output = self.convnd( + tf.expand_dims(inputs[0], 0), + inputs[1], + [1] * (self.n_dims + 2), + padding="SAME", + ) output = tf.squeeze(output, axis=0) return output @@ -872,8 +1102,16 @@ class MimicAcquisition(Layer): Note that the provided res must have higher values than min_low_res. """ - def __init__(self, volume_res, min_subsample_res, resample_shape, build_dist_map=False, - noise_std=0, prob_noise=0.95, **kwargs): + def __init__( + self, + volume_res, + min_subsample_res, + resample_shape, + build_dist_map=False, + noise_std=0, + prob_noise=0.95, + **kwargs + ): # resolutions and dimensions self.volume_res = volume_res @@ -915,11 +1153,17 @@ def build(self, input_shape): self.inshape = input_shape[0][1:] self.n_channels = input_shape[0][-1] self.add_batchsize = False if (input_shape[1][0] is None) else True - down_tensor_shape = np.int32(np.array(self.inshape[:-1]) * self.volume_res / self.min_subsample_res) + down_tensor_shape = np.int32( + np.array(self.inshape[:-1]) * self.volume_res / self.min_subsample_res + ) # build interpolation meshgrids - self.down_grid = tf.expand_dims(tf.stack(nrn_utils.volshape_to_ndgrid(down_tensor_shape), -1), axis=0) - self.up_grid = tf.expand_dims(tf.stack(nrn_utils.volshape_to_ndgrid(self.resample_shape), -1), axis=0) + self.down_grid = tf.expand_dims( + tf.stack(nrn_utils.volshape_to_ndgrid(down_tensor_shape), -1), axis=0 + ) + self.up_grid = tf.expand_dims( + tf.stack(nrn_utils.volshape_to_ndgrid(self.resample_shape), -1), axis=0 + ) self.built = True super(MimicAcquisition, self).build(input_shape) @@ -927,42 +1171,79 @@ def build(self, input_shape): def call(self, inputs, **kwargs): # sort inputs - assert len(inputs) == 2, 'inputs must have two items, the tensor to resample, and the downsampling resolution' + assert ( + len(inputs) == 2 + ), "inputs must have two items, the tensor to resample, and the downsampling resolution" vol = inputs[0] - subsample_res = tf.cast(inputs[1], dtype='float32') + subsample_res = tf.cast(inputs[1], dtype="float32") vol = K.reshape(vol, [-1, *self.inshape]) # necessary for multi_gpu models batchsize = tf.split(tf.shape(vol), [1, -1])[0] - tile_shape = tf.concat([batchsize, tf.ones([1], dtype='int32')], 0) + tile_shape = tf.concat([batchsize, tf.ones([1], dtype="int32")], 0) # get downsampling and upsampling factors if self.add_batchsize: subsample_res = tf.tile(tf.expand_dims(subsample_res, 0), tile_shape) - down_shape = tf.cast(tf.convert_to_tensor(np.array(self.inshape[:-1]) * self.volume_res, dtype='float32') / - subsample_res, dtype='int32') - down_zoom_factor = tf.cast(down_shape / tf.convert_to_tensor(self.inshape[:-1]), dtype='float32') - up_zoom_factor = tf.cast(tf.convert_to_tensor(self.resample_shape, dtype='int32') / down_shape, dtype='float32') + down_shape = tf.cast( + tf.convert_to_tensor( + np.array(self.inshape[:-1]) * self.volume_res, dtype="float32" + ) + / subsample_res, + dtype="int32", + ) + down_zoom_factor = tf.cast( + down_shape / tf.convert_to_tensor(self.inshape[:-1]), dtype="float32" + ) + up_zoom_factor = tf.cast( + tf.convert_to_tensor(self.resample_shape, dtype="int32") / down_shape, + dtype="float32", + ) # downsample - down_loc = tf.tile(self.down_grid, tf.concat([batchsize, tf.ones([self.n_dims + 1], dtype='int32')], 0)) - down_loc = tf.cast(down_loc, 'float32') / l2i_et.expand_dims(down_zoom_factor, axis=[1] * self.n_dims) - inshape_tens = tf.tile(tf.expand_dims(tf.convert_to_tensor(self.inshape[:-1]), 0), tile_shape) + down_loc = tf.tile( + self.down_grid, + tf.concat([batchsize, tf.ones([self.n_dims + 1], dtype="int32")], 0), + ) + down_loc = tf.cast(down_loc, "float32") / l2i_et.expand_dims( + down_zoom_factor, axis=[1] * self.n_dims + ) + inshape_tens = tf.tile( + tf.expand_dims(tf.convert_to_tensor(self.inshape[:-1]), 0), tile_shape + ) inshape_tens = l2i_et.expand_dims(inshape_tens, axis=[1] * self.n_dims) - down_loc = K.clip(down_loc, 0., tf.cast(inshape_tens, 'float32')) + down_loc = K.clip(down_loc, 0.0, tf.cast(inshape_tens, "float32")) vol = tf.map_fn(self._single_down_interpn, [vol, down_loc], tf.float32) # add noise with predefined probability if self.noise_std > 0: - sample_shape = tf.concat([batchsize, tf.ones([self.n_dims], dtype='int32'), - self.n_channels * tf.ones([1], dtype='int32')], 0) - noise = tf.random.normal(tf.shape(vol), stddev=tf.random.uniform(sample_shape, maxval=self.noise_std)) + sample_shape = tf.concat( + [ + batchsize, + tf.ones([self.n_dims], dtype="int32"), + self.n_channels * tf.ones([1], dtype="int32"), + ], + 0, + ) + noise = tf.random.normal( + tf.shape(vol), + stddev=tf.random.uniform(sample_shape, maxval=self.noise_std), + ) if self.prob_noise == 1: vol += noise else: - vol = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_noise)), vol + noise, vol) + vol = K.switch( + tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_noise)), + vol + noise, + vol, + ) # upsample - up_loc = tf.tile(self.up_grid, tf.concat([batchsize, tf.ones([self.n_dims + 1], dtype='int32')], axis=0)) - up_loc = tf.cast(up_loc, 'float32') / l2i_et.expand_dims(up_zoom_factor, axis=[1] * self.n_dims) + up_loc = tf.tile( + self.up_grid, + tf.concat([batchsize, tf.ones([self.n_dims + 1], dtype="int32")], axis=0), + ) + up_loc = tf.cast(up_loc, "float32") / l2i_et.expand_dims( + up_zoom_factor, axis=[1] * self.n_dims + ) vol = tf.map_fn(self._single_up_interpn, [vol, up_loc], tf.float32) # return upsampled volume @@ -981,18 +1262,22 @@ def call(self, inputs, **kwargs): c_dist = ceil - up_loc # keep minimum 1d distances, and compute 3d distance to nearest grid point - dist = tf.math.minimum(f_dist, c_dist) * l2i_et.expand_dims(subsample_res, axis=[1] * self.n_dims) - dist = tf.math.sqrt(tf.math.reduce_sum(tf.math.square(dist), axis=-1, keepdims=True)) + dist = tf.math.minimum(f_dist, c_dist) * l2i_et.expand_dims( + subsample_res, axis=[1] * self.n_dims + ) + dist = tf.math.sqrt( + tf.math.reduce_sum(tf.math.square(dist), axis=-1, keepdims=True) + ) return [vol, dist] @staticmethod def _single_down_interpn(inputs): - return nrn_utils.interpn(inputs[0], inputs[1], interp_method='nearest') + return nrn_utils.interpn(inputs[0], inputs[1], interp_method="nearest") @staticmethod def _single_up_interpn(inputs): - return nrn_utils.interpn(inputs[0], inputs[1], interp_method='linear') + return nrn_utils.interpn(inputs[0], inputs[1], interp_method="linear") def compute_output_shape(self, input_shape): output_shape = tuple([None] + self.resample_shape + [input_shape[0][-1]]) @@ -1015,7 +1300,14 @@ class BiasFieldCorruption(Layer): :param prob: probability to apply this bias field corruption. """ - def __init__(self, bias_field_std=.5, bias_scale=.025, same_bias_for_all_channels=False, prob=0.95, **kwargs): + def __init__( + self, + bias_field_std=0.5, + bias_scale=0.025, + same_bias_for_all_channels=False, + prob=0.95, + **kwargs + ): # input shape self.several_inputs = False @@ -1056,7 +1348,9 @@ def build(self, input_shape): # sampling shapes self.std_shape = [1] * (self.n_dims + 1) - self.small_bias_shape = utils.get_resample_shape(self.inshape[0][1:self.n_dims + 1], self.bias_scale, 1) + self.small_bias_shape = utils.get_resample_shape( + self.inshape[0][1 : self.n_dims + 1], self.bias_scale, 1 + ) if not self.same_bias_for_all_channels: self.std_shape[-1] = self.n_channels self.small_bias_shape[-1] = self.n_channels @@ -1073,14 +1367,24 @@ def call(self, inputs, **kwargs): # sampling shapes batchsize = tf.split(tf.shape(inputs[0]), [1, -1])[0] - std_shape = tf.concat([batchsize, tf.convert_to_tensor(self.std_shape, dtype='int32')], 0) - bias_shape = tf.concat([batchsize, tf.convert_to_tensor(self.small_bias_shape, dtype='int32')], axis=0) + std_shape = tf.concat( + [batchsize, tf.convert_to_tensor(self.std_shape, dtype="int32")], 0 + ) + bias_shape = tf.concat( + [batchsize, tf.convert_to_tensor(self.small_bias_shape, dtype="int32")], + axis=0, + ) # sample small bias field - bias_field = tf.random.normal(bias_shape, stddev=tf.random.uniform(std_shape, maxval=self.bias_field_std)) + bias_field = tf.random.normal( + bias_shape, + stddev=tf.random.uniform(std_shape, maxval=self.bias_field_std), + ) # resize bias field and take exponential - bias_field = nrn_layers.Resize(size=self.inshape[0][1:self.n_dims + 1], interp_method='linear')(bias_field) + bias_field = nrn_layers.Resize( + size=self.inshape[0][1 : self.n_dims + 1], interp_method="linear" + )(bias_field) bias_field = tf.math.exp(bias_field) # apply bias field with predefined probability @@ -1089,9 +1393,14 @@ def call(self, inputs, **kwargs): else: rand_trans = tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob)) if self.several_inputs: - return [K.switch(rand_trans, tf.math.multiply(bias_field, v), v) for v in inputs] + return [ + K.switch(rand_trans, tf.math.multiply(bias_field, v), v) + for v in inputs + ] else: - return K.switch(rand_trans, tf.math.multiply(bias_field, inputs[0]), inputs[0]) + return K.switch( + rand_trans, tf.math.multiply(bias_field, inputs[0]), inputs[0] + ) else: return inputs @@ -1125,8 +1434,19 @@ class IntensityAugmentation(Layer): :param prob_gamma: probability to apply gamma augmentation """ - def __init__(self, noise_std=0, clip=0, normalise=True, norm_perc=0, gamma_std=0, contrast_inversion=False, - separate_channels=True, prob_noise=0.95, prob_gamma=1, **kwargs): + def __init__( + self, + noise_std=0, + clip=0, + normalise=True, + norm_perc=0, + gamma_std=0, + contrast_inversion=False, + separate_channels=True, + prob_noise=0.95, + prob_gamma=1, + **kwargs + ): # shape attributes self.n_dims = None @@ -1166,17 +1486,29 @@ def build(self, input_shape): self.n_dims = len(input_shape) - 2 self.n_channels = input_shape[-1] self.flatten_shape = np.prod(np.array(input_shape[1:-1])) - self.flatten_shape = self.flatten_shape * self.n_channels if not self.separate_channels else self.flatten_shape - self.expand_minmax_dim = self.n_dims if self.separate_channels else self.n_dims + 1 - self.one = tf.ones([1], dtype='int32') + self.flatten_shape = ( + self.flatten_shape * self.n_channels + if not self.separate_channels + else self.flatten_shape + ) + self.expand_minmax_dim = ( + self.n_dims if self.separate_channels else self.n_dims + 1 + ) + self.one = tf.ones([1], dtype="int32") if self.clip: self.clip_values = utils.reformat_to_list(self.clip) - self.clip_values = self.clip_values if len(self.clip_values) == 2 else [0, self.clip_values[0]] + self.clip_values = ( + self.clip_values + if len(self.clip_values) == 2 + else [0, self.clip_values[0]] + ) else: self.clip_values = None if self.norm_perc: self.perc = utils.reformat_to_list(self.norm_perc) - self.perc = self.perc if len(self.perc) == 2 else [self.perc[0], 1 - self.perc[0]] + self.perc = ( + self.perc if len(self.perc) == 2 else [self.perc[0], 1 - self.perc[0]] + ) else: self.perc = None @@ -1188,7 +1520,9 @@ def call(self, inputs, **kwargs): # prepare shape for sampling the noise and gamma std dev (depending on whether we augment channels separately) batchsize = tf.split(tf.shape(inputs), [1, -1])[0] if (self.noise_std > 0) | (self.gamma_std > 0) | self.contrast_inversion: - sample_shape = tf.concat([batchsize, tf.ones([self.n_dims], dtype='int32')], 0) + sample_shape = tf.concat( + [batchsize, tf.ones([self.n_dims], dtype="int32")], 0 + ) if self.separate_channels: sample_shape = tf.concat([sample_shape, self.n_channels * self.one], 0) else: @@ -1202,13 +1536,21 @@ def call(self, inputs, **kwargs): if self.separate_channels: noise = tf.random.normal(tf.shape(inputs), stddev=noise_stddev) else: - noise = tf.random.normal(tf.shape(tf.split(inputs, [1, -1], -1)[0]), stddev=noise_stddev) - noise = tf.tile(noise, tf.convert_to_tensor([1] * (self.n_dims + 1) + [self.n_channels])) + noise = tf.random.normal( + tf.shape(tf.split(inputs, [1, -1], -1)[0]), stddev=noise_stddev + ) + noise = tf.tile( + noise, + tf.convert_to_tensor([1] * (self.n_dims + 1) + [self.n_channels]), + ) if self.prob_noise == 1: inputs = inputs + noise else: - inputs = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_noise)), - inputs + noise, inputs) + inputs = K.switch( + tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_noise)), + inputs + noise, + inputs, + ) # clip images to given values if self.clip_values is not None: @@ -1219,12 +1561,23 @@ def call(self, inputs, **kwargs): # define robust min and max by sorting values and taking percentile if self.perc is not None: if self.separate_channels: - shape = tf.concat([batchsize, self.flatten_shape * self.one, self.n_channels * self.one], 0) + shape = tf.concat( + [ + batchsize, + self.flatten_shape * self.one, + self.n_channels * self.one, + ], + 0, + ) else: shape = tf.concat([batchsize, self.flatten_shape * self.one], 0) intensities = tf.sort(tf.reshape(inputs, shape), axis=1) m = intensities[:, max(int(self.perc[0] * self.flatten_shape), 0), ...] - M = intensities[:, min(int(self.perc[1] * self.flatten_shape), self.flatten_shape - 1), ...] + M = intensities[ + :, + min(int(self.perc[1] * self.flatten_shape), self.flatten_shape - 1), + ..., + ] # simple min and max else: m = K.min(inputs, axis=list(range(1, self.expand_minmax_dim + 1))) @@ -1241,8 +1594,11 @@ def call(self, inputs, **kwargs): if self.prob_gamma == 1: inputs = tf.math.pow(inputs, tf.math.exp(gamma)) else: - inputs = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_gamma)), - tf.math.pow(inputs, tf.math.exp(gamma)), inputs) + inputs = K.switch( + tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_gamma)), + tf.math.pow(inputs, tf.math.exp(gamma)), + inputs, + ) # apply random contrast inversion if self.contrast_inversion: @@ -1250,8 +1606,12 @@ def call(self, inputs, **kwargs): split_channels = tf.split(inputs, [1] * self.n_channels, axis=-1) split_rand_invert = tf.split(rand_invert, [1] * self.n_channels, axis=-1) inverted_channel = list() - for (channel, invert) in zip(split_channels, split_rand_invert): - inverted_channel.append(tf.map_fn(self._single_invert, [channel, invert], dtype=channel.dtype)) + for channel, invert in zip(split_channels, split_rand_invert): + inverted_channel.append( + tf.map_fn( + self._single_invert, [channel, invert], dtype=channel.dtype + ) + ) inputs = tf.concat(inverted_channel, -1) return inputs @@ -1279,13 +1639,15 @@ class DiceLoss(Layer): probabilities sum to 1 at each voxel location). Default is True. """ - def __init__(self, - class_weights=None, - boundary_weights=0, - boundary_dist=3, - skip_background=True, - enable_checks=True, - **kwargs): + def __init__( + self, + class_weights=None, + boundary_weights=0, + boundary_dist=3, + skip_background=True, + enable_checks=True, + **kwargs + ): self.class_weights = class_weights self.dynamic_weighting = False @@ -1310,13 +1672,17 @@ def get_config(self): def build(self, input_shape): # get shape - assert len(input_shape) == 2, 'DiceLoss expects 2 inputs to compute the Dice loss.' - assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.' + assert ( + len(input_shape) == 2 + ), "DiceLoss expects 2 inputs to compute the Dice loss." + assert ( + input_shape[0] == input_shape[1] + ), "the two inputs must have the same shape." inshape = input_shape[0][1:] n_dims = len(inshape[:-1]) n_labels = inshape[-1] self.spatial_axes = list(range(1, n_dims + 1)) - self.avg_pooling_layer = getattr(keras.layers, 'AvgPool%dD' % n_dims) + self.avg_pooling_layer = getattr(keras.layers, "AvgPool%dD" % n_dims) self.skip_background = False if n_labels == 1 else self.skip_background # build tensor with class weights @@ -1324,8 +1690,10 @@ def build(self, input_shape): if self.class_weights == -1: self.dynamic_weighting = True else: - class_weights_tens = utils.reformat_to_list(self.class_weights, n_labels) - class_weights_tens = tf.convert_to_tensor(class_weights_tens, 'float32') + class_weights_tens = utils.reformat_to_list( + self.class_weights, n_labels + ) + class_weights_tens = tf.convert_to_tensor(class_weights_tens, "float32") self.class_weights_tens = l2i_et.expand_dims(class_weights_tens, 0) self.built = True @@ -1336,9 +1704,27 @@ def call(self, inputs, **kwargs): # make sure tensors are probabilistic gt = inputs[0] pred = inputs[1] - if self.enable_checks: # disabling is useful to, e.g., use incomplete label maps - gt = K.clip(gt / (tf.math.reduce_sum(gt, axis=-1, keepdims=True) + tf.keras.backend.epsilon()), 0, 1) - pred = K.clip(pred / (tf.math.reduce_sum(pred, axis=-1, keepdims=True) + tf.keras.backend.epsilon()), 0, 1) + if ( + self.enable_checks + ): # disabling is useful to, e.g., use incomplete label maps + gt = K.clip( + gt + / ( + tf.math.reduce_sum(gt, axis=-1, keepdims=True) + + tf.keras.backend.epsilon() + ), + 0, + 1, + ) + pred = K.clip( + pred + / ( + tf.math.reduce_sum(pred, axis=-1, keepdims=True) + + tf.keras.backend.epsilon() + ), + 0, + 1, + ) # compute dice loss for each label top = 2 * gt * pred @@ -1346,11 +1732,18 @@ def call(self, inputs, **kwargs): # apply boundary weighting (ie voxels close to region boundaries will be counted several times to compute Dice) if self.boundary_weights: - avg = self.avg_pooling_layer(pool_size=2 * self.boundary_dist + 1, strides=1, padding='same')(gt) - boundaries = tf.cast(avg > 0., 'float32') * tf.cast(avg < (1 / len(self.spatial_axes) - 1e-4), 'float32') + avg = self.avg_pooling_layer( + pool_size=2 * self.boundary_dist + 1, strides=1, padding="same" + )(gt) + boundaries = tf.cast(avg > 0.0, "float32") * tf.cast( + avg < (1 / len(self.spatial_axes) - 1e-4), "float32" + ) if self.skip_background: boundaries_channels = tf.unstack(boundaries, axis=-1) - boundaries = tf.stack([tf.zeros_like(boundaries_channels[0])] + boundaries_channels[1:], axis=-1) + boundaries = tf.stack( + [tf.zeros_like(boundaries_channels[0])] + boundaries_channels[1:], + axis=-1, + ) boundary_weights_tensor = 1 + self.boundary_weights * boundaries top *= boundary_weights_tensor bottom *= boundary_weights_tensor @@ -1360,17 +1753,25 @@ def call(self, inputs, **kwargs): # compute loss top = tf.math.reduce_sum(top, self.spatial_axes) bottom = tf.math.reduce_sum(bottom, self.spatial_axes) - dice = (top + tf.keras.backend.epsilon()) / (bottom + tf.keras.backend.epsilon()) + dice = (top + tf.keras.backend.epsilon()) / ( + bottom + tf.keras.backend.epsilon() + ) loss = 1 - dice # apply class weighting across labels. In this case loss will have shape (batch), otherwise (batch, n_labels). - if self.dynamic_weighting: # the weight of a class is the inverse of its volume in the gt - if boundary_weights_tensor is not None: # we account for the boundary weighting to compute volume - self.class_weights_tens = 1 / tf.reduce_sum(gt * boundary_weights_tensor, self.spatial_axes) + if ( + self.dynamic_weighting + ): # the weight of a class is the inverse of its volume in the gt + if ( + boundary_weights_tensor is not None + ): # we account for the boundary weighting to compute volume + self.class_weights_tens = 1 / tf.reduce_sum( + gt * boundary_weights_tensor, self.spatial_axes + ) else: self.class_weights_tens = 1 / tf.reduce_sum(gt, self.spatial_axes) if self.class_weights_tens is not None: - self. class_weights_tens /= tf.reduce_sum(self.class_weights_tens, -1) + self.class_weights_tens /= tf.reduce_sum(self.class_weights_tens, -1) loss = tf.reduce_sum(loss * self.class_weights_tens, -1) return tf.math.reduce_mean(loss) @@ -1399,8 +1800,12 @@ def get_config(self): return config def build(self, input_shape): - assert len(input_shape) == 2, 'DiceLoss expects 2 inputs to compute the Dice loss.' - assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.' + assert ( + len(input_shape) == 2 + ), "DiceLoss expects 2 inputs to compute the Dice loss." + assert ( + input_shape[0] == input_shape[1] + ), "the two inputs must have the same shape." self.n_labels = input_shape[0][-1] self.built = True super(WeightedL2Loss, self).build(input_shape) @@ -1409,7 +1814,9 @@ def call(self, inputs, **kwargs): gt = inputs[0] pred = inputs[1] weights = tf.expand_dims(1 - gt[..., 0] + 1e-8, -1) - return K.sum(weights * K.square(pred - self.target_value * (2 * gt - 1))) / (K.sum(weights) * self.n_labels) + return K.sum(weights * K.square(pred - self.target_value * (2 * gt - 1))) / ( + K.sum(weights) * self.n_labels + ) def compute_output_shape(self, input_shape): return [[]] @@ -1433,13 +1840,15 @@ class CrossEntropyLoss(Layer): probabilities sum to 1 at each voxel location). Default is True. """ - def __init__(self, - class_weights=None, - boundary_weights=0, - boundary_dist=3, - skip_background=True, - enable_checks=True, - **kwargs): + def __init__( + self, + class_weights=None, + boundary_weights=0, + boundary_dist=3, + skip_background=True, + enable_checks=True, + **kwargs + ): self.class_weights = class_weights self.dynamic_weighting = False @@ -1464,13 +1873,17 @@ def get_config(self): def build(self, input_shape): # get shape - assert len(input_shape) == 2, 'CrossEntropy expects 2 inputs to compute the Dice loss.' - assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.' + assert ( + len(input_shape) == 2 + ), "CrossEntropy expects 2 inputs to compute the Dice loss." + assert ( + input_shape[0] == input_shape[1] + ), "the two inputs must have the same shape." inshape = input_shape[0][1:] n_dims = len(inshape[:-1]) n_labels = inshape[-1] self.spatial_axes = list(range(1, n_dims + 1)) - self.avg_pooling_layer = getattr(keras.layers, 'AvgPool%dD' % n_dims) + self.avg_pooling_layer = getattr(keras.layers, "AvgPool%dD" % n_dims) self.skip_background = False if n_labels == 1 else self.skip_background # build tensor with class weights @@ -1478,9 +1891,13 @@ def build(self, input_shape): if self.class_weights == -1: self.dynamic_weighting = True else: - class_weights_tens = utils.reformat_to_list(self.class_weights, n_labels) - class_weights_tens = tf.convert_to_tensor(class_weights_tens, 'float32') - self.class_weights_tens = l2i_et.expand_dims(class_weights_tens, [0] * (1 + n_dims)) + class_weights_tens = utils.reformat_to_list( + self.class_weights, n_labels + ) + class_weights_tens = tf.convert_to_tensor(class_weights_tens, "float32") + self.class_weights_tens = l2i_et.expand_dims( + class_weights_tens, [0] * (1 + n_dims) + ) self.built = True super(CrossEntropyLoss, self).build(input_shape) @@ -1490,30 +1907,58 @@ def call(self, inputs, **kwargs): # make sure tensors are probabilistic gt = inputs[0] pred = inputs[1] - if self.enable_checks: # disabling is useful to, e.g., use incomplete label maps - gt = K.clip(gt / (tf.math.reduce_sum(gt, axis=-1, keepdims=True) + tf.keras.backend.epsilon()), 0, 1) - pred = pred / (tf.math.reduce_sum(pred, axis=-1, keepdims=True) + tf.keras.backend.epsilon()) - pred = K.clip(pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon()) # to avoid log(0) + if ( + self.enable_checks + ): # disabling is useful to, e.g., use incomplete label maps + gt = K.clip( + gt + / ( + tf.math.reduce_sum(gt, axis=-1, keepdims=True) + + tf.keras.backend.epsilon() + ), + 0, + 1, + ) + pred = pred / ( + tf.math.reduce_sum(pred, axis=-1, keepdims=True) + + tf.keras.backend.epsilon() + ) + pred = K.clip( + pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon() + ) # to avoid log(0) # compare prediction/target, ce has the same shape has the input tensors ce = -gt * tf.math.log(pred) # apply boundary weighting (ie voxels close to region boundaries will be counted several times to compute Dice) if self.boundary_weights: - avg = self.avg_pooling_layer(pool_size=2 * self.boundary_dist + 1, strides=1, padding='same')(gt) - boundaries = tf.cast(avg > 0., 'float32') * tf.cast(avg < (1 / len(self.spatial_axes) - 1e-4), 'float32') + avg = self.avg_pooling_layer( + pool_size=2 * self.boundary_dist + 1, strides=1, padding="same" + )(gt) + boundaries = tf.cast(avg > 0.0, "float32") * tf.cast( + avg < (1 / len(self.spatial_axes) - 1e-4), "float32" + ) if self.skip_background: boundaries_channels = tf.unstack(boundaries, axis=-1) - boundaries = tf.stack([tf.zeros_like(boundaries_channels[0])] + boundaries_channels[1:], axis=-1) + boundaries = tf.stack( + [tf.zeros_like(boundaries_channels[0])] + boundaries_channels[1:], + axis=-1, + ) boundary_weights_tensor = 1 + self.boundary_weights * boundaries ce *= boundary_weights_tensor else: boundary_weights_tensor = None # apply class weighting across labels. By the end of this, ce still has the same shape has the input tensors. - if self.dynamic_weighting: # the weight of a class is the inverse of its volume in the gt - if boundary_weights_tensor is not None: # we account for the boundary weighting to compute volume - self.class_weights_tens = 1 / tf.reduce_sum(gt * boundary_weights_tensor, self.spatial_axes, True) + if ( + self.dynamic_weighting + ): # the weight of a class is the inverse of its volume in the gt + if ( + boundary_weights_tensor is not None + ): # we account for the boundary weighting to compute volume + self.class_weights_tens = 1 / tf.reduce_sum( + gt * boundary_weights_tensor, self.spatial_axes, True + ) else: self.class_weights_tens = 1 / tf.reduce_sum(gt, self.spatial_axes) if self.class_weights_tens is not None: @@ -1560,8 +2005,12 @@ def get_config(self): def build(self, input_shape): # get shape - assert len(input_shape) == 2, 'MomentLoss expects 2 inputs to compute the Dice loss.' - assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.' + assert ( + len(input_shape) == 2 + ), "MomentLoss expects 2 inputs to compute the Dice loss." + assert ( + input_shape[0] == input_shape[1] + ), "the two inputs must have the same shape." inshape = input_shape[0][1:] n_dims = len(inshape[:-1]) n_labels = inshape[-1] @@ -1569,15 +2018,20 @@ def build(self, input_shape): # build coordinate meshgrid of size (1, dim1, dim2, ..., dimN, ndim, nchan) self.coordinates = tf.stack(nrn_utils.volshape_to_ndgrid(inshape[:-1]), -1) - self.coordinates = tf.cast(l2i_et.expand_dims(tf.stack([self.coordinates] * n_labels, -1), 0), 'float32') + self.coordinates = tf.cast( + l2i_et.expand_dims(tf.stack([self.coordinates] * n_labels, -1), 0), + "float32", + ) # build tensor with class weights if self.class_weights is not None: if self.class_weights == -1: self.dynamic_weighting = True else: - class_weights_tens = utils.reformat_to_list(self.class_weights, n_labels) - class_weights_tens = tf.convert_to_tensor(class_weights_tens, 'float32') + class_weights_tens = utils.reformat_to_list( + self.class_weights, n_labels + ) + class_weights_tens = tf.convert_to_tensor(class_weights_tens, "float32") self.class_weights_tens = l2i_et.expand_dims(class_weights_tens, 0) self.built = True @@ -1588,17 +2042,31 @@ def call(self, inputs, **kwargs): # make sure tensors are probabilistic gt = inputs[0] # (B, dim1, dim2, ..., dimN, nchan) pred = inputs[1] - if self.enable_checks: # disabling is useful to, e.g., use incomplete label maps - gt = gt / (tf.math.reduce_sum(gt, axis=-1, keepdims=True) + tf.keras.backend.epsilon()) - pred = pred / (tf.math.reduce_sum(pred, axis=-1, keepdims=True) + tf.keras.backend.epsilon()) + if ( + self.enable_checks + ): # disabling is useful to, e.g., use incomplete label maps + gt = gt / ( + tf.math.reduce_sum(gt, axis=-1, keepdims=True) + + tf.keras.backend.epsilon() + ) + pred = pred / ( + tf.math.reduce_sum(pred, axis=-1, keepdims=True) + + tf.keras.backend.epsilon() + ) # compute loss gt_mean_coordinates = self._mean_coordinates(gt) # (B, ndim, nchan) pred_mean_coordinates = self._mean_coordinates(pred) - loss = tf.math.sqrt(tf.reduce_sum(tf.square(pred_mean_coordinates - gt_mean_coordinates), axis=1)) # (B, nchan) + loss = tf.math.sqrt( + tf.reduce_sum( + tf.square(pred_mean_coordinates - gt_mean_coordinates), axis=1 + ) + ) # (B, nchan) # apply class weighting across labels. In this case loss will have shape (batch), otherwise (batch, n_labels). - if self.dynamic_weighting: # the weight of a class is the inverse of its volume in the gt + if ( + self.dynamic_weighting + ): # the weight of a class is the inverse of its volume in the gt self.class_weights_tens = 1 / tf.reduce_sum(gt, self.spatial_axes) if self.class_weights_tens is not None: self.class_weights_tens /= tf.reduce_sum(self.class_weights_tens, -1) @@ -1607,9 +2075,15 @@ def call(self, inputs, **kwargs): return tf.math.reduce_mean(loss) def _mean_coordinates(self, tensor): - tensor = l2i_et.expand_dims(tensor, axis=-2) # (B, dim1, dim2, ..., dimN, 1, nchan) - numerator = tf.reduce_sum(tensor * self.coordinates, axis=self.spatial_axes) # (B, ndim, nchan) - denominator = tf.reduce_sum(tensor, axis=self.spatial_axes) + tf.keras.backend.epsilon() + tensor = l2i_et.expand_dims( + tensor, axis=-2 + ) # (B, dim1, dim2, ..., dimN, 1, nchan) + numerator = tf.reduce_sum( + tensor * self.coordinates, axis=self.spatial_axes + ) # (B, ndim, nchan) + denominator = ( + tf.reduce_sum(tensor, axis=self.spatial_axes) + tf.keras.backend.epsilon() + ) return numerator / denominator def compute_output_shape(self, input_shape): @@ -1633,7 +2107,9 @@ class ResetValuesToZero(Layer): """ def __init__(self, values, **kwargs): - assert values is not None, 'please provide correct list of values, received None' + assert ( + values is not None + ), "please provide correct list of values, received None" self.values = utils.reformat_to_list(values) self.values_tens = None self.n_values = len(values) @@ -1652,7 +2128,9 @@ def build(self, input_shape): def call(self, inputs, **kwargs): values = tf.cast(self.values_tens, dtype=inputs.dtype) for i in range(self.n_values): - inputs = tf.where(tf.equal(inputs, values[i]), tf.zeros_like(inputs), inputs) + inputs = tf.where( + tf.equal(inputs, values[i]), tf.zeros_like(inputs), inputs + ) return inputs @@ -1681,12 +2159,15 @@ def get_config(self): return config def build(self, input_shape): - self.lut = tf.convert_to_tensor(utils.get_mapping_lut(self.source_values, dest=self.dest_values), dtype='int32') + self.lut = tf.convert_to_tensor( + utils.get_mapping_lut(self.source_values, dest=self.dest_values), + dtype="int32", + ) self.built = True super(ConvertLabels, self).build(input_shape) def call(self, inputs, **kwargs): - return tf.gather(self.lut, tf.cast(inputs, dtype='int32')) + return tf.gather(self.lut, tf.cast(inputs, dtype="int32")) class PadAroundCentre(Layer): @@ -1725,19 +2206,32 @@ def build(self, input_shape): shape[-1] = 0 if self.pad_margin is not None: - assert self.pad_shape is None, 'please do not provide a padding shape and margin at the same time.' + assert ( + self.pad_shape is None + ), "please do not provide a padding shape and margin at the same time." # reformat padding margins - pad = np.transpose(np.array([[0] + utils.reformat_to_list(self.pad_margin, self.n_dims) + [0]] * 2)) - self.pad_margin_tens = tf.convert_to_tensor(pad, dtype='int32') + pad = np.transpose( + np.array( + [[0] + utils.reformat_to_list(self.pad_margin, self.n_dims) + [0]] + * 2 + ) + ) + self.pad_margin_tens = tf.convert_to_tensor(pad, dtype="int32") elif self.pad_shape is not None: - assert self.pad_margin is None, 'please do not provide a padding shape and margin at the same time.' + assert ( + self.pad_margin is None + ), "please do not provide a padding shape and margin at the same time." # pad shape - tensor_shape = tf.cast(tf.convert_to_tensor(shape), 'int32') - self.pad_shape_tens = np.array([0] + utils.reformat_to_list(self.pad_shape, length=self.n_dims) + [0]) - self.pad_shape_tens = tf.convert_to_tensor(self.pad_shape_tens, dtype='int32') + tensor_shape = tf.cast(tf.convert_to_tensor(shape), "int32") + self.pad_shape_tens = np.array( + [0] + utils.reformat_to_list(self.pad_shape, length=self.n_dims) + [0] + ) + self.pad_shape_tens = tf.convert_to_tensor( + self.pad_shape_tens, dtype="int32" + ) self.pad_shape_tens = tf.math.maximum(tensor_shape, self.pad_shape_tens) # padding margin @@ -1746,13 +2240,17 @@ def build(self, input_shape): self.pad_margin_tens = tf.stack([min_margins, max_margins], axis=-1) else: - raise Exception('please either provide a padding shape or a padding margin.') + raise Exception( + "please either provide a padding shape or a padding margin." + ) self.built = True super(PadAroundCentre, self).build(input_shape) def call(self, inputs, **kwargs): - return tf.pad(inputs, self.pad_margin_tens, mode='CONSTANT', constant_values=self.value) + return tf.pad( + inputs, self.pad_margin_tens, mode="CONSTANT", constant_values=self.value + ) class MaskEdges(Layer): @@ -1796,8 +2294,10 @@ class MaskEdges(Layer): """ def __init__(self, axes, boundaries, prob_mask=1, **kwargs): - self.axes = utils.reformat_to_list(axes, dtype='int') - self.boundaries = utils.reformat_to_n_channels_array(boundaries, n_dims=4, n_channels=len(self.axes)) + self.axes = utils.reformat_to_list(axes, dtype="int") + self.boundaries = utils.reformat_to_n_channels_array( + boundaries, n_dims=4, n_channels=len(self.axes) + ) self.prob_mask = prob_mask self.inputshape = None super(MaskEdges, self).__init__(**kwargs) @@ -1822,26 +2322,42 @@ def call(self, inputs, **kwargs): # select restricting indices axis_boundaries = self.boundaries[i, :] - idx1 = tf.math.round(tf.random.uniform([1], - minval=axis_boundaries[0] * self.inputshape[axis], - maxval=axis_boundaries[1] * self.inputshape[axis])) - idx2 = tf.math.round(tf.random.uniform([1], - minval=axis_boundaries[2] * self.inputshape[axis], - maxval=axis_boundaries[3] * self.inputshape[axis] - 1) - idx1) + idx1 = tf.math.round( + tf.random.uniform( + [1], + minval=axis_boundaries[0] * self.inputshape[axis], + maxval=axis_boundaries[1] * self.inputshape[axis], + ) + ) + idx2 = tf.math.round( + tf.random.uniform( + [1], + minval=axis_boundaries[2] * self.inputshape[axis], + maxval=axis_boundaries[3] * self.inputshape[axis] - 1, + ) + - idx1 + ) idx3 = self.inputshape[axis] - idx1 - idx2 - split_idx = tf.cast(tf.concat([idx1, idx2, idx3], axis=0), dtype='int32') + split_idx = tf.cast(tf.concat([idx1, idx2, idx3], axis=0), dtype="int32") # update mask split_list = tf.split(inputs, split_idx, axis=axis) - tmp_mask = tf.concat([tf.zeros_like(split_list[0]), - tf.ones_like(split_list[1]), - tf.zeros_like(split_list[2])], axis=axis) + tmp_mask = tf.concat( + [ + tf.zeros_like(split_list[0]), + tf.ones_like(split_list[1]), + tf.zeros_like(split_list[2]), + ], + axis=axis, + ) mask = mask * tmp_mask # mask second_channel - tensor = K.switch(tf.squeeze(K.greater(tf.random.uniform([1], 0, 1), 1 - self.prob_mask)), - inputs * mask, - inputs) + tensor = K.switch( + tf.squeeze(K.greater(tf.random.uniform([1], 0, 1), 1 - self.prob_mask)), + inputs * mask, + inputs, + ) return [tensor, mask] @@ -1851,11 +2367,15 @@ def compute_output_shape(self, input_shape): class ImageGradients(Layer): - def __init__(self, gradient_type='sobel', return_magnitude=False, **kwargs): + def __init__(self, gradient_type="sobel", return_magnitude=False, **kwargs): self.gradient_type = gradient_type - assert (self.gradient_type == 'sobel') | (self.gradient_type == '1-step_diff'), \ - 'gradient_type should be either sobel or 1-step_diff, had %s' % self.gradient_type + assert (self.gradient_type == "sobel") | ( + self.gradient_type == "1-step_diff" + ), ( + "gradient_type should be either sobel or 1-step_diff, had %s" + % self.gradient_type + ) # shape self.n_dims = 0 @@ -1885,10 +2405,10 @@ def build(self, input_shape): self.n_channels = input_shape[-1] # prepare kernel if sobel gradients - if self.gradient_type == 'sobel': + if self.gradient_type == "sobel": self.kernels = l2i_et.sobel_kernels(self.n_dims) self.stride = [1] * (self.n_dims + 2) - self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims) + self.convnd = getattr(tf.nn, "conv%dd" % self.n_dims) else: self.kernels = self.convnd = self.stride = None @@ -1902,14 +2422,24 @@ def call(self, inputs, **kwargs): gradients = list() # sobel method - if self.gradient_type == 'sobel': + if self.gradient_type == "sobel": # get sobel gradients in each direction for n in range(self.n_dims): gradient = image # apply 1D kernel in each direction (sobel kernels are separable), instead of applying a nD kernel for k in self.kernels[n]: - gradient = tf.concat([self.convnd(tf.expand_dims(gradient[..., n], -1), k, self.stride, 'SAME') - for n in range(self.n_channels)], -1) + gradient = tf.concat( + [ + self.convnd( + tf.expand_dims(gradient[..., n], -1), + k, + self.stride, + "SAME", + ) + for n in range(self.n_channels) + ], + -1, + ) gradients.append(gradient) # 1-step method, only supports 2 and 3D @@ -1926,18 +2456,28 @@ def call(self, inputs, **kwargs): gradients.append(image[:, :, :, 1:, :] - image[:, :, :, :-1, :]) # dz else: - raise Exception('ImageGradients only support 2D or 3D tensors for 1-step diff, had: %dD' % self.n_dims) + raise Exception( + "ImageGradients only support 2D or 3D tensors for 1-step diff, had: %dD" + % self.n_dims + ) # pad with zeros to return tensors of the same shape as input for i in range(self.n_dims): tmp_shape = list(self.shape) tmp_shape[i] = 1 - zeros = tf.zeros(tf.concat([batchsize, tf.convert_to_tensor(tmp_shape, dtype='int32')], 0), image.dtype) + zeros = tf.zeros( + tf.concat( + [batchsize, tf.convert_to_tensor(tmp_shape, dtype="int32")], 0 + ), + image.dtype, + ) gradients[i] = tf.concat([gradients[i], zeros], axis=i + 1) # compute total gradient magnitude if necessary, or concatenate different gradients along the channel axis if self.return_magnitude: - gradients = tf.sqrt(tf.reduce_sum(tf.square(tf.stack(gradients, axis=-1)), axis=-1)) + gradients = tf.sqrt( + tf.reduce_sum(tf.square(tf.stack(gradients, axis=-1)), axis=-1) + ) else: gradients = tf.concat(gradients, axis=-1) @@ -1965,12 +2505,22 @@ class RandomDilationErosion(Layer): choice to either return the eroded label map or the mask (return_mask=True) """ - def __init__(self, min_factor, max_factor, max_factor_dilate=None, prob=1, operation='random', return_mask=False, - **kwargs): + def __init__( + self, + min_factor, + max_factor, + max_factor_dilate=None, + prob=1, + operation="random", + return_mask=False, + **kwargs + ): self.min_factor = min_factor self.max_factor = max_factor - self.max_factor_dilate = max_factor_dilate if max_factor_dilate is not None else self.max_factor + self.max_factor_dilate = ( + max_factor_dilate if max_factor_dilate is not None else self.max_factor + ) self.prob = prob self.operation = operation self.return_mask = return_mask @@ -1998,7 +2548,7 @@ def build(self, input_shape): self.n_channels = self.inshape[-1] # prepare convolution - self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims) + self.convnd = getattr(tf.nn, "conv%dd" % self.n_dims) self.built = True super(RandomDilationErosion, self).build(input_shape) @@ -2007,30 +2557,42 @@ def call(self, inputs, **kwargs): # sample probability of applying operation. If random negative is erosion and positive is dilation batchsize = tf.split(tf.shape(inputs), [1, -1])[0] - shape = tf.concat([batchsize, tf.convert_to_tensor([1], dtype='int32')], axis=0) - if self.operation == 'dilation': + shape = tf.concat([batchsize, tf.convert_to_tensor([1], dtype="int32")], axis=0) + if self.operation == "dilation": prob = tf.random.uniform(shape, 0, 1) - elif self.operation == 'erosion': + elif self.operation == "erosion": prob = tf.random.uniform(shape, -1, 0) - elif self.operation == 'random': + elif self.operation == "random": prob = tf.random.uniform(shape, -1, 1) else: - raise ValueError("operation should either be 'dilation' 'erosion' or 'random', had %s" % self.operation) + raise ValueError( + "operation should either be 'dilation' 'erosion' or 'random', had %s" + % self.operation + ) # build kernel if self.min_factor == self.max_factor: - dist_threshold = self.min_factor * tf.ones(shape, dtype='int32') + dist_threshold = self.min_factor * tf.ones(shape, dtype="int32") else: - if (self.max_factor == self.max_factor_dilate) | (self.operation != 'random'): - dist_threshold = tf.random.uniform(shape, minval=self.min_factor, maxval=self.max_factor, dtype='int32') + if (self.max_factor == self.max_factor_dilate) | ( + self.operation != "random" + ): + dist_threshold = tf.random.uniform( + shape, minval=self.min_factor, maxval=self.max_factor, dtype="int32" + ) else: - dist_threshold = tf.cast(tf.map_fn(self._sample_factor, [prob], dtype=tf.float32), dtype='int32') - kernel = l2i_et.unit_kernel(dist_threshold, self.n_dims, max_dist_threshold=self.max_factor) + dist_threshold = tf.cast( + tf.map_fn(self._sample_factor, [prob], dtype=tf.float32), + dtype="int32", + ) + kernel = l2i_et.unit_kernel( + dist_threshold, self.n_dims, max_dist_threshold=self.max_factor + ) # convolve input mask with kernel according to given probability - mask = tf.cast(tf.cast(inputs, dtype='bool'), dtype='float32') + mask = tf.cast(tf.cast(inputs, dtype="bool"), dtype="float32") mask = tf.map_fn(self._single_blur, [mask, kernel, prob], dtype=tf.float32) - mask = tf.cast(mask, 'bool') + mask = tf.cast(mask, "bool") if self.return_mask: return mask @@ -2038,22 +2600,61 @@ def call(self, inputs, **kwargs): return inputs * tf.cast(mask, dtype=inputs.dtype) def _sample_factor(self, inputs): - return tf.cast(K.switch(K.less(tf.squeeze(inputs[0]), 0), - tf.random.uniform((1,), self.min_factor, self.max_factor, dtype='int32'), - tf.random.uniform((1,), self.min_factor, self.max_factor_dilate, dtype='int32')), - dtype='float32') + return tf.cast( + K.switch( + K.less(tf.squeeze(inputs[0]), 0), + tf.random.uniform( + (1,), self.min_factor, self.max_factor, dtype="int32" + ), + tf.random.uniform( + (1,), self.min_factor, self.max_factor_dilate, dtype="int32" + ), + ), + dtype="float32", + ) def _single_blur(self, inputs): # dilate... - new_mask = K.switch(K.greater(tf.squeeze(inputs[2]), 1 - self.prob + 0.001), - tf.cast(tf.greater(tf.squeeze(self.convnd(tf.expand_dims(inputs[0], 0), inputs[1], - [1] * (self.n_dims + 2), padding='SAME'), axis=0), 0.01), dtype='float32'), - inputs[0]) + new_mask = K.switch( + K.greater(tf.squeeze(inputs[2]), 1 - self.prob + 0.001), + tf.cast( + tf.greater( + tf.squeeze( + self.convnd( + tf.expand_dims(inputs[0], 0), + inputs[1], + [1] * (self.n_dims + 2), + padding="SAME", + ), + axis=0, + ), + 0.01, + ), + dtype="float32", + ), + inputs[0], + ) # ...or erode - new_mask = K.switch(K.less(tf.squeeze(inputs[2]), - (1 - self.prob + 0.001)), - 1 - tf.cast(tf.greater(tf.squeeze(self.convnd(tf.expand_dims(1 - new_mask, 0), inputs[1], - [1] * (self.n_dims + 2), padding='SAME'), axis=0), 0.01), dtype='float32'), - new_mask) + new_mask = K.switch( + K.less(tf.squeeze(inputs[2]), -(1 - self.prob + 0.001)), + 1 + - tf.cast( + tf.greater( + tf.squeeze( + self.convnd( + tf.expand_dims(1 - new_mask, 0), + inputs[1], + [1] * (self.n_dims + 2), + padding="SAME", + ), + axis=0, + ), + 0.01, + ), + dtype="float32", + ), + new_mask, + ) return new_mask def compute_output_shape(self, input_shape): diff --git a/nobrainer/ext/lab2im/utils.py b/nobrainer/ext/lab2im/utils.py index 86c64675..c0f08e95 100644 --- a/nobrainer/ext/lab2im/utils.py +++ b/nobrainer/ext/lab2im/utils.py @@ -55,20 +55,19 @@ License. """ - -import os +from datetime import timedelta import glob import math -import time +import os import pickle -import numpy as np -import nibabel as nib -import tensorflow as tf -import keras.layers as KL +import time + import keras.backend as K -from datetime import timedelta +import keras.layers as KL +import nibabel as nib +import numpy as np from scipy.ndimage.morphology import distance_transform_edt - +import tensorflow as tf # ---------------------------------------------- loading/saving functions ---------------------------------------------- @@ -86,9 +85,11 @@ def load_volume(path_volume, im_only=True, squeeze=True, dtype=None, aff_ref=Non The returned affine matrix is also given in this new space. Must be a numpy array of dimension 4x4. :return: the volume, with corresponding affine matrix and header if im_only is False. """ - assert path_volume.endswith(('.nii', '.nii.gz', '.mgz', '.npz')), 'Unknown data file: %s' % path_volume + assert path_volume.endswith((".nii", ".nii.gz", ".mgz", ".npz")), ( + "Unknown data file: %s" % path_volume + ) - if path_volume.endswith(('.nii', '.nii.gz', '.mgz')): + if path_volume.endswith((".nii", ".nii.gz", ".mgz")): x = nib.load(path_volume) if squeeze: volume = np.squeeze(x.get_fdata()) @@ -97,21 +98,26 @@ def load_volume(path_volume, im_only=True, squeeze=True, dtype=None, aff_ref=Non aff = x.affine header = x.header else: # npz - volume = np.load(path_volume)['vol_data'] + volume = np.load(path_volume)["vol_data"] if squeeze: volume = np.squeeze(volume) aff = np.eye(4) header = nib.Nifti1Header() if dtype is not None: - if 'int' in dtype: + if "int" in dtype: volume = np.round(volume) volume = volume.astype(dtype=dtype) # align image to reference affine matrix if aff_ref is not None: - from ext.lab2im import edit_volumes # the import is done here to avoid import loops + from ext.lab2im import ( # the import is done here to avoid import loops + edit_volumes, + ) + n_dims, _ = get_dims(list(volume.shape), max_channels=10) - volume, aff = edit_volumes.align_volume_to_ref(volume, aff, aff_ref=aff_ref, return_aff=True, n_dims=n_dims) + volume, aff = edit_volumes.align_volume_to_ref( + volume, aff, aff_ref=aff_ref, return_aff=True, n_dims=n_dims + ) if im_only: return volume @@ -134,18 +140,20 @@ def save_volume(volume, aff, header, path, res=None, dtype=None, n_dims=3): """ mkdir(os.path.dirname(path)) - if '.npz' in path: + if ".npz" in path: np.savez_compressed(path, vol_data=volume) else: if header is None: header = nib.Nifti1Header() if isinstance(aff, str): - if aff == 'FS': - aff = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]) + if aff == "FS": + aff = np.array( + [[-1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]] + ) elif aff is None: aff = np.eye(4) if dtype is not None: - if 'int' in dtype: + if "int" in dtype: volume = np.round(volume) volume = volume.astype(dtype=dtype) nifty = nib.Nifti1Image(volume, aff, header) @@ -180,16 +188,19 @@ def get_volume_info(path_volume, return_volume=False, aff_ref=None, max_channels im_shape = im_shape[:n_dims] # get labels res - if '.nii' in path_volume: - data_res = np.array(header['pixdim'][1:n_dims + 1]) - elif '.mgz' in path_volume: - data_res = np.array(header['delta']) # mgz image + if ".nii" in path_volume: + data_res = np.array(header["pixdim"][1 : n_dims + 1]) + elif ".mgz" in path_volume: + data_res = np.array(header["delta"]) # mgz image else: data_res = np.array([1.0] * n_dims) # align to given affine matrix if aff_ref is not None: - from ext.lab2im import edit_volumes # the import is done here to avoid import loops + from ext.lab2im import ( # the import is done here to avoid import loops + edit_volumes, + ) + ras_axes = edit_volumes.get_ras_axes(aff, n_dims=n_dims) ras_axes_ref = edit_volumes.get_ras_axes(aff_ref, n_dims=n_dims) im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, n_dims=n_dims) @@ -206,7 +217,9 @@ def get_volume_info(path_volume, return_volume=False, aff_ref=None, max_channels return im_shape, aff, n_dims, n_channels, header, data_res -def get_list_labels(label_list=None, labels_dir=None, save_label_list=None, FS_sort=False): +def get_list_labels( + label_list=None, labels_dir=None, save_label_list=None, FS_sort=False +): """This function reads or computes a list of all label values used in a set of label maps. It can also sort all labels according to FreeSurfer lut. :param label_list: (optional) already computed label_list. Can be a sequence, a 1d numpy array, or the path to @@ -224,32 +237,104 @@ def get_list_labels(label_list=None, labels_dir=None, save_label_list=None, FS_s # load label list if previously computed if label_list is not None: - label_list = np.array(reformat_to_list(label_list, load_as_numpy=True, dtype='int')) + label_list = np.array( + reformat_to_list(label_list, load_as_numpy=True, dtype="int") + ) # compute label list from all label files elif labels_dir is not None: - print('Compiling list of unique labels') + print("Compiling list of unique labels") # go through all labels files and compute unique list of labels labels_paths = list_images_in_folder(labels_dir) label_list = np.empty(0) - loop_info = LoopInfo(len(labels_paths), 10, 'processing', print_time=True) + loop_info = LoopInfo(len(labels_paths), 10, "processing", print_time=True) for lab_idx, path in enumerate(labels_paths): loop_info.update(lab_idx) - y = load_volume(path, dtype='int32') + y = load_volume(path, dtype="int32") y_unique = np.unique(y) - label_list = np.unique(np.concatenate((label_list, y_unique))).astype('int') + label_list = np.unique(np.concatenate((label_list, y_unique))).astype("int") else: - raise Exception('either label_list, path_label_list or labels_dir should be provided') + raise Exception( + "either label_list, path_label_list or labels_dir should be provided" + ) # sort labels in neutral/left/right according to FS labels n_neutral_labels = 0 if FS_sort: - neutral_FS_labels = [0, 14, 15, 16, 21, 22, 23, 24, 72, 77, 80, 85, 100, 101, 102, 103, 104, 105, 106, 107, 108, - 109, 165, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, - 251, 252, 253, 254, 255, 258, 259, 260, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, - 502, 506, 507, 508, 509, 511, 512, 514, 515, 516, 517, 530, - 531, 532, 533, 534, 535, 536, 537] + neutral_FS_labels = [ + 0, + 14, + 15, + 16, + 21, + 22, + 23, + 24, + 72, + 77, + 80, + 85, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 165, + 200, + 201, + 202, + 203, + 204, + 205, + 206, + 207, + 208, + 209, + 210, + 251, + 252, + 253, + 254, + 255, + 258, + 259, + 260, + 331, + 332, + 333, + 334, + 335, + 336, + 337, + 338, + 339, + 340, + 502, + 506, + 507, + 508, + 509, + 511, + 512, + 514, + 515, + 516, + 517, + 530, + 531, + 532, + 533, + 534, + 535, + 536, + 537, + ] neutral = list() left = list() right = list() @@ -257,19 +342,36 @@ def get_list_labels(label_list=None, labels_dir=None, save_label_list=None, FS_s if la in neutral_FS_labels: if la not in neutral: neutral.append(la) - elif (0 < la < 14) | (16 < la < 21) | (24 < la < 40) | (135 < la < 139) | (1000 <= la <= 1035) | \ - (la == 865) | (20100 < la < 20110): + elif ( + (0 < la < 14) + | (16 < la < 21) + | (24 < la < 40) + | (135 < la < 139) + | (1000 <= la <= 1035) + | (la == 865) + | (20100 < la < 20110) + ): if la not in left: left.append(la) - elif (39 < la < 72) | (162 < la < 165) | (2000 <= la <= 2035) | (20000 < la < 20010) | (la == 139) | \ - (la == 866): + elif ( + (39 < la < 72) + | (162 < la < 165) + | (2000 <= la <= 2035) + | (20000 < la < 20010) + | (la == 139) + | (la == 866) + ): if la not in right: right.append(la) else: - raise Exception('label {} not in our current FS classification, ' - 'please update get_list_labels in utils.py'.format(la)) + raise Exception( + "label {} not in our current FS classification, " + "please update get_list_labels in utils.py".format(la) + ) label_list = np.concatenate([sorted(neutral), sorted(left), sorted(right)]) - if ((len(left) > 0) & (len(right) > 0)) | ((len(left) == 0) & (len(right) == 0)): + if ((len(left) > 0) & (len(right) > 0)) | ( + (len(left) == 0) & (len(right) == 0) + ): n_neutral_labels = len(neutral) else: n_neutral_labels = len(label_list) @@ -288,29 +390,29 @@ def load_array_if_path(var, load_as_numpy=True): """If var is a string and load_as_numpy is True, this function loads the array writen at the path indicated by var. Otherwise it simply returns var as it is.""" if (isinstance(var, str)) & load_as_numpy: - assert os.path.isfile(var), 'No such path: %s' % var + assert os.path.isfile(var), "No such path: %s" % var var = np.load(var) return var def write_pickle(filepath, obj): - """ write a python object with a pickle at a given path""" - with open(filepath, 'wb') as file: + """write a python object with a pickle at a given path""" + with open(filepath, "wb") as file: pickler = pickle.Pickler(file) pickler.dump(obj) def read_pickle(filepath): - """ read a python object with a pickle""" - with open(filepath, 'rb') as file: + """read a python object with a pickle""" + with open(filepath, "rb") as file: unpickler = pickle.Unpickler(file) return unpickler.load() -def write_model_summary(model, filepath='./model_summary.txt', line_length=150): +def write_model_summary(model, filepath="./model_summary.txt", line_length=150): """Write the summary of a keras model at a given path, with a given length for each line""" - with open(filepath, 'w') as fh: - model.summary(print_fn=lambda x: fh.write(x + '\n'), line_length=line_length) + with open(filepath, "w") as fh: + model.summary(print_fn=lambda x: fh.write(x + "\n"), line_length=line_length) # ----------------------------------------------- reformatting functions ----------------------------------------------- @@ -332,7 +434,9 @@ def reformat_to_list(var, length=None, load_as_numpy=False, dtype=None): if var is None: return None var = load_array_if_path(var, load_as_numpy=load_as_numpy) - if isinstance(var, (int, float, np.int, np.int32, np.int64, np.float, np.float32, np.float64)): + if isinstance( + var, (int, float, np.int, np.int32, np.int64, np.float, np.float32, np.float64) + ): var = [var] elif isinstance(var, tuple): var = list(var) @@ -350,23 +454,29 @@ def reformat_to_list(var, length=None, load_as_numpy=False, dtype=None): if len(var) == 1: var = var * length elif len(var) != length: - raise ValueError('if var is a list/tuple/numpy array, it should be of length 1 or {0}, ' - 'had {1}'.format(length, var)) + raise ValueError( + "if var is a list/tuple/numpy array, it should be of length 1 or {0}, " + "had {1}".format(length, var) + ) else: - raise TypeError('var should be an int, float, tuple, list, numpy array, or path to numpy array') + raise TypeError( + "var should be an int, float, tuple, list, numpy array, or path to numpy array" + ) # convert items type if dtype is not None: - if dtype == 'int': + if dtype == "int": var = [int(v) for v in var] - elif dtype == 'float': + elif dtype == "float": var = [float(v) for v in var] - elif dtype == 'bool': + elif dtype == "bool": var = [bool(v) for v in var] - elif dtype == 'str': + elif dtype == "str": var = [str(v) for v in var] else: - raise ValueError("dtype should be 'str', 'float', 'int', or 'bool'; had {}".format(dtype)) + raise ValueError( + "dtype should be 'str', 'float', 'int', or 'bool'; had {}".format(dtype) + ) return var @@ -391,9 +501,13 @@ def reformat_to_n_channels_array(var, n_dims=3, n_channels=1): if np.squeeze(var).shape == (n_dims,): var = np.tile(var.reshape((1, n_dims)), (n_channels, 1)) elif var.shape != (n_channels, n_dims): - raise ValueError('if array, var should be {0} or {1}'.format((1, n_dims), (n_channels, n_dims))) + raise ValueError( + "if array, var should be {0} or {1}".format( + (1, n_dims), (n_channels, n_dims) + ) + ) else: - raise TypeError('var should be int, float, list, tuple or ndarray') + raise TypeError("var should be int, float, list, tuple or ndarray") return np.round(var, 3) @@ -403,24 +517,32 @@ def reformat_to_n_channels_array(var, n_dims=3, n_channels=1): def list_images_in_folder(path_dir, include_single_image=True, check_if_empty=True): """List all files with extension nii, nii.gz, mgz, or npz within a folder.""" basename = os.path.basename(path_dir) - if include_single_image & \ - (('.nii.gz' in basename) | ('.nii' in basename) | ('.mgz' in basename) | ('.npz' in basename)): - assert os.path.isfile(path_dir), 'file %s does not exist' % path_dir + if include_single_image & ( + (".nii.gz" in basename) + | (".nii" in basename) + | (".mgz" in basename) + | (".npz" in basename) + ): + assert os.path.isfile(path_dir), "file %s does not exist" % path_dir list_images = [path_dir] else: if os.path.isdir(path_dir): - list_images = sorted(glob.glob(os.path.join(path_dir, '*nii.gz')) + - glob.glob(os.path.join(path_dir, '*nii')) + - glob.glob(os.path.join(path_dir, '*.mgz')) + - glob.glob(os.path.join(path_dir, '*.npz'))) + list_images = sorted( + glob.glob(os.path.join(path_dir, "*nii.gz")) + + glob.glob(os.path.join(path_dir, "*nii")) + + glob.glob(os.path.join(path_dir, "*.mgz")) + + glob.glob(os.path.join(path_dir, "*.npz")) + ) else: - raise Exception('Folder does not exist: %s' % path_dir) + raise Exception("Folder does not exist: %s" % path_dir) if check_if_empty: - assert len(list_images) > 0, 'no .nii, .nii.gz, .mgz or .npz image could be found in %s' % path_dir + assert len(list_images) > 0, ( + "no .nii, .nii.gz, .mgz or .npz image could be found in %s" % path_dir + ) return list_images -def list_files(path_dir, whole_path=True, expr=None, cond_type='or'): +def list_files(path_dir, whole_path=True, expr=None, cond_type="or"): """This function returns a list of files contained in a folder, with possible regexp. :param path_dir: path of a folder :param whole_path: (optional) whether to return whole path or just the filenames. @@ -430,31 +552,46 @@ def list_files(path_dir, whole_path=True, expr=None, cond_type='or'): :return: a list of files """ assert isinstance(whole_path, bool), "whole_path should be bool" - assert cond_type in ['or', 'and'], "cond_type should be either 'or', or 'and'" + assert cond_type in ["or", "and"], "cond_type should be either 'or', or 'and'" if whole_path: - files_list = sorted([os.path.join(path_dir, f) for f in os.listdir(path_dir) - if os.path.isfile(os.path.join(path_dir, f))]) + files_list = sorted( + [ + os.path.join(path_dir, f) + for f in os.listdir(path_dir) + if os.path.isfile(os.path.join(path_dir, f)) + ] + ) else: - files_list = sorted([f for f in os.listdir(path_dir) if os.path.isfile(os.path.join(path_dir, f))]) + files_list = sorted( + [ + f + for f in os.listdir(path_dir) + if os.path.isfile(os.path.join(path_dir, f)) + ] + ) if expr is not None: # assumed to be either str or list of str if isinstance(expr, str): expr = [expr] elif not isinstance(expr, (list, tuple)): - raise Exception("if specified, 'expr' should be a string or list of strings.") + raise Exception( + "if specified, 'expr' should be a string or list of strings." + ) matched_list_files = list() for match in expr: - tmp_matched_files_list = sorted([f for f in files_list if match in os.path.basename(f)]) - if cond_type == 'or': + tmp_matched_files_list = sorted( + [f for f in files_list if match in os.path.basename(f)] + ) + if cond_type == "or": files_list = [f for f in files_list if f not in tmp_matched_files_list] matched_list_files += tmp_matched_files_list - elif cond_type == 'and': + elif cond_type == "and": files_list = tmp_matched_files_list matched_list_files = tmp_matched_files_list files_list = sorted(matched_list_files) return files_list -def list_subfolders(path_dir, whole_path=True, expr=None, cond_type='or'): +def list_subfolders(path_dir, whole_path=True, expr=None, cond_type="or"): """This function returns a list of subfolders contained in a folder, with possible regexp. :param path_dir: path of a folder :param whole_path: (optional) whether to return whole path or just the subfolder names. @@ -464,24 +601,41 @@ def list_subfolders(path_dir, whole_path=True, expr=None, cond_type='or'): :return: a list of subfolders """ assert isinstance(whole_path, bool), "whole_path should be bool" - assert cond_type in ['or', 'and'], "cond_type should be either 'or', or 'and'" + assert cond_type in ["or", "and"], "cond_type should be either 'or', or 'and'" if whole_path: - subdirs_list = sorted([os.path.join(path_dir, f) for f in os.listdir(path_dir) - if os.path.isdir(os.path.join(path_dir, f))]) + subdirs_list = sorted( + [ + os.path.join(path_dir, f) + for f in os.listdir(path_dir) + if os.path.isdir(os.path.join(path_dir, f)) + ] + ) else: - subdirs_list = sorted([f for f in os.listdir(path_dir) if os.path.isdir(os.path.join(path_dir, f))]) + subdirs_list = sorted( + [ + f + for f in os.listdir(path_dir) + if os.path.isdir(os.path.join(path_dir, f)) + ] + ) if expr is not None: # assumed to be either str or list of str if isinstance(expr, str): expr = [expr] elif not isinstance(expr, (list, tuple)): - raise Exception("if specified, 'expr' should be a string or list of strings.") + raise Exception( + "if specified, 'expr' should be a string or list of strings." + ) matched_list_subdirs = list() for match in expr: - tmp_matched_list_subdirs = sorted([f for f in subdirs_list if match in os.path.basename(f)]) - if cond_type == 'or': - subdirs_list = [f for f in subdirs_list if f not in tmp_matched_list_subdirs] + tmp_matched_list_subdirs = sorted( + [f for f in subdirs_list if match in os.path.basename(f)] + ) + if cond_type == "or": + subdirs_list = [ + f for f in subdirs_list if f not in tmp_matched_list_subdirs + ] matched_list_subdirs += tmp_matched_list_subdirs - elif cond_type == 'and': + elif cond_type == "and": subdirs_list = tmp_matched_list_subdirs matched_list_subdirs = tmp_matched_list_subdirs subdirs_list = sorted(matched_list_subdirs) @@ -490,53 +644,58 @@ def list_subfolders(path_dir, whole_path=True, expr=None, cond_type='or'): def get_image_extension(path): name = os.path.basename(path) - if name[-7:] == '.nii.gz': - return 'nii.gz' - elif name[-4:] == '.mgz': - return 'mgz' - elif name[-4:] == '.nii': - return 'nii' - elif name[-4:] == '.npz': - return 'npz' + if name[-7:] == ".nii.gz": + return "nii.gz" + elif name[-4:] == ".mgz": + return "mgz" + elif name[-4:] == ".nii": + return "nii" + elif name[-4:] == ".npz": + return "npz" def strip_extension(path): """Strip classical image extensions (.nii.gz, .nii, .mgz, .npz) from a filename.""" - return path.replace('.nii.gz', '').replace('.nii', '').replace('.mgz', '').replace('.npz', '') + return ( + path.replace(".nii.gz", "") + .replace(".nii", "") + .replace(".mgz", "") + .replace(".npz", "") + ) def strip_suffix(path): """Strip classical image suffix from a filename.""" - path = path.replace('_aseg', '') - path = path.replace('aseg', '') - path = path.replace('.aseg', '') - path = path.replace('_aseg_1', '') - path = path.replace('_aseg_2', '') - path = path.replace('aseg_1_', '') - path = path.replace('aseg_2_', '') - path = path.replace('_orig', '') - path = path.replace('orig', '') - path = path.replace('.orig', '') - path = path.replace('_norm', '') - path = path.replace('norm', '') - path = path.replace('.norm', '') - path = path.replace('_talairach', '') - path = path.replace('GSP_FS_4p5', 'GSP') - path = path.replace('.nii_crispSegmentation', '') - path = path.replace('_crispSegmentation', '') - path = path.replace('_seg', '') - path = path.replace('.seg', '') - path = path.replace('seg', '') - path = path.replace('_seg_1', '') - path = path.replace('_seg_2', '') - path = path.replace('seg_1_', '') - path = path.replace('seg_2_', '') + path = path.replace("_aseg", "") + path = path.replace("aseg", "") + path = path.replace(".aseg", "") + path = path.replace("_aseg_1", "") + path = path.replace("_aseg_2", "") + path = path.replace("aseg_1_", "") + path = path.replace("aseg_2_", "") + path = path.replace("_orig", "") + path = path.replace("orig", "") + path = path.replace(".orig", "") + path = path.replace("_norm", "") + path = path.replace("norm", "") + path = path.replace(".norm", "") + path = path.replace("_talairach", "") + path = path.replace("GSP_FS_4p5", "GSP") + path = path.replace(".nii_crispSegmentation", "") + path = path.replace("_crispSegmentation", "") + path = path.replace("_seg", "") + path = path.replace(".seg", "") + path = path.replace("seg", "") + path = path.replace("_seg_1", "") + path = path.replace("_seg_2", "") + path = path.replace("seg_1_", "") + path = path.replace("seg_2_", "") return path def mkdir(path_dir): """Recursively creates the current dir as well as its parent folders if they do not already exist.""" - if path_dir[-1] == '/': + if path_dir[-1] == "/": path_dir = path_dir[:-1] if not os.path.isdir(path_dir): list_dir_to_create = [path_dir] @@ -549,7 +708,7 @@ def mkdir(path_dir): def mkcmd(*args): """Creates terminal command with provided inputs. Example: mkcmd('mv', 'source', 'dest') will give 'mv source dest'.""" - return ' '.join([str(arg) for arg in args]) + return " ".join([str(arg) for arg in args]) # ---------------------------------------------- shape-related functions ----------------------------------------------- @@ -591,7 +750,8 @@ def get_resample_shape(patch_shape, factor, n_channels=None): def add_axis(x, axis=0): """Add axis to a numpy array. :param x: input array - :param axis: index of the new axis to add. Can also be a list of indices to add several axes at the same time.""" + :param axis: index of the new axis to add. Can also be a list of indices to add several axes at the same time. + """ axis = reformat_to_list(axis) for ax in axis: x = np.expand_dims(x, axis=ax) @@ -606,7 +766,9 @@ def get_padding_margin(cropping, loss_cropping): n_dims = max(len(cropping), len(loss_cropping)) cropping = reformat_to_list(cropping, length=n_dims) loss_cropping = reformat_to_list(loss_cropping, length=n_dims) - padding_margin = [int((cropping[i] - loss_cropping[i]) / 2) for i in range(n_dims)] + padding_margin = [ + int((cropping[i] - loss_cropping[i]) / 2) for i in range(n_dims) + ] if len(padding_margin) == 1: padding_margin = padding_margin[0] else: @@ -617,7 +779,9 @@ def get_padding_margin(cropping, loss_cropping): # -------------------------------------------- build affine matrices/tensors ------------------------------------------- -def create_affine_transformation_matrix(n_dims, scaling=None, rotation=None, shearing=None, translation=None): +def create_affine_transformation_matrix( + n_dims, scaling=None, rotation=None, shearing=None, translation=None +): """Create a 4x4 affine transformation matrix from specified values :param n_dims: integer, can either be 2 or 3. :param scaling: list of 3 scaling values @@ -635,14 +799,16 @@ def create_affine_transformation_matrix(n_dims, scaling=None, rotation=None, she T_scaling[np.arange(n_dims + 1), np.arange(n_dims + 1)] = np.append(scaling, 1) if shearing is not None: - shearing_index = np.ones((n_dims + 1, n_dims + 1), dtype='bool') - shearing_index[np.eye(n_dims + 1, dtype='bool')] = False + shearing_index = np.ones((n_dims + 1, n_dims + 1), dtype="bool") + shearing_index[np.eye(n_dims + 1, dtype="bool")] = False shearing_index[-1, :] = np.zeros((n_dims + 1)) shearing_index[:, -1] = np.zeros((n_dims + 1)) T_shearing[shearing_index] = shearing if translation is not None: - T_translation[np.arange(n_dims), n_dims * np.ones(n_dims, dtype='int')] = translation + T_translation[np.arange(n_dims), n_dims * np.ones(n_dims, dtype="int")] = ( + translation + ) if n_dims == 2: if rotation is None: @@ -650,8 +816,12 @@ def create_affine_transformation_matrix(n_dims, scaling=None, rotation=None, she else: rotation = np.asarray(rotation) * (math.pi / 180) T_rot = np.eye(n_dims + 1) - T_rot[np.array([0, 1, 0, 1]), np.array([0, 0, 1, 1])] = [np.cos(rotation[0]), np.sin(rotation[0]), - np.sin(rotation[0]) * -1, np.cos(rotation[0])] + T_rot[np.array([0, 1, 0, 1]), np.array([0, 0, 1, 1])] = [ + np.cos(rotation[0]), + np.sin(rotation[0]), + np.sin(rotation[0]) * -1, + np.cos(rotation[0]), + ] return T_translation @ T_rot @ T_shearing @ T_scaling else: @@ -661,92 +831,138 @@ def create_affine_transformation_matrix(n_dims, scaling=None, rotation=None, she else: rotation = np.asarray(rotation) * (math.pi / 180) T_rot1 = np.eye(n_dims + 1) - T_rot1[np.array([1, 2, 1, 2]), np.array([1, 1, 2, 2])] = [np.cos(rotation[0]), np.sin(rotation[0]), - np.sin(rotation[0]) * -1, np.cos(rotation[0])] + T_rot1[np.array([1, 2, 1, 2]), np.array([1, 1, 2, 2])] = [ + np.cos(rotation[0]), + np.sin(rotation[0]), + np.sin(rotation[0]) * -1, + np.cos(rotation[0]), + ] T_rot2 = np.eye(n_dims + 1) - T_rot2[np.array([0, 2, 0, 2]), np.array([0, 0, 2, 2])] = [np.cos(rotation[1]), np.sin(rotation[1]) * -1, - np.sin(rotation[1]), np.cos(rotation[1])] + T_rot2[np.array([0, 2, 0, 2]), np.array([0, 0, 2, 2])] = [ + np.cos(rotation[1]), + np.sin(rotation[1]) * -1, + np.sin(rotation[1]), + np.cos(rotation[1]), + ] T_rot3 = np.eye(n_dims + 1) - T_rot3[np.array([0, 1, 0, 1]), np.array([0, 0, 1, 1])] = [np.cos(rotation[2]), np.sin(rotation[2]), - np.sin(rotation[2]) * -1, np.cos(rotation[2])] + T_rot3[np.array([0, 1, 0, 1]), np.array([0, 0, 1, 1])] = [ + np.cos(rotation[2]), + np.sin(rotation[2]), + np.sin(rotation[2]) * -1, + np.cos(rotation[2]), + ] return T_translation @ T_rot3 @ T_rot2 @ T_rot1 @ T_shearing @ T_scaling -def sample_affine_transform(batchsize, - n_dims, - rotation_bounds=False, - scaling_bounds=False, - shearing_bounds=False, - translation_bounds=False, - enable_90_rotations=False): +def sample_affine_transform( + batchsize, + n_dims, + rotation_bounds=False, + scaling_bounds=False, + shearing_bounds=False, + translation_bounds=False, + enable_90_rotations=False, +): """build batchsize x 4 x 4 tensor representing an affine transformation in homogeneous coordinates. If return_inv is True, also returns the inverse of the created affine matrix.""" if (rotation_bounds is not False) | (enable_90_rotations is not False): if n_dims == 2: if rotation_bounds is not False: - rotation = draw_value_from_distribution(rotation_bounds, - size=1, - default_range=15.0, - return_as_tensor=True, - batchsize=batchsize) + rotation = draw_value_from_distribution( + rotation_bounds, + size=1, + default_range=15.0, + return_as_tensor=True, + batchsize=batchsize, + ) else: - rotation = tf.zeros(tf.concat([batchsize, tf.ones(1, dtype='int32')], axis=0)) + rotation = tf.zeros( + tf.concat([batchsize, tf.ones(1, dtype="int32")], axis=0) + ) else: # n_dims = 3 if rotation_bounds is not False: - rotation = draw_value_from_distribution(rotation_bounds, - size=n_dims, - default_range=15.0, - return_as_tensor=True, - batchsize=batchsize) + rotation = draw_value_from_distribution( + rotation_bounds, + size=n_dims, + default_range=15.0, + return_as_tensor=True, + batchsize=batchsize, + ) else: - rotation = tf.zeros(tf.concat([batchsize, 3 * tf.ones(1, dtype='int32')], axis=0)) + rotation = tf.zeros( + tf.concat([batchsize, 3 * tf.ones(1, dtype="int32")], axis=0) + ) if enable_90_rotations: - rotation = tf.cast(tf.random.uniform(tf.shape(rotation), maxval=4, dtype='int32') * 90, 'float32') \ - + rotation + rotation = ( + tf.cast( + tf.random.uniform(tf.shape(rotation), maxval=4, dtype="int32") * 90, + "float32", + ) + + rotation + ) T_rot = create_rotation_transform(rotation, n_dims) else: - T_rot = tf.tile(tf.expand_dims(tf.eye(n_dims), axis=0), - tf.concat([batchsize, tf.ones(2, dtype='int32')], axis=0)) + T_rot = tf.tile( + tf.expand_dims(tf.eye(n_dims), axis=0), + tf.concat([batchsize, tf.ones(2, dtype="int32")], axis=0), + ) if shearing_bounds is not False: - shearing = draw_value_from_distribution(shearing_bounds, - size=n_dims ** 2 - n_dims, - default_range=.01, - return_as_tensor=True, - batchsize=batchsize) + shearing = draw_value_from_distribution( + shearing_bounds, + size=n_dims**2 - n_dims, + default_range=0.01, + return_as_tensor=True, + batchsize=batchsize, + ) T_shearing = create_shearing_transform(shearing, n_dims) else: - T_shearing = tf.tile(tf.expand_dims(tf.eye(n_dims), axis=0), - tf.concat([batchsize, tf.ones(2, dtype='int32')], axis=0)) + T_shearing = tf.tile( + tf.expand_dims(tf.eye(n_dims), axis=0), + tf.concat([batchsize, tf.ones(2, dtype="int32")], axis=0), + ) if scaling_bounds is not False: - scaling = draw_value_from_distribution(scaling_bounds, - size=n_dims, - centre=1, - default_range=.15, - return_as_tensor=True, - batchsize=batchsize) + scaling = draw_value_from_distribution( + scaling_bounds, + size=n_dims, + centre=1, + default_range=0.15, + return_as_tensor=True, + batchsize=batchsize, + ) T_scaling = tf.linalg.diag(scaling) else: - T_scaling = tf.tile(tf.expand_dims(tf.eye(n_dims), axis=0), - tf.concat([batchsize, tf.ones(2, dtype='int32')], axis=0)) + T_scaling = tf.tile( + tf.expand_dims(tf.eye(n_dims), axis=0), + tf.concat([batchsize, tf.ones(2, dtype="int32")], axis=0), + ) T = tf.matmul(T_scaling, tf.matmul(T_shearing, T_rot)) if translation_bounds is not False: - translation = draw_value_from_distribution(translation_bounds, - size=n_dims, - default_range=5, - return_as_tensor=True, - batchsize=batchsize) + translation = draw_value_from_distribution( + translation_bounds, + size=n_dims, + default_range=5, + return_as_tensor=True, + batchsize=batchsize, + ) T = tf.concat([T, tf.expand_dims(translation, axis=-1)], axis=-1) else: - T = tf.concat([T, tf.zeros(tf.concat([tf.shape(T)[:2], tf.ones(1, dtype='int32')], 0))], axis=-1) + T = tf.concat( + [T, tf.zeros(tf.concat([tf.shape(T)[:2], tf.ones(1, dtype="int32")], 0))], + axis=-1, + ) # build rigid transform - T_last_row = tf.expand_dims(tf.concat([tf.zeros((1, n_dims)), tf.ones((1, 1))], axis=1), 0) - T_last_row = tf.tile(T_last_row, tf.concat([batchsize, tf.ones(2, dtype='int32')], axis=0)) + T_last_row = tf.expand_dims( + tf.concat([tf.zeros((1, n_dims)), tf.ones((1, 1))], axis=1), 0 + ) + T_last_row = tf.tile( + T_last_row, tf.concat([batchsize, tf.ones(2, dtype="int32")], axis=0) + ) T = tf.concat([T, T_last_row], axis=1) return T @@ -758,38 +974,93 @@ def create_rotation_transform(rotation, n_dims): if n_dims == 3: shape = tf.shape(tf.expand_dims(rotation[..., 0], -1)) - Rx_row0 = tf.expand_dims(tf.tile(tf.expand_dims(tf.convert_to_tensor([1., 0., 0.]), 0), shape), axis=1) - Rx_row1 = tf.stack([tf.zeros(shape), tf.expand_dims(tf.cos(rotation[..., 0]), -1), - tf.expand_dims(-tf.sin(rotation[..., 0]), -1)], axis=-1) - Rx_row2 = tf.stack([tf.zeros(shape), tf.expand_dims(tf.sin(rotation[..., 0]), -1), - tf.expand_dims(tf.cos(rotation[..., 0]), -1)], axis=-1) + Rx_row0 = tf.expand_dims( + tf.tile(tf.expand_dims(tf.convert_to_tensor([1.0, 0.0, 0.0]), 0), shape), + axis=1, + ) + Rx_row1 = tf.stack( + [ + tf.zeros(shape), + tf.expand_dims(tf.cos(rotation[..., 0]), -1), + tf.expand_dims(-tf.sin(rotation[..., 0]), -1), + ], + axis=-1, + ) + Rx_row2 = tf.stack( + [ + tf.zeros(shape), + tf.expand_dims(tf.sin(rotation[..., 0]), -1), + tf.expand_dims(tf.cos(rotation[..., 0]), -1), + ], + axis=-1, + ) Rx = tf.concat([Rx_row0, Rx_row1, Rx_row2], axis=1) - Ry_row0 = tf.stack([tf.expand_dims(tf.cos(rotation[..., 1]), -1), tf.zeros(shape), - tf.expand_dims(tf.sin(rotation[..., 1]), -1)], axis=-1) - Ry_row1 = tf.expand_dims(tf.tile(tf.expand_dims(tf.convert_to_tensor([0., 1., 0.]), 0), shape), axis=1) - Ry_row2 = tf.stack([tf.expand_dims(-tf.sin(rotation[..., 1]), -1), tf.zeros(shape), - tf.expand_dims(tf.cos(rotation[..., 1]), -1)], axis=-1) + Ry_row0 = tf.stack( + [ + tf.expand_dims(tf.cos(rotation[..., 1]), -1), + tf.zeros(shape), + tf.expand_dims(tf.sin(rotation[..., 1]), -1), + ], + axis=-1, + ) + Ry_row1 = tf.expand_dims( + tf.tile(tf.expand_dims(tf.convert_to_tensor([0.0, 1.0, 0.0]), 0), shape), + axis=1, + ) + Ry_row2 = tf.stack( + [ + tf.expand_dims(-tf.sin(rotation[..., 1]), -1), + tf.zeros(shape), + tf.expand_dims(tf.cos(rotation[..., 1]), -1), + ], + axis=-1, + ) Ry = tf.concat([Ry_row0, Ry_row1, Ry_row2], axis=1) - Rz_row0 = tf.stack([tf.expand_dims(tf.cos(rotation[..., 2]), -1), - tf.expand_dims(-tf.sin(rotation[..., 2]), -1), tf.zeros(shape)], axis=-1) - Rz_row1 = tf.stack([tf.expand_dims(tf.sin(rotation[..., 2]), -1), - tf.expand_dims(tf.cos(rotation[..., 2]), -1), tf.zeros(shape)], axis=-1) - Rz_row2 = tf.expand_dims(tf.tile(tf.expand_dims(tf.convert_to_tensor([0., 0., 1.]), 0), shape), axis=1) + Rz_row0 = tf.stack( + [ + tf.expand_dims(tf.cos(rotation[..., 2]), -1), + tf.expand_dims(-tf.sin(rotation[..., 2]), -1), + tf.zeros(shape), + ], + axis=-1, + ) + Rz_row1 = tf.stack( + [ + tf.expand_dims(tf.sin(rotation[..., 2]), -1), + tf.expand_dims(tf.cos(rotation[..., 2]), -1), + tf.zeros(shape), + ], + axis=-1, + ) + Rz_row2 = tf.expand_dims( + tf.tile(tf.expand_dims(tf.convert_to_tensor([0.0, 0.0, 1.0]), 0), shape), + axis=1, + ) Rz = tf.concat([Rz_row0, Rz_row1, Rz_row2], axis=1) T_rot = tf.matmul(tf.matmul(Rx, Ry), Rz) elif n_dims == 2: - R_row0 = tf.stack([tf.expand_dims(tf.cos(rotation[..., 0]), -1), - tf.expand_dims(tf.sin(rotation[..., 0]), -1)], axis=-1) - R_row1 = tf.stack([tf.expand_dims(-tf.sin(rotation[..., 0]), -1), - tf.expand_dims(tf.cos(rotation[..., 0]), -1)], axis=-1) + R_row0 = tf.stack( + [ + tf.expand_dims(tf.cos(rotation[..., 0]), -1), + tf.expand_dims(tf.sin(rotation[..., 0]), -1), + ], + axis=-1, + ) + R_row1 = tf.stack( + [ + tf.expand_dims(-tf.sin(rotation[..., 0]), -1), + tf.expand_dims(tf.cos(rotation[..., 0]), -1), + ], + axis=-1, + ) T_rot = tf.concat([R_row0, R_row1], axis=1) else: - raise Exception('only supports 2 or 3D.') + raise Exception("only supports 2 or 3D.") return T_rot @@ -798,20 +1069,42 @@ def create_shearing_transform(shearing, n_dims): """build shearing transform from 2d/3d shearing coefficients""" shape = tf.shape(tf.expand_dims(shearing[..., 0], -1)) if n_dims == 3: - shearing_row0 = tf.stack([tf.ones(shape), tf.expand_dims(shearing[..., 0], -1), - tf.expand_dims(shearing[..., 1], -1)], axis=-1) - shearing_row1 = tf.stack([tf.expand_dims(shearing[..., 2], -1), tf.ones(shape), - tf.expand_dims(shearing[..., 3], -1)], axis=-1) - shearing_row2 = tf.stack([tf.expand_dims(shearing[..., 4], -1), tf.expand_dims(shearing[..., 5], -1), - tf.ones(shape)], axis=-1) + shearing_row0 = tf.stack( + [ + tf.ones(shape), + tf.expand_dims(shearing[..., 0], -1), + tf.expand_dims(shearing[..., 1], -1), + ], + axis=-1, + ) + shearing_row1 = tf.stack( + [ + tf.expand_dims(shearing[..., 2], -1), + tf.ones(shape), + tf.expand_dims(shearing[..., 3], -1), + ], + axis=-1, + ) + shearing_row2 = tf.stack( + [ + tf.expand_dims(shearing[..., 4], -1), + tf.expand_dims(shearing[..., 5], -1), + tf.ones(shape), + ], + axis=-1, + ) T_shearing = tf.concat([shearing_row0, shearing_row1, shearing_row2], axis=1) elif n_dims == 2: - shearing_row0 = tf.stack([tf.ones(shape), tf.expand_dims(shearing[..., 0], -1)], axis=-1) - shearing_row1 = tf.stack([tf.expand_dims(shearing[..., 1], -1), tf.ones(shape)], axis=-1) + shearing_row0 = tf.stack( + [tf.ones(shape), tf.expand_dims(shearing[..., 0], -1)], axis=-1 + ) + shearing_row1 = tf.stack( + [tf.expand_dims(shearing[..., 1], -1), tf.ones(shape)], axis=-1 + ) T_shearing = tf.concat([shearing_row0, shearing_row1], axis=1) else: - raise Exception('only supports 2 or 3D.') + raise Exception("only supports 2 or 3D.") return T_shearing @@ -819,16 +1112,18 @@ def create_shearing_transform(shearing, n_dims): def infer(x): - """ Try to parse input to float. If it fails, tries boolean, and otherwise keep it as string """ + """Try to parse input to float. If it fails, tries boolean, and otherwise keep it as string""" try: x = float(x) except ValueError: - if x == 'False': + if x == "False": x = False - elif x == 'True': + elif x == "True": x = True elif not isinstance(x, str): - raise TypeError('input should be an int/float/boolean/str, had {}'.format(type(x))) + raise TypeError( + "input should be an int/float/boolean/str, had {}".format(type(x)) + ) return x @@ -840,7 +1135,7 @@ class LoopInfo: processing i/total remaining time: hh:mm:ss """ - def __init__(self, n_iterations, spacing=10, text='processing', print_time=False): + def __init__(self, n_iterations, spacing=10, text="processing", print_time=False): """ :param n_iterations: total number of iterations of the for loop. :param spacing: frequency at which the update info will be printed on screen. @@ -872,23 +1167,32 @@ def update(self, idx): # print text if idx == 0: - print(self.text + ' 1/{}'.format(self.n_iterations)) + print(self.text + " 1/{}".format(self.n_iterations)) elif idx % self.spacing == self.spacing - 1: - iteration = str(idx + 1) + '/' + str(self.n_iterations) + iteration = str(idx + 1) + "/" + str(self.n_iterations) if self.print_time: # estimate remaining time max_duration = np.max(self.iteration_durations) - average_duration = np.mean(self.iteration_durations[self.iteration_durations > .01 * max_duration]) + average_duration = np.mean( + self.iteration_durations[ + self.iteration_durations > 0.01 * max_duration + ] + ) remaining_time = int(average_duration * (self.n_iterations - idx)) # print total remaining time only if it is greater than 1s or if it was previously printed if (remaining_time > 1) | self.print_previous_time: eta = str(timedelta(seconds=remaining_time)) - print(self.text + ' {:<{x}} remaining time: {}'.format(iteration, eta, x=self.align)) + print( + self.text + + " {:<{x}} remaining time: {}".format( + iteration, eta, x=self.align + ) + ) self.print_previous_time = True else: - print(self.text + ' {}'.format(iteration)) + print(self.text + " {}".format(iteration)) else: - print(self.text + ' {}'.format(iteration)) + print(self.text + " {}".format(iteration)) def get_mapping_lut(source, dest=None): @@ -896,18 +1200,20 @@ def get_mapping_lut(source, dest=None): If the second list is not given, we assume it is equal to [0, ..., N-1].""" # initialise - source = np.array(reformat_to_list(source), dtype='int32') + source = np.array(reformat_to_list(source), dtype="int32") n_labels = source.shape[0] # build new label list if necessary if dest is None: - dest = np.arange(n_labels, dtype='int32') + dest = np.arange(n_labels, dtype="int32") else: - assert len(source) == len(dest), 'label_list and new_label_list should have the same length' - dest = np.array(reformat_to_list(dest, dtype='int')) + assert len(source) == len( + dest + ), "label_list and new_label_list should have the same length" + dest = np.array(reformat_to_list(dest, dtype="int")) # build look-up table - lut = np.zeros(np.max(source) + 1, dtype='int32') + lut = np.zeros(np.max(source) + 1, dtype="int32") for source, dest in zip(source, dest): lut[source] = dest @@ -925,7 +1231,7 @@ def build_training_generator(gen, batchsize): yield inputs, target -def find_closest_number_divisible_by_m(n, m, answer_type='lower'): +def find_closest_number_divisible_by_m(n, m, answer_type="lower"): """Return the closest integer to n that is divisible by m. answer_type can either be 'closer', 'lower' (only returns values lower than n), or 'higher' (only returns values higher than m).""" if n % m == 0: @@ -934,14 +1240,16 @@ def find_closest_number_divisible_by_m(n, m, answer_type='lower'): q = int(n / m) lower = q * m higher = (q + 1) * m - if answer_type == 'lower': + if answer_type == "lower": return lower - elif answer_type == 'higher': + elif answer_type == "higher": return higher - elif answer_type == 'closer': + elif answer_type == "closer": return lower if (n - lower) < (higher - n) else higher else: - raise Exception('answer_type should be lower, higher, or closer, had : %s' % answer_type) + raise Exception( + "answer_type should be lower, higher, or closer, had : %s" % answer_type + ) def build_binary_structure(connectivity, n_dims, shape=None): @@ -958,14 +1266,16 @@ def build_binary_structure(connectivity, n_dims, shape=None): return struct -def draw_value_from_distribution(hyperparameter, - size=1, - distribution='uniform', - centre=0., - default_range=10.0, - positive_only=False, - return_as_tensor=False, - batchsize=None): +def draw_value_from_distribution( + hyperparameter, + size=1, + distribution="uniform", + centre=0.0, + default_range=10.0, + positive_only=False, + return_as_tensor=False, + batchsize=None, +): """Sample values from a uniform, or normal distribution of given hyperparameters. These hyperparameters are to the number of 2 in both uniform and normal cases. :param hyperparameter: values of the hyperparameters. Can either be: @@ -1001,47 +1311,73 @@ def draw_value_from_distribution(hyperparameter, hyperparameter = load_array_if_path(hyperparameter, load_as_numpy=True) if not isinstance(hyperparameter, np.ndarray): if hyperparameter is None: - hyperparameter = np.array([[centre - default_range] * size, [centre + default_range] * size]) + hyperparameter = np.array( + [[centre - default_range] * size, [centre + default_range] * size] + ) elif isinstance(hyperparameter, (int, float)): - hyperparameter = np.array([[centre - hyperparameter] * size, [centre + hyperparameter] * size]) + hyperparameter = np.array( + [[centre - hyperparameter] * size, [centre + hyperparameter] * size] + ) elif isinstance(hyperparameter, (list, tuple)): - assert len(hyperparameter) == 2, 'if list, parameter_range should be of length 2.' + assert ( + len(hyperparameter) == 2 + ), "if list, parameter_range should be of length 2." hyperparameter = np.transpose(np.tile(np.array(hyperparameter), (size, 1))) else: - raise ValueError('parameter_range should either be None, a number, a sequence, or a numpy array.') + raise ValueError( + "parameter_range should either be None, a number, a sequence, or a numpy array." + ) elif isinstance(hyperparameter, np.ndarray): - assert hyperparameter.shape[0] % 2 == 0, 'number of rows of parameter_range should be divisible by 2' + assert ( + hyperparameter.shape[0] % 2 == 0 + ), "number of rows of parameter_range should be divisible by 2" n_modalities = int(hyperparameter.shape[0] / 2) modality_idx = 2 * np.random.randint(n_modalities) - hyperparameter = hyperparameter[modality_idx: modality_idx + 2, :] + hyperparameter = hyperparameter[modality_idx : modality_idx + 2, :] # draw values as tensor if return_as_tensor: - shape = KL.Lambda(lambda x: tf.convert_to_tensor(hyperparameter.shape[1], 'int32'))([]) + shape = KL.Lambda( + lambda x: tf.convert_to_tensor(hyperparameter.shape[1], "int32") + )([]) if batchsize is not None: - shape = KL.Lambda(lambda x: tf.concat([x[0], tf.expand_dims(x[1], axis=0)], axis=0))([batchsize, shape]) - if distribution == 'uniform': - parameter_value = KL.Lambda(lambda x: tf.random.uniform(shape=x, - minval=hyperparameter[0, :], - maxval=hyperparameter[1, :]))(shape) - elif distribution == 'normal': - parameter_value = KL.Lambda(lambda x: tf.random.normal(shape=x, - mean=hyperparameter[0, :], - stddev=hyperparameter[1, :]))(shape) + shape = KL.Lambda( + lambda x: tf.concat([x[0], tf.expand_dims(x[1], axis=0)], axis=0) + )([batchsize, shape]) + if distribution == "uniform": + parameter_value = KL.Lambda( + lambda x: tf.random.uniform( + shape=x, minval=hyperparameter[0, :], maxval=hyperparameter[1, :] + ) + )(shape) + elif distribution == "normal": + parameter_value = KL.Lambda( + lambda x: tf.random.normal( + shape=x, mean=hyperparameter[0, :], stddev=hyperparameter[1, :] + ) + )(shape) else: - raise ValueError("Distribution not supported, should be 'uniform' or 'normal'.") + raise ValueError( + "Distribution not supported, should be 'uniform' or 'normal'." + ) if positive_only: parameter_value = KL.Lambda(lambda x: K.clip(x, 0, None))(parameter_value) # draw values as numpy array else: - if distribution == 'uniform': - parameter_value = np.random.uniform(low=hyperparameter[0, :], high=hyperparameter[1, :]) - elif distribution == 'normal': - parameter_value = np.random.normal(loc=hyperparameter[0, :], scale=hyperparameter[1, :]) + if distribution == "uniform": + parameter_value = np.random.uniform( + low=hyperparameter[0, :], high=hyperparameter[1, :] + ) + elif distribution == "normal": + parameter_value = np.random.normal( + loc=hyperparameter[0, :], scale=hyperparameter[1, :] + ) else: - raise ValueError("Distribution not supported, should be 'uniform' or 'normal'.") + raise ValueError( + "Distribution not supported, should be 'uniform' or 'normal'." + ) if positive_only: parameter_value[parameter_value < 0] = 0 @@ -1053,5 +1389,5 @@ def build_exp(x, first, last, fix_point): # first = f(0), last = f(+inf), fix_point = [x0, f(x0))] a = last b = first - last - c = - (1 / fix_point[0]) * np.log((fix_point[1] - last) / (first - last)) + c = -(1 / fix_point[0]) * np.log((fix_point[1] - last) / (first - last)) return a + b * np.exp(-c * x) diff --git a/nobrainer/models/__init__.py b/nobrainer/models/__init__.py index 68a797d3..bce8ad95 100644 --- a/nobrainer/models/__init__.py +++ b/nobrainer/models/__init__.py @@ -27,7 +27,7 @@ "attention_unet_with_inception": attention_unet_with_inception, "unetr": unetr, "variational_meshnet": variational_meshnet, - "bayesian_vnet": bayesian_vnet + "bayesian_vnet": bayesian_vnet, } diff --git a/nobrainer/models/lab2im_model.py b/nobrainer/models/lab2im_model.py index 743626cf..b20e5274 100644 --- a/nobrainer/models/lab2im_model.py +++ b/nobrainer/models/lab2im_model.py @@ -13,27 +13,27 @@ License. """ +# project imports +from ext.lab2im import layers, utils +from ext.lab2im.edit_tensors import blurring_sigma_for_downsampling, resample_tensor +import keras.layers as KL +from keras.models import Model # python imports import numpy as np -import keras.layers as KL -from keras.models import Model -# project imports -from ext.lab2im import utils -from ext.lab2im import layers -from ext.lab2im.edit_tensors import resample_tensor, blurring_sigma_for_downsampling - - -def lab2im_model(labels_shape, - n_channels, - generation_labels, - output_labels, - atlas_res, - target_res, - output_shape=None, - output_div_by_n=None, - blur_range=1.15): + +def lab2im_model( + labels_shape, + n_channels, + generation_labels, + output_labels, + atlas_res, + target_res, + output_shape=None, + output_div_by_n=None, + blur_range=1.15, +): """ This function builds a keras/tensorflow model to generate images from provided label maps. The images are generated by sampling a Gaussian Mixture Model (of given parameters), conditioned on the label map. @@ -74,18 +74,30 @@ def lab2im_model(labels_shape, labels_shape = utils.reformat_to_list(labels_shape) n_dims, _ = utils.get_dims(labels_shape) atlas_res = utils.reformat_to_n_channels_array(atlas_res, n_dims=n_dims)[0] - target_res = atlas_res if (target_res is None) else utils.reformat_to_n_channels_array(target_res, n_dims)[0] + target_res = ( + atlas_res + if (target_res is None) + else utils.reformat_to_n_channels_array(target_res, n_dims)[0] + ) # get shapes - crop_shape, output_shape = get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_n) + crop_shape, output_shape = get_shapes( + labels_shape, output_shape, atlas_res, target_res, output_div_by_n + ) # define model inputs - labels_input = KL.Input(shape=labels_shape+[1], name='labels_input', dtype='int32') - means_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='means_input') - stds_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='stds_input') + labels_input = KL.Input( + shape=labels_shape + [1], name="labels_input", dtype="int32" + ) + means_input = KL.Input( + shape=list(generation_labels.shape) + [n_channels], name="means_input" + ) + stds_input = KL.Input( + shape=list(generation_labels.shape) + [n_channels], name="stds_input" + ) # deform labels - labels = layers.RandomSpatialDeformation(inter_method='nearest')(labels_input) + labels = layers.RandomSpatialDeformation(inter_method="nearest")(labels_input) # cropping if crop_shape != labels_shape: @@ -94,15 +106,19 @@ def lab2im_model(labels_shape, # build synthetic image labels._keras_shape = tuple(labels.get_shape().as_list()) - image = layers.SampleConditionalGMM(generation_labels)([labels, means_input, stds_input]) + image = layers.SampleConditionalGMM(generation_labels)( + [labels, means_input, stds_input] + ) # apply bias field image._keras_shape = tuple(image.get_shape().as_list()) - image = layers.BiasFieldCorruption(.3, .025, same_bias_for_all_channels=False)(image) + image = layers.BiasFieldCorruption(0.3, 0.025, same_bias_for_all_channels=False)( + image + ) # intensity augmentation image._keras_shape = tuple(image.get_shape().as_list()) - image = layers.IntensityAugmentation(clip=300, normalise=True, gamma_std=.2)(image) + image = layers.IntensityAugmentation(clip=300, normalise=True, gamma_std=0.2)(image) # blur image sigma = blurring_sigma_for_downsampling(atlas_res, target_res) @@ -111,15 +127,19 @@ def lab2im_model(labels_shape, # resample to target res if crop_shape != output_shape: - image = resample_tensor(image, output_shape, interp_method='linear') - labels = resample_tensor(labels, output_shape, interp_method='nearest') + image = resample_tensor(image, output_shape, interp_method="linear") + labels = resample_tensor(labels, output_shape, interp_method="nearest") # reset unwanted labels to zero - labels = layers.ConvertLabels(generation_labels, dest_values=output_labels, name='labels_out')(labels) + labels = layers.ConvertLabels( + generation_labels, dest_values=output_labels, name="labels_out" + )(labels) # build model (dummy layer enables to keep the labels when plugging this model to other models) - image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels]) - brain_model = Model(inputs=[labels_input, means_input, stds_input], outputs=[image, labels]) + image = KL.Lambda(lambda x: x[0], name="image_out")([image, labels]) + brain_model = Model( + inputs=[labels_input, means_input, stds_input], outputs=[image, labels] + ) return brain_model @@ -136,26 +156,39 @@ def get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_ # output shape specified, need to get cropping shape, and resample shape if necessary if output_shape is not None: - output_shape = utils.reformat_to_list(output_shape, length=n_dims, dtype='int') + output_shape = utils.reformat_to_list(output_shape, length=n_dims, dtype="int") # make sure that output shape is smaller or equal to label shape if resample_factor is not None: - output_shape = [min(int(labels_shape[i] * resample_factor[i]), output_shape[i]) for i in range(n_dims)] + output_shape = [ + min(int(labels_shape[i] * resample_factor[i]), output_shape[i]) + for i in range(n_dims) + ] else: - output_shape = [min(labels_shape[i], output_shape[i]) for i in range(n_dims)] + output_shape = [ + min(labels_shape[i], output_shape[i]) for i in range(n_dims) + ] # make sure output shape is divisible by output_div_by_n if output_div_by_n is not None: - tmp_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n) - for s in output_shape] + tmp_shape = [ + utils.find_closest_number_divisible_by_m(s, output_div_by_n) + for s in output_shape + ] if output_shape != tmp_shape: - print('output shape {0} not divisible by {1}, changed to {2}'.format(output_shape, output_div_by_n, - tmp_shape)) + print( + "output shape {0} not divisible by {1}, changed to {2}".format( + output_shape, output_div_by_n, tmp_shape + ) + ) output_shape = tmp_shape # get cropping and resample shape if resample_factor is not None: - cropping_shape = [int(np.around(output_shape[i]/resample_factor[i], 0)) for i in range(n_dims)] + cropping_shape = [ + int(np.around(output_shape[i] / resample_factor[i], 0)) + for i in range(n_dims) + ] else: cropping_shape = output_shape @@ -163,12 +196,19 @@ def get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_ else: cropping_shape = labels_shape if resample_factor is not None: - output_shape = [int(np.around(cropping_shape[i]*resample_factor[i], 0)) for i in range(n_dims)] + output_shape = [ + int(np.around(cropping_shape[i] * resample_factor[i], 0)) + for i in range(n_dims) + ] else: output_shape = cropping_shape # make sure output shape is divisible by output_div_by_n if output_div_by_n is not None: - output_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n, answer_type='closer') - for s in output_shape] + output_shape = [ + utils.find_closest_number_divisible_by_m( + s, output_div_by_n, answer_type="closer" + ) + for s in output_shape + ] return cropping_shape, output_shape diff --git a/nobrainer/processing/image_generator.py b/nobrainer/processing/image_generator.py index d48886a7..1c015d58 100644 --- a/nobrainer/processing/image_generator.py +++ b/nobrainer/processing/image_generator.py @@ -13,34 +13,34 @@ License. """ +# project imports +from ext.lab2im import edit_volumes, utils +from ext.lab2im.lab2im_model import lab2im_model # python imports import numpy as np import numpy.random as npr -# project imports -from ext.lab2im import utils -from ext.lab2im import edit_volumes -from ext.lab2im.lab2im_model import lab2im_model - class ImageGenerator: - def __init__(self, - labels_dir, - generation_labels=None, - output_labels=None, - batchsize=1, - n_channels=1, - target_res=None, - output_shape=None, - output_div_by_n=None, - generation_classes=None, - prior_distributions='uniform', - prior_means=None, - prior_stds=None, - use_specific_stats_for_channel=False, - blur_range=1.15): + def __init__( + self, + labels_dir, + generation_labels=None, + output_labels=None, + batchsize=1, + n_channels=1, + target_res=None, + output_shape=None, + output_div_by_n=None, + generation_classes=None, + prior_distributions="uniform", + prior_means=None, + prior_stds=None, + use_specific_stats_for_channel=False, + blur_range=1.15, + ): """ This class is wrapper around the lab2im_model model. It contains the GPU model that generates images from labels maps, and a python generator that supplies the input data for this model. @@ -115,8 +115,9 @@ def __init__(self, self.labels_paths = utils.list_images_in_folder(labels_dir) # generation parameters - self.labels_shape, self.aff, self.n_dims, _, self.header, self.atlas_res = \ + self.labels_shape, self.aff, self.n_dims, _, self.header, self.atlas_res = ( utils.get_volume_info(self.labels_paths[0], aff_ref=np.eye(4)) + ) self.n_channels = n_channels if generation_labels is not None: self.generation_labels = utils.load_array_if_path(generation_labels) @@ -135,11 +136,13 @@ def __init__(self, self.prior_distributions = prior_distributions if generation_classes is not None: self.generation_classes = utils.load_array_if_path(generation_classes) - assert self.generation_classes.shape == self.generation_labels.shape, \ - 'if provided, generation labels should have the same shape as generation_labels' + assert ( + self.generation_classes.shape == self.generation_labels.shape + ), "if provided, generation labels should have the same shape as generation_labels" unique_classes = np.unique(self.generation_classes) - assert np.array_equal(unique_classes, np.arange(np.max(unique_classes)+1)), \ - 'generation_classes should a linear range between 0 and its maximum value.' + assert np.array_equal( + unique_classes, np.arange(np.max(unique_classes) + 1) + ), "generation_classes should a linear range between 0 and its maximum value." else: self.generation_classes = np.arange(self.generation_labels.shape[0]) self.prior_means = utils.load_array_if_path(prior_means) @@ -153,22 +156,26 @@ def __init__(self, self.labels_to_image_model, self.model_output_shape = self._build_lab2im_model() # build generator for model inputs - self.model_inputs_generator = self._build_model_inputs(len(self.generation_labels)) + self.model_inputs_generator = self._build_model_inputs( + len(self.generation_labels) + ) # build brain generator self.image_generator = self._build_image_generator() def _build_lab2im_model(self): # build_model - lab_to_im_model = lab2im_model(labels_shape=self.labels_shape, - n_channels=self.n_channels, - generation_labels=self.generation_labels, - output_labels=self.output_labels, - atlas_res=self.atlas_res, - target_res=self.target_res, - output_shape=self.output_shape, - output_div_by_n=self.output_div_by_n, - blur_range=self.blur_range) + lab_to_im_model = lab2im_model( + labels_shape=self.labels_shape, + n_channels=self.n_channels, + generation_labels=self.generation_labels, + output_labels=self.output_labels, + atlas_res=self.atlas_res, + target_res=self.target_res, + output_shape=self.output_shape, + output_div_by_n=self.output_div_by_n, + blur_range=self.blur_range, + ) out_shape = lab_to_im_model.output[0].get_shape().as_list()[1:] return lab_to_im_model, out_shape @@ -185,10 +192,16 @@ def generate_image(self): list_images = list() list_labels = list() for i in range(self.batchsize): - list_images.append(edit_volumes.align_volume_to_ref(image[i], np.eye(4), aff_ref=self.aff, - n_dims=self.n_dims)) - list_labels.append(edit_volumes.align_volume_to_ref(labels[i], np.eye(4), aff_ref=self.aff, - n_dims=self.n_dims)) + list_images.append( + edit_volumes.align_volume_to_ref( + image[i], np.eye(4), aff_ref=self.aff, n_dims=self.n_dims + ) + ) + list_labels.append( + edit_volumes.align_volume_to_ref( + labels[i], np.eye(4), aff_ref=self.aff, n_dims=self.n_dims + ) + ) image = np.stack(list_images, axis=0) labels = np.stack(list_labels, axis=0) return np.squeeze(image), np.squeeze(labels) @@ -212,7 +225,9 @@ def _build_model_inputs(self, n_labels): for idx in indices: # load label in identity space, and add them to inputs - y = utils.load_volume(self.labels_paths[idx], dtype='int', aff_ref=np.eye(4)) + y = utils.load_volume( + self.labels_paths[idx], dtype="int", aff_ref=np.eye(4) + ) list_label_maps.append(utils.add_axis(y, axis=[0, -1])) # add means and standard deviations to inputs @@ -222,35 +237,61 @@ def _build_model_inputs(self, n_labels): # retrieve channel specific stats if necessary if isinstance(self.prior_means, np.ndarray): - if (self.prior_means.shape[0] > 2) & self.use_specific_stats_for_channel: + if ( + self.prior_means.shape[0] > 2 + ) & self.use_specific_stats_for_channel: if self.prior_means.shape[0] / 2 != self.n_channels: - raise ValueError("the number of blocks in prior_means does not match n_channels. This " - "message is printed because use_specific_stats_for_channel is True.") - tmp_prior_means = self.prior_means[2 * channel:2 * channel + 2, :] + raise ValueError( + "the number of blocks in prior_means does not match n_channels. This " + "message is printed because use_specific_stats_for_channel is True." + ) + tmp_prior_means = self.prior_means[ + 2 * channel : 2 * channel + 2, : + ] else: tmp_prior_means = self.prior_means else: tmp_prior_means = self.prior_means if isinstance(self.prior_stds, np.ndarray): - if (self.prior_stds.shape[0] > 2) & self.use_specific_stats_for_channel: + if ( + self.prior_stds.shape[0] > 2 + ) & self.use_specific_stats_for_channel: if self.prior_stds.shape[0] / 2 != self.n_channels: - raise ValueError("the number of blocks in prior_stds does not match n_channels. This " - "message is printed because use_specific_stats_for_channel is True.") - tmp_prior_stds = self.prior_stds[2 * channel:2 * channel + 2, :] + raise ValueError( + "the number of blocks in prior_stds does not match n_channels. This " + "message is printed because use_specific_stats_for_channel is True." + ) + tmp_prior_stds = self.prior_stds[ + 2 * channel : 2 * channel + 2, : + ] else: tmp_prior_stds = self.prior_stds else: tmp_prior_stds = self.prior_stds # draw means and std devs from priors - tmp_classes_means = utils.draw_value_from_distribution(tmp_prior_means, n_labels, - self.prior_distributions, 125., 100., - positive_only=True) - tmp_classes_stds = utils.draw_value_from_distribution(tmp_prior_stds, n_labels, - self.prior_distributions, 15., 10., - positive_only=True) - tmp_means = utils.add_axis(tmp_classes_means[self.generation_classes], axis=[0, -1]) - tmp_stds = utils.add_axis(tmp_classes_stds[self.generation_classes], axis=[0, -1]) + tmp_classes_means = utils.draw_value_from_distribution( + tmp_prior_means, + n_labels, + self.prior_distributions, + 125.0, + 100.0, + positive_only=True, + ) + tmp_classes_stds = utils.draw_value_from_distribution( + tmp_prior_stds, + n_labels, + self.prior_distributions, + 15.0, + 10.0, + positive_only=True, + ) + tmp_means = utils.add_axis( + tmp_classes_means[self.generation_classes], axis=[0, -1] + ) + tmp_stds = utils.add_axis( + tmp_classes_stds[self.generation_classes], axis=[0, -1] + ) means = np.concatenate([means, tmp_means], axis=-1) stds = np.concatenate([stds, tmp_stds], axis=-1) list_means.append(means) @@ -258,7 +299,9 @@ def _build_model_inputs(self, n_labels): # build list of inputs of augmentation model list_inputs = [list_label_maps, list_means, list_stds] - if self.batchsize > 1: # concatenate individual input types if batchsize > 1 + if ( + self.batchsize > 1 + ): # concatenate individual input types if batchsize > 1 list_inputs = [np.concatenate(item, 0) for item in list_inputs] else: list_inputs = [item[0] for item in list_inputs] diff --git a/nobrainer/processing/segmentation.py b/nobrainer/processing/segmentation.py index 58d0e03e..11aadab7 100644 --- a/nobrainer/processing/segmentation.py +++ b/nobrainer/processing/segmentation.py @@ -50,7 +50,7 @@ def fit( metrics=metrics.dice, callbacks=None, verbose=1, - initial_epoch=0 + initial_epoch=0, ): """Train a segmentation model""" # TODO: check validity of datasets @@ -116,7 +116,7 @@ def _compile(): ), callbacks=callbacks, verbose=verbose, - initial_epoch=initial_epoch + initial_epoch=initial_epoch, ) return self diff --git a/nobrainer/tfrecord.py b/nobrainer/tfrecord.py index 0f3bbd95..a4701fde 100644 --- a/nobrainer/tfrecord.py +++ b/nobrainer/tfrecord.py @@ -58,7 +58,9 @@ def write( verbose: int, if 1, print progress bar. If 0, print nothing. """ n_examples = len(features_labels) - shards = np.array_split(features_labels, np.arange(examples_per_shard, n_examples, examples_per_shard)) + shards = np.array_split( + features_labels, np.arange(examples_per_shard, n_examples, examples_per_shard) + ) # Test that the `filename_template` has a `shard` formatting key. try: @@ -77,7 +79,9 @@ def write( # This is the object that returns a protocol buffer string of the feature and label # on each iteration. It is pickle-able, unlike a generator. proto_iterators = [ - _ProtoIterator(s, to_ras=to_ras, multi_resolution=multi_resolution, resolutions=resolutions) + _ProtoIterator( + s, to_ras=to_ras, multi_resolution=multi_resolution, resolutions=resolutions + ) for s in shards ] # Set up positional arguments for the core writer function.