Skip to content

Commit

Permalink
resolved #303
Browse files Browse the repository at this point in the history
  • Loading branch information
hvgazula committed Mar 22, 2024
1 parent da5e0bc commit 2caa2e6
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 19 deletions.
44 changes: 28 additions & 16 deletions nobrainer/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pprint import pprint

from .attention_unet import attention_unet
from .attention_unet_with_inception import attention_unet_with_inception
from .autoencoder import autoencoder
Expand All @@ -10,6 +12,22 @@
from .unet import unet
from .unetr import unetr

__all__ = ["get", "list_available_models"]

_models = {
"highresnet": highresnet,
"meshnet": meshnet,
"unet": unet,
"autoencoder": autoencoder,
"progressivegan": progressivegan,
"progressiveae": progressiveae,
"dcgan": dcgan,
"attention_unet": attention_unet,
"attention_unet_with_inception": attention_unet_with_inception,
"unetr": unetr,
"variational_meshnet": variational_meshnet,
}


def get(name):
"""Return callable that creates a particular `tf.keras.Model`.
Expand All @@ -25,24 +43,18 @@ def get(name):
if not isinstance(name, str):
raise ValueError("Model name must be a string.")

models = {
"highresnet": highresnet,
"meshnet": meshnet,
"unet": unet,
"autoencoder": autoencoder,
"progressivegan": progressivegan,
"progressiveae": progressiveae,
"dcgan": dcgan,
"attention_unet": attention_unet,
"attention_unet_with_inception": attention_unet_with_inception,
"unetr": unetr,
"variational_meshnet": variational_meshnet,
}

try:
return models[name.lower()]
return _models[name.lower()]
except KeyError:
avail = ", ".join(models.keys())
avail = ", ".join(_models.keys())
raise ValueError(
"Unknown model: '{}'. Available models are {}.".format(name, avail)
)


def available_models():
return list(_models)

def list_available_models():
pprint(available_models())

23 changes: 20 additions & 3 deletions nobrainer/processing/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .base import BaseEstimator
from .. import losses, metrics
from ..models import available_models, list_available_models

logging.getLogger().setLevel(logging.INFO)

Expand All @@ -23,12 +24,25 @@ def __init__(
self.base_model = base_model.__name__
else:
self.base_model = base_model

if self.base_model and self.base_model not in available_models():
raise ValueError(
"Unknown model: '{}'. Available models are {}.".format(
self.base_model, available_models()
)
)

self.model_ = None
self.model_args = model_args or {}
self.block_shape_ = None
self.volume_shape_ = None
self.scalar_labels_ = None

def add_model(self, base_model, model_args=None):
"""Add a segmentation model"""
self.base_model = base_model
self.model_args = model_args or {}

def fit(
self,
dataset_train,
Expand Down Expand Up @@ -97,9 +111,9 @@ def _compile():
epochs=epochs,
steps_per_epoch=dataset_train.get_steps_per_epoch(),
validation_data=dataset_validate.dataset if dataset_validate else None,
validation_steps=dataset_validate.get_steps_per_epoch()
if dataset_validate
else None,
validation_steps=(
dataset_validate.get_steps_per_epoch() if dataset_validate else None
),
callbacks=callbacks,
verbose=verbose,
)
Expand All @@ -119,3 +133,6 @@ def predict(self, x, batch_size=1, normalizer=None):
batch_size=batch_size,
normalizer=normalizer,
)
@classmethod
def list_available_models(cls):
list_available_models()

0 comments on commit 2caa2e6

Please sign in to comment.