Skip to content

Commit

Permalink
resolved #303, #306
Browse files Browse the repository at this point in the history
  • Loading branch information
hvgazula committed Mar 22, 2024
1 parent da5e0bc commit 865a1a2
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 20 deletions.
44 changes: 28 additions & 16 deletions nobrainer/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pprint import pprint

from .attention_unet import attention_unet
from .attention_unet_with_inception import attention_unet_with_inception
from .autoencoder import autoencoder
Expand All @@ -10,6 +12,22 @@
from .unet import unet
from .unetr import unetr

__all__ = ["get", "list_available_models"]

_models = {
"highresnet": highresnet,
"meshnet": meshnet,
"unet": unet,
"autoencoder": autoencoder,
"progressivegan": progressivegan,
"progressiveae": progressiveae,
"dcgan": dcgan,
"attention_unet": attention_unet,
"attention_unet_with_inception": attention_unet_with_inception,
"unetr": unetr,
"variational_meshnet": variational_meshnet,
}


def get(name):
"""Return callable that creates a particular `tf.keras.Model`.
Expand All @@ -25,24 +43,18 @@ def get(name):
if not isinstance(name, str):
raise ValueError("Model name must be a string.")

models = {
"highresnet": highresnet,
"meshnet": meshnet,
"unet": unet,
"autoencoder": autoencoder,
"progressivegan": progressivegan,
"progressiveae": progressiveae,
"dcgan": dcgan,
"attention_unet": attention_unet,
"attention_unet_with_inception": attention_unet_with_inception,
"unetr": unetr,
"variational_meshnet": variational_meshnet,
}

try:
return models[name.lower()]
return _models[name.lower()]
except KeyError:
avail = ", ".join(models.keys())
avail = ", ".join(_models.keys())
raise ValueError(
"Unknown model: '{}'. Available models are {}.".format(name, avail)
)


def available_models():
return list(_models)

def list_available_models():
pprint(available_models())

20 changes: 16 additions & 4 deletions nobrainer/processing/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .base import BaseEstimator
from .. import losses, metrics
from ..models import available_models, list_available_models

logging.getLogger().setLevel(logging.INFO)

Expand All @@ -23,6 +24,14 @@ def __init__(
self.base_model = base_model.__name__
else:
self.base_model = base_model

if self.base_model not in available_models():
raise ValueError(
"Unknown model: '{}'. Available models are {}.".format(
self.base_model, available_models()
)
)

self.model_ = None
self.model_args = model_args or {}
self.block_shape_ = None
Expand Down Expand Up @@ -72,7 +81,7 @@ def _compile():
metrics=metrics,
)

if self.model is None:
if self.model_ is None:
mod = importlib.import_module("..models", "nobrainer.processing")
base_model = getattr(mod, self.base_model)
if batch_size % self.strategy.num_replicas_in_sync:
Expand All @@ -97,9 +106,9 @@ def _compile():
epochs=epochs,
steps_per_epoch=dataset_train.get_steps_per_epoch(),
validation_data=dataset_validate.dataset if dataset_validate else None,
validation_steps=dataset_validate.get_steps_per_epoch()
if dataset_validate
else None,
validation_steps=(
dataset_validate.get_steps_per_epoch() if dataset_validate else None
),
callbacks=callbacks,
verbose=verbose,
)
Expand All @@ -119,3 +128,6 @@ def predict(self, x, batch_size=1, normalizer=None):
batch_size=batch_size,
normalizer=normalizer,
)

def list_available_models(self):
list_available_models()

0 comments on commit 865a1a2

Please sign in to comment.