diff --git a/nobrainer/ext/lab2im/edit_tensors.py b/nobrainer/ext/lab2im/edit_tensors.py index 035a104b..65e72a02 100644 --- a/nobrainer/ext/lab2im/edit_tensors.py +++ b/nobrainer/ext/lab2im/edit_tensors.py @@ -31,11 +31,11 @@ from itertools import combinations # project imports -from ext.lab2im import utils +from nobrainer.ext.lab2im import utils # third-party imports -import ext.neuron.layers as nrn_layers -from ext.neuron.utils import volshape_to_meshgrid +import nobrainer.ext.neuron.layers as nrn_layers +from nobrainer.ext.neuron.utils import volshape_to_meshgrid def blurring_sigma_for_downsampling(current_res, downsample_res, mult_coef=None, thickness=None): diff --git a/nobrainer/ext/lab2im/edit_volumes.py b/nobrainer/ext/lab2im/edit_volumes.py index 1afeb34f..c8e388a5 100644 --- a/nobrainer/ext/lab2im/edit_volumes.py +++ b/nobrainer/ext/lab2im/edit_volumes.py @@ -85,9 +85,9 @@ from scipy.ndimage import binary_dilation, binary_erosion, gaussian_filter # project imports -from ext.lab2im import utils -from ext.lab2im.layers import GaussianBlur, ConvertLabels -from ext.lab2im.edit_tensors import blurring_sigma_for_downsampling +from nobrainer.ext.lab2im import utils +from nobrainer.ext.lab2im.layers import GaussianBlur, ConvertLabels +from nobrainer.ext.lab2im.edit_tensors import blurring_sigma_for_downsampling # ---------------------------------------------------- edit volume ----------------------------------------------------- diff --git a/nobrainer/ext/lab2im/image_generator.py b/nobrainer/ext/lab2im/image_generator.py index d48886a7..073442e8 100644 --- a/nobrainer/ext/lab2im/image_generator.py +++ b/nobrainer/ext/lab2im/image_generator.py @@ -19,9 +19,9 @@ import numpy.random as npr # project imports -from ext.lab2im import utils -from ext.lab2im import edit_volumes -from ext.lab2im.lab2im_model import lab2im_model +from nobrainer.ext.lab2im import utils +from nobrainer.ext.lab2im import edit_volumes +from nobrainer.ext.lab2im.lab2im_model import lab2im_model class ImageGenerator: diff --git a/nobrainer/ext/lab2im/lab2im_model.py b/nobrainer/ext/lab2im/lab2im_model.py index 743626cf..f32c5b5a 100644 --- a/nobrainer/ext/lab2im/lab2im_model.py +++ b/nobrainer/ext/lab2im/lab2im_model.py @@ -20,9 +20,9 @@ from keras.models import Model # project imports -from ext.lab2im import utils -from ext.lab2im import layers -from ext.lab2im.edit_tensors import resample_tensor, blurring_sigma_for_downsampling +from nobrainer.ext.lab2im import utils +from nobrainer.ext.lab2im import layers +from nobrainer.ext.lab2im.edit_tensors import resample_tensor, blurring_sigma_for_downsampling def lab2im_model(labels_shape, diff --git a/nobrainer/ext/lab2im/layers.py b/nobrainer/ext/lab2im/layers.py index 96cbda30..e477d607 100644 --- a/nobrainer/ext/lab2im/layers.py +++ b/nobrainer/ext/lab2im/layers.py @@ -43,12 +43,12 @@ from keras.layers import Layer # project imports -from ext.lab2im import utils -from ext.lab2im import edit_tensors as l2i_et +from nobrainer.ext.lab2im import utils +from nobrainer.ext.lab2im import edit_tensors as l2i_et # third-party imports -from ext.neuron import utils as nrn_utils -import ext.neuron.layers as nrn_layers +from nobrainer.ext.neuron import utils as nrn_utils +import nobrainer.ext.neuron.layers as nrn_layers class RandomSpatialDeformation(Layer): diff --git a/nobrainer/ext/neuron/__init__.py b/nobrainer/ext/neuron/__init__.py new file mode 100644 index 00000000..2f28f4d4 --- /dev/null +++ b/nobrainer/ext/neuron/__init__.py @@ -0,0 +1,3 @@ +from . import layers +from . import models +from . import utils diff --git a/nobrainer/ext/neuron/layers.py b/nobrainer/ext/neuron/layers.py new file mode 100644 index 00000000..61b46a78 --- /dev/null +++ b/nobrainer/ext/neuron/layers.py @@ -0,0 +1,435 @@ +""" +tensorflow/keras utilities for the neuron project + +If you use this code, please cite +Dalca AV, Guttag J, Sabuncu MR +Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, +CVPR 2018 + +or for the transformation/integration functions: + +Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration +Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu +MICCAI 2018. + +Contact: adalca [at] csail [dot] mit [dot] edu +License: GPLv3 +""" + +# third party +import tensorflow as tf +from keras import backend as K +from keras.layers import Layer +from copy import deepcopy + +# local +from nobrainer.ext.neuron.utils import transform, resize, integrate_vec, affine_to_shift, combine_non_linear_and_aff_to_shift + + +class SpatialTransformer(Layer): + """ + N-D Spatial Transformer Tensorflow / Keras Layer + + The Layer can handle both affine and dense transforms. + Both transforms are meant to give a 'shift' from the current position. + Therefore, a dense transform gives displacements (not absolute locations) at each voxel, + and an affine transform gives the *difference* of the affine matrix from + the identity matrix. + + If you find this function useful, please cite: + Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration + Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu + MICCAI 2018. + + Originally, this code was based on voxelmorph code, which + was in turn transformed to be dense with the help of (affine) STN code + via https://github.com/kevinzakka/spatial-transformer-network + + Since then, we've re-written the code to be generalized to any + dimensions, and along the way wrote grid and interpolation functions + """ + + def __init__(self, + interp_method='linear', + indexing='ij', + single_transform=False, + **kwargs): + """ + Parameters: + interp_method: 'linear' or 'nearest' + single_transform: whether a single transform supplied for the whole batch + indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian) + 'xy' indexing will have the first two entries of the flow + (along last axis) flipped compared to 'ij' indexing + """ + self.interp_method = interp_method + self.ndims = None + self.inshape = None + self.single_transform = single_transform + self.is_affine = list() + + assert indexing in ['ij', 'xy'], "indexing has to be 'ij' (matrix) or 'xy' (cartesian)" + self.indexing = indexing + + super(self.__class__, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["interp_method"] = self.interp_method + config["indexing"] = self.indexing + config["single_transform"] = self.single_transform + return config + + def build(self, input_shape): + """ + input_shape should be a list for two inputs: + input1: image. + input2: list of transform Tensors + if affine: + should be an N+1 x N+1 matrix + *or* a N+1*N+1 tensor (which will be reshaped to N x (N+1) and an identity row added) + if not affine: + should be a *vol_shape x N + """ + + if len(input_shape) > 3: + raise Exception('Spatial Transformer must be called on a list of min length 2 and max length 3.' + 'First argument is the image followed by the affine and non linear transforms.') + + # set up number of dimensions + self.ndims = len(input_shape[0]) - 2 + self.inshape = input_shape + trf_shape = [trans_shape[1:] for trans_shape in input_shape[1:]] + + for (i, shape) in enumerate(trf_shape): + + # the transform is an affine iff: + # it's a 1D Tensor [dense transforms need to be at least ndims + 1] + # it's a 2D Tensor and shape == [N+1, N+1]. + self.is_affine.append(len(shape) == 1 or + (len(shape) == 2 and all([f == (self.ndims + 1) for f in shape]))) + + # check sizes + if self.is_affine[i] and len(shape) == 1: + ex = self.ndims * (self.ndims + 1) + if shape[0] != ex: + raise Exception('Expected flattened affine of len %d but got %d' % (ex, shape[0])) + + if not self.is_affine[i]: + if shape[-1] != self.ndims: + raise Exception('Offset flow field size expected: %d, found: %d' % (self.ndims, shape[-1])) + + # confirm built + self.built = True + + def call(self, inputs, **kwargs): + """ + Parameters + inputs: list with several entries: the volume followed by the transforms + """ + + # check shapes + assert 1 < len(inputs) < 4, "inputs has to be len 2 or 3, found: %d" % len(inputs) + vol = inputs[0] + trf = inputs[1:] + + # necessary for multi_gpu models... + vol = K.reshape(vol, [-1, *self.inshape[0][1:]]) + for i in range(len(trf)): + trf[i] = K.reshape(trf[i], [-1, *self.inshape[i+1][1:]]) + + # reorder transforms, non-linear first and affine second + ind_nonlinear_linear = [i[0] for i in sorted(enumerate(self.is_affine), key=lambda x:x[1])] + self.is_affine = [self.is_affine[i] for i in ind_nonlinear_linear] + self.inshape = [self.inshape[i] for i in ind_nonlinear_linear] + trf = [trf[i] for i in ind_nonlinear_linear] + + # go from affine to deformation field + if len(trf) == 1: + trf = trf[0] + if self.is_affine[0]: + trf = tf.map_fn(lambda x: self._single_aff_to_shift(x, vol.shape[1:-1]), trf, dtype=tf.float32) + # combine non-linear and affine to obtain a single deformation field + elif len(trf) == 2: + trf = tf.map_fn(lambda x: self._non_linear_and_aff_to_shift(x, vol.shape[1:-1]), trf, dtype=tf.float32) + + # prepare location shift + if self.indexing == 'xy': # shift the first two dimensions + trf_split = tf.split(trf, trf.shape[-1], axis=-1) + trf_lst = [trf_split[1], trf_split[0], *trf_split[2:]] + trf = tf.concat(trf_lst, -1) + + # map transform across batch + if self.single_transform: + return tf.map_fn(self._single_transform, [vol, trf[0, :]], dtype=tf.float32) + else: + return tf.map_fn(self._single_transform, [vol, trf], dtype=tf.float32) + + def _single_aff_to_shift(self, trf, volshape): + if len(trf.shape) == 1: # go from vector to matrix + trf = tf.reshape(trf, [self.ndims, self.ndims + 1]) + return affine_to_shift(trf, volshape, shift_center=True) + + def _non_linear_and_aff_to_shift(self, trf, volshape): + if len(trf[1].shape) == 1: # go from vector to matrix + trf[1] = tf.reshape(trf[1], [self.ndims, self.ndims + 1]) + return combine_non_linear_and_aff_to_shift(trf, volshape, shift_center=True) + + def _single_transform(self, inputs): + return transform(inputs[0], inputs[1], interp_method=self.interp_method) + + +class VecInt(Layer): + """ + Vector Integration Layer + + Enables vector integration via several methods + (ode or quadrature for time-dependent vector fields, + scaling and squaring for stationary fields) + + If you find this function useful, please cite: + Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration + Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu + MICCAI 2018. + """ + + def __init__(self, indexing='ij', method='ss', int_steps=7, out_time_pt=1, + ode_args=None, + odeint_fn=None, **kwargs): + """ + Parameters: + method can be any of the methods in neuron.utils.integrate_vec + indexing can be 'xy' (switches first two dimensions) or 'ij' + int_steps is the number of integration steps + out_time_pt is time point at which to output if using odeint integration + """ + + assert indexing in ['ij', 'xy'], "indexing has to be 'ij' (matrix) or 'xy' (cartesian)" + self.indexing = indexing + self.method = method + self.int_steps = int_steps + self.inshape = None + self.out_time_pt = out_time_pt + self.odeint_fn = odeint_fn # if none then will use a tensorflow function + self.ode_args = ode_args + if ode_args is None: + self.ode_args = {'rtol': 1e-6, 'atol': 1e-12} + super(self.__class__, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["indexing"] = self.indexing + config["method"] = self.method + config["int_steps"] = self.int_steps + config["out_time_pt"] = self.out_time_pt + config["ode_args"] = self.ode_args + config["odeint_fn"] = self.odeint_fn + return config + + def build(self, input_shape): + # confirm built + self.built = True + + trf_shape = input_shape + if isinstance(input_shape[0], (list, tuple)): + trf_shape = input_shape[0] + self.inshape = trf_shape + + if trf_shape[-1] != len(trf_shape) - 2: + raise Exception('transform ndims %d does not match expected ndims %d' % (trf_shape[-1], len(trf_shape) - 2)) + + def call(self, inputs, **kwargs): + if not isinstance(inputs, (list, tuple)): + inputs = [inputs] + loc_shift = inputs[0] + + # necessary for multi_gpu models... + loc_shift = K.reshape(loc_shift, [-1, *self.inshape[1:]]) + + # prepare location shift + if self.indexing == 'xy': # shift the first two dimensions + loc_shift_split = tf.split(loc_shift, loc_shift.shape[-1], axis=-1) + loc_shift_lst = [loc_shift_split[1], loc_shift_split[0], *loc_shift_split[2:]] + loc_shift = tf.concat(loc_shift_lst, -1) + + if len(inputs) > 1: + assert self.out_time_pt is None, 'out_time_pt should be None if providing batch_based out_time_pt' + + # map transform across batch + out = tf.map_fn(self._single_int, [loc_shift] + inputs[1:], dtype=tf.float32) + return out + + def _single_int(self, inputs): + + vel = inputs[0] + out_time_pt = self.out_time_pt + if len(inputs) == 2: + out_time_pt = inputs[1] + return integrate_vec(vel, method=self.method, + nb_steps=self.int_steps, + ode_args=self.ode_args, + out_time_pt=out_time_pt, + odeint_fn=self.odeint_fn) + + +class Resize(Layer): + """ + N-D Resize Tensorflow / Keras Layer + Note: this is not re-shaping an existing volume, but resizing, like scipy's "Zoom" + + If you find this function useful, please cite: + Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation,Dalca AV, Guttag J, Sabuncu MR + CVPR 2018 + + Since then, we've re-written the code to be generalized to any + dimensions, and along the way wrote grid and interpolation functions + """ + + def __init__(self, + zoom_factor=None, + size=None, + interp_method='linear', + **kwargs): + """ + Parameters: + interp_method: 'linear' or 'nearest' + 'xy' indexing will have the first two entries of the flow + (along last axis) flipped compared to 'ij' indexing + """ + self.zoom_factor = zoom_factor + self.size = list(size) + self.zoom_factor0 = None + self.size0 = None + self.interp_method = interp_method + self.ndims = None + self.inshape = None + super(Resize, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["zoom_factor"] = self.zoom_factor + config["size"] = self.size + config["interp_method"] = self.interp_method + return config + + def build(self, input_shape): + """ + input_shape should be an element of list of one inputs: + input1: volume + should be a *vol_shape x N + """ + + if isinstance(input_shape[0], (list, tuple)) and len(input_shape) > 1: + raise Exception('Resize must be called on a list of length 1.') + + if isinstance(input_shape[0], (list, tuple)): + input_shape = input_shape[0] + + # set up number of dimensions + self.ndims = len(input_shape) - 2 + self.inshape = input_shape + + # check zoom_factor + if isinstance(self.zoom_factor, float): + self.zoom_factor0 = [self.zoom_factor] * self.ndims + elif self.zoom_factor is None: + self.zoom_factor0 = [0] * self.ndims + elif isinstance(self.zoom_factor, (list, tuple)): + self.zoom_factor0 = deepcopy(self.zoom_factor) + assert len(self.zoom_factor0) == self.ndims, \ + 'zoom factor length {} does not match number of dimensions {}'.format(len(self.zoom_factor), self.ndims) + else: + raise Exception('zoom_factor should be an int or a list/tuple of int (or None if size is not set to None)') + + # check size + if isinstance(self.size, int): + self.size0 = [self.size] * self.ndims + elif self.size is None: + self.size0 = [0] * self.ndims + elif isinstance(self.size, (list, tuple)): + self.size0 = deepcopy(self.size) + assert len(self.size0) == self.ndims, \ + 'size length {} does not match number of dimensions {}'.format(len(self.size0), self.ndims) + else: + raise Exception('size should be an int or a list/tuple of int (or None if zoom_factor is not set to None)') + + # confirm built + self.built = True + + super(Resize, self).build(input_shape) # Be sure to call this somewhere! + + def call(self, inputs, **kwargs): + """ + Parameters + inputs: volume or list of one volume + """ + + # check shapes + if isinstance(inputs, (list, tuple)): + assert len(inputs) == 1, "inputs has to be len 1. found: %d" % len(inputs) + vol = inputs[0] + else: + vol = inputs + + # necessary for multi_gpu models... + vol = K.reshape(vol, [-1, *self.inshape[1:]]) + + # set value of missing size or zoom_factor + if not any(self.zoom_factor0): + self.zoom_factor0 = [self.size0[i] / self.inshape[i+1] for i in range(self.ndims)] + else: + self.size0 = [int(self.inshape[f+1] * self.zoom_factor0[f]) for f in range(self.ndims)] + + # map transform across batch + return tf.map_fn(self._single_resize, vol, dtype=vol.dtype) + + def compute_output_shape(self, input_shape): + + output_shape = [input_shape[0]] + output_shape += [int(input_shape[1:-1][f] * self.zoom_factor0[f]) for f in range(self.ndims)] + output_shape += [input_shape[-1]] + return tuple(output_shape) + + def _single_resize(self, inputs): + return resize(inputs, self.zoom_factor0, self.size0, interp_method=self.interp_method) + + +# Zoom naming of resize, to match scipy's naming +Zoom = Resize + + +######################################################### +# "Local" layers -- layers with parameters at each voxel +######################################################### + +class LocalBias(Layer): + """ + Local bias layer: each pixel/voxel has its own bias operation (one parameter) + out[v] = in[v] + b + """ + + def __init__(self, my_initializer='RandomNormal', biasmult=1.0, **kwargs): + self.initializer = my_initializer + self.biasmult = biasmult + self.kernel = None + super(LocalBias, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["my_initializer"] = self.initializer + config["biasmult"] = self.biasmult + return config + + def build(self, input_shape): + # Create a trainable weight variable for this layer. + self.kernel = self.add_weight(name='kernel', + shape=input_shape[1:], + initializer=self.initializer, + trainable=True) + super(LocalBias, self).build(input_shape) # Be sure to call this somewhere! + + def call(self, x, **kwargs): + return x + self.kernel * self.biasmult # weights are difference from input + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/nobrainer/ext/neuron/models.py b/nobrainer/ext/neuron/models.py new file mode 100644 index 00000000..9b5c87ed --- /dev/null +++ b/nobrainer/ext/neuron/models.py @@ -0,0 +1,768 @@ +""" +tensorflow/keras utilities for the neuron project + +If you use this code, please cite +Dalca AV, Guttag J, Sabuncu MR +Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, +CVPR 2018 + +Contact: adalca [at] csail [dot] mit [dot] edu +License: GPLv3 +""" + +import sys + +from nobrainer.ext.neuron import layers + +# third party +import numpy as np +import tensorflow as tf +import keras +import keras.layers as KL +from keras.models import Model +import keras.backend as K + + +def unet(nb_features, + input_shape, + nb_levels, + conv_size, + nb_labels, + name='unet', + prefix=None, + feat_mult=1, + pool_size=2, + use_logp=True, + padding='same', + dilation_rate_mult=1, + activation='elu', + skip_n_concatenations=0, + use_residuals=False, + final_pred_activation='softmax', + nb_conv_per_level=1, + add_prior_layer=False, + layer_nb_feats=None, + conv_dropout=0, + batch_norm=None, + input_model=None): + """ + unet-style keras model with an overdose of parametrization. + + Parameters: + nb_features: the number of features at each convolutional level + see below for `feat_mult` and `layer_nb_feats` for modifiers to this number + input_shape: input layer shape, vector of size ndims + 1 (nb_channels) + conv_size: the convolution kernel size + nb_levels: the number of Unet levels (number of downsamples) in the "encoder" + (e.g. 4 would give you 4 levels in encoder, 4 in decoder) + nb_labels: number of output channels + name (default: 'unet'): the name of the network + prefix (default: `name` value): prefix to be added to layer names + feat_mult (default: 1) multiple for `nb_features` as we go down the encoder levels. + e.g. feat_mult of 2 and nb_features of 16 would yield 32 features in the + second layer, 64 features in the third layer, etc. + pool_size (default: 2): max pooling size (integer or list if specifying per dimension) + skip_n_concatenations=0: enabled to skip concatenation links between contracting and expanding paths for the n + top levels. + use_logp: + padding: + dilation_rate_mult: + activation: + use_residuals: + final_pred_activation: + nb_conv_per_level: + add_prior_layer: + skip_n_concatenations: + layer_nb_feats: list of the number of features for each layer. Automatically used if specified + conv_dropout: dropout probability + batch_norm: + input_model: concatenate the provided input_model to this current model. + Only the first output of input_model is used. + """ + + # naming + model_name = name + if prefix is None: + prefix = model_name + + # volume size data + ndims = len(input_shape) - 1 + if isinstance(pool_size, int): + pool_size = (pool_size,) * ndims + + # get encoding model + enc_model = conv_enc(nb_features, + input_shape, + nb_levels, + conv_size, + name=model_name, + prefix=prefix, + feat_mult=feat_mult, + pool_size=pool_size, + padding=padding, + dilation_rate_mult=dilation_rate_mult, + activation=activation, + use_residuals=use_residuals, + nb_conv_per_level=nb_conv_per_level, + layer_nb_feats=layer_nb_feats, + conv_dropout=conv_dropout, + batch_norm=batch_norm, + input_model=input_model) + + # get decoder + # use_skip_connections=True makes it a u-net + lnf = layer_nb_feats[(nb_levels * nb_conv_per_level):] if layer_nb_feats is not None else None + dec_model = conv_dec(nb_features, + [], + nb_levels, + conv_size, + nb_labels, + name=model_name, + prefix=prefix, + feat_mult=feat_mult, + pool_size=pool_size, + use_skip_connections=True, + skip_n_concatenations=skip_n_concatenations, + padding=padding, + dilation_rate_mult=dilation_rate_mult, + activation=activation, + use_residuals=use_residuals, + final_pred_activation='linear' if add_prior_layer else final_pred_activation, + nb_conv_per_level=nb_conv_per_level, + batch_norm=batch_norm, + layer_nb_feats=lnf, + conv_dropout=conv_dropout, + input_model=enc_model) + final_model = dec_model + + if add_prior_layer: + final_model = add_prior(dec_model, + [*input_shape[:-1], nb_labels], + name=model_name + '_prior', + use_logp=use_logp, + final_pred_activation=final_pred_activation) + + return final_model + + +def ae(nb_features, + input_shape, + nb_levels, + conv_size, + nb_labels, + enc_size, + name='ae', + feat_mult=1, + pool_size=2, + padding='same', + activation='elu', + use_residuals=False, + nb_conv_per_level=1, + batch_norm=None, + enc_batch_norm=None, + ae_type='conv', # 'dense', or 'conv' + enc_lambda_layers=None, + add_prior_layer=False, + use_logp=True, + conv_dropout=0, + include_mu_shift_layer=False, + single_model=False, # whether to return a single model, or a tuple of models that can be stacked. + final_pred_activation='softmax', + do_vae=False, + input_model=None): + """Convolutional Auto-Encoder. Optionally Variational (if do_vae is set to True).""" + + # naming + model_name = name + + # volume size data + ndims = len(input_shape) - 1 + if isinstance(pool_size, int): + pool_size = (pool_size,) * ndims + + # get encoding model + enc_model = conv_enc(nb_features, + input_shape, + nb_levels, + conv_size, + name=model_name, + feat_mult=feat_mult, + pool_size=pool_size, + padding=padding, + activation=activation, + use_residuals=use_residuals, + nb_conv_per_level=nb_conv_per_level, + conv_dropout=conv_dropout, + batch_norm=batch_norm, + input_model=input_model) + + # middle AE structure + if single_model: + in_input_shape = None + in_model = enc_model + else: + in_input_shape = enc_model.output.shape.as_list()[1:] + in_model = None + mid_ae_model = single_ae(enc_size, + in_input_shape, + conv_size=conv_size, + name=model_name, + ae_type=ae_type, + input_model=in_model, + batch_norm=enc_batch_norm, + enc_lambda_layers=enc_lambda_layers, + include_mu_shift_layer=include_mu_shift_layer, + do_vae=do_vae) + + # decoder + if single_model: + in_input_shape = None + in_model = mid_ae_model + else: + in_input_shape = mid_ae_model.output.shape.as_list()[1:] + in_model = None + dec_model = conv_dec(nb_features, + in_input_shape, + nb_levels, + conv_size, + nb_labels, + name=model_name, + feat_mult=feat_mult, + pool_size=pool_size, + use_skip_connections=False, + padding=padding, + activation=activation, + use_residuals=use_residuals, + final_pred_activation='linear', + nb_conv_per_level=nb_conv_per_level, + batch_norm=batch_norm, + conv_dropout=conv_dropout, + input_model=in_model) + + if add_prior_layer: + dec_model = add_prior(dec_model, + [*input_shape[:-1], nb_labels], + name=model_name, + prefix=model_name + '_prior', + use_logp=use_logp, + final_pred_activation=final_pred_activation) + + if single_model: + return dec_model + else: + return dec_model, mid_ae_model, enc_model + + +def conv_enc(nb_features, + input_shape, + nb_levels, + conv_size, + name=None, + prefix=None, + feat_mult=1, + pool_size=2, + dilation_rate_mult=1, + padding='same', + activation='elu', + layer_nb_feats=None, + use_residuals=False, + nb_conv_per_level=2, + conv_dropout=0, + batch_norm=None, + input_model=None): + """Fully Convolutional Encoder""" + + # naming + model_name = name + if prefix is None: + prefix = model_name + + # first layer: input + name = '%s_input' % prefix + if input_model is None: + input_tensor = KL.Input(shape=input_shape, name=name) + last_tensor = input_tensor + else: + input_tensor = input_model.inputs + last_tensor = input_model.outputs + if isinstance(last_tensor, list): + last_tensor = last_tensor[0] + + # volume size data + ndims = len(input_shape) - 1 + if isinstance(pool_size, int): + pool_size = (pool_size,) * ndims + + # prepare layers + convL = getattr(KL, 'Conv%dD' % ndims) + conv_kwargs = {'padding': padding, 'activation': activation, 'data_format': 'channels_last'} + maxpool = getattr(KL, 'MaxPooling%dD' % ndims) + + # down arm: + # add nb_levels of conv + ReLu + conv + ReLu. Pool after each of first nb_levels - 1 layers + lfidx = 0 # level feature index + for level in range(nb_levels): + lvl_first_tensor = last_tensor + nb_lvl_feats = np.round(nb_features * feat_mult ** level).astype(int) + conv_kwargs['dilation_rate'] = dilation_rate_mult ** level + + for conv in range(nb_conv_per_level): # does several conv per level, max pooling applied at the end + if layer_nb_feats is not None: # None or List of all the feature numbers + nb_lvl_feats = layer_nb_feats[lfidx] + lfidx += 1 + + name = '%s_conv_downarm_%d_%d' % (prefix, level, conv) + if conv < (nb_conv_per_level - 1) or (not use_residuals): + last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor) + else: # no activation + last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor) + + if conv_dropout > 0: + # conv dropout along feature space only + name = '%s_dropout_downarm_%d_%d' % (prefix, level, conv) + noise_shape = [None, *[1] * ndims, nb_lvl_feats] + last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor) + + if use_residuals: + convarm_layer = last_tensor + + # the "add" layer is the original input + # However, it may not have the right number of features to be added + nb_feats_in = lvl_first_tensor.get_shape()[-1] + nb_feats_out = convarm_layer.get_shape()[-1] + add_layer = lvl_first_tensor + if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out): + name = '%s_expand_down_merge_%d' % (prefix, level) + last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(lvl_first_tensor) + add_layer = last_tensor + + if conv_dropout > 0: + noise_shape = [None, *[1] * ndims, nb_lvl_feats] + convarm_layer = KL.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor) + + name = '%s_res_down_merge_%d' % (prefix, level) + last_tensor = KL.add([add_layer, convarm_layer], name=name) + + name = '%s_res_down_merge_act_%d' % (prefix, level) + last_tensor = KL.Activation(activation, name=name)(last_tensor) + + if batch_norm is not None: + name = '%s_bn_down_%d' % (prefix, level) + last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) + + # max pool if we're not at the last level + if level < (nb_levels - 1): + name = '%s_maxpool_%d' % (prefix, level) + last_tensor = maxpool(pool_size=pool_size, name=name, padding=padding)(last_tensor) + + # create the model and return + model = Model(inputs=input_tensor, outputs=[last_tensor], name=model_name) + return model + + +def conv_dec(nb_features, + input_shape, + nb_levels, + conv_size, + nb_labels, + name=None, + prefix=None, + feat_mult=1, + pool_size=2, + use_skip_connections=False, + skip_n_concatenations=0, + padding='same', + dilation_rate_mult=1, + activation='elu', + use_residuals=False, + final_pred_activation='softmax', + nb_conv_per_level=2, + layer_nb_feats=None, + batch_norm=None, + conv_dropout=0, + input_model=None): + """Fully Convolutional Decoder""" + + # naming + model_name = name + if prefix is None: + prefix = model_name + + # if using skip connections, make sure need to use them. + if use_skip_connections: + assert input_model is not None, "is using skip connections, tensors dictionary is required" + + # first layer: input + input_name = '%s_input' % prefix + if input_model is None: + input_tensor = KL.Input(shape=input_shape, name=input_name) + last_tensor = input_tensor + else: + input_tensor = input_model.input + last_tensor = input_model.output + input_shape = last_tensor.shape.as_list()[1:] + + # vol size info + ndims = len(input_shape) - 1 + if isinstance(pool_size, int): + if ndims > 1: + pool_size = (pool_size,) * ndims + + # prepare layers + convL = getattr(KL, 'Conv%dD' % ndims) + conv_kwargs = {'padding': padding, 'activation': activation} + upsample = getattr(KL, 'UpSampling%dD' % ndims) + + # up arm: + # nb_levels - 1 layers of Deconvolution3D + # (approx via up + conv + ReLu) + merge + conv + ReLu + conv + ReLu + lfidx = 0 + for level in range(nb_levels - 1): + nb_lvl_feats = np.round(nb_features * feat_mult ** (nb_levels - 2 - level)).astype(int) + conv_kwargs['dilation_rate'] = dilation_rate_mult ** (nb_levels - 2 - level) + + # upsample matching the max pooling layers size + name = '%s_up_%d' % (prefix, nb_levels + level) + last_tensor = upsample(size=pool_size, name=name)(last_tensor) + up_tensor = last_tensor + + # merge layers combining previous layer + if use_skip_connections & (level < (nb_levels - skip_n_concatenations - 1)): + conv_name = '%s_conv_downarm_%d_%d' % (prefix, nb_levels - 2 - level, nb_conv_per_level - 1) + cat_tensor = input_model.get_layer(conv_name).output + name = '%s_merge_%d' % (prefix, nb_levels + level) + last_tensor = KL.concatenate([cat_tensor, last_tensor], axis=ndims + 1, name=name) + + # convolution layers + for conv in range(nb_conv_per_level): + if layer_nb_feats is not None: + nb_lvl_feats = layer_nb_feats[lfidx] + lfidx += 1 + + name = '%s_conv_uparm_%d_%d' % (prefix, nb_levels + level, conv) + if conv < (nb_conv_per_level - 1) or (not use_residuals): + last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor) + else: + last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor) + + if conv_dropout > 0: + name = '%s_dropout_uparm_%d_%d' % (prefix, level, conv) + noise_shape = [None, *[1] * ndims, nb_lvl_feats] + last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor) + + # residual block + if use_residuals: + + # the "add" layer is the original input + # However, it may not have the right number of features to be added + add_layer = up_tensor + nb_feats_in = add_layer.get_shape()[-1] + nb_feats_out = last_tensor.get_shape()[-1] + if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out): + name = '%s_expand_up_merge_%d' % (prefix, level) + add_layer = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(add_layer) + + if conv_dropout > 0: + noise_shape = [None, *[1] * ndims, nb_lvl_feats] + last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor) + + name = '%s_res_up_merge_%d' % (prefix, level) + last_tensor = KL.add([last_tensor, add_layer], name=name) + + name = '%s_res_up_merge_act_%d' % (prefix, level) + last_tensor = KL.Activation(activation, name=name)(last_tensor) + + if batch_norm is not None: + name = '%s_bn_up_%d' % (prefix, level) + last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) + + # Compute likelihood prediction (no activation yet) + name = '%s_likelihood' % prefix + last_tensor = convL(nb_labels, 1, activation=None, name=name)(last_tensor) + like_tensor = last_tensor + + # output prediction layer + # we use a softmax to compute P(L_x|I) where x is each location + if final_pred_activation == 'softmax': + name = '%s_prediction' % prefix + softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=ndims + 1) + pred_tensor = KL.Lambda(softmax_lambda_fcn, name=name)(last_tensor) + + # otherwise create a layer that does nothing. + else: + name = '%s_prediction' % prefix + pred_tensor = KL.Activation('linear', name=name)(like_tensor) + + # create the model and return + model = Model(inputs=input_tensor, outputs=pred_tensor, name=model_name) + return model + + +def add_prior(input_model, + prior_shape, + name='prior_model', + prefix=None, + use_logp=True, + final_pred_activation='softmax'): + """ + Append post-prior layer to a given model + """ + + # naming + model_name = name + if prefix is None: + prefix = model_name + + # prior input layer + prior_input_name = '%s-input' % prefix + prior_tensor = KL.Input(shape=prior_shape, name=prior_input_name) + prior_tensor_input = prior_tensor + like_tensor = input_model.output + + # operation varies depending on whether we log() prior or not. + if use_logp: + print("Breaking change: use_logp option now requires log input!", file=sys.stderr) + merge_op = KL.add + + else: + # using sigmoid to get the likelihood values between 0 and 1 + # note: they won't add up to 1. + name = '%s_likelihood_sigmoid' % prefix + like_tensor = KL.Activation('sigmoid', name=name)(like_tensor) + merge_op = KL.multiply + + # merge the likelihood and prior layers into posterior layer + name = '%s_posterior' % prefix + post_tensor = merge_op([prior_tensor, like_tensor], name=name) + + # output prediction layer + # we use a softmax to compute P(L_x|I) where x is each location + pred_name = '%s_prediction' % prefix + if final_pred_activation == 'softmax': + assert use_logp, 'cannot do softmax when adding prior via P()' + print("using final_pred_activation %s for %s" % (final_pred_activation, model_name)) + softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=-1) + pred_tensor = KL.Lambda(softmax_lambda_fcn, name=pred_name)(post_tensor) + + else: + pred_tensor = KL.Activation('linear', name=pred_name)(post_tensor) + + # create the model + model_inputs = [*input_model.inputs, prior_tensor_input] + model = Model(inputs=model_inputs, outputs=[pred_tensor], name=model_name) + + # compile + return model + + +def single_ae(enc_size, + input_shape, + name='single_ae', + prefix=None, + ae_type='dense', # 'dense', or 'conv' + conv_size=None, + input_model=None, + enc_lambda_layers=None, + batch_norm=True, + padding='same', + activation=None, + include_mu_shift_layer=False, + do_vae=False): + """single-layer Autoencoder (i.e. input - encoding - output""" + + # naming + model_name = name + if prefix is None: + prefix = model_name + + if enc_lambda_layers is None: + enc_lambda_layers = [] + + # prepare input + input_name = '%s_input' % prefix + if input_model is None: + assert input_shape is not None, 'input_shape of input_model is necessary' + input_tensor = KL.Input(shape=input_shape, name=input_name) + last_tensor = input_tensor + else: + input_tensor = input_model.input + last_tensor = input_model.output + input_shape = last_tensor.shape.as_list()[1:] + input_nb_feats = last_tensor.shape.as_list()[-1] + + # prepare conv type based on input + ndims = len(input_shape) - 1 + if ae_type == 'conv': + convL = getattr(KL, 'Conv%dD' % ndims) + assert conv_size is not None, 'with conv ae, need conv_size' + conv_kwargs = {'padding': padding, 'activation': activation} + enc_size_str = None + + # if want to go through a dense layer in the middle of the U, need to: + # - flatten last layer if not flat + # - do dense encoding and decoding + # - unflatten (reshape spatially) at end + else: # ae_type == 'dense' + if len(input_shape) > 1: + name = '%s_ae_%s_down_flat' % (prefix, ae_type) + last_tensor = KL.Flatten(name=name)(last_tensor) + convL = conv_kwargs = None + assert len(enc_size) == 1, "enc_size should be of length 1 for dense layer" + enc_size_str = ''.join(['%d_' % d for d in enc_size])[:-1] + + # recall this layer + pre_enc_layer = last_tensor + + # encoding layer + if ae_type == 'dense': + name = '%s_ae_mu_enc_dense_%s' % (prefix, enc_size_str) + last_tensor = KL.Dense(enc_size[0], name=name)(pre_enc_layer) + + else: # convolution + + # convolve then resize. enc_size should be [nb_dim1, nb_dim2, ..., nb_feats] + assert len(enc_size) == len(input_shape), \ + "encoding size does not match input shape %d %d" % (len(enc_size), len(input_shape)) + + if list(enc_size)[:-1] != list(input_shape)[:-1] and \ + all([f is not None for f in input_shape[:-1]]) and \ + all([f is not None for f in enc_size[:-1]]): + + name = '%s_ae_mu_enc_conv' % prefix + last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) + + name = '%s_ae_mu_enc' % prefix + zf = [enc_size[:-1][f] / last_tensor.shape.as_list()[1:-1][f] for f in range(len(enc_size) - 1)] + last_tensor = layers.Resize(zoom_factor=zf, name=name)(last_tensor) + + elif enc_size[-1] is None: # convolutional, but won't tell us bottleneck + name = '%s_ae_mu_enc' % prefix + last_tensor = KL.Lambda(lambda x: x, name=name)(pre_enc_layer) + + else: + name = '%s_ae_mu_enc' % prefix + last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) + + if include_mu_shift_layer: + # shift + name = '%s_ae_mu_shift' % prefix + last_tensor = layers.LocalBias(name=name)(last_tensor) + + # encoding clean-up layers + for layer_fcn in enc_lambda_layers: + lambda_name = layer_fcn.__name__ + name = '%s_ae_mu_%s' % (prefix, lambda_name) + last_tensor = KL.Lambda(layer_fcn, name=name)(last_tensor) + + if batch_norm is not None: + name = '%s_ae_mu_bn' % prefix + last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) + + # have a simple layer that does nothing to have a clear name before sampling + name = '%s_ae_mu' % prefix + last_tensor = KL.Lambda(lambda x: x, name=name)(last_tensor) + + # if doing variational AE, will need the sigma layer as well. + if do_vae: + mu_tensor = last_tensor + + # encoding layer + if ae_type == 'dense': + name = '%s_ae_sigma_enc_dense_%s' % (prefix, enc_size_str) + last_tensor = KL.Dense(enc_size[0], name=name)(pre_enc_layer) + + else: + if list(enc_size)[:-1] != list(input_shape)[:-1] and \ + all([f is not None for f in input_shape[:-1]]) and \ + all([f is not None for f in enc_size[:-1]]): + + assert len(enc_size) - 1 == 2, "Sorry, I have not yet implemented non-2D resizing..." + name = '%s_ae_sigma_enc_conv' % prefix + last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) + + name = '%s_ae_sigma_enc' % prefix + resize_fn = lambda x: tf.image.resize_bilinear(x, enc_size[:-1]) + last_tensor = KL.Lambda(resize_fn, name=name)(last_tensor) + + elif enc_size[-1] is None: # convolutional, but won't tell us bottleneck + name = '%s_ae_sigma_enc' % prefix + last_tensor = convL(pre_enc_layer.shape.as_list()[-1], conv_size, name=name, **conv_kwargs)( + pre_enc_layer) + # cannot use lambda, then mu and sigma will be same layer. + # last_tensor = KL.Lambda(lambda x: x, name=name)(pre_enc_layer) + + else: + name = '%s_ae_sigma_enc' % prefix + last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) + + # encoding clean-up layers + for layer_fcn in enc_lambda_layers: + lambda_name = layer_fcn.__name__ + name = '%s_ae_sigma_%s' % (prefix, lambda_name) + last_tensor = KL.Lambda(layer_fcn, name=name)(last_tensor) + + if batch_norm is not None: + name = '%s_ae_sigma_bn' % prefix + last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) + + # have a simple layer that does nothing to have a clear name before sampling + name = '%s_ae_sigma' % prefix + last_tensor = KL.Lambda(lambda x: x, name=name)(last_tensor) + + logvar_tensor = last_tensor + + # VAE sampling + sampler = _VAESample().sample_z + + name = '%s_ae_sample' % prefix + last_tensor = KL.Lambda(sampler, name=name)([mu_tensor, logvar_tensor]) + + if include_mu_shift_layer: + # shift + name = '%s_ae_sample_shift' % prefix + last_tensor = layers.LocalBias(name=name)(last_tensor) + + # decoding layer + if ae_type == 'dense': + name = '%s_ae_%s_dec_flat_%s' % (prefix, ae_type, enc_size_str) + last_tensor = KL.Dense(np.prod(input_shape), name=name)(last_tensor) + + # unflatten if dense method + if len(input_shape) > 1: + name = '%s_ae_%s_dec' % (prefix, ae_type) + last_tensor = KL.Reshape(input_shape, name=name)(last_tensor) + + else: + + if list(enc_size)[:-1] != list(input_shape)[:-1] and \ + all([f is not None for f in input_shape[:-1]]) and \ + all([f is not None for f in enc_size[:-1]]): + name = '%s_ae_mu_dec' % prefix + zf = [last_tensor.shape.as_list()[1:-1][f] / enc_size[:-1][f] for f in range(len(enc_size) - 1)] + last_tensor = layers.Resize(zoom_factor=zf, name=name)(last_tensor) + + name = '%s_ae_%s_dec' % (prefix, ae_type) + last_tensor = convL(input_nb_feats, conv_size, name=name, **conv_kwargs)(last_tensor) + + if batch_norm is not None: + name = '%s_bn_ae_%s_dec' % (prefix, ae_type) + last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) + + # create the model and return + model = Model(inputs=input_tensor, outputs=[last_tensor], name=model_name) + return model + + +############################################################################### +# Helper function +############################################################################### + +class _VAESample: + def __init__(self): + pass + + def sample_z(self, args): + mu, log_var = args + shape = K.shape(mu) + eps = K.random_normal(shape=shape, mean=0., stddev=1.) + return mu + K.exp(log_var / 2) * eps diff --git a/nobrainer/ext/neuron/utils.py b/nobrainer/ext/neuron/utils.py new file mode 100644 index 00000000..1162b51c --- /dev/null +++ b/nobrainer/ext/neuron/utils.py @@ -0,0 +1,548 @@ +""" +tensorflow/keras utilities for the neuron project + +If you use this code, please cite +Dalca AV, Guttag J, Sabuncu MR +Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, +CVPR 2018 + +or for the transformation/interpolation related functions: + +Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration +Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu +MICCAI 2018. + +Contact: adalca [at] csail [dot] mit [dot] edu +License: GPLv3 +""" + +import itertools +import numpy as np +import tensorflow as tf +import keras.backend as K + + +def interpn(vol, loc, interp_method='linear'): + """ + N-D gridded interpolation in tensorflow + + vol can have more dimensions than loc[i], in which case loc[i] acts as a slice + for the first dimensions + + Parameters: + vol: volume with size vol_shape or [*vol_shape, nb_features] + loc: an N-long list of N-D Tensors (the interpolation locations) for the new grid + each tensor has to have the same size (but not nec. same size as vol) + or a tensor of size [*new_vol_shape, D] + interp_method: interpolation type 'linear' (default) or 'nearest' + + Returns: + new interpolated volume of the same size as the entries in loc + """ + + if isinstance(loc, (list, tuple)): + loc = tf.stack(loc, -1) + nb_dims = loc.shape[-1] + + if len(vol.shape) not in [nb_dims, nb_dims + 1]: + raise Exception("Number of loc Tensors %d does not match volume dimension %d" + % (nb_dims, len(vol.shape[:-1]))) + + if nb_dims > len(vol.shape): + raise Exception("Loc dimension %d does not match volume dimension %d" + % (nb_dims, len(vol.shape))) + + if len(vol.shape) == nb_dims: + vol = K.expand_dims(vol, -1) + + # flatten and float location Tensors + loc = tf.cast(loc, 'float32') + + if isinstance(vol.shape, tf.TensorShape): + volshape = vol.shape.as_list() + else: + volshape = vol.shape + + # interpolate + if interp_method == 'linear': + loc0 = tf.floor(loc) + + # clip values + max_loc = [d - 1 for d in vol.get_shape().as_list()] + clipped_loc = [tf.clip_by_value(loc[..., d], 0, max_loc[d]) for d in range(nb_dims)] + loc0lst = [tf.clip_by_value(loc0[..., d], 0, max_loc[d]) for d in range(nb_dims)] + + # get other end of point cube + loc1 = [tf.clip_by_value(loc0lst[d] + 1, 0, max_loc[d]) for d in range(nb_dims)] + locs = [[tf.cast(f, 'int32') for f in loc0lst], [tf.cast(f, 'int32') for f in loc1]] + + # compute the difference between the upper value and the original value + # differences are basically 1 - (pt - floor(pt)) + # because: floor(pt) + 1 - pt = 1 + (floor(pt) - pt) = 1 - (pt - floor(pt)) + diff_loc1 = [loc1[d] - clipped_loc[d] for d in range(nb_dims)] + diff_loc0 = [1 - d for d in diff_loc1] + weights_loc = [diff_loc1, diff_loc0] # note reverse ordering since weights are inverse of diff. + + # go through all the cube corners, indexed by a ND binary vector + # e.g. [0, 0] means this "first" corner in a 2-D "cube" + cube_pts = list(itertools.product([0, 1], repeat=nb_dims)) + interp_vol = 0 + + for c in cube_pts: + # get nd values + # note re: indices above volumes via https://github.com/tensorflow/tensorflow/issues/15091 + # It works on GPU because we do not perform index validation checking on GPU -- it's too + # expensive. Instead we fill the output with zero for the corresponding value. The CPU + # version caught the bad index and returned the appropriate error. + subs = [locs[c[d]][d] for d in range(nb_dims)] + + idx = sub2ind(vol.shape[:-1], subs) + vol_val = tf.gather(tf.reshape(vol, [-1, volshape[-1]]), idx) + + # get the weight of this cube_pt based on the distance + # if c[d] is 0 --> want weight = 1 - (pt - floor[pt]) = diff_loc1 + # if c[d] is 1 --> want weight = pt - floor[pt] = diff_loc0 + wts_lst = [weights_loc[c[d]][d] for d in range(nb_dims)] + wt = prod_n(wts_lst) + wt = K.expand_dims(wt, -1) + + # compute final weighted value for each cube corner + interp_vol += wt * vol_val + + else: + assert interp_method == 'nearest' + roundloc = tf.cast(tf.round(loc), 'int32') + + # clip values + max_loc = [tf.cast(d - 1, 'int32') for d in vol.shape] + roundloc = [tf.clip_by_value(roundloc[..., d], 0, max_loc[d]) for d in range(nb_dims)] + + # get values + idx = sub2ind(vol.shape[:-1], roundloc) + interp_vol = tf.gather(tf.reshape(vol, [-1, vol.shape[-1]]), idx) + + return interp_vol + + +def resize(vol, zoom_factor, new_shape, interp_method='linear'): + """ + if zoom_factor is a list, it will determine the ndims, in which case vol has to be of length ndims or ndims + 1 + + if zoom_factor is an integer, then vol must be of length ndims + 1 + + new_shape should be a list of length ndims + + """ + + if isinstance(zoom_factor, (list, tuple)): + ndims = len(zoom_factor) + vol_shape = vol.shape[:ndims] + assert len(vol_shape) in (ndims, ndims + 1), \ + "zoom_factor length %d does not match ndims %d" % (len(vol_shape), ndims) + else: + vol_shape = vol.shape[:-1] + ndims = len(vol_shape) + zoom_factor = [zoom_factor] * ndims + + # get grid for new shape + grid = volshape_to_ndgrid(new_shape) + grid = [tf.cast(f, 'float32') for f in grid] + offset = [grid[f] / zoom_factor[f] - grid[f] for f in range(ndims)] + offset = tf.stack(offset, ndims) + + # transform + return transform(vol, offset, interp_method) + + +zoom = resize + + +def affine_to_shift(affine_matrix, volshape, shift_center=True, indexing='ij'): + """ + transform an affine matrix to a dense location shift tensor in tensorflow + + Algorithm: + - get grid and shift grid to be centered at the center of the image (optionally) + - apply affine matrix to each index. + - subtract grid + + Parameters: + affine_matrix: ND+1 x ND+1 or ND x ND+1 matrix (Tensor) + volshape: 1xN Nd Tensor of the size of the volume. + shift_center (optional) + indexing + + Returns: + shift field (Tensor) of size *volshape x N + """ + + if isinstance(volshape, tf.TensorShape): + volshape = volshape.as_list() + + if affine_matrix.dtype != 'float32': + affine_matrix = tf.cast(affine_matrix, 'float32') + + nb_dims = len(volshape) + + if len(affine_matrix.shape) == 1: + if len(affine_matrix) != (nb_dims * (nb_dims + 1)): + raise ValueError('transform is supposed a vector of len ndims * (ndims + 1).' + 'Got len %d' % len(affine_matrix)) + + affine_matrix = tf.reshape(affine_matrix, [nb_dims, nb_dims + 1]) + + if not (affine_matrix.shape[0] in [nb_dims, nb_dims + 1] and affine_matrix.shape[1] == (nb_dims + 1)): + raise Exception('Affine matrix shape should match' + '%d+1 x %d+1 or ' % (nb_dims, nb_dims) + + '%d x %d+1.' % (nb_dims, nb_dims) + + 'Got: ' + str(volshape)) + + # list of volume ndgrid + # N-long list, each entry of shape volshape + mesh = volshape_to_meshgrid(volshape, indexing=indexing) + mesh = [tf.cast(f, 'float32') for f in mesh] + + if shift_center: + mesh = [mesh[f] - (volshape[f] - 1) / 2 for f in range(len(volshape))] + + # add an all-ones entry and transform into a large matrix + flat_mesh = [flatten(f) for f in mesh] + flat_mesh.append(tf.ones(flat_mesh[0].shape, dtype='float32')) + mesh_matrix = tf.transpose(tf.stack(flat_mesh, axis=1)) # 4 x nb_voxels + + # compute locations + loc_matrix = tf.matmul(affine_matrix, mesh_matrix) # N+1 x nb_voxels + loc_matrix = tf.transpose(loc_matrix[:nb_dims, :]) # nb_voxels x N + loc = tf.reshape(loc_matrix, list(volshape) + [nb_dims]) # *volshape x N + + # get shifts and return + return loc - tf.stack(mesh, axis=nb_dims) + + +def combine_non_linear_and_aff_to_shift(transform_list, volshape, shift_center=True, indexing='ij'): + """ + transform an affine matrix to a dense location shift tensor in tensorflow + + Algorithm: + - get grid and shift grid to be centered at the center of the image (optionally) + - apply affine matrix to each index. + - subtract grid + + Parameters: + transform_list: list of non-linear tensor (size of volshape) and affine ND+1 x ND+1 or ND x ND+1 tensor + volshape: 1xN Nd Tensor of the size of the volume. + shift_center (optional) + indexing + + Returns: + shift field (Tensor) of size *volshape x N + """ + + if isinstance(volshape, tf.TensorShape): + volshape = volshape.as_list() + + # convert transforms to floats + for i in range(len(transform_list)): + if transform_list[i].dtype != 'float32': + transform_list[i] = tf.cast(transform_list[i], 'float32') + + nb_dims = len(volshape) + + # transform affine to matrix if given as vector + if len(transform_list[1].shape) == 1: + if len(transform_list[1]) != (nb_dims * (nb_dims + 1)): + raise ValueError('transform is supposed a vector of len ndims * (ndims + 1).' + 'Got len %d' % len(transform_list[1])) + + transform_list[1] = tf.reshape(transform_list[1], [nb_dims, nb_dims + 1]) + + if not (transform_list[1].shape[0] in [nb_dims, nb_dims + 1] and transform_list[1].shape[1] == (nb_dims + 1)): + raise Exception('Affine matrix shape should match' + '%d+1 x %d+1 or ' % (nb_dims, nb_dims) + + '%d x %d+1.' % (nb_dims, nb_dims) + + 'Got: ' + str(volshape)) + + # list of volume ndgrid + # N-long list, each entry of shape volshape + mesh = volshape_to_meshgrid(volshape, indexing=indexing) + mesh = [tf.cast(f, 'float32') for f in mesh] + + if shift_center: + mesh = [mesh[f] - (volshape[f] - 1) / 2 for f in range(len(volshape))] + + # add an all-ones entry and transform into a large matrix + # non_linear_mesh = tf.unstack(transform_list[0], axis=3) + non_linear_mesh = tf.unstack(transform_list[0], axis=-1) + flat_mesh = [flatten(mesh[i]+non_linear_mesh[i]) for i in range(len(mesh))] + flat_mesh.append(tf.ones(flat_mesh[0].shape, dtype='float32')) + mesh_matrix = tf.transpose(tf.stack(flat_mesh, axis=1)) # N+1 x nb_voxels + + # compute locations + loc_matrix = tf.matmul(transform_list[1], mesh_matrix) # N+1 x nb_voxels + loc_matrix = tf.transpose(loc_matrix[:nb_dims, :]) # nb_voxels x N + loc = tf.reshape(loc_matrix, list(volshape) + [nb_dims]) # *volshape x N + + # get shifts and return + return loc - tf.stack(mesh, axis=nb_dims) + + +def transform(vol, loc_shift, interp_method='linear', indexing='ij'): + """ + transform interpolation N-D volumes (features) given shifts at each location in tensorflow + + Essentially interpolates volume vol at locations determined by loc_shift. + This is a spatial transform in the sense that at location [x] we now have the data from, + [x + shift] so we've moved data. + + Parameters: + vol: volume with size vol_shape or [*vol_shape, nb_features] + loc_shift: shift volume [*new_vol_shape, N] + interp_method (default:'linear'): 'linear', 'nearest' + indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian). + In general, prefer to leave this 'ij' + + Return: + new interpolated volumes in the same size as loc_shift[0] + """ + + # parse shapes + if isinstance(loc_shift.shape, tf.TensorShape): + volshape = loc_shift.shape[:-1].as_list() + else: + volshape = loc_shift.shape[:-1] + nb_dims = len(volshape) + + # location should be meshed and delta + mesh = volshape_to_meshgrid(volshape, indexing=indexing) # volume mesh + loc = [tf.cast(mesh[d], 'float32') + loc_shift[..., d] for d in range(nb_dims)] + + # test single + return interpn(vol, loc, interp_method=interp_method) + + +def integrate_vec(vec, time_dep=False, method='ss', **kwargs): + """ + Integrate (stationary of time-dependent) vector field (N-D Tensor) in tensorflow + + Aside from directly using tensorflow's numerical integration odeint(), also implements + "scaling and squaring", and quadrature. Note that the diff. equation given to odeint + is the one used in quadrature. + + Parameters: + vec: the Tensor field to integrate. + If vol_size is the size of the intrinsic volume, and vol_ndim = len(vol_size), + then vector shape (vec_shape) should be + [vol_size, vol_ndim] (if stationary) + [vol_size, vol_ndim, nb_time_steps] (if time dependent) + time_dep: bool whether vector is time dependent + method: 'scaling_and_squaring' or 'ss' or 'quadrature' + + if using 'scaling_and_squaring': currently only supports integrating to time point 1. + nb_steps int number of steps. Note that this means the vec field gets broken own to 2**nb_steps. + so nb_steps of 0 means integral = vec. + + Returns: + int_vec: integral of vector field with same shape as the input + """ + + if method not in ['ss', 'scaling_and_squaring', 'ode', 'quadrature']: + raise ValueError("method has to be 'scaling_and_squaring' or 'ode'. found: %s" % method) + + if method in ['ss', 'scaling_and_squaring']: + nb_steps = kwargs['nb_steps'] + assert nb_steps >= 0, 'nb_steps should be >= 0, found: %d' % nb_steps + + if time_dep: + svec = K.permute_dimensions(vec, [-1, *range(0, vec.shape[-1] - 1)]) + assert 2 ** nb_steps == svec.shape[0], "2**nb_steps and vector shape don't match" + + svec = svec / (2 ** nb_steps) + for _ in range(nb_steps): + svec = svec[0::2] + tf.map_fn(transform, svec[1::2, :], svec[0::2, :]) + + disp = svec[0, :] + + else: + vec = vec / (2 ** nb_steps) + for _ in range(nb_steps): + vec += transform(vec, vec) + disp = vec + + else: # method == 'quadrature': + nb_steps = kwargs['nb_steps'] + assert nb_steps >= 1, 'nb_steps should be >= 1, found: %d' % nb_steps + + vec = vec / nb_steps + + if time_dep: + disp = vec[..., 0] + for si in range(nb_steps - 1): + disp += transform(vec[..., si + 1], disp) + else: + disp = vec + for _ in range(nb_steps - 1): + disp += transform(vec, disp) + + return disp + + +def volshape_to_ndgrid(volshape, **kwargs): + """ + compute Tensor ndgrid from a volume size + + Parameters: + volshape: the volume size + + Returns: + A list of Tensors + + See Also: + ndgrid + """ + + isint = [float(d).is_integer() for d in volshape] + if not all(isint): + raise ValueError("volshape needs to be a list of integers") + + linvec = [tf.range(0, d) for d in volshape] + return ndgrid(*linvec, **kwargs) + + +def volshape_to_meshgrid(volshape, **kwargs): + """ + compute Tensor meshgrid from a volume size + + Parameters: + volshape: the volume size + + Returns: + A list of Tensors + + See Also: + tf.meshgrid, meshgrid, ndgrid, volshape_to_ndgrid + """ + + isint = [float(d).is_integer() for d in volshape] + if not all(isint): + raise ValueError("volshape needs to be a list of integers") + + linvec = [tf.range(0, d) for d in volshape] + return meshgrid(*linvec, **kwargs) + + +def ndgrid(*args, **kwargs): + """ + broadcast Tensors on an N-D grid with ij indexing + uses meshgrid with ij indexing + + Parameters: + *args: Tensors with rank 1 + **args: "name" (optional) + + Returns: + A list of Tensors + + """ + return meshgrid(*args, indexing='ij', **kwargs) + + +def meshgrid(*args, **kwargs): + """ + + meshgrid code that builds on (copies) tensorflow's meshgrid but dramatically + improves runtime by changing the last step to tiling instead of multiplication. + https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/python/ops/array_ops.py#L1921 + + Broadcasts parameters for evaluation on an N-D grid. + Given N one-dimensional coordinate arrays `*args`, returns a list `outputs` + of N-D coordinate arrays for evaluating expressions on an N-D grid. + Notes: + `meshgrid` supports cartesian ('xy') and matrix ('ij') indexing conventions. + When the `indexing` argument is set to 'xy' (the default), the broadcasting + instructions for the first two dimensions are swapped. + Examples: + Calling `X, Y = meshgrid(x, y)` with the tensors + ```python + x = [1, 2, 3] + y = [4, 5, 6] + X, Y = meshgrid(x, y) + # X = [[1, 2, 3], + # [1, 2, 3], + # [1, 2, 3]] + # Y = [[4, 4, 4], + # [5, 5, 5], + # [6, 6, 6]] + ``` + Args: + *args: `Tensor`s with rank 1. + **kwargs: + - indexing: Either 'xy' or 'ij' (optional, default: 'xy'). + - name: A name for the operation (optional). + Returns: + outputs: A list of N `Tensor`s with rank N. + Raises: + TypeError: When no keyword arguments (kwargs) are passed. + ValueError: When indexing keyword argument is not one of `xy` or `ij`. + """ + + indexing = kwargs.pop("indexing", "xy") + if kwargs: + key = list(kwargs.keys())[0] + raise TypeError("'{}' is an invalid keyword argument " + "for this function".format(key)) + + if indexing not in ("xy", "ij"): + raise ValueError("indexing parameter must be either 'xy' or 'ij'") + + # with ops.name_scope(name, "meshgrid", args) as name: + ndim = len(args) + s0 = (1,) * ndim + + # Prepare reshape by inserting dimensions with size 1 where needed + output = [] + for i, x in enumerate(args): + output.append(tf.reshape(tf.stack(x), (s0[:i] + (-1,) + s0[i + 1::]))) + # Create parameters for broadcasting each tensor to the full size + shapes = [tf.size(x) for x in args] + sz = [x.get_shape().as_list()[0] for x in args] + + # output_dtype = tf.convert_to_tensor(args[0]).dtype.base_dtype + if indexing == "xy" and ndim > 1: + output[0] = tf.reshape(output[0], (1, -1) + (1,) * (ndim - 2)) + output[1] = tf.reshape(output[1], (-1, 1) + (1,) * (ndim - 2)) + shapes[0], shapes[1] = shapes[1], shapes[0] + sz[0], sz[1] = sz[1], sz[0] + + for i in range(len(output)): + stack_sz = [*sz[:i], 1, *sz[(i + 1):]] + if indexing == 'xy' and ndim > 1 and i < 2: + stack_sz[0], stack_sz[1] = stack_sz[1], stack_sz[0] + output[i] = tf.tile(output[i], tf.stack(stack_sz)) + return output + + +def flatten(v): + """flatten Tensor v""" + + return tf.reshape(v, [-1]) + + +def prod_n(lst): + prod = lst[0] + for p in lst[1:]: + prod *= p + return prod + + +def sub2ind(siz, subs): + """assumes column-order major""" + # subs is a list + assert len(siz) == len(subs), 'found inconsistent siz and subs: %d %d' % (len(siz), len(subs)) + + k = np.cumprod(siz[::-1]) + + ndx = subs[-1] + for i, v in enumerate(subs[:-1][::-1]): + ndx = ndx + v * k[i] + + return ndx diff --git a/nobrainer/models/__init__.py b/nobrainer/models/__init__.py index 68a797d3..645bf009 100644 --- a/nobrainer/models/__init__.py +++ b/nobrainer/models/__init__.py @@ -12,6 +12,7 @@ from .progressivegan import progressivegan from .unet import unet from .unetr import unetr +from .lab2im_model import lab2im_model __all__ = ["get", "list_available_models"] @@ -27,7 +28,8 @@ "attention_unet_with_inception": attention_unet_with_inception, "unetr": unetr, "variational_meshnet": variational_meshnet, - "bayesian_vnet": bayesian_vnet + "bayesian_vnet": bayesian_vnet, + "synth_generator": lab2im_model } diff --git a/nobrainer/models/lab2im_model.py b/nobrainer/models/lab2im_model.py index 743626cf..f32c5b5a 100644 --- a/nobrainer/models/lab2im_model.py +++ b/nobrainer/models/lab2im_model.py @@ -20,9 +20,9 @@ from keras.models import Model # project imports -from ext.lab2im import utils -from ext.lab2im import layers -from ext.lab2im.edit_tensors import resample_tensor, blurring_sigma_for_downsampling +from nobrainer.ext.lab2im import utils +from nobrainer.ext.lab2im import layers +from nobrainer.ext.lab2im.edit_tensors import resample_tensor, blurring_sigma_for_downsampling def lab2im_model(labels_shape,