Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 14, 2024
1 parent 14d11ca commit af67744
Show file tree
Hide file tree
Showing 17 changed files with 4,189 additions and 1,959 deletions.
95 changes: 65 additions & 30 deletions nobrainer/ext/SynthSeg/model_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
License.
"""


# python imports
import numpy as np
import numpy.random as npr
Expand All @@ -22,17 +21,19 @@
from nobrainer.ext.lab2im import utils


def build_model_inputs(path_label_maps,
n_labels,
batchsize=1,
n_channels=1,
subjects_prob=None,
generation_classes=None,
prior_distributions='uniform',
prior_means=None,
prior_stds=None,
use_specific_stats_for_channel=False,
mix_prior_and_random=False):
def build_model_inputs(
path_label_maps,
n_labels,
batchsize=1,
n_channels=1,
subjects_prob=None,
generation_classes=None,
prior_distributions="uniform",
prior_means=None,
prior_stds=None,
use_specific_stats_for_channel=False,
mix_prior_and_random=False,
):
"""
This function builds a generator that will be used to give the necessary inputs to the label_to_image model: the
input label maps, as well as the means and stds defining the parameters of the GMM (which change at each minibatch).
Expand Down Expand Up @@ -86,7 +87,9 @@ def build_model_inputs(path_label_maps,
while True:

# randomly pick as many images as batchsize
indices = npr.choice(np.arange(len(path_label_maps)), size=batchsize, p=subjects_prob)
indices = npr.choice(
np.arange(len(path_label_maps)), size=batchsize, p=subjects_prob
)

# initialise input lists
list_label_maps = []
Expand All @@ -96,8 +99,10 @@ def build_model_inputs(path_label_maps,
for idx in indices:

# load input label map
lab = utils.load_volume(path_label_maps[idx], dtype='int', aff_ref=np.eye(4))
if (npr.uniform() > 0.7) & ('seg_cerebral' in path_label_maps[idx]):
lab = utils.load_volume(
path_label_maps[idx], dtype="int", aff_ref=np.eye(4)
)
if (npr.uniform() > 0.7) & ("seg_cerebral" in path_label_maps[idx]):
lab[lab == 24] = 0

# add label map to inputs
Expand All @@ -112,42 +117,72 @@ def build_model_inputs(path_label_maps,
if isinstance(prior_means, np.ndarray):
if (prior_means.shape[0] > 2) & use_specific_stats_for_channel:
if prior_means.shape[0] / 2 != n_channels:
raise ValueError("the number of blocks in prior_means does not match n_channels. This "
"message is printed because use_specific_stats_for_channel is True.")
tmp_prior_means = prior_means[2 * channel:2 * channel + 2, :]
raise ValueError(
"the number of blocks in prior_means does not match n_channels. This "
"message is printed because use_specific_stats_for_channel is True."
)
tmp_prior_means = prior_means[2 * channel : 2 * channel + 2, :]
else:
tmp_prior_means = prior_means
else:
tmp_prior_means = prior_means
if (prior_means is not None) & mix_prior_and_random & (npr.uniform() > 0.5):
if (
(prior_means is not None)
& mix_prior_and_random
& (npr.uniform() > 0.5)
):
tmp_prior_means = None
if isinstance(prior_stds, np.ndarray):
if (prior_stds.shape[0] > 2) & use_specific_stats_for_channel:
if prior_stds.shape[0] / 2 != n_channels:
raise ValueError("the number of blocks in prior_stds does not match n_channels. This "
"message is printed because use_specific_stats_for_channel is True.")
tmp_prior_stds = prior_stds[2 * channel:2 * channel + 2, :]
raise ValueError(
"the number of blocks in prior_stds does not match n_channels. This "
"message is printed because use_specific_stats_for_channel is True."
)
tmp_prior_stds = prior_stds[2 * channel : 2 * channel + 2, :]
else:
tmp_prior_stds = prior_stds
else:
tmp_prior_stds = prior_stds
if (prior_stds is not None) & mix_prior_and_random & (npr.uniform() > 0.5):
if (
(prior_stds is not None)
& mix_prior_and_random
& (npr.uniform() > 0.5)
):
tmp_prior_stds = None

# draw means and std devs from priors
tmp_classes_means = utils.draw_value_from_distribution(tmp_prior_means, n_classes, prior_distributions,
125., 125., positive_only=True)
tmp_classes_stds = utils.draw_value_from_distribution(tmp_prior_stds, n_classes, prior_distributions,
15., 15., positive_only=True)
tmp_classes_means = utils.draw_value_from_distribution(
tmp_prior_means,
n_classes,
prior_distributions,
125.0,
125.0,
positive_only=True,
)
tmp_classes_stds = utils.draw_value_from_distribution(
tmp_prior_stds,
n_classes,
prior_distributions,
15.0,
15.0,
positive_only=True,
)
random_coef = npr.uniform()
if random_coef > 0.95: # reset the background to 0 in 5% of cases
tmp_classes_means[0] = 0
tmp_classes_stds[0] = 0
elif random_coef > 0.7: # reset the background to low Gaussian in 25% of cases
elif (
random_coef > 0.7
): # reset the background to low Gaussian in 25% of cases
tmp_classes_means[0] = npr.uniform(0, 15)
tmp_classes_stds[0] = npr.uniform(0, 5)
tmp_means = utils.add_axis(tmp_classes_means[generation_classes], axis=[0, -1])
tmp_stds = utils.add_axis(tmp_classes_stds[generation_classes], axis=[0, -1])
tmp_means = utils.add_axis(
tmp_classes_means[generation_classes], axis=[0, -1]
)
tmp_stds = utils.add_axis(
tmp_classes_stds[generation_classes], axis=[0, -1]
)
means = np.concatenate([means, tmp_means], axis=-1)
stds = np.concatenate([stds, tmp_stds], axis=-1)
list_means.append(means)
Expand Down
7 changes: 1 addition & 6 deletions nobrainer/ext/lab2im/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1 @@
from . import edit_tensors
from . import edit_volumes
from . import image_generator
from . import lab2im_model
from . import layers
from . import utils
from . import edit_tensors, edit_volumes, image_generator, lab2im_model, layers, utils
Loading

0 comments on commit af67744

Please sign in to comment.