Skip to content

Commit

Permalink
Change multi_gpu to True by default, pass args through everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
ohinds committed Sep 18, 2023
1 parent 07c14f4 commit 1a2a52b
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
14 changes: 8 additions & 6 deletions nobrainer/processing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class BaseEstimator:
state_variables = []
model_ = None

def __init__(self, checkpoint_filepath=None, multi_gpu=False):
def __init__(self, checkpoint_filepath=None, multi_gpu=True):
self.checkpoint_tracker = None
if checkpoint_filepath:
from .checkpoint import CheckpointTracker
Expand Down Expand Up @@ -54,7 +54,8 @@ def save(self, save_dir):
pk.dump(model_info, fp)

@classmethod
def load(cls, model_dir, multi_gpu=False, custom_objects=None, compile=False):
def load(cls, model_dir, *args, **kwargs):
breakpoint()
"""Loads a trained model from a save directory"""
model_dir = Path(str(model_dir).rstrip(os.pathsep))
assert model_dir.exists() and model_dir.is_dir()
Expand All @@ -64,7 +65,8 @@ def load(cls, model_dir, multi_gpu=False, custom_objects=None, compile=False):
if model_info["classname"] != cls.__name__:
raise ValueError(f"Model class does not match {cls.__name__}")
del model_info["classname"]
klass = cls(**model_info["__init__"])
model_info["__init__"].update(kwargs)
klass = cls(*args, **model_info["__init__"])
del model_info["__init__"]
for key, value in model_info.items():
setattr(klass, key, value)
Expand All @@ -77,7 +79,7 @@ def load(cls, model_dir, multi_gpu=False, custom_objects=None, compile=False):
return klass

@classmethod
def init_with_checkpoints(cls, model_name, checkpoint_filepath):
def init_with_checkpoints(cls, model_name, checkpoint_filepath, *args, **kwargs):
"""Initialize a model for training, either from the latest
checkpoint found, or from scratch if no checkpoints are
found. This is useful for long-running model fits that may be
Expand All @@ -96,9 +98,9 @@ def init_with_checkpoints(cls, model_name, checkpoint_filepath):
from .checkpoint import CheckpointTracker

checkpoint_tracker = CheckpointTracker(cls, checkpoint_filepath)
estimator = checkpoint_tracker.load()
estimator = checkpoint_tracker.load(*args, **kwargs)
if not estimator:
estimator = cls(model_name)
estimator = cls(model_name, *args, **kwargs)
estimator.checkpoint_tracker = checkpoint_tracker
checkpoint_tracker.estimator = estimator
return estimator
Expand Down
4 changes: 2 additions & 2 deletions nobrainer/processing/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def save(self, directory):
logging.info(f"Saving to dir {directory}")
self.estimator.save(directory)

def load(self):
def load(self, *args, **kwargs):
"""Loads the most-recently created checkpoint from the
checkpoint directory.
"""
Expand All @@ -44,6 +44,6 @@ def load(self):
return None

latest = max(checkpoints, key=os.path.getctime)
self.estimator = self.estimator.load(latest)
self.estimator = self.estimator.load(latest, *args, **kwargs)
logging.info(f"Loaded estimator from {latest}.")
return self.estimator
2 changes: 1 addition & 1 deletion nobrainer/processing/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
dimensionality=3,
g_fmap_base=1024,
d_fmap_base=1024,
multi_gpu=False,
multi_gpu=True,
):
super().__init__(multi_gpu=multi_gpu)
self.model_ = None
Expand Down
2 changes: 1 addition & 1 deletion nobrainer/processing/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class Segmentation(BaseEstimator):
state_variables = ["block_shape_", "volume_shape_", "scalar_labels_"]

def __init__(
self, base_model, model_args=None, checkpoint_filepath=None, multi_gpu=False
self, base_model, model_args=None, checkpoint_filepath=None, multi_gpu=True
):
super().__init__(checkpoint_filepath=checkpoint_filepath, multi_gpu=multi_gpu)

Expand Down

0 comments on commit 1a2a52b

Please sign in to comment.