From 4862c9d57a6aab850cf5409c3c43f8efe5c6dd1b Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Mon, 6 May 2024 13:00:01 -0400 Subject: [PATCH] CLN: Refactor model abstraction, fix #737 #726 (#753) * Change type annotations on models.definition.validate to just be Type, not Type[ModelDefinition] * Rewrite models.base.Model to not subclass LightningModule, and instead to have class variables definition and family, that are *both* used by the from_config method to make a new instance of the family with instances of the definition's attributes (trust me, this makes perfect sense) * Change model decorator to not subclass family, and to instead subclass Model, then make a new instance of the subclass, add that instance to the registry, and then return the instance. This makes it possible to call the 'from_config' method of the instance and get a new class instance. Think of the instance as a Singleton * Rewrite FrameClassificationModel to subclass LightningModule directly, remove from_config method * Rewrite ParemetricUMAPModel to subclass LightningModule directly, remove from_config method * Rewrite base.Model class, move/rename methods so that we will get a singleton that can return new lightning.Module instances with the appropriate instances of network + loss + optimizer + metrics from its from_config method * Add load_state_dict_from_path method to FrameClassificationModel and ParametricUmapModel * Change model_family decorator to check if family_class is a subtype of LightningModule, not vak.base.Model * Rewrite models.base.Model.__init__ to take definition and family attributes, that are used by from_config method * Fix how models.decorator.model makes Model instance -- we don't subclass anymore, but we do change Model instance's __name__,__doc__, and __module__ to match those of the definition * Fix FrameClassificationModel to subclass lightning.LightningModule (not pytorch_lightning.LightningModule :eyeroll:) and to no longer pass network, loss, etc to super().__init__ since its no longer a sub-class of models.base.Model * Fix ParametricUMAPModel to subclass lightning.LightningModule (not lightning.pytorch.LightningModule) and to no longer pass network, loss, etc to super().__init__ since its no longer a sub-class of models.base.Model * Fix how we add a Model instance to MODEL_REGISTRY * Fix src/vak/models/frame_classification_model.py to set network/loss/optimizer/metrics as attributes on self * Fix src/vak/models/parametric_umap_model.py to set network/loss/optimizer/metrics as attributes on self * Fix how we get MODEL_FAMILY_FROM_NAME dict in models.registry.__getattr__ * Fix classes in tests/test_models/conftest.py so we can use them to run tests * Fix tests in tests/test_models/test_base.py * Add method from_instances to vak.models.base.Model * Rename vak.models.base.Model -> vak.models.factory.ModelFactory * Add tests in tests/test_models/test_factory.py from test_frame_classification_model * Fix unit test in tests/test_models/test_convencoder_umap.py * Fix unit tests in tests/test_models/test_decorator.py * Fix unit tests in tests/test_models/test_tweetynet.py * Fix adding load_from_state_dict method to ParametricUMAPModel * Fix unit tests in tests/test_models/test_frame_classification_model.py * Rename method in tests/test_models/test_convencoder_umap.py * Fix unit tests in tests/test_models/test_ed_tcn.py * Add a unit test from another test_models module to test_factory.py * Add a unit test from another test_models module to test_factory.py * Fix unit tests in tests/test_models/test_registry.py * Remove unused fixture 'monkeypath' in tests/test_models/test_frame_classification_model.py * Fix unit tests in tests/test_models/test_parametric_umap_model.py * BUG: Fix how we check if we need to add empty dicts to model config in src/vak/config/model.py * Rename model_class -> model_factory in src/vak/models/get.py * Clean up docstring in src/vak/models/factory.py * Fix how we parametrize two unit tests in tests/test_config/test_model.py * Fix ConvEncoderUMAP configs in tests/data_for_tests/configs to have network.encoder sub-table * Rewrite docstring, fix type annotations, rename vars for clarity in src/vak/models/decorator.py * Revise docstring in src/vak/models/definition.py * Revise type hinting + docstring in src/vak/models/get.py * Revise docstring + comment in src/vak/models/registry.py * Fix unit test in tests/test_models/test_factory.py * Fix ParametricUMAPModel to use a ModuleDict * Fix unit test in tests/test_models/test_convencoder_umap.py * Fix unit test in tests/test_models/test_factory.py * Fix unit test in tests/test_models/test_parametric_umap_model.py * Fix common.tensorboard.events2df to avoid pandas error about re-indexing with duplicate values -- we need to not use the 'epoch' Scalar since it's all zeros --- src/vak/common/tensorboard.py | 2 +- src/vak/config/model.py | 4 +- src/vak/models/__init__.py | 8 +- src/vak/models/decorator.py | 95 ++-- src/vak/models/definition.py | 17 +- src/vak/models/{base.py => factory.py} | 445 +++++++++-------- src/vak/models/frame_classification_model.py | 66 ++- src/vak/models/get.py | 20 +- src/vak/models/parametric_umap_model.py | 75 +-- src/vak/models/registry.py | 60 +-- ...oderUMAP_eval_audio_cbin_annot_notmat.toml | 2 +- ...derUMAP_train_audio_cbin_annot_notmat.toml | 2 +- tests/test_config/test_model.py | 24 +- tests/test_models/conftest.py | 42 +- tests/test_models/test_base.py | 311 ------------ tests/test_models/test_convencoder_umap.py | 9 +- tests/test_models/test_decorator.py | 8 +- tests/test_models/test_ed_tcn.py | 4 +- tests/test_models/test_factory.py | 451 ++++++++++++++++++ .../test_frame_classification_model.py | 195 +++----- .../test_models/test_parametric_umap_model.py | 179 +++---- tests/test_models/test_registry.py | 65 +-- tests/test_models/test_tweetynet.py | 12 +- 23 files changed, 1103 insertions(+), 993 deletions(-) rename src/vak/models/{base.py => factory.py} (63%) delete mode 100644 tests/test_models/test_base.py create mode 100644 tests/test_models/test_factory.py diff --git a/src/vak/common/tensorboard.py b/src/vak/common/tensorboard.py index 43db0e53e..d73d0a488 100644 --- a/src/vak/common/tensorboard.py +++ b/src/vak/common/tensorboard.py @@ -128,5 +128,5 @@ def events2df( ).set_index("step") if drop_wall_time: dfs[scalar_tag].drop("wall_time", axis=1, inplace=True) - df = pd.concat([v for k, v in dfs.items()], axis=1) + df = pd.concat([v for k, v in dfs.items() if k != "epoch"], axis=1) return df diff --git a/src/vak/config/model.py b/src/vak/config/model.py index aaf920866..7e8f43ed8 100644 --- a/src/vak/config/model.py +++ b/src/vak/config/model.py @@ -78,6 +78,8 @@ def from_config_dict(cls, config_dict: dict): f"Model name not found in registry: {model_name}\n" f"Model names in registry:\n{MODEL_NAMES}" ) + + # NOTE: we are getting model_config here model_config = config_dict[model_name] if not all(key in MODEL_TABLES for key in model_config.keys()): invalid_keys = ( @@ -89,7 +91,7 @@ def from_config_dict(cls, config_dict: dict): ) # for any tables not specified, default to empty dict so we can still use ``**`` operator on it for model_table in MODEL_TABLES: - if model_table not in config_dict: + if model_table not in model_config: model_config[model_table] = {} return cls(name=model_name, **model_config) diff --git a/src/vak/models/__init__.py b/src/vak/models/__init__.py index 604fa408e..50404ccc1 100644 --- a/src/vak/models/__init__.py +++ b/src/vak/models/__init__.py @@ -1,5 +1,5 @@ -from . import base, decorator, definition, registry -from .base import Model +from . import decorator, definition, factory, registry +from .factory import ModelFactory from .convencoder_umap import ConvEncoderUMAP from .decorator import model from .ed_tcn import ED_TCN @@ -10,14 +10,14 @@ from .tweetynet import TweetyNet __all__ = [ - "base", + "factory", "ConvEncoderUMAP", "decorator", "definition", "ED_TCN", "FrameClassificationModel", "get", - "Model", + "ModelFactory", "model", "model_family", "ParametricUMAPModel", diff --git a/src/vak/models/decorator.py b/src/vak/models/decorator.py index 5a0fff875..1b3911f1f 100644 --- a/src/vak/models/decorator.py +++ b/src/vak/models/decorator.py @@ -1,22 +1,25 @@ -"""Decorator that makes a model class, +"""Decorator that makes a :class:`vak.models.ModelFactory`, given a definition of the model, -and another class that represents a +and a :class:`lightning.LightningModule` that represents a family of models that the new model belongs to. -The function returns a newly-created subclass -of the class representing the family of models. -The subclass can then be instantiated -and have all model methods. +The function returns a new instance of :class:`vak.models.ModelFactory`, +that can create new instances of the model with its +:meth:`~:class:`vak.models.ModelFactory.from_config` and +:meth:`~:class:`vak.models.ModelFactory.from_instances` methods. """ from __future__ import annotations -from typing import Type +from typing import Type, TYPE_CHECKING + +import lightning -from .base import Model from .definition import validate as validate_definition from .registry import register_model +if TYPE_CHECKING: + from .factory import ModelFactory class ModelDefinitionValidationError(Exception): """Exception raised when validating a model @@ -28,16 +31,16 @@ class ModelDefinitionValidationError(Exception): pass -def model(family: Type[Model]): - """Decorator that makes a model class, +def model(family: lightning.pytorch.LightningModule): + """Decorator that makes a :class:`vak.models.ModelFactory`, given a definition of the model, - and another class that represents a + and a :class:`lightning.LightningModule` that represents a family of models that the new model belongs to. - Returns a newly-created subclass - of the class representing the family of models. - The subclass can then be instantiated - and have all model methods. + The function returns a new instance of :class:`vak.models.ModelFactory`, + that can create new instances of the model with its + :meth:`~:class:`vak.models.ModelFactory.from_config` and + :meth:`~:class:`vak.models.ModelFactory.from_instances` methods. Parameters ---------- @@ -46,50 +49,40 @@ def model(family: Type[Model]): A class with all the class variables required by :func:`vak.models.definition.validate`. See docstring of that function for specification. - family : subclass of vak.models.Model + See also :class:`vak.models.definition.ModelDefinition`, + but note that it is not necessary to subclass + :class:`~vak.models.definition.ModelDefinition` to + define a model. + family : lightning.LightningModule The class representing the family of models that the new model will belong to. E.g., :class:`vak.models.FrameClassificationModel`. + Should be a subclass of :class:`lightning.LightningModule` + that was registered with the + :func:`vak.models.registry.model_family` decorator. Returns ------- - model : type - A sub-class of ``model_family``, - with attribute ``definition``, + model_factory : vak.models.ModelFactory + An instance of :class:`~vak.models.ModelFactory`, + with attribute ``definition`` and ``family``, that will be used when making - new instances of the model. + new instances of the model by calling the + :meth:`~vak.models.ModelFactory.from_config` method + or the :meth:`~:class:`vak.models.ModelFactory.from_instances` method. """ - def _model(definition: Type): - if not issubclass(family, Model): - raise TypeError( - "The ``family`` argument to the ``vak.models.model`` decorator" - "should be a subclass of ``vak.models.base.Model``," - f"but the type was: {type(family)}, " - "which was not recognized as a subclass " - "of ``vak.models.base.Model``." - ) - - try: - validate_definition(definition) - except ValueError as err: - raise ModelDefinitionValidationError( - f"Validation failed for the following model definition:\n{definition}" - ) from err - except TypeError as err: - raise ModelDefinitionValidationError( - f"Validation failed for the following model definition:\n{definition}" - ) from err - - attributes = dict(family.__dict__) - attributes.update({"definition": definition}) - subclass_name = definition.__name__ - subclass = type(subclass_name, (family,), attributes) - subclass.__module__ = definition.__module__ - - # finally, add model to registry - register_model(subclass) - - return subclass + def _model(definition: Type) -> ModelFactory: + from .factory import ModelFactory # avoid circular import + + model_factory = ModelFactory( + definition, + family + ) + model_factory.__name__ = definition.__name__ + model_factory.__doc__ = definition.__doc__ + model_factory.__module__ = definition.__module__ + register_model(model_factory) + return model_factory return _model diff --git a/src/vak/models/definition.py b/src/vak/models/definition.py index b3742d2a8..a44db3108 100644 --- a/src/vak/models/definition.py +++ b/src/vak/models/definition.py @@ -27,10 +27,7 @@ class ModelDefinition: """A class that represents the definition of a neural network model. - Note it is **not** necessary to sub-class this class; - it exists mainly for type-checking purposes. - - A model definition is a class that has the following class variables: + A model definition is any class that has the following class variables: network: torch.nn.Module or dict Neural network. @@ -48,6 +45,12 @@ class ModelDefinition: Used by ``vak.models.base.Model`` and its sub-classes that represent model families. E.g., those classes will do: ``network = self.definition.network(**self.definition.default_config['network'])``. + + Note it is **not** necessary to sub-class this class; + it exists mainly for type-checking purposes. + + For more detail, see :func:`vak.models.decorator.model` + and :class:`vak.models.ModelFactory`. """ network: Union[torch.nn.Module, dict] @@ -67,7 +70,7 @@ class ModelDefinition: } -def validate(definition: Type[ModelDefinition]) -> Type[ModelDefinition]: +def validate(definition: Type) -> Type: """Validate a model definition. A model definition is a class that has the following class variables: @@ -124,8 +127,8 @@ def validate(definition: Type[ModelDefinition]) -> Type[ModelDefinition]: converting it into a sub-class ofhttps://peps.python.org/pep-0416/ ``vak.models.Model``. - It's also used by ``vak.models.Model`` - to validate a definition when initializing + It's also used by :class:`vak.models.ModelFactory`, + to validate a definition before building a new model instance from the definition. """ # need to set this default first diff --git a/src/vak/models/base.py b/src/vak/models/factory.py similarity index 63% rename from src/vak/models/base.py rename to src/vak/models/factory.py index 42542099f..d5ff274ce 100644 --- a/src/vak/models/base.py +++ b/src/vak/models/factory.py @@ -1,59 +1,77 @@ -"""Base class for a model in ``vak``, -that other families of models should subclass. -""" +"""Class that represent a model builit into ``vak``.""" from __future__ import annotations import inspect -from typing import Callable, ClassVar +from typing import Callable, Type import lightning import torch -from .definition import ModelDefinition from .definition import validate as validate_definition +from .decorator import ModelDefinitionValidationError -class Model(lightning.pytorch.LightningModule): - """Base class for a model in ``vak``, - that other families of models should subclass. - - This class provides methods for working with - neural network models, e.g. training the model - and generating productions, - and it also converts a - model definition into a model instance. - - It provides the methods for working with neural - network models by subclassing - ``lighting.LightningModule``, and it handles - converting a model definition into a model instance - inside its ``__init__`` method. - Model definitions are declared programmatically - using a ``vak.model.ModelDefinition``; - see the documentation on that class for more detail. +class ModelFactory: + """Class that represent a model builit into ``vak``. + + Attributes + ---------- + definition: vak.models.definition.ModelDefinition + family: lighting.pytorch.LightningModule + + Notes + ----- + This class is used by the :func:`vak.models.decorator.model` + decorator to make a new class representing a model + from a model definition. + As such, this class is not meant to be used directly. + See the docstring of :func:`vak.models.decorator.model` + for more detail. """ - definition: ClassVar[ModelDefinition] + def __init__(self, + definition: Type, + family: lightning.pytorch.LightningModule, + ) -> None: + if not issubclass(family, lightning.pytorch.LightningModule): + raise TypeError( + "The ``family`` argument to the ``vak.models.model`` decorator" + "should be a subclass of ``lightning.pytorch.LightningModule``," + f"but the type was: {type(family)}, " + "which was not recognized as a subclass " + "of ``lightning.pytorch.LightningModule``." + ) - def __init__( - self, - network: torch.nn.Module | dict | None = None, - loss: torch.nn.Module | Callable | None = None, - optimizer: torch.optim.Optimizer | None = None, - metrics: dict | None = None, - ): - """Initializes an instance of a model, using its definition. + try: + validate_definition(definition) + except ValueError as err: + raise ModelDefinitionValidationError( + f"Validation failed for the following model definition:\n{definition}" + ) from err + except TypeError as err: + raise ModelDefinitionValidationError( + f"Validation failed for the following model definition:\n{definition}" + ) from err - Takes in instances of the attributes defined by the class variable - ``self.definition``: ``network``, ``loss``, ``optimizer``, and ``metrics``. - If any of those arguments are ``None``, then ``__init__`` - instantiates the corresponding attribute with its defaults. - If any of those arguments are not an instance of the type - defined by ``self.definition``, then a TypeError is raised. + self.definition = definition + self.family = family + + def attributes_from_config(self, config: dict): + """Get attributes for an instance of a model, + given a configuration. + + Given a :class:`dict`, ``config``, return instances + of `network`, `optimizer`, `loss`, and `metrics`. Parameters ---------- + config : dict + A :class:`dict` obtained by calling + :meth:`vak.config.ModelConfig.to_dict()`. + + Returns + ------- network : torch.nn.Module, dict An instance of a ``torch.nn.Module`` that implements a neural network, @@ -61,9 +79,7 @@ def __init__( to a set of such instances. loss : torch.nn.Module, callable An instance of a ``torch.nn.Module`` - that implements a loss function, - or a callable Python function that - computes a scalar loss. + that implements a loss function. optimizer : torch.optim.Optimizer An instance of a ``torch.optim.Optimizer`` class used with ``loss`` to optimize @@ -73,79 +89,57 @@ def __init__( to ``Callable`` functions, used to measure performance of the model. """ - from .decorator import ModelDefinitionValidationError - - super().__init__() - - # check that we are a sub-class of some other class with required class variables - if not hasattr(self, "definition"): - raise ValueError( - "This model does not have a definition." - "Define a model by wrapping a class with the required class variables with " - "a ``vak.models`` decorator, e.g. ``vak.models.windowed_frame_classification_model``" - ) - - try: - validate_definition(self.definition) - except ModelDefinitionValidationError as err: - raise ValueError( - "Creating model instance failed because model definition is invalid." - ) from err - - # ---- validate any instances that user passed in - self.validate_init(network, loss, optimizer, metrics) + network_kwargs = config.get( + "network", self.definition.default_config["network"] + ) + if inspect.isclass(self.definition.network): + network = self.definition.network(**network_kwargs) + elif isinstance(self.definition.network, dict): + network = {} + for net_name, net_class in self.definition.network.items(): + net_class_kwargs = network_kwargs.get(net_name, {}) + network[net_name] = net_class(**net_class_kwargs) - if network is None: - net_kwargs = self.definition.default_config.get("network") - if isinstance(self.definition.network, dict): - network = { - network_name: network_class(**net_kwargs[network_name]) - for network_name, network_class in self.definition.network.items() - } - else: - network = self.definition.network(**net_kwargs) - self.network = network + if isinstance(self.definition.network, dict): + params = [ + param + for net_name, net_instance in network.items() + for param in net_instance.parameters() + ] + else: + params = network.parameters() - if loss is None: - if inspect.isclass(self.definition.loss): - loss_kwargs = self.definition.default_config.get("loss") - loss = self.definition.loss(**loss_kwargs) - elif inspect.isfunction(self.definition.loss): - loss = self.definition.loss - self.loss = loss + optimizer_kwargs = config.get( + "optimizer", self.definition.default_config["optimizer"] + ) + optimizer = self.definition.optimizer(params=params, **optimizer_kwargs) - if optimizer is None: - optimizer_kwargs = self.definition.default_config.get("optimizer") - if isinstance(network, dict): - params = [ - param - for net_name, net_instance in network.items() - for param in net_instance.parameters() - ] - else: - params = network.parameters() - optimizer = self.definition.optimizer( - params=params, **optimizer_kwargs + if inspect.isclass(self.definition.loss): + loss_kwargs = config.get( + "loss", self.definition.default_config["loss"] ) - self.optimizer = optimizer + loss = self.definition.loss(**loss_kwargs) + else: + loss = self.definition.loss - if metrics is None: - metric_kwargs = self.definition.default_config.get("metrics") - metrics = {} - for metric_name, metric_class in self.definition.metrics.items(): - metric_class_kwargs = metric_kwargs.get(metric_name, {}) - metrics[metric_name] = metric_class(**metric_class_kwargs) - self.metrics = metrics + metrics_config = config.get( + "metrics", self.definition.default_config["metrics"] + ) + metrics = {} + for metric_name, metric_class in self.definition.metrics.items(): + metrics_class_kwargs = metrics_config.get(metric_name, {}) + metrics[metric_name] = metric_class(**metrics_class_kwargs) + + return network, loss, optimizer, metrics - @classmethod def validate_init( - cls, + self, network: torch.nn.Module | dict | None = None, loss: torch.nn.Module | Callable | None = None, optimizer: torch.optim.Optimizer | None = None, metrics: dict | None = None, ): - """Validate arguments to ``vak.models.base.Model.__init__``. + """Validate arguments to ``vak.models.base.Model.init``. Parameters ---------- @@ -177,21 +171,21 @@ def validate_init( it just raises an error if any value is invalid. """ if network: - if inspect.isclass(cls.definition.network): - if not isinstance(network, cls.definition.network): + if inspect.isclass(self.definition.network): + if not isinstance(network, self.definition.network): raise TypeError( - f"``network`` should be an instance of {cls.definition.network}" + f"``network`` should be an instance of {self.definition.network}" f"but was of type {type(network)}" ) - elif isinstance(cls.definition.network, dict): + elif isinstance(self.definition.network, dict): if not isinstance(network, dict): raise TypeError( "Expected ``network`` to be a ``dict`` mapping network names " f"to ``torch.nn.Module`` instances, but type was {type(network)}" ) expected_network_dict_keys = list( - cls.definition.network.keys() + self.definition.network.keys() ) network_dict_keys = list(network.keys()) if not all( @@ -222,11 +216,11 @@ def validate_init( for network_name, network_instance in network.items(): if not isinstance( - network_instance, cls.definition.network[network_name] + network_instance, self.definition.network[network_name] ): raise TypeError( f"Network with name '{network_name}' in ``network`` dict " - f"should be an instance of {cls.definition.network[network_name]}" + f"should be an instance of {self.definition.network[network_name]}" f"but was of type {type(network)}" ) else: @@ -235,24 +229,24 @@ def validate_init( ) if loss: - if issubclass(cls.definition.loss, torch.nn.Module): - if not isinstance(loss, cls.definition.loss): + if issubclass(self.definition.loss, torch.nn.Module): + if not isinstance(loss, self.definition.loss): raise TypeError( - f"``loss`` should be an instance of {cls.definition.loss}" + f"``loss`` should be an instance of {self.definition.loss}" f"but was of type {type(loss)}" ) - elif callable(cls.definition.loss): - if loss is not cls.definition.loss: + elif callable(self.definition.loss): + if loss is not self.definition.loss: raise ValueError( - f"``loss`` should be the following callable (probably a function): {cls.definition.loss}" + f"``loss`` should be the following callable (probably a function): {self.definition.loss}" ) else: raise TypeError(f"Invalid type for ``loss``: {type(loss)}") if optimizer: - if not isinstance(optimizer, cls.definition.optimizer): + if not isinstance(optimizer, self.definition.optimizer): raise TypeError( - f"``optimizer`` should be an instance of {cls.definition.optimizer}" + f"``optimizer`` should be an instance of {self.definition.optimizer}" f"but was of type {type(optimizer)}" ) @@ -263,68 +257,41 @@ def validate_init( f"to callable metrics, but type of ``metrics`` was {type(metrics)}" ) for metric_name, metric_callable in metrics.items(): - if metric_name not in cls.definition.metrics: + if metric_name not in self.definition.metrics: raise ValueError( f"``metrics`` has name '{metric_name}' but that name " f"is not in the model definition. " - f"Valid metric names are: {', '.join(list(cls.definition.metrics.keys()))}" + f"Valid metric names are: {', '.join(list(self.definition.metrics.keys()))}" ) if not isinstance( - metric_callable, cls.definition.metrics[metric_name] + metric_callable, self.definition.metrics[metric_name] ): raise TypeError( - f"metric '{metric_name}' should be an instance of {cls.definition.metrics[metric_name]}" + f"metric '{metric_name}' should be an instance of {self.definition.metrics[metric_name]}" f"but was of type {type(metric_callable)}" ) - def load_state_dict_from_path(self, ckpt_path): - """Loads a model from the path to a saved checkpoint. - - Loads the checkpoint and then calls - ``self.load_state_dict`` with the ``state_dict`` - in that chekcpoint. - - This method allows loading a state dict into an instance. - It's necessary because `lightning.pytorch.LightningModule.load`` is a - ``classmethod``, so calling that method will trigger - ``LightningModule.__init__`` instead of running - ``vak.models.Model.__init__``. - - Parameters - ---------- - ckpt_path : str, pathlib.Path - Path to a checkpoint saved by a model in ``vak``. - This checkpoint has the same key-value pairs as - any other checkpoint saved by a - ``lightning.pytorch.LightningModule``. - - Returns - ------- - None - - This method modifies the model state by loading the ``state_dict``; - it does not return anything. - """ - ckpt = torch.load(ckpt_path) - self.load_state_dict(ckpt["state_dict"]) - - @classmethod - def attributes_from_config(cls, config: dict): - """Get attributes for an instance of a model, - given a configuration. + def validate_instances_or_get_default( + self, + network: torch.nn.Module | dict | None = None, + loss: torch.nn.Module | Callable | None = None, + optimizer: torch.optim.Optimizer | None = None, + metrics: dict | None = None, + ): + """Validate instances of model attributes, using its definition, + or if no instance is passed in for an attribute, + make an instance using the default config. - Given a ``dict``, ``config``, return instances of - class variables + Takes in instances of the attributes defined by the class variable + ``self.definition``: ``network``, ``loss``, ``optimizer``, and ``metrics``. + If any of those arguments are ``None``, then ``__init__`` + instantiates the corresponding attribute with its defaults. + If any of those arguments are not an instance of the type + defined by ``self.definition``, then a TypeError is raised. Parameters ---------- - config : dict - Returned by calling ``vak.config.models.map_from_path`` - or ``vak.config.models.map_from_config_dict``. - - Returns - ------- network : torch.nn.Module, dict An instance of a ``torch.nn.Module`` that implements a neural network, @@ -344,66 +311,122 @@ class variables to ``Callable`` functions, used to measure performance of the model. """ - network_kwargs = config.get( - "network", cls.definition.default_config["network"] - ) - if inspect.isclass(cls.definition.network): - network = cls.definition.network(**network_kwargs) - elif isinstance(cls.definition.network, dict): - network = {} - for net_name, net_class in cls.definition.network.items(): - net_class_kwargs = network_kwargs.get(net_name, {}) - network[net_name] = net_class(**net_class_kwargs) + from .decorator import ModelDefinitionValidationError - if isinstance(cls.definition.network, dict): - params = [ - param - for net_name, net_instance in network.items() - for param in net_instance.parameters() - ] - else: - params = network.parameters() + # check that we are a sub-class of some other class with required class variables + if not hasattr(self, "definition"): + raise ValueError( + "This model does not have a definition." + "Define a model by wrapping a class with the required class variables with " + "a ``vak.models`` decorator, e.g. ``vak.models.windowed_frame_classification_model``" + ) - optimizer_kwargs = config.get( - "optimizer", cls.definition.default_config["optimizer"] - ) - optimizer = cls.definition.optimizer(params=params, **optimizer_kwargs) + try: + validate_definition(self.definition) + except ModelDefinitionValidationError as err: + raise ValueError( + "Creating model instance failed because model definition is invalid." + ) from err - if inspect.isclass(cls.definition.loss): - loss_kwargs = config.get( - "loss", cls.definition.default_config["loss"] - ) - loss = cls.definition.loss(**loss_kwargs) - else: - loss = cls.definition.loss + # ---- validate any instances that user passed in + self.validate_init(network, loss, optimizer, metrics) - metrics_config = config.get( - "metrics", cls.definition.default_config["metrics"] - ) - metrics = {} - for metric_name, metric_class in cls.definition.metrics.items(): - metrics_class_kwargs = metrics_config.get(metric_name, {}) - metrics[metric_name] = metric_class(**metrics_class_kwargs) + if network is None: + net_kwargs = self.definition.default_config.get("network") + if isinstance(self.definition.network, dict): + network = { + network_name: network_class(**net_kwargs[network_name]) + for network_name, network_class in self.definition.network.items() + } + else: + network = self.definition.network(**net_kwargs) + if loss is None: + if inspect.isclass(self.definition.loss): + loss_kwargs = self.definition.default_config.get("loss") + loss = self.definition.loss(**loss_kwargs) + elif inspect.isfunction(self.definition.loss): + loss = self.definition.loss + + if optimizer is None: + optimizer_kwargs = self.definition.default_config.get("optimizer") + if isinstance(network, dict): + params = [ + param + for net_name, net_instance in network.items() + for param in net_instance.parameters() + ] + else: + params = network.parameters() + optimizer = self.definition.optimizer( + params=params, **optimizer_kwargs + ) + + if metrics is None: + metric_kwargs = self.definition.default_config.get("metrics") + metrics = {} + for metric_name, metric_class in self.definition.metrics.items(): + metric_class_kwargs = metric_kwargs.get(metric_name, {}) + metrics[metric_name] = metric_class(**metric_class_kwargs) return network, loss, optimizer, metrics - @classmethod - def from_config(cls, config: dict): - """Return an initialized model instance from a config ``dict`` + def from_config(self, config: dict, **kwargs): + """Return a a new instance of a model, given a config :class:`dict`. Parameters ---------- config : dict - Returned by calling ``vak.config.models.map_from_path`` - or ``vak.config.models.map_from_config_dict``. + The dict obtained by by calling :meth:`vak.config.ModelConfig.asdict`. Returns ------- - cls : vak.models.base.Model - An instance of the model with its attributes - initialized using parameters from ``config``. + model : lightning.LightningModule + An instance of the model :attr:`~ModelFactory.family` + with attributes specified by :attr:`~ModelFactory.definition`, + that are initialized using parameters from ``config``. + """ + network, loss, optimizer, metrics = self.attributes_from_config(config) + network, loss, optimizer, metrics = self.validate_instances_or_get_default( + network, loss, optimizer, metrics, + ) + return self.family( + network=network, loss=loss, optimizer=optimizer, metrics=metrics, **kwargs + ) + + def from_instances( + self, + network: torch.nn.Module | dict | None = None, + loss: torch.nn.Module | Callable | None = None, + optimizer: torch.optim.Optimizer | None = None, + metrics: dict | None = None, + **kwargs, + ): + """ + + Parameters + ---------- + network : torch.nn.Module, dict + An instance of a ``torch.nn.Module`` + that implements a neural network, + or a ``dict`` that maps human-readable string names + to a set of such instances. + loss : torch.nn.Module, callable + An instance of a ``torch.nn.Module`` + that implements a loss function, + or a callable Python function that + computes a scalar loss. + optimizer : torch.optim.Optimizer + An instance of a ``torch.optim.Optimizer`` class + used with ``loss`` to optimize + the parameters of ``network``. + metrics : dict + A ``dict`` that maps human-readable string names + to ``Callable`` functions, used to measure + performance of the model. """ - network, loss, optimizer, metrics = cls.attributes_from_config(config) - return cls( - network=network, loss=loss, optimizer=optimizer, metrics=metrics + network, loss, optimizer, metrics = self.validate_instances_or_get_default( + network, loss, optimizer, metrics, + ) + return self.family( + network=network, loss=loss, optimizer=optimizer, metrics=metrics, **kwargs ) diff --git a/src/vak/models/frame_classification_model.py b/src/vak/models/frame_classification_model.py index 0b1777e80..8c2e29d20 100644 --- a/src/vak/models/frame_classification_model.py +++ b/src/vak/models/frame_classification_model.py @@ -6,21 +6,20 @@ from __future__ import annotations import logging -from typing import Callable, ClassVar, Mapping +from typing import Callable, Mapping +import lightning import torch from .. import transforms from ..common import labels -from . import base -from .definition import ModelDefinition from .registry import model_family logger = logging.getLogger(__name__) @model_family -class FrameClassificationModel(base.Model): +class FrameClassificationModel(lightning.LightningModule): """Class that represents a family of neural network models that predicts a label for each frame in a time series, e.g., each time bin in a window from a spectrogram. @@ -86,9 +85,6 @@ class FrameClassificationModel(base.Model): to string labels inside of ``validation_step``, for computing edit distance. """ - - definition: ClassVar[ModelDefinition] - def __init__( self, labelmap: Mapping, @@ -126,9 +122,12 @@ def __init__( post_tfm : callable Post-processing transform applied to predictions. """ - super().__init__( - network=network, loss=loss, optimizer=optimizer, metrics=metrics - ) + super().__init__() + + self.network = network + self.loss = loss + self.optimizer = optimizer + self.metrics = metrics self.labelmap = labelmap # replace any multiple character labels in mapping @@ -352,34 +351,33 @@ def predict_step(self, batch: tuple, batch_idx: int): y_pred = self.network(x) return {frames_path: y_pred} - @classmethod - def from_config( - cls, config: dict, labelmap: Mapping, post_tfm: Callable | None = None - ): - """Return an initialized model instance from a config ``dict`` + def load_state_dict_from_path(self, ckpt_path): + """Loads a model from the path to a saved checkpoint. + + Loads the checkpoint and then calls + ``self.load_state_dict`` with the ``state_dict`` + in that chekcpoint. + + This method allows loading a state dict into an instance. + It's necessary because `lightning.pytorch.LightningModule.load`` is a + ``classmethod``, so calling that method will trigger + ``LightningModule.__init__`` instead of running + ``vak.models.Model.__init__``. Parameters ---------- - config : dict - Returned by calling :func:`vak.config.models.map_from_path` - or :func:`vak.config.models.map_from_config_dict`. - post_tfm : callable - Post-processing transformation. - A callable applied to the network output. - Default is None. + ckpt_path : str, pathlib.Path + Path to a checkpoint saved by a model in ``vak``. + This checkpoint has the same key-value pairs as + any other checkpoint saved by a + ``lightning.pytorch.LightningModule``. Returns ------- - cls : vak.models.base.Model - An instance of the model with its attributes - initialized using parameters from ``config``. + None + + This method modifies the model state by loading the ``state_dict``; + it does not return anything. """ - network, loss, optimizer, metrics = cls.attributes_from_config(config) - return cls( - labelmap=labelmap, - network=network, - optimizer=optimizer, - loss=loss, - metrics=metrics, - post_tfm=post_tfm, - ) + ckpt = torch.load(ckpt_path) + self.load_state_dict(ckpt["state_dict"]) diff --git a/src/vak/models/get.py b/src/vak/models/get.py index b6f6a849c..a210d2445 100644 --- a/src/vak/models/get.py +++ b/src/vak/models/get.py @@ -6,6 +6,8 @@ import inspect from typing import Callable +import lightning + from . import registry @@ -16,7 +18,7 @@ def get( num_classes: int | None = None, labelmap: dict | None = None, post_tfm: Callable | None = None, -): +) -> lightning.LightningModule: """Get a model instance, given its name and a configuration as a :class:`dict`. @@ -44,13 +46,13 @@ def get( Returns ------- - model : vak.models.Model - Instance of a sub-class of the base Model class, - e.g. a TweetyNet instance. + model : lightning.LightningModule + Instance of :class:`lightning.LightningModule`, + one of the model familes. """ # we do this dynamically so we always get all registered models try: - model_class = registry.MODEL_REGISTRY[name] + model_factory = registry.MODEL_REGISTRY[name] except KeyError as e: raise ValueError( f"Invalid model name: '{name}'.\n" @@ -63,7 +65,7 @@ def get( # still need to special case model logic here net_init_params = list( inspect.signature( - model_class.definition.network.__init__ + model_factory.definition.network.__init__ ).parameters.keys() ) if ("num_input_channels" in net_init_params) and ( @@ -82,13 +84,13 @@ def get( f"unable to determine network init arguments for model. Currently all models " f"in this family must have networks with parameters ``num_input_channels`` and ``num_freqbins``" ) - model = model_class.from_config( + model = model_factory.from_config( config=config, labelmap=labelmap, post_tfm=post_tfm ) elif model_family == "ParametricUMAPModel": encoder_init_params = list( inspect.signature( - model_class.definition.network["encoder"].__init__ + model_factory.definition.network["encoder"].__init__ ).parameters.keys() ) if "input_shape" in encoder_init_params: @@ -97,7 +99,7 @@ def get( else: config["network"]["encoder"] = dict(input_shape=input_shape) - model = model_class.from_config(config=config) + model = model_factory.from_config(config=config) else: raise ValueError( f"Value for ``model_family`` not recognized: {model_family}" diff --git a/src/vak/models/parametric_umap_model.py b/src/vak/models/parametric_umap_model.py index 5abaf7bbf..b9a31474a 100644 --- a/src/vak/models/parametric_umap_model.py +++ b/src/vak/models/parametric_umap_model.py @@ -15,13 +15,12 @@ import torch import torch.utils.data -from . import base from .definition import ModelDefinition from .registry import model_family @model_family -class ParametricUMAPModel(base.Model): +class ParametricUMAPModel(lightning.LightningModule): """Parametric UMAP model, as described in [1]_. Notes @@ -43,28 +42,29 @@ class ParametricUMAPModel(base.Model): def __init__( self, - network: dict | None = None, - loss: torch.nn.Module | Callable | None = None, - optimizer: torch.optim.Optimizer | None = None, - metrics: dict[str:Type] | None = None, + network: dict, + loss: torch.nn.Module | Callable, + optimizer: torch.optim.Optimizer, + metrics: dict[str:Type], ): - super().__init__( - network=network, loss=loss, optimizer=optimizer, metrics=metrics + super().__init__() + self.network = torch.nn.ModuleDict( + network ) - self.encoder = network["encoder"] - self.decoder = network.get("decoder", None) + self.loss = loss + self.optimizer = optimizer + self.metrics = metrics def configure_optimizers(self): return self.optimizer def training_step(self, batch, batch_idx): (edges_to_exp, edges_from_exp) = batch - embedding_to, embedding_from = self.encoder( - edges_to_exp - ), self.encoder(edges_from_exp) + embedding_to = self.network['encoder'](edges_to_exp) + embedding_from = self.network['encoder'](edges_from_exp) - if self.decoder is not None: - reconstruction = self.decoder(embedding_to) + if 'decoder' in self.network: + reconstruction = self.network['decoder'](embedding_to) before_encoding = edges_to_exp else: reconstruction = None @@ -83,12 +83,11 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): (edges_to_exp, edges_from_exp) = batch - embedding_to, embedding_from = self.encoder( - edges_to_exp - ), self.encoder(edges_from_exp) + embedding_to = self.network['encoder'](edges_to_exp) + embedding_from = self.network['encoder'](edges_from_exp) - if self.decoder is not None: - reconstruction = self.decoder(embedding_to) + if 'decoder' in self.network is not None: + reconstruction = self.network['decoder'](embedding_to) before_encoding = edges_to_exp else: reconstruction = None @@ -104,26 +103,36 @@ def validation_step(self, batch, batch_idx): # note if there's no ``loss_reconstruction``, then ``loss`` == ``loss_umap`` self.log("val_loss", loss, on_step=True) - @classmethod - def from_config(cls, config: dict): - """Return an initialized model instance from a config ``dict`` + def load_state_dict_from_path(self, ckpt_path): + """Loads a model from the path to a saved checkpoint. + + Loads the checkpoint and then calls + ``self.load_state_dict`` with the ``state_dict`` + in that chekcpoint. + + This method allows loading a state dict into an instance. + It's necessary because `lightning.pytorch.LightningModule.load`` is a + ``classmethod``, so calling that method will trigger + ``LightningModule.__init__`` instead of running + ``vak.models.Model.__init__``. Parameters ---------- - config : dict - Returned by calling :func:`vak.config.models.map_from_path` - or :func:`vak.config.models.map_from_config_dict`. + ckpt_path : str, pathlib.Path + Path to a checkpoint saved by a model in ``vak``. + This checkpoint has the same key-value pairs as + any other checkpoint saved by a + ``lightning.pytorch.LightningModule``. Returns ------- - cls : vak.models.base.Model - An instance of the model with its attributes - initialized using parameters from ``config``. + None + + This method modifies the model state by loading the ``state_dict``; + it does not return anything. """ - network, loss, optimizer, metrics = cls.attributes_from_config(config) - return cls( - network=network, optimizer=optimizer, loss=loss, metrics=metrics - ) + ckpt = torch.load(ckpt_path) + self.load_state_dict(ckpt["state_dict"]) class ParametricUMAPDatamodule(lightning.pytorch.LightningDataModule): diff --git a/src/vak/models/registry.py b/src/vak/models/registry.py index b187b2480..f4fd0472b 100644 --- a/src/vak/models/registry.py +++ b/src/vak/models/registry.py @@ -7,21 +7,24 @@ from __future__ import annotations import inspect -from typing import Any, Type +from typing import Any, Type, TYPE_CHECKING -from .base import Model +import lightning + +if TYPE_CHECKING: + from .factory import ModelFactory MODEL_FAMILY_REGISTRY = {} def model_family(family_class: Type) -> None: - """Decorator that adds a class to the registry of model families.""" - if family_class not in Model.__subclasses__(): + """Decorator that adds a :class:`lightning.LightningModule` class to the registry of model families.""" + if not issubclass(family_class, lightning.LightningModule): raise TypeError( "The ``family_class`` provided to the `vak.models.model_family` decorator" - "must be a subclass of `vak.models.base.Model`, " + "must be a subclass of `lightning.LightningModule`, " f"but the class specified is not: {family_class}. " - f"Subclasses of `vak.models.base.Model` are: {Model.__subclasses__()}" + f"Subclasses of `lightning.LightningModule` are: {lightning.LightningModule.__subclasses__()}" ) model_family_name = family_class.__name__ @@ -40,48 +43,49 @@ def model_family(family_class: Type) -> None: MODEL_REGISTRY = {} -def register_model(model_class: Type) -> Type: - """Decorator that registers a model in the model registry. +def register_model(model: ModelFactory) -> ModelFactory: + """Function that registers a model in the model registry. This function is called by :func:`vak.models.decorator.model`, - that creates a model class from a model definition. - So you will not usually need to use this decorator directly, + that creates an instance of a :class:`vak.models.ModelFactory`, + given a :class:`vak.models.definition.ModelDefinition` + and a :class:`lightning.LightningModule` class that has been + registered as a model family with :func:`model_family`. + + So you will not usually need to use this function directly, and should prefer to use :func:`vak.models.decorator.model` instead. """ model_family_classes = list(MODEL_FAMILY_REGISTRY.values()) - model_parent_class = inspect.getmro(model_class)[1] - if model_parent_class not in model_family_classes: + model_family = model.family + if model_family not in model_family_classes: raise TypeError( - "The parent class of ``model_class`` passed to the ``model`` decorator " - f"is not recognized as a model family. Class was: {model_class} and " - f"parent is {model_parent_class}, as determined with " - f"``inspect.getmro(model_class)[1]``. " - f"Please specify a class that is a sub-class of a model family. " + "The family of `model` passed to the `register_model` decorator " + f"is not recognized as a model family. Class was '{model}' and " + f"its family is '{model_family}'. " + f"Please specify a valid model family. " f"Valid model family classes are: {model_family_classes}" ) - model_name = model_class.__name__ + model_name = model.__name__ if model_name in MODEL_REGISTRY: raise ValueError( f"Attempted to register a model family with the name '{model_name}', " f"but this name is already in the registry.\n" ) - MODEL_REGISTRY[model_name] = model_class - # need to return class after we register it or we replace it with None - # when this function is used as a decorator - return model_class + MODEL_REGISTRY[model_name] = model + # need to return class after we register it, + # or we would replace it with None when this function is used as a decorator + return model def __getattr__(name: str) -> Any: """Module-level __getattr__ function that we use to dynamically determine models.""" if name == "MODEL_FAMILY_FROM_NAME": - model_name_family_name_map = {} - for model_name, model_class in MODEL_REGISTRY.items(): - model_parent_class = inspect.getmro(model_class)[1] - family_name = model_parent_class.__name__ - model_name_family_name_map[model_name] = family_name - return model_name_family_name_map + return { + model_name: model.family.__name__ + for model_name, model in MODEL_REGISTRY.items() + } elif name == "MODEL_NAMES": return list(MODEL_REGISTRY.keys()) else: diff --git a/tests/data_for_tests/configs/ConvEncoderUMAP_eval_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/ConvEncoderUMAP_eval_audio_cbin_annot_notmat.toml index 32940a679..68aa38334 100644 --- a/tests/data_for_tests/configs/ConvEncoderUMAP_eval_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/ConvEncoderUMAP_eval_audio_cbin_annot_notmat.toml @@ -20,7 +20,7 @@ num_workers = 16 output_dir = "./tests/data_for_tests/generated/results/eval/audio_cbin_annot_notmat/ConvEncoderUMAP" -[vak.eval.model.ConvEncoderUMAP.network] +[vak.eval.model.ConvEncoderUMAP.network.encoder] conv1_filters = 8 conv2_filters = 16 conv_kernel_size = 3 diff --git a/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml index f188b650c..9028f214b 100644 --- a/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml +++ b/tests/data_for_tests/configs/ConvEncoderUMAP_train_audio_cbin_annot_notmat.toml @@ -24,7 +24,7 @@ num_workers = 16 root_results_dir = "./tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/ConvEncoderUMAP" -[vak.train.model.ConvEncoderUMAP.network] +[vak.train.model.ConvEncoderUMAP.network.encoder] conv1_filters = 8 conv2_filters = 16 conv_kernel_size = 3 diff --git a/tests/test_config/test_model.py b/tests/test_config/test_model.py index 3ba45fd48..e081d0899 100644 --- a/tests/test_config/test_model.py +++ b/tests/test_config/test_model.py @@ -64,7 +64,17 @@ def test_init(self, config_dict): }, { "ConvEncoderUMAP": { - "optimizer": 1e-3 + "optimizer": {'lr': 1e-3}, + } + }, + { + "ConvEncoderUMAP": { + "network": { + "encoder": { + "conv1_filters": 8, + } + }, + "optimizer": {'lr': 1e-3}, } } ] @@ -127,7 +137,17 @@ def test_from_config_dict_real_config(self, a_generated_config_dict): }, { "ConvEncoderUMAP": { - "optimizer": 1e-3 + "optimizer": {'lr': 1e-3}, + } + }, + { + "ConvEncoderUMAP": { + "network": { + "encoder": { + "conv1_filters": 8, + } + }, + "optimizer": {'lr': 1e-3}, } } ] diff --git a/tests/test_models/conftest.py b/tests/test_models/conftest.py index 3b4c9033d..991878c8b 100644 --- a/tests/test_models/conftest.py +++ b/tests/test_models/conftest.py @@ -1,3 +1,4 @@ +import lightning import torch import vak.models.registry @@ -76,14 +77,16 @@ def __call__(self, y: torch.Tensor, y_pred: torch.Tensor): # ---- mock model families --------------------------------------------------------------------------------------------- -class UnregisteredMockModelFamily(vak.models.Model): +class UnregisteredMockModelFamily(lightning.LightningModule): """A model family defined only for tests. Used to test :func:`vak.models.registry.model_family`. """ def __init__(self, network, optimizer, loss, metrics): - super().__init__( - network=network, loss=loss, optimizer=optimizer, metrics=metrics - ) + super().__init__() + self.network=network + self.loss=loss + self.optimizer=optimizer + self.metrics=metrics def training_step(self, *args, **kwargs): pass @@ -91,26 +94,29 @@ def training_step(self, *args, **kwargs): def validation_step(self, *args, **kwargs): pass - @classmethod - def from_config(cls, config: dict): - """Return an initialized model instance from a config ``dict``.""" - network, loss, optimizer, metrics = cls.attributes_from_config(config) - return cls( - network=network, - optimizer=optimizer, - loss=loss, - metrics=metrics, - ) - # Make a "copy" of UnregisteredModelFamily that we *do* register # so we can use it to test `vak.models.decorator.model` and other functions # that require a registered ModelFamily. # Used when testing :func:`vak.models.decorator.model` -- we need a model in the registry to test # and we don't want to have to deal with the idiosyncrasies of actual model families -MockModelFamily = type('MockModelFamily', - UnregisteredMockModelFamily.__bases__, - dict(UnregisteredMockModelFamily.__dict__)) +class MockModelFamily(lightning.LightningModule): + """A model family defined only for tests. + Used to test :func:`vak.models.registry.model_family`. + """ + def __init__(self, network, optimizer, loss, metrics): + super().__init__() + self.network=network + self.loss=loss + self.optimizer=optimizer + self.metrics=metrics + + def training_step(self, *args, **kwargs): + pass + + def validation_step(self, *args, **kwargs): + pass + vak.models.registry.model_family(MockModelFamily) diff --git a/tests/test_models/test_base.py b/tests/test_models/test_base.py deleted file mode 100644 index ec9d37ccf..000000000 --- a/tests/test_models/test_base.py +++ /dev/null @@ -1,311 +0,0 @@ -import copy -import inspect - -import pytest -import torch - -import vak - -from .conftest import ( - MockAcc, - MockDecoder, - MockEncoder, - MockEncoderDecoderModel, - MockModel, - MockNetwork, - other_loss_func, - other_metrics_dict, - OtherNetwork, - OtherOptimizer, -) - -from .test_definition import ( - InvalidMetricsDictKeyModelDefinition, - TweetyNetDefinition, -) - - -MODEL_DEFINITION_CLASS_VARS = ( - 'network', - 'loss', - 'optimizer', - 'metrics', - 'default_config' -) - -mock_net_instance = MockNetwork() - - -TEST_INIT_ARGVALS = [ - (MockModel, None), - (MockModel, {'network': mock_net_instance}), - (MockModel, {'loss': torch.nn.CrossEntropyLoss()},), - (MockModel, - { - 'network': mock_net_instance, - 'optimizer': torch.optim.SGD(lr=0.003, params=mock_net_instance.parameters()) - } - ), - (MockModel, - {'metrics': - { - 'acc': MockAcc(), - } - }), - (MockEncoderDecoderModel, None), -] - -TEST_INIT_RAISES_ARGVALS = [ - (MockModel, dict(network=OtherNetwork()), TypeError), - (MockModel, dict(loss=other_loss_func), TypeError), - (MockModel, dict(optimizer=OtherOptimizer), TypeError), - (MockModel, dict(metrics=other_metrics_dict), ValueError), - (MockEncoderDecoderModel, - # first value is wrong - dict(network={'MockEncoder': OtherNetwork(), 'MockDecoder': MockDecoder()}), - TypeError), - (MockEncoderDecoderModel, - # missng key, MockEncoder - dict(network={'MockDecoder': MockDecoder()}), - ValueError), - (MockEncoderDecoderModel, - # extra key, MockRecoder - dict(network={'MockEncoder': MockEncoder(), 'MockDecoder': MockDecoder(), 'MockRecoder': MockNetwork()}), - ValueError), -] - - -class TestModel: - - @pytest.mark.parametrize( - 'definition, kwargs', - TEST_INIT_ARGVALS, - ) - def test_init(self, - definition, - kwargs, - monkeypatch): - """Test Model.__init__ works as expected""" - # monkeypatch a definition so we can test __init__ - definition = vak.models.definition.validate(definition) - monkeypatch.setattr( - vak.models.base.Model, 'definition', definition, raising=False - ) - - # actually instantiate model - if kwargs: - model = vak.models.base.Model(**kwargs) - else: - model = vak.models.base.Model() - - # now test that attributes are what we expect - assert isinstance(model, vak.models.base.Model) - for attr in ('network', 'loss', 'optimizer', 'metrics'): - assert hasattr(model, attr) - model_attr = getattr(model, attr) - definition_attr = getattr(definition, attr) - if inspect.isclass(definition_attr): - assert isinstance(model_attr, definition_attr) - elif isinstance(definition_attr, dict): - assert isinstance(model_attr, dict) - for definition_key, definition_val in definition_attr.items(): - assert definition_key in model_attr - model_val = model_attr[definition_key] - if inspect.isclass(definition_val): - assert isinstance(model_val, definition_val) - else: - assert callable(definition_val) - assert model_val is definition_val - else: - # must be a function - assert callable(model_attr) - assert model_attr is definition_attr - - def test_init_no_definition_raises(self): - """Test that initializing a Model instance without a definition raises a ValueError.""" - with pytest.raises(ValueError): - vak.models.base.Model() - - def test_init_invalid_definition_raises(self, monkeypatch): - """Test that initializing a Model instance with an invalid definition raises a ValueError.""" - monkeypatch.setattr( - vak.models.base.Model, 'definition', InvalidMetricsDictKeyModelDefinition, raising=False - ) - with pytest.raises(TypeError): - vak.models.base.Model() - - @pytest.mark.parametrize( - 'definition, kwargs, expected_exception', - TEST_INIT_RAISES_ARGVALS - ) - def test_init_raises(self, definition, kwargs, expected_exception, monkeypatch): - """Test that init raises errors as expected given input arguments. - - Note that this should happen from ``__init__`` calling ``validate_init``, - so here we test that this is happening inside ``__init__``. - Next method tests ``validate_init`` directly. - """ - # monkeypatch a definition so we can test __init__ - monkeypatch.setattr( - # we just always use TweetyNetDefinition here since we just want to test that a mismatch raises - vak.models.base.Model, 'definition', definition, raising=False - ) - with pytest.raises(expected_exception): - vak.models.base.Model(**kwargs) - - @pytest.mark.parametrize( - 'definition, kwargs, expected_exception', - TEST_INIT_RAISES_ARGVALS - ) - def test_validate_init_raises(self, definition, kwargs, expected_exception, monkeypatch): - """Test that ``validate_init`` raises errors as expected""" - # monkeypatch a definition so we can test __init__ - monkeypatch.setattr( - # we just always use TweetyNetDefinition here since we just want to test that a mismatch raises - vak.models.base.Model, 'definition', definition, raising=False - ) - with pytest.raises(expected_exception): - vak.models.base.Model.validate_init(**kwargs) - - MODEL_DEFINITION_MAP = { - 'TweetyNet': TweetyNetDefinition, - } - - @pytest.mark.parametrize( - 'model_name', - [ - 'TweetyNet', - ] - ) - def test_load_state_dict_from_path(self, - model_name, - # our fixtures - specific_config_toml_path, - # pytest fixtures - monkeypatch, - device - ): - """Smoke test that makes sure ``load_state_dict_from_path`` runs without failure. - - We use actual model definitions here so we can test with real checkpoints. - """ - definition = self.MODEL_DEFINITION_MAP[model_name] - train_toml_path = specific_config_toml_path('train', model_name, audio_format='cbin', annot_format='notmat') - train_cfg = vak.config.Config.from_toml_path(train_toml_path) - - # stuff we need just to be able to instantiate network - labelmap = vak.common.labels.to_map(train_cfg.prep.labelset, map_unlabeled=True) - item_transform = vak.transforms.defaults.get_default_transform( - model_name, - "train", - transform_kwargs={}, - ) - train_dataset = vak.datasets.frame_classification.WindowDataset.from_dataset_path( - dataset_path=train_cfg.train.dataset.path, - split="train", - window_size=train_cfg.train.dataset.params['window_size'], - item_transform=item_transform, - ) - input_shape = train_dataset.shape - num_input_channels = input_shape[-3] - num_freqbins = input_shape[-2] - - monkeypatch.setattr( - vak.models.base.Model, 'definition', definition, raising=False - ) - # network is the one thing that has required args - # and we also need to use its config from the toml file - cfg = vak.config.Config.from_toml_path(train_toml_path) - model_config = cfg.train.model.asdict() - network = definition.network(num_classes=len(labelmap), - num_input_channels=num_input_channels, - num_freqbins=num_freqbins, - **model_config['network']) - model = vak.models.base.Model(network=network) - model.to(device) - - eval_toml_path = specific_config_toml_path('eval', model_name, audio_format='cbin', annot_format='notmat') - eval_cfg = vak.config.Config.from_toml_path(eval_toml_path) - checkpoint_path = eval_cfg.eval.checkpoint_path - - # ---- actually test method - sd_before = copy.deepcopy(model.state_dict()) - sd_before = { - k: v.to(device) for k, v in sd_before.items() - } - ckpt = torch.load(checkpoint_path) - sd_to_be_loaded = ckpt['state_dict'] - sd_to_be_loaded = { - k: v.to(device) for k, v in sd_to_be_loaded.items() - } - - model.load_state_dict_from_path(checkpoint_path) - - assert not all([ - torch.all(torch.eq(val, before_val)) - for val, before_val in zip(model.state_dict().values(), sd_before.values())] - ) - assert all([ - torch.all(torch.eq(val, before_val)) - for val, before_val in zip(model.state_dict().values(), sd_to_be_loaded.values())] - ) - - @pytest.mark.parametrize( - 'definition, config', - [ - (MockModel, {'network': {'n_classes': 10}}), - (MockModel, {'loss': {'reduction': 'sum'}}), - (MockModel, { - 'network': {'n_classes': 10}, - 'optimizer': {'lr': 0.003}} - ), - (MockModel, {'metrics': {'acc': {'average': 'micro'}}}), - (MockEncoderDecoderModel, { - 'network': { - 'MockEncoder': {'input_size': 5}, - 'MockDecoder': {'output_size': 5}, - }, - 'optimizer': {'lr': 0.003}} - ) - ] - ) - def test_from_config(self, - definition, - config, - monkeypatch, - ): - monkeypatch.setattr( - vak.models.base.Model, 'definition', definition, raising=False - ) - - model = vak.models.base.Model.from_config(config) - - assert isinstance(model, vak.models.base.Model) - - if 'network' in config: - if inspect.isclass(definition.network): - for network_kwarg, network_kwargval in config['network'].items(): - assert hasattr(model.network, network_kwarg) - assert getattr(model.network, network_kwarg) == network_kwargval - elif isinstance(definition.network, dict): - for net_name, net_kwargs in config['network'].items(): - for network_kwarg, network_kwargval in net_kwargs.items(): - assert hasattr(model.network[net_name], network_kwarg) - assert getattr(model.network[net_name], network_kwarg) == network_kwargval - - if 'loss' in config: - for loss_kwarg, loss_kwargval in config['loss'].items(): - assert hasattr(model.loss, loss_kwarg) - assert getattr(model.loss, loss_kwarg) == loss_kwargval - - if 'optimizer' in config: - for optimizer_kwarg, optimizer_kwargval in config['optimizer'].items(): - assert optimizer_kwarg in model.optimizer.param_groups[0] - assert model.optimizer.param_groups[0][optimizer_kwarg] == optimizer_kwargval - - if 'metrics' in config: - for metric_name, metric_kwargs in config['metrics'].items(): - assert metric_name in model.metrics - for metric_kwarg, metric_kwargval in metric_kwargs.items(): - assert hasattr(model.metrics[metric_name], metric_kwarg) - assert getattr(model.metrics[metric_name], metric_kwarg) == metric_kwargval diff --git a/tests/test_models/test_convencoder_umap.py b/tests/test_models/test_convencoder_umap.py index 29c63643c..82beb88ff 100644 --- a/tests/test_models/test_convencoder_umap.py +++ b/tests/test_models/test_convencoder_umap.py @@ -1,4 +1,5 @@ import pytest +import torch import vak @@ -11,18 +12,18 @@ class TestConvEncoderUMAP: (1, 64, 64), ] ) - def test_init(self, input_shape): + def test_from_instances(self, input_shape): network = { 'encoder': vak.models.ConvEncoderUMAP.definition.network['encoder'](input_shape=input_shape) } - model = vak.models.ConvEncoderUMAP(network=network) - assert isinstance(model, vak.models.ConvEncoderUMAP) + model = vak.models.ConvEncoderUMAP.from_instances(network=network) + assert isinstance(model, vak.models.ParametricUMAPModel) for attr in ('network', 'loss', 'optimizer'): assert hasattr(model, attr) attr_from_definition = getattr(vak.models.convencoder_umap.ConvEncoderUMAP.definition, attr) if isinstance(attr_from_definition, dict): attr_from_model = getattr(model, attr) - assert isinstance(attr_from_model, dict) + assert isinstance(attr_from_model, (dict, torch.nn.ModuleDict)) assert attr_from_model.keys() == attr_from_definition.keys() for net_name, net_instance in attr_from_model.items(): assert isinstance(net_instance, attr_from_definition[net_name]) diff --git a/tests/test_models/test_decorator.py b/tests/test_models/test_decorator.py index bb06b9246..5333121f3 100644 --- a/tests/test_models/test_decorator.py +++ b/tests/test_models/test_decorator.py @@ -35,11 +35,13 @@ ) def test_model(definition, family, expected_name): """Test that :func:`vak.models.decorator.model` decorator - returns a subclass of the specified model family, + returns a new instance of ModelFactory, and has the expected name""" model_class = vak.models.decorator.model(family)(definition) - assert issubclass(model_class, family) + + assert isinstance(model_class, vak.models.factory.ModelFactory) assert model_class.__name__ == expected_name + # need to delete model from registry so other tests don't fail del vak.models.registry.MODEL_REGISTRY[model_class.__name__] @@ -61,4 +63,4 @@ def test_model(definition, family, expected_name): ) def test_model_raises(definition): with pytest.raises(vak.models.decorator.ModelDefinitionValidationError): - vak.models.decorator.model(vak.models.base.Model)(definition) + vak.models.decorator.model(MockModelFamily)(definition) diff --git a/tests/test_models/test_ed_tcn.py b/tests/test_models/test_ed_tcn.py index 9e81c0d9a..d520fe7a8 100644 --- a/tests/test_models/test_ed_tcn.py +++ b/tests/test_models/test_ed_tcn.py @@ -15,8 +15,8 @@ def test_init(self, labelmap, input_shape): num_input_channels = input_shape[-3] num_freqbins = input_shape[-2] network = vak.models.ED_TCN.definition.network(len(labelmap), num_input_channels, num_freqbins) - model = vak.models.ED_TCN(labelmap=labelmap, network=network) - assert isinstance(model, vak.models.ED_TCN) + model = vak.models.ED_TCN.from_instances(network=network, labelmap=labelmap) + assert isinstance(model, vak.models.FrameClassificationModel) for attr in ('network', 'loss', 'optimizer'): assert hasattr(model, attr) assert isinstance(getattr(model, attr), diff --git a/tests/test_models/test_factory.py b/tests/test_models/test_factory.py new file mode 100644 index 000000000..169b3ed8b --- /dev/null +++ b/tests/test_models/test_factory.py @@ -0,0 +1,451 @@ +import inspect +import itertools + +import pytest +import torch + +import vak + +from .conftest import ( + MockAcc, + MockDecoder, + MockEncoder, + MockEncoderDecoderModel, + MockModel, + MockModelFamily, + MockNetwork, + other_loss_func, + other_metrics_dict, + OtherNetwork, + OtherOptimizer, +) + +from .test_definition import ( + InvalidMetricsDictKeyModelDefinition, + TweetyNetDefinition, +) +from .test_tweetynet import LABELMAPS, INPUT_SHAPES + +MODEL_DEFINITION_CLASS_VARS = ( + 'network', + 'loss', + 'optimizer', + 'metrics', + 'default_config' +) + +mock_net_instance = MockNetwork() + +TEST_VALIDATE_RAISES_ARGVALS = [ + (MockModel, dict(network=OtherNetwork()), TypeError), + (MockModel, dict(loss=other_loss_func), TypeError), + (MockModel, dict(optimizer=OtherOptimizer), TypeError), + (MockModel, dict(metrics=other_metrics_dict), ValueError), + (MockEncoderDecoderModel, + # first value is wrong + dict(network={'MockEncoder': OtherNetwork(), 'MockDecoder': MockDecoder()}), + TypeError), + (MockEncoderDecoderModel, + # missng key, MockEncoder + dict(network={'MockDecoder': MockDecoder()}), + ValueError), + (MockEncoderDecoderModel, + # extra key, MockRecoder + dict(network={'MockEncoder': MockEncoder(), 'MockDecoder': MockDecoder(), 'MockRecoder': MockNetwork()}), + ValueError), +] + +# pytest.mark.parametrize vals for test_init_with_definition +MODEL_DEFS = ( + TweetyNetDefinition, +) + +TEST_WITH_FRAME_CLASSIFICATION_ARGVALS = itertools.product(LABELMAPS, INPUT_SHAPES, MODEL_DEFS) + +MOCK_INPUT_SHAPE = torch.Size([1, 128, 44]) + + +class ConvEncoderUMAPDefinition: + network = {"encoder": vak.nets.ConvEncoder} + loss = vak.nn.UmapLoss + optimizer = torch.optim.AdamW + metrics = { + "acc": vak.metrics.Accuracy, + "levenshtein": vak.metrics.Levenshtein, + "character_error_rate": vak.metrics.CharacterErrorRate, + "loss": torch.nn.CrossEntropyLoss, + } + default_config = { + "optimizer": {"lr": 1e-3}, + } + + +class TestModelFactory: + def test_init_no_definition_raises(self): + """Test that initializing a Model instance without a definition or family raises a ValueError.""" + with pytest.raises(TypeError): + vak.models.factory.ModelFactory() + + def test_init_invalid_definition_raises(self): + """Test that initializing a Model instance with an invalid definition raises a ValueError.""" + with pytest.raises(vak.models.decorator.ModelDefinitionValidationError): + vak.models.factory.ModelFactory( + definition=InvalidMetricsDictKeyModelDefinition, + family=MockModelFamily, + ) + + @pytest.mark.parametrize( + 'definition, kwargs', + [ + (MockModel, None), + (MockModel, {'network': mock_net_instance}), + (MockModel, {'loss': torch.nn.CrossEntropyLoss()},), + (MockModel, + { + 'network': mock_net_instance, + 'optimizer': torch.optim.SGD(lr=0.003, params=mock_net_instance.parameters()) + } + ), + (MockModel, + {'metrics': + { + 'acc': MockAcc(), + } + }), + (MockEncoderDecoderModel, None), + ] + ) + def test_validate_instances_or_get_default(self, definition, kwargs): + model = vak.models.factory.ModelFactory( + definition, + MockModelFamily, + ) + # actually instantiate model + if kwargs: + (network, + loss, + optimizer, + metrics + ) = model.validate_instances_or_get_default(**kwargs) + else: + (network, + loss, + optimizer, + metrics + ) = model.validate_instances_or_get_default() + + model_attrs = { + 'network': network, + 'loss': loss, + 'optimizer': optimizer, + 'metrics': metrics, + } + for attr in ('network', 'loss', 'optimizer', 'metrics'): + model_attr = model_attrs[attr] + definition_attr = getattr(definition, attr) + if inspect.isclass(definition_attr): + assert isinstance(model_attr, definition_attr) + elif isinstance(definition_attr, dict): + assert isinstance(model_attr, dict) + for definition_key, definition_val in definition_attr.items(): + assert definition_key in model_attr + model_val = model_attr[definition_key] + if inspect.isclass(definition_val): + assert isinstance(model_val, definition_val) + else: + assert callable(definition_val) + assert model_val is definition_val + else: + # must be a function + assert callable(model_attr) + assert model_attr is definition_attr + + @pytest.mark.parametrize( + 'definition, kwargs, expected_exception', + TEST_VALIDATE_RAISES_ARGVALS + ) + def test_validate_instances_or_get_default_raises(self, definition, kwargs, expected_exception): + """Test that :meth:`validate_instances_or_get_default` raises errors as expected given input arguments. + + Note that this should happen from ``validate_instances_or_get_default`` calling ``validate_init``, + so here we test that this is happening inside ``validate_instances_or_get_default``. + Next method tests ``validate_init`` directly. + """ + model = vak.models.factory.ModelFactory( + definition, + MockModelFamily, + ) + with pytest.raises(expected_exception): + model.validate_instances_or_get_default(**kwargs) + + @pytest.mark.parametrize( + 'definition, kwargs, expected_exception', + TEST_VALIDATE_RAISES_ARGVALS + ) + def test_validate_init_raises(self, definition, kwargs, expected_exception): + """Test that ``validate_init`` raises errors as expected""" + model = vak.models.factory.ModelFactory( + definition=definition, + family=MockModelFamily + ) + with pytest.raises(expected_exception): + model.validate_init(**kwargs) + + MODEL_DEFINITION_MAP = { + 'TweetyNet': TweetyNetDefinition, + } + + @pytest.mark.parametrize( + 'definition, config', + [ + (MockModel, {'network': {'n_classes': 10}}), + (MockModel, {'loss': {'reduction': 'sum'}}), + (MockModel, { + 'network': {'n_classes': 10}, + 'optimizer': {'lr': 0.003}} + ), + (MockModel, {'metrics': {'acc': {'average': 'micro'}}}), + (MockEncoderDecoderModel, { + 'network': { + 'MockEncoder': {'input_size': 5}, + 'MockDecoder': {'output_size': 5}, + }, + 'optimizer': {'lr': 0.003}} + ) + ] + ) + def test_from_config(self, + definition, + config, + ): + model = vak.models.factory.ModelFactory( + definition=definition, + family=MockModelFamily + ) + new_model_instance = model.from_config(config) + + assert isinstance(new_model_instance, MockModelFamily) + + if 'network' in config: + if inspect.isclass(definition.network): + for network_kwarg, network_kwargval in config['network'].items(): + assert hasattr(new_model_instance.network, network_kwarg) + assert getattr(new_model_instance.network, network_kwarg) == network_kwargval + elif isinstance(definition.network, dict): + for net_name, net_kwargs in config['network'].items(): + for network_kwarg, network_kwargval in net_kwargs.items(): + assert hasattr(new_model_instance.network[net_name], network_kwarg) + assert getattr(new_model_instance.network[net_name], network_kwarg) == network_kwargval + + if 'loss' in config: + for loss_kwarg, loss_kwargval in config['loss'].items(): + assert hasattr(new_model_instance.loss, loss_kwarg) + assert getattr(new_model_instance.loss, loss_kwarg) == loss_kwargval + + if 'optimizer' in config: + for optimizer_kwarg, optimizer_kwargval in config['optimizer'].items(): + assert optimizer_kwarg in new_model_instance.optimizer.param_groups[0] + assert new_model_instance.optimizer.param_groups[0][optimizer_kwarg] == optimizer_kwargval + + if 'metrics' in config: + for metric_name, metric_kwargs in config['metrics'].items(): + assert metric_name in new_model_instance.metrics + for metric_kwarg, metric_kwargval in metric_kwargs.items(): + assert hasattr(new_model_instance.metrics[metric_name], metric_kwarg) + assert getattr(new_model_instance.metrics[metric_name], metric_kwarg) == metric_kwargval + + @pytest.mark.parametrize( + 'labelmap, input_shape, definition', + TEST_WITH_FRAME_CLASSIFICATION_ARGVALS + ) + def test_from_config_frame_classification(self, labelmap, input_shape, definition): + model_factory = vak.models.factory.ModelFactory( + definition, + vak.models.FrameClassificationModel, + ) + num_input_channels, num_freqbins = input_shape[0], input_shape[1] + # network has required args that need to be determined dynamically + network = definition.network(len(labelmap), num_input_channels, num_freqbins) + model = model_factory.from_instances(network=network, labelmap=labelmap) + + # now test that attributes are what we expect + assert isinstance(model, vak.models.FrameClassificationModel) + for attr in ('network', 'loss', 'optimizer', 'metrics'): + assert hasattr(model, attr) + model_attr = getattr(model, attr) + definition_attr = getattr(definition, attr) + if inspect.isclass(definition_attr): + assert isinstance(model_attr, definition_attr) + elif isinstance(definition_attr, dict): + assert isinstance(model_attr, dict) + for definition_key, definition_val in definition_attr.items(): + assert definition_key in model_attr + model_val = model_attr[definition_key] + if inspect.isclass(definition_val): + assert isinstance(model_val, definition_val) + else: + assert callable(definition_val) + assert model_val is definition_val + else: + # must be a function + assert callable(model_attr) + assert model_attr is definition_attr + + @pytest.mark.parametrize( + 'definition', + [ + TweetyNetDefinition, + ] + ) + def test_from_config_with_frame_classification(self, definition, specific_config_toml_path): + model_name = definition.__name__.replace('Definition', '') + toml_path = specific_config_toml_path('train', model_name, audio_format='cbin', annot_format='notmat') + cfg = vak.config.Config.from_toml_path(toml_path) + + # stuff we need just to be able to instantiate network + labelmap = vak.common.labels.to_map(cfg.prep.labelset, map_unlabeled=True) + + model_factory = vak.models.factory.ModelFactory( + definition, + vak.models.FrameClassificationModel, + ) + + config = cfg.train.model.asdict() + num_input_channels, num_freqbins = MOCK_INPUT_SHAPE[0], MOCK_INPUT_SHAPE[1] + + config["network"].update( + num_classes=len(labelmap), + num_input_channels=num_input_channels, + num_freqbins=num_freqbins + ) + + model = model_factory.from_config(config=config, labelmap=labelmap) + assert isinstance(model, vak.models.FrameClassificationModel) + + # below, we can only test the config kwargs that actually end up as attributes + # so we use `if hasattr` before checking + if 'network' in config: + if inspect.isclass(definition.network): + for network_kwarg, network_kwargval in config['network'].items(): + if hasattr(model.network, network_kwarg): + assert getattr(model.network, network_kwarg) == network_kwargval + elif isinstance(definition.network, dict): + for net_name, net_kwargs in config['network'].items(): + for network_kwarg, network_kwargval in net_kwargs.items(): + if hasattr(model.network[net_name], network_kwarg): + assert getattr(model.network[net_name], network_kwarg) == network_kwargval + + if 'loss' in config: + for loss_kwarg, loss_kwargval in config['loss'].items(): + if hasattr(model.loss, loss_kwarg): + assert getattr(model.loss, loss_kwarg) == loss_kwargval + + if 'optimizer' in config: + for optimizer_kwarg, optimizer_kwargval in config['optimizer'].items(): + if optimizer_kwarg in model.optimizer.param_groups[0]: + assert model.optimizer.param_groups[0][optimizer_kwarg] == optimizer_kwargval + + if 'metrics' in config: + for metric_name, metric_kwargs in config['metrics'].items(): + assert metric_name in model.metrics + for metric_kwarg, metric_kwargval in metric_kwargs.items(): + if hasattr(model.metrics[metric_name], metric_kwarg): + assert getattr(model.metrics[metric_name], metric_kwarg) == metric_kwargval + + @pytest.mark.parametrize( + 'input_shape, definition', + [ + ((1, 128, 128), ConvEncoderUMAPDefinition), + ] + ) + def test_from_instances_parametric_umap( + self, + input_shape, + definition, + ): + network = {'encoder': vak.nets.ConvEncoder(input_shape)} + + model_factory = vak.models.ModelFactory( + definition, + vak.models.ParametricUMAPModel, + ) + model = model_factory.from_instances(network=network) + + # now test that attributes are what we expect + assert isinstance(model, vak.models.ParametricUMAPModel) + for attr in ('network', 'loss', 'optimizer', 'metrics'): + assert hasattr(model, attr) + model_attr = getattr(model, attr) + definition_attr = getattr(definition, attr) + if inspect.isclass(definition_attr): + assert isinstance(model_attr, definition_attr) + elif isinstance(definition_attr, dict): + assert isinstance(model_attr, (dict, torch.nn.ModuleDict)) + for definition_key, definition_val in definition_attr.items(): + assert definition_key in model_attr + model_val = model_attr[definition_key] + if inspect.isclass(definition_val): + assert isinstance(model_val, definition_val) + else: + assert callable(definition_val) + assert model_val is definition_val + else: + # must be a function + assert callable(model_attr) + assert model_attr is definition_attr + + @pytest.mark.parametrize( + 'input_shape, definition', + [ + ((1, 128, 128), ConvEncoderUMAPDefinition), + ] + ) + def test_from_config_with_parametric_umap( + self, + input_shape, + definition, + specific_config_toml_path, + ): + model_name = definition.__name__.replace('Definition', '') + toml_path = specific_config_toml_path('train', model_name, audio_format='cbin', annot_format='notmat') + cfg = vak.config.Config.from_toml_path(toml_path) + + model_factory = vak.models.ModelFactory( + definition, + vak.models.ParametricUMAPModel, + ) + + config = cfg.train.model.asdict() + config["network"]["encoder"]["input_shape"] = input_shape + + model = model_factory.from_config(config=config) + assert isinstance(model, vak.models.ParametricUMAPModel) + + if 'network' in config: + if inspect.isclass(definition.network): + for network_kwarg, network_kwargval in config['network'].items(): + assert hasattr(model.network, network_kwarg) + assert getattr(model.network, network_kwarg) == network_kwargval + elif isinstance(definition.network, dict): + for net_name, net_kwargs in config['network'].items(): + for network_kwarg, network_kwargval in net_kwargs.items(): + network = model.network[net_name] + if hasattr(network, network_kwarg): + assert getattr(network, network_kwarg) == network_kwargval + + if 'loss' in config: + for loss_kwarg, loss_kwargval in config['loss'].items(): + assert hasattr(model.loss, loss_kwarg) + assert getattr(model.loss, loss_kwarg) == loss_kwargval + + if 'optimizer' in config: + for optimizer_kwarg, optimizer_kwargval in config['optimizer'].items(): + assert optimizer_kwarg in model.optimizer.param_groups[0] + assert model.optimizer.param_groups[0][optimizer_kwarg] == optimizer_kwargval + + if 'metrics' in config: + for metric_name, metric_kwargs in config['metrics'].items(): + assert metric_name in model.metrics + for metric_kwarg, metric_kwargval in metric_kwargs.items(): + assert hasattr(model.metrics[metric_name], metric_kwarg) + assert getattr(model.metrics[metric_name], metric_kwarg) == metric_kwargval diff --git a/tests/test_models/test_frame_classification_model.py b/tests/test_models/test_frame_classification_model.py index c66dbcb47..3c1363496 100644 --- a/tests/test_models/test_frame_classification_model.py +++ b/tests/test_models/test_frame_classification_model.py @@ -1,138 +1,91 @@ -import inspect -import itertools +import copy +from .test_definition import TweetyNetDefinition + import pytest import torch import vak.models -from .test_definition import ( - TweetyNetDefinition, -) -from .test_tweetynet import LABELMAPS, INPUT_SHAPES - - -# pytest.mark.parametrize vals for test_init_with_definition -MODEL_DEFS = ( - TweetyNetDefinition, -) - -TEST_INIT_ARGVALS = itertools.product(LABELMAPS, INPUT_SHAPES, MODEL_DEFS) - - class TestFrameClassificationModel: - @pytest.mark.parametrize( - 'labelmap, input_shape, definition', - TEST_INIT_ARGVALS - ) - def test_init(self, - labelmap, - input_shape, - definition, - monkeypatch): - """Test FrameClassificationModel.__init__ works as expected""" - # monkeypatch a definition so we can test __init__ - definition = vak.models.definition.validate(definition) - monkeypatch.setattr( - vak.models.FrameClassificationModel, - 'definition', - definition, - raising=False - ) - num_input_channels, num_freqbins = input_shape[0], input_shape[1] - # network has required args that need to be determined dynamically - network = definition.network(len(labelmap), num_input_channels, num_freqbins) - model = vak.models.FrameClassificationModel(labelmap=labelmap, network=network) - - # now test that attributes are what we expect - assert isinstance(model, vak.models.FrameClassificationModel) - for attr in ('network', 'loss', 'optimizer', 'metrics'): - assert hasattr(model, attr) - model_attr = getattr(model, attr) - definition_attr = getattr(definition, attr) - if inspect.isclass(definition_attr): - assert isinstance(model_attr, definition_attr) - elif isinstance(definition_attr, dict): - assert isinstance(model_attr, dict) - for definition_key, definition_val in definition_attr.items(): - assert definition_key in model_attr - model_val = model_attr[definition_key] - if inspect.isclass(definition_val): - assert isinstance(model_val, definition_val) - else: - assert callable(definition_val) - assert model_val is definition_val - else: - # must be a function - assert callable(model_attr) - assert model_attr is definition_attr - - MOCK_INPUT_SHAPE = torch.Size([1, 128, 44]) + MODEL_DEFINITION_MAP = { + 'TweetyNet': TweetyNetDefinition, + } @pytest.mark.parametrize( - 'definition', + 'model_name', [ - TweetyNetDefinition, + 'TweetyNet', ] ) - def test_from_config(self, - definition, - # our fixtures - specific_config_toml_path, - # pytest fixtures - monkeypatch, - ): - definition = vak.models.definition.validate(definition) - model_name = definition.__name__.replace('Definition', '') - toml_path = specific_config_toml_path('train', model_name, audio_format='cbin', annot_format='notmat') - cfg = vak.config.Config.from_toml_path(toml_path) + def test_load_state_dict_from_path(self, + model_name, + specific_config_toml_path, + device + ): + """Smoke test that makes sure ``load_state_dict_from_path`` runs without failure. + + We use actual model definitions here so we can test with real checkpoints. + """ + definition = self.MODEL_DEFINITION_MAP[model_name] + train_toml_path = specific_config_toml_path('train', model_name, audio_format='cbin', annot_format='notmat') + train_cfg = vak.config.Config.from_toml_path(train_toml_path) # stuff we need just to be able to instantiate network - labelmap = vak.common.labels.to_map(cfg.prep.labelset, map_unlabeled=True) - - monkeypatch.setattr( - vak.models.FrameClassificationModel, 'definition', definition, raising=False + labelmap = vak.common.labels.to_map(train_cfg.prep.labelset, map_unlabeled=True) + item_transform = vak.transforms.defaults.get_default_transform( + model_name, + "train", + transform_kwargs={}, ) - - config = cfg.train.model.asdict() - num_input_channels, num_freqbins = self.MOCK_INPUT_SHAPE[0], self.MOCK_INPUT_SHAPE[1] - - config["network"].update( - num_classes=len(labelmap), - num_input_channels=num_input_channels, - num_freqbins=num_freqbins + train_dataset = vak.datasets.frame_classification.WindowDataset.from_dataset_path( + dataset_path=train_cfg.train.dataset.path, + split="train", + window_size=train_cfg.train.dataset.params['window_size'], + item_transform=item_transform, + ) + input_shape = train_dataset.shape + num_input_channels = input_shape[-3] + num_freqbins = input_shape[-2] + + # network is the one thing that has required args + # and we also need to use its config from the toml file + cfg = vak.config.Config.from_toml_path(train_toml_path) + model_config = cfg.train.model.asdict() + network = definition.network(num_classes=len(labelmap), + num_input_channels=num_input_channels, + num_freqbins=num_freqbins, + **model_config['network']) + model_factory = vak.models.factory.ModelFactory( + definition, + vak.models.FrameClassificationModel, + ) + model = model_factory.from_instances(network=network, labelmap=labelmap) + model.to(device) + + eval_toml_path = specific_config_toml_path('eval', model_name, audio_format='cbin', annot_format='notmat') + eval_cfg = vak.config.Config.from_toml_path(eval_toml_path) + checkpoint_path = eval_cfg.eval.checkpoint_path + + # ---- actually test method + sd_before = copy.deepcopy(model.state_dict()) + sd_before = { + k: v.to(device) for k, v in sd_before.items() + } + ckpt = torch.load(checkpoint_path) + sd_to_be_loaded = ckpt['state_dict'] + sd_to_be_loaded = { + k: v.to(device) for k, v in sd_to_be_loaded.items() + } + + model.load_state_dict_from_path(checkpoint_path) + + assert not all([ + torch.all(torch.eq(val, before_val)) + for val, before_val in zip(model.state_dict().values(), sd_before.values())] + ) + assert all([ + torch.all(torch.eq(val, before_val)) + for val, before_val in zip(model.state_dict().values(), sd_to_be_loaded.values())] ) - - model = vak.models.FrameClassificationModel.from_config(config=config, labelmap=labelmap) - assert isinstance(model, vak.models.FrameClassificationModel) - - # below, we can only test the config kwargs that actually end up as attributes - # so we use `if hasattr` before checking - if 'network' in config: - if inspect.isclass(definition.network): - for network_kwarg, network_kwargval in config['network'].items(): - if hasattr(model.network, network_kwarg): - assert getattr(model.network, network_kwarg) == network_kwargval - elif isinstance(definition.network, dict): - for net_name, net_kwargs in config['network'].items(): - for network_kwarg, network_kwargval in net_kwargs.items(): - if hasattr(model.network[net_name], network_kwarg): - assert getattr(model.network[net_name], network_kwarg) == network_kwargval - - if 'loss' in config: - for loss_kwarg, loss_kwargval in config['loss'].items(): - if hasattr(model.loss, loss_kwarg): - assert getattr(model.loss, loss_kwarg) == loss_kwargval - - if 'optimizer' in config: - for optimizer_kwarg, optimizer_kwargval in config['optimizer'].items(): - if optimizer_kwarg in model.optimizer.param_groups[0]: - assert model.optimizer.param_groups[0][optimizer_kwarg] == optimizer_kwargval - - if 'metrics' in config: - for metric_name, metric_kwargs in config['metrics'].items(): - assert metric_name in model.metrics - for metric_kwarg, metric_kwargval in metric_kwargs.items(): - if hasattr(model.metrics[metric_name], metric_kwarg): - assert getattr(model.metrics[metric_name], metric_kwarg) == metric_kwargval diff --git a/tests/test_models/test_parametric_umap_model.py b/tests/test_models/test_parametric_umap_model.py index eba4f77d1..74bada3f6 100644 --- a/tests/test_models/test_parametric_umap_model.py +++ b/tests/test_models/test_parametric_umap_model.py @@ -1,130 +1,89 @@ -import inspect +import copy import pytest import torch import vak.models +from .test_factory import ConvEncoderUMAPDefinition -class ConvEncoderUMAPDefinition: - network = {"encoder": vak.nets.ConvEncoder} - loss = vak.nn.UmapLoss - optimizer = torch.optim.AdamW - metrics = { - "acc": vak.metrics.Accuracy, - "levenshtein": vak.metrics.Levenshtein, - "character_error_rate": vak.metrics.CharacterErrorRate, - "loss": torch.nn.CrossEntropyLoss, - } - default_config = { - "optimizer": {"lr": 1e-3}, - } class TestParametricUMAPModel: + MODEL_DEFINITION_MAP = { + 'ConvEncoderUMAP': ConvEncoderUMAPDefinition, + } + @pytest.mark.parametrize( - 'input_shape, definition', + 'model_name', [ - ((1, 128, 128), ConvEncoderUMAPDefinition), + 'ConvEncoderUMAP', ] ) - def test_init( - self, - input_shape, - definition, - monkeypatch, - ): - """Test ParametricUMAPModel.__init__ works as expected""" - # monkeypatch a definition so we can test __init__ - definition = vak.models.definition.validate(definition) - monkeypatch.setattr( - vak.models.ParametricUMAPModel, - 'definition', - definition, - raising=False + def test_load_state_dict_from_path(self, + model_name, + specific_config_toml_path, + device + ): + """Smoke test that makes sure ``load_state_dict_from_path`` runs without failure. + + We use actual model definitions here so we can test with real checkpoints. + """ + definition = self.MODEL_DEFINITION_MAP[model_name] + train_toml_path = specific_config_toml_path('train', model_name, audio_format='cbin', annot_format='notmat') + train_cfg = vak.config.Config.from_toml_path(train_toml_path) + + # stuff we need just to be able to instantiate network + item_transform = vak.transforms.defaults.get_default_transform( + model_name, + "train", + transform_kwargs={}, + ) + train_dataset = vak.datasets.parametric_umap.ParametricUMAPDataset.from_dataset_path( + dataset_path=train_cfg.train.dataset.path, + split="train", + transform=item_transform, ) - network = {'encoder': vak.nets.ConvEncoder(input_shape)} - model = vak.models.ParametricUMAPModel(network=network) - - # now test that attributes are what we expect - assert isinstance(model, vak.models.ParametricUMAPModel) - for attr in ('network', 'loss', 'optimizer', 'metrics'): - assert hasattr(model, attr) - model_attr = getattr(model, attr) - definition_attr = getattr(definition, attr) - if inspect.isclass(definition_attr): - assert isinstance(model_attr, definition_attr) - elif isinstance(definition_attr, dict): - assert isinstance(model_attr, dict) - for definition_key, definition_val in definition_attr.items(): - assert definition_key in model_attr - model_val = model_attr[definition_key] - if inspect.isclass(definition_val): - assert isinstance(model_val, definition_val) - else: - assert callable(definition_val) - assert model_val is definition_val - else: - # must be a function - assert callable(model_attr) - assert model_attr is definition_attr - @pytest.mark.xfail - @pytest.mark.parametrize( - 'input_shape, definition', - [ - ((1, 128, 128), ConvEncoderUMAPDefinition), - ] - ) - def test_from_config( - self, - input_shape, + # network is the one thing that has required args + # and we also need to use its config from the toml file + cfg = vak.config.Config.from_toml_path(train_toml_path) + model_config = cfg.train.model.asdict() + network = { + 'encoder': definition.network['encoder']( + input_shape=train_dataset.shape, + **model_config['network']['encoder'] + ) + } + model_factory = vak.models.factory.ModelFactory( definition, - specific_config_toml_path, - monkeypatch, - ): - definition = vak.models.definition.validate(definition) - model_name = definition.__name__.replace('Definition', '') - toml_path = specific_config_toml_path('train', model_name, audio_format='cbin', annot_format='notmat') - cfg = vak.config.Config.from_toml_path(toml_path) - - monkeypatch.setattr( - vak.models.ParametricUMAPModel, 'definition', definition, raising=False + vak.models.ParametricUMAPModel, ) - - config = cfg.train.model.asdict() - config["network"].update( - encoder=dict(input_shape=input_shape) + model = model_factory.from_instances(network=network) + model.to(device) + eval_toml_path = specific_config_toml_path('eval', model_name, audio_format='cbin', annot_format='notmat') + eval_cfg = vak.config.Config.from_toml_path(eval_toml_path) + checkpoint_path = eval_cfg.eval.checkpoint_path + + # ---- actually test method + sd_before = copy.deepcopy(model.state_dict()) + sd_before = { + k: v.to(device) for k, v in sd_before.items() + } + ckpt = torch.load(checkpoint_path) + sd_to_be_loaded = ckpt['state_dict'] + sd_to_be_loaded = { + k: v.to(device) for k, v in sd_to_be_loaded.items() + } + + model.load_state_dict_from_path(checkpoint_path) + + assert not all([ + torch.all(torch.eq(val, before_val)) + for val, before_val in zip(model.state_dict().values(), sd_before.values())] + ) + assert all([ + torch.all(torch.eq(val, before_val)) + for val, before_val in zip(model.state_dict().values(), sd_to_be_loaded.values())] ) - - model = vak.models.ParametricUMAPModel.from_config(config=config) - assert isinstance(model, vak.models.ParametricUMAPModel) - - if 'network' in config: - if inspect.isclass(definition.network): - for network_kwarg, network_kwargval in config['network'].items(): - assert hasattr(model.network, network_kwarg) - assert getattr(model.network, network_kwarg) == network_kwargval - elif isinstance(definition.network, dict): - for net_name, net_kwargs in config['network'].items(): - for network_kwarg, network_kwargval in net_kwargs.items(): - assert hasattr(model.network[net_name], network_kwarg) - assert getattr(model.network[net_name], network_kwarg) == network_kwargval - - if 'loss' in config: - for loss_kwarg, loss_kwargval in config['loss'].items(): - assert hasattr(model.loss, loss_kwarg) - assert getattr(model.loss, loss_kwarg) == loss_kwargval - - if 'optimizer' in config: - for optimizer_kwarg, optimizer_kwargval in config['optimizer'].items(): - assert optimizer_kwarg in model.optimizer.param_groups[0] - assert model.optimizer.param_groups[0][optimizer_kwarg] == optimizer_kwargval - - if 'metrics' in config: - for metric_name, metric_kwargs in config['metrics'].items(): - assert metric_name in model.metrics - for metric_kwarg, metric_kwargval in metric_kwargs.items(): - assert hasattr(model.metrics[metric_name], metric_kwarg) - assert getattr(model.metrics[metric_name], metric_kwarg) == metric_kwargval diff --git a/tests/test_models/test_registry.py b/tests/test_models/test_registry.py index af8f8cfc6..b2d597cc2 100644 --- a/tests/test_models/test_registry.py +++ b/tests/test_models/test_registry.py @@ -39,17 +39,18 @@ def test_register_model(family, definition): """Test that :func:`vak.models.registry.register_model` adds a model to the registry""" # to set up, we repeat what :func:`vak.models.decorator.model` does - attributes = dict(family.__dict__) - attributes.update({"definition": definition}) - subclass_name = definition.__name__ - subclass = type(subclass_name, (family,), attributes) - subclass.__module__ = definition.__module__ + model_name = definition.__name__ + model_factory = vak.models.factory.ModelFactory( + definition, + family + ) + model_factory.__name__ = model_name - assert subclass_name not in vak.models.registry.MODEL_REGISTRY - vak.models.registry.register_model(subclass) - assert subclass_name in vak.models.registry.MODEL_REGISTRY - assert vak.models.registry.MODEL_REGISTRY[subclass_name] == subclass - del vak.models.registry.MODEL_REGISTRY[subclass_name] # so this test doesn't fail for the second case + assert model_name not in vak.models.registry.MODEL_REGISTRY + vak.models.registry.register_model(model_factory) + assert model_name in vak.models.registry.MODEL_REGISTRY + assert vak.models.registry.MODEL_REGISTRY[model_name] == model_factory + del vak.models.registry.MODEL_REGISTRY[model_name] # so this test doesn't fail for the second case def test_register_model_raises_family(): @@ -57,14 +58,13 @@ def test_register_model_raises_family(): raises an error if parent class is not in model_family_classes""" # to set up, we repeat what :func:`vak.models.decorator.model` does, # but notice that we use an unregistered model family - attributes = dict(UnregisteredMockModelFamily.__dict__) - attributes.update({"definition": MockModel}) - subclass_name = MockModel.__name__ - subclass = type(subclass_name, (UnregisteredMockModelFamily,), attributes) - subclass.__module__ = MockModel.__module__ + model_factory = vak.models.ModelFactory( + MockModel, + UnregisteredMockModelFamily, + ) with pytest.raises(TypeError): - vak.models.registry.register_model(subclass) + vak.models.registry.register_model(model_factory) @pytest.mark.parametrize( @@ -77,31 +77,34 @@ def test_register_model_raises_registered(family, definition): """Test that :func:`vak.models.registry.register_model` raises an error if a class is already registered""" # to set up, we repeat what :func:`vak.models.decorator.model` does - attributes = dict(family.__dict__) - attributes.update({"definition": definition}) + model_factory = vak.models.ModelFactory( + definition, + family + ) + # NOTE we replace 'Definition' with an empty string # so that the name clashes with an existing model name - subclass_name = definition.__name__.replace('Definition', '') - subclass = type(subclass_name, (family,), attributes) - subclass.__module__ = definition.__module__ + model_factory.__name__ = definition.__name__.replace('Definition', '') with pytest.raises(ValueError): - vak.models.registry.register_model(subclass) + vak.models.registry.register_model(model_factory) def test___get_attr__MODEL_FAMILY_FROM_NAME(): assert hasattr(vak.models.registry, 'MODEL_FAMILY_FROM_NAME') - attr = getattr(vak.models.registry, 'MODEL_FAMILY_FROM_NAME') - assert isinstance(attr, dict) - for model_name, model_class in vak.models.registry.MODEL_REGISTRY.items(): - model_parent_class = inspect.getmro(model_class)[1] - family_name = model_parent_class.__name__ - assert attr[model_name] == family_name + + model_family_from_name_dict = getattr(vak.models.registry, 'MODEL_FAMILY_FROM_NAME') + assert isinstance(model_family_from_name_dict, dict) + + for model_name, model_factory in vak.models.registry.MODEL_REGISTRY.items(): + model_family = model_factory.family + family_name = model_family.__name__ + assert model_family_from_name_dict[model_name] == family_name def test___get_attr__MODEL_NAMES(): assert hasattr(vak.models.registry, 'MODEL_NAMES') - attr = getattr(vak.models.registry, 'MODEL_NAMES') - assert isinstance(attr, list) + model_names_list = getattr(vak.models.registry, 'MODEL_NAMES') + assert isinstance(model_names_list, list) for model_name in vak.models.registry.MODEL_REGISTRY.keys(): - assert model_name in attr + assert model_name in model_names_list diff --git a/tests/test_models/test_tweetynet.py b/tests/test_models/test_tweetynet.py index ce219ea9a..c0e8f8378 100644 --- a/tests/test_models/test_tweetynet.py +++ b/tests/test_models/test_tweetynet.py @@ -1,7 +1,6 @@ import itertools import pytest -import lightning import vak @@ -31,13 +30,6 @@ class TestTweetyNet: - def test_model_is_decorated(self): - assert issubclass(vak.models.TweetyNet, - vak.models.FrameClassificationModel) - assert issubclass(vak.models.TweetyNet, - vak.models.base.Model) - assert issubclass(vak.models.TweetyNet, - lightning.pytorch.LightningModule) @pytest.mark.parametrize( 'labelmap, input_shape', @@ -48,8 +40,8 @@ def test_init(self, labelmap, input_shape): num_input_channels = input_shape[-3] num_freqbins = input_shape[-2] network = vak.models.TweetyNet.definition.network(len(labelmap), num_input_channels, num_freqbins) - model = vak.models.TweetyNet(labelmap=labelmap, network=network) - assert isinstance(model, vak.models.TweetyNet) + model = vak.models.TweetyNet.from_instances(network=network, labelmap=labelmap) + assert isinstance(model, vak.models.FrameClassificationModel) for attr in ('network', 'loss', 'optimizer'): assert hasattr(model, attr) assert isinstance(getattr(model, attr),