From a3934c27c1f576734b74e25614054ee38af4cce1 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Fri, 1 Sep 2023 12:27:43 -0400 Subject: [PATCH 01/27] Initial start implementation --- src/sparseml/core/__init__.py | 24 +++ src/sparseml/core/data/__init__.py | 15 ++ src/sparseml/core/data/base.py | 36 ++++ src/sparseml/core/data/pytorch.py | 137 ++++++++++++++ src/sparseml/core/event.py | 59 ++++++ src/sparseml/core/framework.py | 113 +++++++++++ src/sparseml/core/model/__init__.py | 15 ++ src/sparseml/core/model/base.py | 47 +++++ src/sparseml/core/model/pytorch.py | 50 +++++ src/sparseml/core/modifier/__init__.py | 19 ++ src/sparseml/core/modifier/base.py | 43 +++++ src/sparseml/core/modifier/factory.py | 29 +++ src/sparseml/core/modifier/modifier.py | 137 ++++++++++++++ src/sparseml/core/modifier/stage.py | 36 ++++ src/sparseml/core/optimizer/__init__.py | 15 ++ src/sparseml/core/optimizer/base.py | 52 ++++++ src/sparseml/core/optimizer/pytorch.py | 59 ++++++ src/sparseml/core/recipe/__init__.py | 19 ++ src/sparseml/core/recipe/args.py | 97 ++++++++++ src/sparseml/core/recipe/metadata.py | 80 ++++++++ src/sparseml/core/recipe/modifier.py | 49 +++++ src/sparseml/core/recipe/recipe.py | 91 +++++++++ src/sparseml/core/recipe/stage.py | 99 ++++++++++ src/sparseml/core/session.py | 233 +++++++++++++++++++++++ src/sparseml/core/state.py | 60 ++++++ src/sparseml/integrations/__init__.py | 13 ++ src/sparseml/modifiers/__init__.py | 13 ++ src/sparseml/tools/__init__.py | 13 ++ src/sparseml/utils/pytorch/__init__.py | 15 ++ src/sparseml/utils/pytorch/module.py | 238 ++++++++++++++++++++++++ 30 files changed, 1906 insertions(+) create mode 100644 src/sparseml/core/__init__.py create mode 100644 src/sparseml/core/data/__init__.py create mode 100644 src/sparseml/core/data/base.py create mode 100644 src/sparseml/core/data/pytorch.py create mode 100644 src/sparseml/core/event.py create mode 100644 src/sparseml/core/framework.py create mode 100644 src/sparseml/core/model/__init__.py create mode 100644 src/sparseml/core/model/base.py create mode 100644 src/sparseml/core/model/pytorch.py create mode 100644 src/sparseml/core/modifier/__init__.py create mode 100644 src/sparseml/core/modifier/base.py create mode 100644 src/sparseml/core/modifier/factory.py create mode 100644 src/sparseml/core/modifier/modifier.py create mode 100644 src/sparseml/core/modifier/stage.py create mode 100644 src/sparseml/core/optimizer/__init__.py create mode 100644 src/sparseml/core/optimizer/base.py create mode 100644 src/sparseml/core/optimizer/pytorch.py create mode 100644 src/sparseml/core/recipe/__init__.py create mode 100644 src/sparseml/core/recipe/args.py create mode 100644 src/sparseml/core/recipe/metadata.py create mode 100644 src/sparseml/core/recipe/modifier.py create mode 100644 src/sparseml/core/recipe/recipe.py create mode 100644 src/sparseml/core/recipe/stage.py create mode 100644 src/sparseml/core/session.py create mode 100644 src/sparseml/core/state.py create mode 100644 src/sparseml/integrations/__init__.py create mode 100644 src/sparseml/modifiers/__init__.py create mode 100644 src/sparseml/tools/__init__.py create mode 100644 src/sparseml/utils/pytorch/__init__.py create mode 100644 src/sparseml/utils/pytorch/module.py diff --git a/src/sparseml/core/__init__.py b/src/sparseml/core/__init__.py new file mode 100644 index 00000000000..9f98c81ee1f --- /dev/null +++ b/src/sparseml/core/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .data import * +from .model import * +from .modifier import * +from .optimizer import * +from .recipe import * + +from .event import * +from .framework import * +from .session import * +from .state import * diff --git a/src/sparseml/core/data/__init__.py b/src/sparseml/core/data/__init__.py new file mode 100644 index 00000000000..87930811c41 --- /dev/null +++ b/src/sparseml/core/data/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import * diff --git a/src/sparseml/core/data/base.py b/src/sparseml/core/data/base.py new file mode 100644 index 00000000000..ebbe38a93ed --- /dev/null +++ b/src/sparseml/core/data/base.py @@ -0,0 +1,36 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pydantic import BaseModel +from typing import TypeVar, Generic, Union, Any, List + +from sparseml.core.framework import MultiFrameworkObject + +__all__ = ["ModifiableData"] + +DT = TypeVar("DT") # Dataset Type + + +class ModifiableData(Generic[DT], MultiFrameworkObject, BaseModel): + data: DT = None + num_samples: int = None + + def get_num_batches(self) -> int: + raise NotImplementedError() + + def set_batch_size(self, batch_size: int): + raise NotImplementedError() + + def get_batch_size(self) -> int: + raise NotImplementedError() diff --git a/src/sparseml/core/data/pytorch.py b/src/sparseml/core/data/pytorch.py new file mode 100644 index 00000000000..afbf9893a0a --- /dev/null +++ b/src/sparseml/core/data/pytorch.py @@ -0,0 +1,137 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from sparseml.core.data.base import ModifiableData +from torch.utils.data import DataLoader +from typing import Mapping, Sequence + +__all__ = ["ModifiableDataPyTorch", "DynamicBatchSizeDataLoader"] + + +class DynamicBatchSizeDataLoader: + def __init__(self, data_loader: DataLoader): + self.data_loader = data_loader + self.current_batch_size = data_loader.batch_size + + def __iter__(self): + if self.current_batch_size == self.data_loader.batch_size: + yield from self.data_loader + elif self.current_batch_size < self.data_loader.batch_size: + yield from self._data_split_iter() + else: + yield from self._data_merge_iter() + + def set_batch_size(self, batch_size: int): + self.current_batch_size = batch_size + + def get_batch_size(self) -> int: + return self.current_batch_size + + def _data_split_iter(self): + if self.current_batch_size >= self.data_loader.batch_size: + raise ValueError( + "Current batch size must be less than the original batch size" + ) + + for batch in self.data_loader: + num_splits = self.data_loader.batch_size // self.current_batch_size + for i in range(num_splits): + start_idx = i * self.current_batch_size + end_idx = (i + 1) * self.current_batch_size + yield DynamicBatchSizeDataLoader.split_batch(batch, start_idx, end_idx) + + def _data_merge_iter(self): + if self.current_batch_size <= self.data_loader.batch_size: + raise ValueError( + "Current batch size must be greater than the original batch size" + ) + + buffer = [] + buffer_size = 0 + for batch in self.data_loader: + buffer.append(batch) + buffer_size += len(batch) + while buffer_size >= self.current_batch_size: + merged = DynamicBatchSizeDataLoader.merge_batches(buffer) + yield DynamicBatchSizeDataLoader.split_batch( + merged, 0, self.current_batch_size + ) + buffer = [ + DynamicBatchSizeDataLoader.split_batch( + merged, self.current_batch_size, buffer_size + ) + ] + buffer_size -= self.current_batch_size + + @staticmethod + def split_batch(batch, start_idx, end_idx): + """ + Splits a batch based on its type (Tensor, Mapping, Sequence) and the provided indices. + """ + if isinstance(batch, torch.Tensor): + return batch[start_idx:end_idx] + elif isinstance(batch, Mapping): + return { + key: DynamicBatchSizeDataLoader.split_batch(value, start_idx, end_idx) + for key, value in batch.items() + } + elif isinstance(batch, Sequence): + return [ + DynamicBatchSizeDataLoader.split_batch(item, start_idx, end_idx) + for item in batch + ] + else: + raise TypeError(f"Unsupported batch type: {type(batch)}") + + @staticmethod + def merge_batches(batches): + """ + Merges a sequence of batches into a single batch. + """ + sample_batch = batches[0] + if isinstance(sample_batch, torch.Tensor): + return torch.cat(batches, dim=0) + elif isinstance(sample_batch, Mapping): + return { + key: DynamicBatchSizeDataLoader.merge_batches( + [batch[key] for batch in batches] + ) + for key in sample_batch.keys() + } + elif isinstance(sample_batch, Sequence): + return [ + DynamicBatchSizeDataLoader.merge_batches( + [batch[i] for batch in batches] + ) + for i in range(len(sample_batch)) + ] + else: + raise TypeError(f"Unsupported batch type: {type(sample_batch)}") + + +class ModifiableDataPyTorch(ModifiableData[DynamicBatchSizeDataLoader]): + def __init__(self, data_loader: DataLoader): + super().__init__() + self.data = DynamicBatchSizeDataLoader(data_loader) + + def get_num_batches(self) -> int: + return self.num_samples // self.data.get_batch_size() + + def set_batch_size(self, batch_size: int): + self.data.set_batch_size(batch_size) + + def get_batch_size(self) -> int: + return self.data.get_batch_size() diff --git a/src/sparseml/core/event.py b/src/sparseml/core/event.py new file mode 100644 index 00000000000..886b13d336e --- /dev/null +++ b/src/sparseml/core/event.py @@ -0,0 +1,59 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from dataclasses import dataclass + + +__all__ = ["EventType", "Event"] + + +class EventType(Enum): + # training lifecycle + PRE_INIT = "pre_init" + INITIALIZE = "initialize" + FINALIZE = "finalize" + + # step lifecycle + BATCH_START = "batch_start" + LOSS_CALCULATED = "loss_calculated" + OPTIM_PRE_STEP = "optim_pre_step" + OPTIM_POST_STEP = "optim_post_step" + BATCH_END = "batch_end" + + +@dataclass +class Event: + type_: EventType = EventType.PRE_INIT + + epoch_based: bool = None + steps_per_epoch: int = None + batches_per_step: int = None + + global_step: int = None + global_batch: int = None + epoch: int = None + epoch_step: int = None + epoch_batch: int = None + + def current_index(self) -> float: + if not self.epoch_based: + return self.global_step + + epoch = self.epoch + (self.epoch_step / self.steps_per_epoch) + + if epoch - self.epoch > 1.0: + raise ValueError("too many steps per epoch for epoch based event") + + return epoch diff --git a/src/sparseml/core/framework.py b/src/sparseml/core/framework.py new file mode 100644 index 00000000000..bb1776bfdbe --- /dev/null +++ b/src/sparseml/core/framework.py @@ -0,0 +1,113 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from enum import Enum +import importlib + +__all__ = ["Framework", "MultiFrameworkObject"] + + +class Framework(Enum): + general = "general" + pytorch = "pytorch" + tensorflow = "tensorflow" + onnx = "onnx" + keras = "keras" + jax = "jax" + + @classmethod + def from_str(cls, framework: str) -> "Framework": + framework = framework.lower().strip() + if framework == "general": + return cls.general + if framework == "pytorch": + return cls.pytorch + if framework == "tensorflow": + return cls.tensorflow + if framework == "onnx": + return cls.onnx + if framework == "keras": + return cls.keras + if framework == "jax": + return cls.jax + raise ValueError(f"Unknown framework: {framework}") + + def __str__(self): + return self.value + + def formatted(self) -> str: + if self == self.general: + return "General" + if self == self.pytorch: + return "PyTorch" + if self == self.tensorflow: + return "TensorFlow" + if self == self.onnx: + return "ONNX" + if self == self.keras: + return "Keras" + if self == self.jax: + return "JAX" + raise ValueError(f"Unknown framework: {self}") + + def class_name(self) -> str: + return self.formatted() if self != self.general else "" + + +class MultiFrameworkObject: + def __new__( + cls, + framework: Framework = None, + enable_experimental: bool = False, + **kwargs, + ): + if cls is MultiFrameworkObject: + raise TypeError("MultiFrameworkObject cannot be instantiated directly") + + instance = super(MultiFrameworkObject, cls).__new__(cls, **kwargs) + + package = instance.__class__.__module__.rsplit(".", 1)[0] + class_name = instance.__class__.__name__ + + if framework is None or framework == Framework.general: + return instance + + if enable_experimental: + # check under the experimental package first + try: + return MultiFrameworkObject.load_framework_class( + f"{package}.experimental.{str(framework)}", + f"{class_name}{framework.class_name()}", + )(**kwargs) + except ImportError: + pass + + # next check under the main package for the framework version + try: + return MultiFrameworkObject.load_framework_class( + f"{package}.{str(framework)}", f"{class_name}{framework.class_name()}" + )(**kwargs) + except ImportError: + pass + + # fall back on the class that was requested and + # fail later if it doesn't support that framework + return instance + + @staticmethod + def load_framework_class(package: str, class_name: str): + module = importlib.import_module(package) + + return getattr(module, class_name) diff --git a/src/sparseml/core/model/__init__.py b/src/sparseml/core/model/__init__.py new file mode 100644 index 00000000000..7df43946035 --- /dev/null +++ b/src/sparseml/core/model/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import ModifiableModel diff --git a/src/sparseml/core/model/base.py b/src/sparseml/core/model/base.py new file mode 100644 index 00000000000..3131aa529a9 --- /dev/null +++ b/src/sparseml/core/model/base.py @@ -0,0 +1,47 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pydantic import BaseModel +from typing import TypeVar, Generic, Union, List, Dict + +from sparseml.core.framework import MultiFrameworkObject + +__all__ = ["ModifiableModel"] + + +MT = TypeVar("MT") +LT = TypeVar("LT") +PT = TypeVar("PT") + + +class ModifiableModel(Generic[MT, LT, PT], MultiFrameworkObject, BaseModel): + model: MT = None + + def get_layers(self, targets: Union[str, List[str]]) -> Dict[str, LT]: + raise NotImplementedError() + + def get_layer(self, target: str) -> LT: + raise NotImplementedError() + + def set_layer(self, target: str, layer: LT): + raise NotImplementedError() + + def get_params(self, targets: Union[str, List[str]]) -> Dict[str, PT]: + raise NotImplementedError() + + def get_param(self, target: str) -> PT: + raise NotImplementedError() + + def set_param(self, target: str, param: PT): + raise NotImplementedError() diff --git a/src/sparseml/core/model/pytorch.py b/src/sparseml/core/model/pytorch.py new file mode 100644 index 00000000000..225b30fa711 --- /dev/null +++ b/src/sparseml/core/model/pytorch.py @@ -0,0 +1,50 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union, List, Dict, Tuple +from torch.nn import Module, Parameter + + +from sparseml.core.model.base import ModifiableModel +from sparseml.utils.pytorch import ( + get_layers, + get_layer, + set_layer, + get_params, + get_param, + set_param, +) + + +__all__ = ["ModifiableModelPyTorch"] + + +class ModifiableModelPyTorch(ModifiableModel[Module, Module, Parameter]): + def get_layers(self, targets: Union[str, List[str]]) -> Dict[str, Module]: + return get_layers(targets, self.model) + + def get_layer(self, target: str) -> Tuple[str, Module]: + return get_layer(target, self.model) + + def set_layer(self, target: str, layer: Module) -> Module: + return set_layer(target, layer, self.model) + + def get_params(self, targets: Union[str, List[str]]) -> Dict[str, Parameter]: + return get_params(targets, self.model) + + def get_param(self, target: str) -> Tuple[str, Parameter]: + return get_param(target, self.model) + + def set_param(self, target: str, param: Parameter): + return set_param(target, param, self.model) diff --git a/src/sparseml/core/modifier/__init__.py b/src/sparseml/core/modifier/__init__.py new file mode 100644 index 00000000000..6405fb2b97d --- /dev/null +++ b/src/sparseml/core/modifier/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import * +from .factory import * +from .modifier import * +from .recipe import * +from .stage import * diff --git a/src/sparseml/core/modifier/base.py b/src/sparseml/core/modifier/base.py new file mode 100644 index 00000000000..d9341dbc3ed --- /dev/null +++ b/src/sparseml/core/modifier/base.py @@ -0,0 +1,43 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC, abstractmethod + +from sparseml.core.event import Event +from sparseml.core.state import State + + +__all__ = ["ModifierInterface"] + + +class ModifierInterface(ABC): + def __init__(self, **kwargs): + pass + + @abstractmethod + def pre_initialize_structure(self, state: State, **kwargs): + pass + + @abstractmethod + def initialize(self, state: State, **kwargs): + pass + + @abstractmethod + def finalize(self, state: State, **kwargs): + pass + + @abstractmethod + def update_event(self, state: State, event: Event, **kwargs): + pass diff --git a/src/sparseml/core/modifier/factory.py b/src/sparseml/core/modifier/factory.py new file mode 100644 index 00000000000..018e2089f58 --- /dev/null +++ b/src/sparseml/core/modifier/factory.py @@ -0,0 +1,29 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sparseml.core.framework import Framework +from sparseml.core.modifier.modifier import Modifier + + +__all__ = ["ModifierFactory"] + + +class ModifierFactory: + @staticmethod + def refresh(): + raise NotImplementedError() + + @staticmethod + def create(type_: str, framework: Framework, **kwargs) -> Modifier: + raise NotImplementedError() diff --git a/src/sparseml/core/modifier/modifier.py b/src/sparseml/core/modifier/modifier.py new file mode 100644 index 00000000000..0f635bfa0de --- /dev/null +++ b/src/sparseml/core/modifier/modifier.py @@ -0,0 +1,137 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pydantic import BaseModel +from typing import Optional + +from abc import abstractmethod +from sparseml.core.framework import MultiFrameworkObject +from sparseml.core.modifier.base import ModifierInterface +from sparseml.core.event import Event, EventType +from sparseml.core.state import State + + +__all__ = ["Modifier"] + + +class Modifier(ModifierInterface, MultiFrameworkObject, BaseModel): + index: int = None + group: str = None + start: float + end: Optional[float] = None + update: Optional[float] = None + + _initialized: bool = False + _finalized: bool = False + _started: bool = False + _ended: bool = False + + def initialize(self, state: State, **kwargs): + if self._initialized: + return + + if self._finalized: + raise RuntimeError("cannot initialize a finalized modifier") + + initialized = self.on_initialize(**kwargs) + + if not isinstance(initialized, bool): + raise ValueError( + "on_initialize must return a boolean value; " + "True for success, False for not initialized" + ) + + self._initialized = initialized + + def finalize(self, state: State, **kwargs): + if self._finalized: + return + + if not self._initialized: + raise RuntimeError("cannot finalize an uninitialized modifier") + + finalized = self.on_finalize(**kwargs) + + if not isinstance(finalized, bool): + raise ValueError( + "on_finalize must return a boolean value; " + "True for success, False for not finalized" + ) + + self._finalized = finalized + + def update_event(self, state: State, event: Event, **kwargs): + if not self._initialized: + raise RuntimeError("cannot update an uninitialized modifier") + + if self._finalized: + raise RuntimeError("cannot update a finalized modifier") + + # handle starting the modifier if needed + if ( + event.type_ == EventType.BATCH_START + and not self._started + and self.should_start(event) + ): + self.on_start(state, event, **kwargs) + self._started = True + self.on_update(state, event, **kwargs) + + return + + # handle ending the modifier if needed + if ( + event.type_ == EventType.BATCH_END + and not self._ended + and self.should_end(event) + ): + self.on_end(state, event, **kwargs) + self._ended = True + self.on_update(state, event, **kwargs) + + return + + if self._started and not self._ended: + self.on_update(state, event, **kwargs) + + def should_start(self, event: Event): + current = event.current_index() + + return self.start <= current and (self.end is None or current < self.end) + + def should_end(self, event: Event): + current = event.current_index() + + return self.end is not None and current >= self.end + + @abstractmethod + def on_initialize(self, state: State, event: Event, **kwargs) -> bool: + raise NotImplementedError() + + @abstractmethod + def on_finalize(self, state: State, event: Event, **kwargs) -> bool: + raise NotImplementedError() + + @abstractmethod + def on_start(self, state: State, event: Event, **kwargs): + raise NotImplementedError() + + @abstractmethod + def on_update(self, state: State, event: Event, **kwargs): + raise NotImplementedError() + + @abstractmethod + def on_end(self, state: State, event: Event, **kwargs): + raise NotImplementedError() diff --git a/src/sparseml/core/modifier/stage.py b/src/sparseml/core/modifier/stage.py new file mode 100644 index 00000000000..bce7ac3b152 --- /dev/null +++ b/src/sparseml/core/modifier/stage.py @@ -0,0 +1,36 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pydantic import BaseModel, Field +from typing import List + + +from sparseml.core.modifier.base import ModifierInterface +from sparseml.core.modifier.modifier import Modifier + + +class StageModifiers(ModifierInterface, BaseModel): + modifiers: List[Modifier] = Field(default_factory=list) + index: int = None + group: str = None + + def initialize(self, **kwargs): + raise NotImplementedError() + + def finalize(self, **kwargs): + raise NotImplementedError() + + def update_event(self, **kwargs): + raise NotImplementedError() diff --git a/src/sparseml/core/optimizer/__init__.py b/src/sparseml/core/optimizer/__init__.py new file mode 100644 index 00000000000..6ded41b5440 --- /dev/null +++ b/src/sparseml/core/optimizer/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import ModifiableOptimizer diff --git a/src/sparseml/core/optimizer/base.py b/src/sparseml/core/optimizer/base.py new file mode 100644 index 00000000000..1c26096e220 --- /dev/null +++ b/src/sparseml/core/optimizer/base.py @@ -0,0 +1,52 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pydantic import BaseModel +from typing import TypeVar, Generic, Union, List, Any + +from sparseml.core.framework import MultiFrameworkObject + +__all__ = ["ModifiableOptimizer"] + + +OT = TypeVar("OT") +PGT = TypeVar("PGT") + + +class ModifiableOptimizer(Generic[OT, PGT], MultiFrameworkObject, BaseModel): + optimizer: OT = None + + def get_param_groups(self) -> List[PGT]: + raise NotImplementedError() + + def set_param_groups(self, param_groups: List[PGT]): + raise NotImplementedError() + + def get_learning_rate( + self, group_index: Union[int, None] = None + ) -> Union[float, List[float]]: + raise NotImplementedError() + + def set_learning_rate(self, lr: float, group_index: Union[int, None] = None): + raise NotImplementedError() + + def get_attribute( + self, name: str, group_index: Union[int, None] = None + ) -> Union[Any, List[Any]]: + raise NotImplementedError() + + def set_attribute( + self, name: str, value: Any, group_index: Union[int, None] = None + ): + raise NotImplementedError() diff --git a/src/sparseml/core/optimizer/pytorch.py b/src/sparseml/core/optimizer/pytorch.py new file mode 100644 index 00000000000..8498f37c730 --- /dev/null +++ b/src/sparseml/core/optimizer/pytorch.py @@ -0,0 +1,59 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union, List, Any, Dict + +from sparseml.core.optimizer.base import ModifiableOptimizer + +from torch.optim import Optimizer + +__all__ = ["ModifiableOptimizerPyTorch"] + + +class ModifiableOptimizerPyTorch(ModifiableOptimizer[Optimizer, Dict[str, Any]]): + def get_param_groups(self) -> List[Dict[str, Any]]: + return self.optimizer.param_groups + + def set_param_groups(self, param_groups: List[Dict[str, Any]]): + self.optimizer.param_groups = param_groups + + def get_learning_rate( + self, group_idx: Union[int, None] = None + ) -> Union[float, List[float]]: + if group_idx is not None: + return self.optimizer.param_groups[group_idx]["lr"] + return [group["lr"] for group in self.optimizer.param_groups] + + def set_learning_rate(self, lr: float, group_idx: Union[int, None] = None): + if group_idx is not None: + self.optimizer.param_groups[group_idx]["lr"] = lr + else: + for param_group in self.optimizer.param_groups: + param_group["lr"] = lr + + def get_attribute( + self, attr_name: str, group_idx: Union[int, None] = None + ) -> Union[Any, List[Any]]: + if group_idx is not None: + return self.optimizer.param_groups[group_idx].get(attr_name, None) + return [group.get(attr_name, None) for group in self.optimizer.param_groups] + + def set_attribute( + self, attr_name: str, value: Any, group_idx: Union[int, None] = None + ): + if group_idx is not None: + self.optimizer.param_groups[group_idx][attr_name] = value + else: + for param_group in self.optimizer.param_groups: + param_group[attr_name] = value diff --git a/src/sparseml/core/recipe/__init__.py b/src/sparseml/core/recipe/__init__.py new file mode 100644 index 00000000000..9bf403c2829 --- /dev/null +++ b/src/sparseml/core/recipe/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .args import * +from .metadata import * +from .modifier import * +from .recipe import * +from .stage import * diff --git a/src/sparseml/core/recipe/args.py b/src/sparseml/core/recipe/args.py new file mode 100644 index 00000000000..a6ab6f128ad --- /dev/null +++ b/src/sparseml/core/recipe/args.py @@ -0,0 +1,97 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Dict, Any, Optional +import math + + +__all__ = ["RecipeArgs"] + + +class RecipeArgs(Dict[str, Any]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._evaluated: "Optional[RecipeArgs]" = None + + def combine(self, other: "RecipeArgs") -> "RecipeArgs": + combined = RecipeArgs() + combined.update(self) + + if other: + combined.update(other) + + return combined + + def evaluate(self, parent: "RecipeArgs" = None) -> "RecipeArgs": + self._evaluated = RecipeArgs.eval_args(self.combine(parent)) + + return self._evaluated + + def evaluate_ext(self, target: Dict[str, Any]) -> Dict[str, Any]: + args = RecipeArgs.eval_args(self) + resolved = {} + + for key, value in target.items(): + resolved[key] = RecipeArgs.eval_obj(value, args) + + return resolved + + @staticmethod + def eval_str(target: str, args: Dict[str, Any] = None) -> str: + if "eval(" not in target: + return target + + pattern = re.compile(r"eval\(([^()]*)\)") + match = pattern.search(target) + + if not match: + raise ValueError(f"invalid eval string {target}") + + inner_expr = match.group(1) + result = eval(inner_expr, {"math": math}, args if args else {}) + new_target = target.replace(match.group(0), str(result)) + + return RecipeArgs.eval_str(new_target, args) + + @staticmethod + def eval_args(args: Dict[str, Any]) -> "RecipeArgs": + resolved = args.copy() + + while True: + for key, value in resolved.items(): + if isinstance(value, str): + resolved[key] = RecipeArgs.eval_str(value, resolved) + else: + resolved[key] = value + + if args == resolved: + break + else: + args = resolved.copy() + + return RecipeArgs(resolved) + + @staticmethod + def eval_obj(target: Any, args: Dict[str, Any] = None) -> Any: + if isinstance(target, str): + return RecipeArgs.eval_str(target, args) + elif isinstance(target, dict): + return { + key: RecipeArgs.eval_obj(value, args) for key, value in target.items() + } + elif isinstance(target, list): + return [RecipeArgs.eval_obj(item, args) for item in target] + + return target diff --git a/src/sparseml/core/recipe/metadata.py b/src/sparseml/core/recipe/metadata.py new file mode 100644 index 00000000000..73cbce3c547 --- /dev/null +++ b/src/sparseml/core/recipe/metadata.py @@ -0,0 +1,80 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Any, List +from pydantic import BaseModel, Field + + +__all__ = [ + "NMVersions", + "DatasetMetaData", + "ParamMetaData", + "LayerMetaData", + "ModelMetaData", + "RecipeMetaData", +] + + +class NMVersions(BaseModel): + sparsezoo_version: str = None + sparsezoo_hash: str = None + sparseml_version: str = None + sparseml_hash: str = None + sparsify_version: str = None + sparsify_hash: str = None + + +class DatasetMetaData(BaseModel): + name: str = None + version: str = None + hash: str = None + shape: List[int] = Field(default_factory=list) + num_classes: int = None + num_train_samples: int = None + num_val_samples: int = None + num_test_samples: int = None + + +class ParamMetaData(BaseModel): + name: str = None + shape: List[int] = None + weight_hash: str = None + + +class LayerMetaData(BaseModel): + name: str = None + type: str = None + index: int = None + attributes: Dict[str, Any] = None + input_shapes: List[List[int]] = None + output_shapes: List[List[int]] = None + params: Dict[str, ParamMetaData] = None + + +class ModelMetaData(BaseModel): + architecture: str = None + sub_architecture: str = None + input_shapes: List[List[int]] = None + output_shapes: List[List[int]] = None + layers: List[LayerMetaData] = Field(default_factory=list) + + +class RecipeMetaData(BaseModel): + domain: str = None + task: str = None + versions: NMVersions = Field(default_factory=NMVersions) + requirements: List[str] = None + tags: List[str] = None + target_dataset: DatasetMetaData = None + target_model: ModelMetaData = None diff --git a/src/sparseml/core/recipe/modifier.py b/src/sparseml/core/recipe/modifier.py new file mode 100644 index 00000000000..ec2b84ed8ff --- /dev/null +++ b/src/sparseml/core/recipe/modifier.py @@ -0,0 +1,49 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Any +from pydantic import BaseModel, root_validator + +from sparseml.core.recipe.args import RecipeArgs +from sparseml.core.modifier import Modifier, ModifierFactory +from sparseml.core.framework import Framework + + +__all__ = ["RecipeModifier"] + + +class RecipeModifier(BaseModel): + type: str + group: str = None + args: Dict[str, Any] = None + _args: Dict[str, Any] = None + + def evaluate(self, parent_args: RecipeArgs): + if self.args: + self._args = parent_args.evaluate_ext(self.args) + else: + self._args = dict() + + def create_modifier(self, framework: Framework) -> Modifier: + return ModifierFactory.create(self.type, framework, **self._args) + + @root_validator(pre=True) + def extract_modifier_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: + assert len(values) == 1, "multiple key pairs found for modifier" + modifier_type, args = list(values.items())[0] + + return {"type": modifier_type, "args": args} + + def dict(self, *args, **kwargs) -> Dict[str, Any]: + return {self.type: self.args} diff --git a/src/sparseml/core/recipe/recipe.py b/src/sparseml/core/recipe/recipe.py new file mode 100644 index 00000000000..b4b34f57fcc --- /dev/null +++ b/src/sparseml/core/recipe/recipe.py @@ -0,0 +1,91 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Any, List, Tuple +from pydantic import BaseModel, Field, root_validator + +from sparseml.core.recipe.args import RecipeArgs +from sparseml.core.recipe.stage import RecipeStage +from sparseml.core.recipe.metadata import RecipeMetaData +from sparseml.core.framework import Framework +from sparseml.core.modifier import StageModifiers + + +__all__ = ["Recipe"] + + +class Recipe(BaseModel): + version: str = None + args: RecipeArgs = None + stages: List[RecipeStage] = Field(default_factory=list) + metadata: RecipeMetaData = None + _args_evaluated: RecipeArgs = None + + def evaluate(self): + self._args_evaluated = self.args.evaluate() + for stage in self.stages: + stage.evaluate(self.args) + + def create_modifiers(self, framework: Framework) -> List[StageModifiers]: + self.evaluate() + modifiers = [] + + for index, stage in enumerate(self.stages): + stage_modifiers = stage.create_modifiers(framework) + stage_modifiers.index = index + stage_modifiers.group = stage.group + modifiers.append(stage_modifiers) + + return stage_modifiers + + @root_validator(pre=True) + def remap_stages(cls, values: Dict[str, Any]) -> Dict[str, Any]: + modifiers = RecipeStage._combine_modifiers(values) + stages = [{"modifiers": modifiers, "group": "default"}] if modifiers else [] + add_stages, remove_keys = Recipe._combine_stages(values) + stages.extend(add_stages) + + for key in remove_keys: + del values[key] + + values["stages"] = Recipe._combine_stages(values) + + return values + + def dict(self, *args, **kwargs) -> Dict[str, Any]: + dict_ = super().dict(*args, **kwargs) + + for stage in dict_["stages"]: + name = f"{stage['group']}_stage" + del stage["group"] + dict_[name] = stage["args"] + + del dict_["stages"] + + return dict_ + + @staticmethod + def _combine_stages( + values: Dict[str, Any] + ) -> Tuple[List[Dict[str, Any]], List[str]]: + stages = [] + keys = [] + + for key, value in list(values.items()): + if key.endswith("_stage"): + keys.append(key) + value["group"] = key.rsplit("_stage", 1)[0] + stages.append(value) + + return stages, keys diff --git a/src/sparseml/core/recipe/stage.py b/src/sparseml/core/recipe/stage.py new file mode 100644 index 00000000000..5bd9f003add --- /dev/null +++ b/src/sparseml/core/recipe/stage.py @@ -0,0 +1,99 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Any, List +from pydantic import BaseModel, Field, root_validator + +from sparseml.core.recipe.args import RecipeArgs +from sparseml.core.recipe.modifier import RecipeModifier +from sparseml.core.framework import Framework +from sparseml.core.modifier import StageModifiers + + +__all__ = ["RecipeStage"] + + +class RecipeStage(BaseModel): + group: str = None + args: RecipeArgs = None + enabled: bool = True + modifiers: List[RecipeModifier] = Field(default_factory=list) + _args_evaluated: RecipeArgs = None + + def evaluate(self, parent_args: RecipeArgs): + merged_args = self.args.combine(parent_args) + self._args_evaluated = merged_args.evaluate() + for modifier in self.modifiers: + modifier.evaluate(merged_args) + + def create_modifiers( + self, framework: Framework, parent_args: RecipeArgs = None + ) -> StageModifiers: + if parent_args is not None: + self.evaluate(parent_args) + + stage_modifiers = StageModifiers() + for index, modifier in enumerate(self.modifiers): + modifier = modifier.create_modifier(framework) + modifier.group = self.group + modifier.index = index + + return stage_modifiers + + @root_validator(pre=True) + def remap_modifiers(cls, values: Dict[str, Any]) -> Dict[str, Any]: + modifiers = [] + add_modifiers, remove_keys = RecipeStage._combine_modifiers(values) + modifiers.extend(add_modifiers) + for key in remove_keys: + del values[key] + values["modifiers"] = modifiers + + return values + + def dict(self, *args, **kwargs) -> Dict[str, Any]: + dict_ = super().dict(*args, **kwargs) + modifier_groups = dict() + + for modifier in dict_["modifiers"]: + group = modifier["group"] + del modifier["group"] + if group not in modifier_groups: + modifier_groups[group] = [] + modifier_groups[group].append(modifier) + + for group, modifiers in modifier_groups.items(): + name = f"{group}_modifiers" if group != "default" else "modifiers" + dict_[name] = modifiers + + del dict_["modifiers"] + + return dict_ + + @staticmethod + def _combine_modifiers(values: Dict[str, Any]) -> List[Dict[str, Any]]: + modifiers = [] + + for key, value in list(values.items()): + if key.endswith("_modifiers") or key == "modifiers": + group = ( + key.rsplit("_modifiers", 1)[0] + if key.endswith("_modifiers") + else "default" + ) + for modifier in value: + modifier["group"] = group + modifiers.append(modifier) + + return modifiers diff --git a/src/sparseml/core/session.py b/src/sparseml/core/session.py new file mode 100644 index 00000000000..840cc347a17 --- /dev/null +++ b/src/sparseml/core/session.py @@ -0,0 +1,233 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +from contextlib import contextmanager +from typing import Callable, Any, Union, List, Dict, Tuple +from dataclasses import dataclass + +from sparseml.core.state import State +from sparseml.core.event import EventType, Event +from sparseml.core.recipe import Recipe +from sparseml.core.framework import Framework +from sparseml.core.modifier import StageModifiers + + +__all__ = [ + "SparseSession", + "create_session", + "active_session", + "apply_structure", + "init", + "finalize", + "apply", + "callbacks", +] + + +@dataclass +class _CallbackContainer: + id_: int + callback: Callable + deregister: Callable + event_type: EventType + kwargs: dict + + +class SparseSession: + def __init__(self): + self._state: State = State() + self._modifiers: List[StageModifiers] = [] + + @property + def state(self) -> State: + return self._state + + @property + def modifiers(self) -> List[StageModifiers]: + return self._modifiers + + def last_event(self) -> Event: + return self._state.last_event + + def pre_initialize_structure( + self, + model: Any, + recipe: Union[Recipe, List[Recipe]], + framework: Framework = None, + **kwargs, + ) -> Any: + self.state.update_framework(framework) + self.state.update_model(model) + self.state.update_recipe(recipe) + + self._check_compile_recipe() + + if self._modifiers: + for modifier in self._modifiers: + modifier.pre_initialize_structure(state=self.state, **kwargs) + + return self.state.model.model + + def initialize( + self, + framework: Framework = None, + recipe: Union[str, List[str], Recipe, List[Recipe]] = None, + recipe_stage: str = None, + recipe_args: Dict[str, Any] = None, + model: Any = None, + optimizer: Any = None, + attach_optim_callbacks: bool = True, + train_data: Any = None, + val_data: Any = None, + test_data: Any = None, + calib_data: Any = None, + copy_data: bool = True, + start: float = None, + steps_per_epoch: int = None, + batches_per_step: int = None, + **kwargs, + ) -> Tuple[Any, Any]: + self.state.update_framework(framework) + self.state.update_recipe(recipe, recipe_stage, recipe_args) + self.state.update_model(model) + self.state.update_optimizer(optimizer, attach_optim_callbacks) + self.state.update_data(train_data, val_data, test_data, calib_data, copy_data) + self.state.update_start(start, steps_per_epoch, batches_per_step) + + self._check_compile_recipe() + + if self._modifiers: + for modifier in self._modifiers: + modifier.initialize(state=self.state, **kwargs) + + model_return = None + optim_return = None + + if model: + model_return = self.state.model.model + if optimizer: + optim_return = self.state.optimizer.optimizer + + return model_return, optim_return + + def finalize(self, **kwargs): + pass + + def apply(self, **kwargs): + self.initialize(**kwargs) + self.finalize(**kwargs) + + def apply_structure( + self, model: Any, recipe: Union[Recipe, List[Recipe]], **kwargs + ): + pass + + def event(self, event_type: EventType, **kwargs): + pass + + def reset(self): + if self._state: + del self._state + self._state = State() + + if self._recipe_modifier: + del self._recipe_modifier + self._recipe_modifier = None + + def _check_compile_recipe(self): + if not self.state.should_recompile_recipe(): + return + + # clear out the modifiers to reinitialize from newly compiled recipe + if self._modifiers: + for modifier in self._modifiers: + if modifier.initialized: + modifier.finalize(self.state) + del self._modifiers + + self.state.recompile_recipe() + self._modifiers = self.state.compiled_recipe.create_modifiers( + self.state.framework + ) + + +_global_session = SparseSession() +_local_storage = threading.local() +_local_storage.session = _global_session + + +@contextmanager +def create_session() -> SparseSession: + global _local_storage + orig_session = getattr(_local_storage, "session", None) + new_session = SparseSession() + _local_storage.session = new_session + try: + yield new_session + finally: + _local_storage.session = orig_session + + +def active_session() -> SparseSession: + global _local_storage + return getattr(_local_storage, "session", _global_session) + + +def apply_structure(**kwargs): + active_session().apply_structure(**kwargs) + + +def init(**kwargs): + active_session().initialize(**kwargs) + + +def finalize(**kwargs): + active_session().finalize(**kwargs) + + +def apply(**kwargs): + init(**kwargs) + finalize(**kwargs) + + +class LifecycleCallbacks: + @classmethod + def event(cls, event_type: EventType, **kwargs) -> Any: + if event_type in [EventType.PRE_INIT, EventType.INITIALIZE, EventType.FINALIZE]: + raise ValueError( + f"Cannot invoke {event_type} event. " + f"Use the corresponding method instead." + ) + + return active_session().event(event_type, **kwargs) + + @classmethod + def batch_start(cls, **kwargs) -> Any: + return cls.event(EventType.BATCH_START, **kwargs) + + @classmethod + def batch_end(cls, **kwargs) -> Any: + return cls.event(EventType.BATCH_END, **kwargs) + + @classmethod + def optim_stepped(cls, **kwargs) -> Any: + return cls.event(EventType.OPTIM_POST_STEP, **kwargs) + + @classmethod + def loss_calculated(cls, **kwargs) -> Any: + return cls.event(EventType.LOSS_CALCULATED, **kwargs) + + +callbacks = LifecycleCallbacks diff --git a/src/sparseml/core/state.py b/src/sparseml/core/state.py new file mode 100644 index 00000000000..7ffca5c69f6 --- /dev/null +++ b/src/sparseml/core/state.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List +from pydantic import Field + +from sparseml.core.event import Event +from sparseml.core.data import ModifiableData +from sparseml.core.model import ModifiableModel +from sparseml.core.optimizer import ModifiableOptimizer +from sparseml.core.recipe import Recipe +from sparseml.core.framework import Framework + + +__all__ = ["State", "Data", "Hardware"] + + +@dataclass +class Data: + train: ModifiableData = None + val: ModifiableData = None + test: ModifiableData = None + calib: ModifiableData = None + + +@dataclass +class Hardware: + device: str = None + devices: List[str] = None + rank: int = None + world_size: int = None + local_rank: int = None + local_world_size: int = None + distributed: bool = None + distributed_strategy: str = None + + +@dataclass +class State: + compiled_recipe: Recipe = None + recipes: List[Recipe] = Field(default_factory=list) + loggers = Field(default_factory=list) + framework: Framework = None + model: ModifiableModel = None + optimizer: ModifiableOptimizer = None + data = Data() + hardware = Hardware() + last_event: Event = Event() diff --git a/src/sparseml/integrations/__init__.py b/src/sparseml/integrations/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/src/sparseml/integrations/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/modifiers/__init__.py b/src/sparseml/modifiers/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/src/sparseml/modifiers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/tools/__init__.py b/src/sparseml/tools/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/src/sparseml/tools/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/utils/pytorch/__init__.py b/src/sparseml/utils/pytorch/__init__.py new file mode 100644 index 00000000000..880ecd996e3 --- /dev/null +++ b/src/sparseml/utils/pytorch/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .module import * diff --git a/src/sparseml/utils/pytorch/module.py b/src/sparseml/utils/pytorch/module.py new file mode 100644 index 00000000000..53f17655ecd --- /dev/null +++ b/src/sparseml/utils/pytorch/module.py @@ -0,0 +1,238 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utility / helper functions +""" + +import re +from typing import Dict, List, Tuple, Union + +import torch +from packaging import version +from torch.nn import Linear, Module, Parameter +from torch.nn.modules.conv import _ConvNd + + +try: + quant_err = None + from torch.nn.qat import Conv2d as QATConv2d + from torch.nn.qat import Linear as QATLinear + from torch.quantization import QuantWrapper +except Exception as _err: + quant_err = _err + QuantWrapper = None + QATLinear = None + QATConv2d = None + +try: + from torch.nn.qat import Conv3d as QATConv3d +except Exception as _err: + quant_conv3d_err = _err + QATConv3d = None + + +try: + from transformers.modeling_utils import Conv1D as TransformerConv1D +except Exception as _err: + gpt_conv1d_err = _err + TransformerConv1D = None + + +__all__ = [ + "match_targets", + "get_default_params", + "match_layers_params", + "get_layers", + "get_layer", + "set_layer", + "get_params", + "get_param", + "set_param", + "get_terminal_layers", + "get_prunable_layers", + "get_quantizable_layers", +] + + +_PARSED_TORCH_VERSION = version.parse(torch.__version__) + + +ALL_TARGET = "__ALL__" +ALL_PRUNABLE_TARGET = "__ALL_PRUNABLE__" +ALL_QUANTIZABLE_TARGET = "__ALL_QUANTIZABLE__" + + +def match_targets(name: str, targets: Union[str, List[str]]) -> Tuple[bool, int]: + if isinstance(targets, str): + targets = [targets] + + for index, target in enumerate(targets): + if target[:3] == "re:": + pattern = target[3:] + if re.match(pattern, name): + return True, index + elif name == target: + return True, index + + return False, -1 + + +def get_default_params(layers: Dict[str, Module]) -> Dict[str, Parameter]: + params = {} + for name, layer in layers.items(): + for param_name, param in layer.named_parameters(): + if param_name == "weight": + params[name] = param + break + return params + + +def match_layers_params( + targets: Union[str, List[str]], module: Module, params: bool = False +) -> Dict[str, Union[Module, Parameter]]: + if targets == ALL_TARGET: + values = get_terminal_layers(module) + + return values if not params else get_default_params(values) + + if targets == ALL_PRUNABLE_TARGET: + values = get_prunable_layers(module) + + return values if not params else get_default_params(values) + + if targets == ALL_QUANTIZABLE_TARGET: + values = get_quantizable_layers(module) + + return values if not params else get_default_params(values) + + if isinstance(targets, str): + targets = [targets] + + resolved = {} + targets_found = [False for _ in range(len(targets))] + + for name, layer in module.named_modules(): + match, match_index = match_targets(name, targets) + if match and not params: + targets_found[match_index] = True + resolved[name] = layer + + for param_name, param in layer.named_parameters(): + if "." in param_name: # skip parameters of nested layers + continue + + param_match, param_match_index = match_targets( + f"{name}.{param_name}", targets + ) + if param_match: + targets_found[param_match_index] = True + resolved[f"{name}"] = layer if not params else param + + missed = [target for found, target in zip(targets_found, targets) if not found] + if len(missed) > 0: + raise ValueError(f"Could not find targets {missed} in module {module}") + + return resolved + + +def get_layers(targets: Union[str, List[str]], module: Module) -> Dict[str, Module]: + return match_layers_params(targets, module) + + +def get_layer(target: str, module: Module) -> Tuple[str, Module]: + layers = get_layers(target, module) + if len(layers) != 1: + raise ValueError(f"Expected 1 layer for target {target}, found {len(layers)}") + name, layer = next(iter(layers.items())) + + return name, layer + + +def set_layer(target: str, layer: Module, module: Module) -> Module: + parent_target = ".".join(target.split(".")[:-1]) + parent_layer = get_layer(parent_target, module)[1] + old_layer = getattr(parent_layer, target.split(".")[-1]) + setattr(parent_layer, target.split(".")[-1], layer) + + return old_layer + + +def get_params(targets: Union[str, List[str]], module: Module) -> Dict[str, Parameter]: + return match_layers_params(targets, module, params=True) + + +def get_param(target: str, module: Module) -> Tuple[str, Parameter]: + params = get_params(target, module) + if len(params) != 1: + raise ValueError( + f"Expected 1 parameter for target {target}, found {len(params)}" + ) + name, param = next(iter(params.items())) + + return name, param + + +def set_param(target: str, param: Parameter, module: Module) -> Parameter: + layer_name, param_name = target.rsplit(".", 1) + layer = get_layer(layer_name, module)[1] + old_param = getattr(layer, param_name) + setattr(layer, param_name, param) + + return old_param + + +def get_terminal_layers(module: Module) -> Dict[str, Module]: + terminal = {} + + for name, layer in module.named_modules(): + if len(list(layer.named_modules())) > 1: + continue + + terminal[name] = layer + + return terminal + + +def get_prunable_layers(module: Module) -> Dict[str, Module]: + prunable = {} + + for name, layer in module.named_modules(): + if ( + isinstance(layer, Linear) + or isinstance(layer, _ConvNd) + or (QATLinear and isinstance(layer, QATLinear)) + or (QATConv2d and isinstance(layer, QATConv2d)) + or (QATConv3d and isinstance(layer, QATConv3d)) + or (TransformerConv1D and isinstance(layer, TransformerConv1D)) + ): + prunable[name] = layer + + return prunable + + +def get_quantizable_layers(module: Module) -> Dict[str, Module]: + if QATLinear is None: + raise ImportError( + "PyTorch version is not setup for Quantization. " + "Please install a QAT compatible version of PyTorch" + ) + + quantizable = {} + + for name, layer in module.named_modules(): + if isinstance(layer, Linear) or isinstance(layer, _ConvNd): + quantizable[name] = layer + + return quantizable From d953be40bbc33a88f58da5783b1b676c1f02f472 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Tue, 5 Sep 2023 09:23:59 -0400 Subject: [PATCH 02/27] add in further completion state for session and events --- src/sparseml/core/__init__.py | 5 +- src/sparseml/core/data/base.py | 4 +- src/sparseml/core/data/pytorch.py | 6 +- src/sparseml/core/event.py | 345 ++++++++++++++++++++++++- src/sparseml/core/framework.py | 3 +- src/sparseml/core/model/base.py | 4 +- src/sparseml/core/model/pytorch.py | 10 +- src/sparseml/core/modifier/modifier.py | 16 +- src/sparseml/core/modifier/stage.py | 29 ++- src/sparseml/core/optimizer/base.py | 4 +- src/sparseml/core/optimizer/pytorch.py | 5 +- src/sparseml/core/recipe/args.py | 6 +- src/sparseml/core/recipe/metadata.py | 3 +- src/sparseml/core/recipe/modifier.py | 28 +- src/sparseml/core/recipe/recipe.py | 16 +- src/sparseml/core/recipe/stage.py | 11 +- src/sparseml/core/session.py | 304 +++++++++++++++++----- src/sparseml/core/state.py | 72 +++++- 18 files changed, 741 insertions(+), 130 deletions(-) diff --git a/src/sparseml/core/__init__.py b/src/sparseml/core/__init__.py index 9f98c81ee1f..fc3f40c71ff 100644 --- a/src/sparseml/core/__init__.py +++ b/src/sparseml/core/__init__.py @@ -13,12 +13,11 @@ # limitations under the License. from .data import * +from .event import * +from .framework import * from .model import * from .modifier import * from .optimizer import * from .recipe import * - -from .event import * -from .framework import * from .session import * from .state import * diff --git a/src/sparseml/core/data/base.py b/src/sparseml/core/data/base.py index ebbe38a93ed..7665d85330e 100644 --- a/src/sparseml/core/data/base.py +++ b/src/sparseml/core/data/base.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Generic, TypeVar + from pydantic import BaseModel -from typing import TypeVar, Generic, Union, Any, List from sparseml.core.framework import MultiFrameworkObject + __all__ = ["ModifiableData"] DT = TypeVar("DT") # Dataset Type diff --git a/src/sparseml/core/data/pytorch.py b/src/sparseml/core/data/pytorch.py index afbf9893a0a..65e28718dd3 100644 --- a/src/sparseml/core/data/pytorch.py +++ b/src/sparseml/core/data/pytorch.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Mapping, Sequence + import torch +from torch.utils.data import DataLoader from sparseml.core.data.base import ModifiableData -from torch.utils.data import DataLoader -from typing import Mapping, Sequence + __all__ = ["ModifiableDataPyTorch", "DynamicBatchSizeDataLoader"] diff --git a/src/sparseml/core/event.py b/src/sparseml/core/event.py index 886b13d336e..d2e9ce1d3ca 100644 --- a/src/sparseml/core/event.py +++ b/src/sparseml/core/event.py @@ -12,11 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from enum import Enum +from abc import ABC, abstractmethod from dataclasses import dataclass +from enum import Enum +from typing import List -__all__ = ["EventType", "Event"] +__all__ = [ + "EventType", + "Event", + "EventLifecycle", + "WrappedOptimEventLifecycle", + "CallbacksEventLifecycle", +] class EventType(Enum): @@ -25,35 +33,344 @@ class EventType(Enum): INITIALIZE = "initialize" FINALIZE = "finalize" - # step lifecycle + # batch lifecycle BATCH_START = "batch_start" LOSS_CALCULATED = "loss_calculated" + BATCH_END = "batch_end" + + # step lifecycle OPTIM_PRE_STEP = "optim_pre_step" OPTIM_POST_STEP = "optim_post_step" - BATCH_END = "batch_end" + + def order(self) -> int: + if self == EventType.PRE_INIT: + return 0 + elif self == EventType.INITIALIZE: + return 10 + elif self == EventType.FINALIZE: + return 20 + elif self == EventType.BATCH_START: + return 100 + elif self == EventType.LOSS_CALCULATED: + return 110 + elif self == EventType.OPTIM_PRE_STEP: + return 120 + elif self == EventType.OPTIM_POST_STEP: + return 130 + elif self == EventType.BATCH_END: + return 140 + else: + raise ValueError(f"invalid event type {self}") @dataclass class Event: - type_: EventType = EventType.PRE_INIT + type_: EventType = None epoch_based: bool = None steps_per_epoch: int = None batches_per_step: int = None + invocations_per_step: int = None + + global_step: int = 0 + global_batch: int = 0 + + @property + def epoch(self) -> int: + return self.global_step // self.steps_per_epoch - global_step: int = None - global_batch: int = None - epoch: int = None - epoch_step: int = None - epoch_batch: int = None + @property + def epoch_full(self) -> float: + return self.global_step / float(self.steps_per_epoch) + + @property + def epoch_step(self) -> int: + return self.global_step % self.steps_per_epoch + + @property + def epoch_batch(self) -> int: + batches_per_epoch = ( + self.steps_per_epoch * self.batches_per_step + if self.batches_per_step + else self.steps_per_epoch + ) + + return self.global_batch % batches_per_epoch def current_index(self) -> float: if not self.epoch_based: return self.global_step - epoch = self.epoch + (self.epoch_step / self.steps_per_epoch) - - if epoch - self.epoch > 1.0: + if self.epoch_full - self.epoch > 1.0: raise ValueError("too many steps per epoch for epoch based event") - return epoch + return self.epoch_full + + def new_instance(self, **kwargs) -> "Event": + instance = Event( + type_=self.type_, + epoch_based=self.epoch_based, + steps_per_epoch=self.steps_per_epoch, + batches_per_step=self.batches_per_step, + global_step=self.global_step, + global_batch=self.global_batch, + ) + for key, value in kwargs.items(): + setattr(instance, key, value) + + return instance + + +class EventLifecycle(ABC, Event): + type_first: EventType = None + batches_step_counter: int = 0 + steps_epoch_counter: int = 0 + step_count: int = 0 + batch_count: int = 0 + + def __init__(self, type_first: EventType): + self.type_first = type_first + + def events_from_type(self, type_: EventType) -> List[Event]: + if type_ == EventType.BATCH_START: + return self.batch_start_events() + + if type_ == EventType.LOSS_CALCULATED: + return self.loss_calculated_events() + + if type_ == EventType.OPTIM_PRE_STEP: + return self.optim_pre_step_events() + + if type_ == EventType.OPTIM_POST_STEP: + return self.optim_post_step_events() + + if type_ == EventType.BATCH_END: + return self.batch_end_events() + + raise ValueError(f"invalid event type {type_}") + + @abstractmethod + def batch_start_events(self) -> List[Event]: + raise NotImplementedError() + + @abstractmethod + def loss_calculated_events(self) -> List[Event]: + raise NotImplementedError() + + @abstractmethod + def optim_pre_step_events(self) -> List[Event]: + raise NotImplementedError() + + @abstractmethod + def optim_post_step_events(self) -> List[Event]: + raise NotImplementedError() + + @abstractmethod + def batch_end_events(self) -> List[Event]: + raise NotImplementedError() + + def check_step_batches_count(self, increment: bool) -> bool: + if self.batches_per_step is None or self.batches_per_step < 2: + return True + + compare_batch = self.batches_step_counter + 1 + at_step = compare_batch % self.batches_per_step == 0 + + if increment: + self.batches_step_counter = compare_batch if not at_step else 0 + + return at_step + + def check_step_invocations_count(self, increment: bool) -> bool: + if self.invocations_per_step is None or self.invocations_per_step < 2: + return True + + compare_step = self.step_count + 1 + at_step = compare_step % self.invocations_per_step == 0 + + if increment: + self.step_count = compare_step if not at_step else 0 + + return at_step + + def reset_step_count(self): + self.step_count = 0 + + +class WrappedOptimEventLifecycle(EventLifecycle): + """ + Optimizer is wrapped and no batch or optim callbacks + - batch_start: must not be invoked, auto triggered + from loss calculated if that is called, otherwise from pre_step + - loss_calculated: must be called before batch_end and optim_pre_step + - batch_end: must not be invoked, auto triggered from optim_post_step + - optim_pre_step: must be called before optim_post_step + - optim_post_step: must be called only once after optim_pre_step + """ + + def batch_start_events(self) -> List[Event]: + raise ValueError("batch start should not be invoked when only wrapped optim") + + def loss_calculated_events(self) -> List[Event]: + if self.type_first != EventType.LOSS_CALCULATED: + raise ValueError("loss calculated must be called first for wrapped optim") + + if ( + self.type_ != EventType.OPTIM_POST_STEP + and self.type_ != EventType.LOSS_CALCULATED + ): + raise ValueError( + "loss calculated must be called after batch end or optim post step" + ) + + self.type_ = EventType.LOSS_CALCULATED + self.global_batch += 1 + + if not self.check_step_batches_count(increment=True): + # step won't be called, so batch end must be called + return [ + self.new_instance(type_=EventType.BATCH_START), + self.new_instance(type_=EventType.LOSS_CALCULATED), + self.new_instance(type_=EventType.BATCH_END), + ] + else: + # batch end handled by optim step + return [ + self.new_instance(type_=EventType.BATCH_START), + self.new_instance(type_=EventType.LOSS_CALCULATED), + ] + + def optim_pre_step_events(self) -> List[Event]: + if ( + self.type_first == EventType.OPTIM_PRE_STEP + and self.type_ is not None + and self.type_ != EventType.OPTIM_POST_STEP + ): + raise ValueError("optim pre step must be called after optim post step") + + if ( + self.type_first == EventType.LOSS_CALCULATED + and self.type_ != EventType.LOSS_CALCULATED + ): + raise ValueError("optim pre step must be called after loss calculated") + + self.type_ = EventType.OPTIM_PRE_STEP + + if self.type_first == EventType.OPTIM_PRE_STEP: + self.global_batch += ( + 1 + if self.batches_per_step is None or self.batches_per_step < 2 + else self.batches_per_step + ) + batch_start_events = [self.new_instance(type_=EventType.BATCH_START)] + else: + batch_start_events = [] + + if not self.check_step_invocations_count(increment=False): + return batch_start_events + + return batch_start_events + [ + self.new_instance(type_=EventType.OPTIM_PRE_STEP), + ] + + def optim_post_step_events(self) -> List[Event]: + if self.type_ != EventType.OPTIM_PRE_STEP: + raise ValueError("optim post step must be called after optim pre step") + + self.type_ = EventType.OPTIM_POST_STEP + + if not self.check_step_invocations_count(increment=True): + return [ + self.new_instance(type_=EventType.BATCH_END), + ] + + self.global_step += 1 + + return [ + self.new_instance(type_=EventType.OPTIM_POST_STEP), + self.new_instance(type_=EventType.BATCH_END), + ] + + def batch_end_events(self) -> List[Event]: + raise ValueError("batch end should not be invoked when only wrapped optim") + + +class CallbacksEventLifecycle(EventLifecycle): + """ + Optimizer is not wrapped, callbacks must be used + - batch_start: must be called first + - loss_calculated: must be called before batch_end and optim_post_step + - batch_end: must be called before next batch start + - optim_pre_step: must be invoked before optim_post_step + - optim_post_step: must be called only once after optim_pre_step + """ + + def batch_start_events(self) -> List[Event]: + if self.type_first != EventType.BATCH_START: + raise ValueError("batch start must be called first for callbacks") + + if self.type_ is not None and self.type_ != EventType.BATCH_END: + raise ValueError("batch start must be called after batch end") + + self.type_ = EventType.BATCH_START + self.global_batch += 1 + + return [self.new_instance(type_=EventType.BATCH_START)] + + def loss_calculated_events(self) -> List[Event]: + if self.type_ != EventType.BATCH_START: + raise ValueError("loss calculated must be called after batch start") + + self.type_ = EventType.LOSS_CALCULATED + + return [self.new_instance(type_=EventType.LOSS_CALCULATED)] + + def optim_pre_step_events(self) -> List[Event]: + if ( + self.type_ != EventType.BATCH_START + and self.type_ != EventType.LOSS_CALCULATED + ): + raise ValueError( + "optim pre step must be called after batch start or loss calculated" + ) + + self.type_ = EventType.OPTIM_PRE_STEP + + if not self.check_step_invocations_count(increment=False): + return [] + + return [ + self.new_instance(type_=EventType.OPTIM_PRE_STEP), + ] + + def optim_post_step_events(self) -> List[Event]: + if self.type_ != EventType.OPTIM_PRE_STEP: + raise ValueError("optim post step must be called after optim pre step") + + self.type_ = EventType.OPTIM_POST_STEP + + if not self.check_step_invocations_count(increment=True): + return [] + + self.global_step += 1 + + return [ + self.new_instance(type_=EventType.OPTIM_POST_STEP), + ] + + def batch_end_events(self) -> List[Event]: + if ( + self.type_ != EventType.OPTIM_POST_STEP + and self.type_ != EventType.LOSS_CALCULATED + and self.type_ != EventType.BATCH_START + ): + raise ValueError( + "batch end must be called after optim post step or " + "loss calculated or batch start" + ) + + self.type_ = EventType.BATCH_END + + return [ + self.new_instance(type_=EventType.BATCH_END), + ] diff --git a/src/sparseml/core/framework.py b/src/sparseml/core/framework.py index bb1776bfdbe..d4f6ddcaebd 100644 --- a/src/sparseml/core/framework.py +++ b/src/sparseml/core/framework.py @@ -13,8 +13,9 @@ # limitations under the License. -from enum import Enum import importlib +from enum import Enum + __all__ = ["Framework", "MultiFrameworkObject"] diff --git a/src/sparseml/core/model/base.py b/src/sparseml/core/model/base.py index 3131aa529a9..db455b18412 100644 --- a/src/sparseml/core/model/base.py +++ b/src/sparseml/core/model/base.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Generic, List, TypeVar, Union + from pydantic import BaseModel -from typing import TypeVar, Generic, Union, List, Dict from sparseml.core.framework import MultiFrameworkObject + __all__ = ["ModifiableModel"] diff --git a/src/sparseml/core/model/pytorch.py b/src/sparseml/core/model/pytorch.py index 225b30fa711..fde46468f8e 100644 --- a/src/sparseml/core/model/pytorch.py +++ b/src/sparseml/core/model/pytorch.py @@ -12,17 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, List, Dict, Tuple -from torch.nn import Module, Parameter +from typing import Dict, List, Tuple, Union +from torch.nn import Module, Parameter from sparseml.core.model.base import ModifiableModel from sparseml.utils.pytorch import ( - get_layers, get_layer, - set_layer, - get_params, + get_layers, get_param, + get_params, + set_layer, set_param, ) diff --git a/src/sparseml/core/modifier/modifier.py b/src/sparseml/core/modifier/modifier.py index 0f635bfa0de..343b502190f 100644 --- a/src/sparseml/core/modifier/modifier.py +++ b/src/sparseml/core/modifier/modifier.py @@ -13,13 +13,14 @@ # limitations under the License. -from pydantic import BaseModel +from abc import abstractmethod from typing import Optional -from abc import abstractmethod +from pydantic import BaseModel + +from sparseml.core.event import Event, EventType from sparseml.core.framework import MultiFrameworkObject from sparseml.core.modifier.base import ModifierInterface -from sparseml.core.event import Event, EventType from sparseml.core.state import State @@ -33,11 +34,16 @@ class Modifier(ModifierInterface, MultiFrameworkObject, BaseModel): end: Optional[float] = None update: Optional[float] = None + _initialized_structure: bool = False _initialized: bool = False _finalized: bool = False _started: bool = False _ended: bool = False + def pre_initialize_structure(self, state: State, **kwargs): + self.on_initialize_structure(state, **kwargs) + self._initialized_structure = True + def initialize(self, state: State, **kwargs): if self._initialized: return @@ -116,6 +122,10 @@ def should_end(self, event: Event): return self.end is not None and current >= self.end + @abstractmethod + def on_initialize_structure(self, state: State, **kwargs): + raise NotImplementedError() + @abstractmethod def on_initialize(self, state: State, event: Event, **kwargs) -> bool: raise NotImplementedError() diff --git a/src/sparseml/core/modifier/stage.py b/src/sparseml/core/modifier/stage.py index bce7ac3b152..11fcc6550bb 100644 --- a/src/sparseml/core/modifier/stage.py +++ b/src/sparseml/core/modifier/stage.py @@ -13,12 +13,13 @@ # limitations under the License. -from pydantic import BaseModel, Field from typing import List +from pydantic import BaseModel, Field from sparseml.core.modifier.base import ModifierInterface from sparseml.core.modifier.modifier import Modifier +from sparseml.core.state import Event, State class StageModifiers(ModifierInterface, BaseModel): @@ -26,11 +27,25 @@ class StageModifiers(ModifierInterface, BaseModel): index: int = None group: str = None - def initialize(self, **kwargs): - raise NotImplementedError() + _initialized_structure: bool = False + _initialized: bool = False + _finalized: bool = False + + def pre_initialize_structure(self, state: State, **kwargs): + for modifier in self.modifiers: + modifier.pre_initialize_structure(state, **kwargs) + self._initialized_structure = True + + def initialize(self, state: State, **kwargs): + for modifier in self.modifiers: + modifier.initialize(state, **kwargs) + self._initialized = True - def finalize(self, **kwargs): - raise NotImplementedError() + def finalize(self, state: State, **kwargs): + for modifier in self.modifiers: + modifier.finalize(state, **kwargs) + self._finalized = True - def update_event(self, **kwargs): - raise NotImplementedError() + def update_event(self, state: State, event: Event, **kwargs): + for modifier in self.modifiers: + modifier.update_event(state, event, **kwargs) diff --git a/src/sparseml/core/optimizer/base.py b/src/sparseml/core/optimizer/base.py index 1c26096e220..bb95135c1f9 100644 --- a/src/sparseml/core/optimizer/base.py +++ b/src/sparseml/core/optimizer/base.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Generic, List, TypeVar, Union + from pydantic import BaseModel -from typing import TypeVar, Generic, Union, List, Any from sparseml.core.framework import MultiFrameworkObject + __all__ = ["ModifiableOptimizer"] diff --git a/src/sparseml/core/optimizer/pytorch.py b/src/sparseml/core/optimizer/pytorch.py index 8498f37c730..502c5f8766c 100644 --- a/src/sparseml/core/optimizer/pytorch.py +++ b/src/sparseml/core/optimizer/pytorch.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, List, Any, Dict +from typing import Any, Dict, List, Union + +from torch.optim import Optimizer from sparseml.core.optimizer.base import ModifiableOptimizer -from torch.optim import Optimizer __all__ = ["ModifiableOptimizerPyTorch"] diff --git a/src/sparseml/core/recipe/args.py b/src/sparseml/core/recipe/args.py index a6ab6f128ad..1fc7edfdbca 100644 --- a/src/sparseml/core/recipe/args.py +++ b/src/sparseml/core/recipe/args.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re -from typing import Dict, Any, Optional import math +import re +from typing import Any, Dict, Optional, Union __all__ = ["RecipeArgs"] @@ -25,7 +25,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._evaluated: "Optional[RecipeArgs]" = None - def combine(self, other: "RecipeArgs") -> "RecipeArgs": + def combine(self, other: Union["RecipeArgs", Dict[str, Any]]) -> "RecipeArgs": combined = RecipeArgs() combined.update(self) diff --git a/src/sparseml/core/recipe/metadata.py b/src/sparseml/core/recipe/metadata.py index 73cbce3c547..65fc907e967 100644 --- a/src/sparseml/core/recipe/metadata.py +++ b/src/sparseml/core/recipe/metadata.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Any, List +from typing import Any, Dict, List + from pydantic import BaseModel, Field diff --git a/src/sparseml/core/recipe/modifier.py b/src/sparseml/core/recipe/modifier.py index ec2b84ed8ff..4d16f1b3100 100644 --- a/src/sparseml/core/recipe/modifier.py +++ b/src/sparseml/core/recipe/modifier.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Any +from typing import Any, Dict + from pydantic import BaseModel, root_validator -from sparseml.core.recipe.args import RecipeArgs -from sparseml.core.modifier import Modifier, ModifierFactory from sparseml.core.framework import Framework +from sparseml.core.modifier import Modifier, ModifierFactory +from sparseml.core.recipe.args import RecipeArgs __all__ = ["RecipeModifier"] @@ -27,16 +28,23 @@ class RecipeModifier(BaseModel): type: str group: str = None args: Dict[str, Any] = None - _args: Dict[str, Any] = None + _args_evaluated: Dict[str, Any] = None + + def evaluate(self, parent_args: RecipeArgs = None, shift: int = None): + if not self.args: + raise ValueError("args must be set before evaluating") + + comb_args = parent_args or RecipeArgs() + self._args_evaluated = comb_args.evaluate_ext(self.args) + + if shift is not None and "start" in self._args_evaluated: + self._args_evaluated["start"] += shift - def evaluate(self, parent_args: RecipeArgs): - if self.args: - self._args = parent_args.evaluate_ext(self.args) - else: - self._args = dict() + if shift is not None and "end" in self._args_evaluated: + self._args_evaluated["end"] += shift def create_modifier(self, framework: Framework) -> Modifier: - return ModifierFactory.create(self.type, framework, **self._args) + return ModifierFactory.create(self.type, framework, **self._args_evaluated) @root_validator(pre=True) def extract_modifier_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: diff --git a/src/sparseml/core/recipe/recipe.py b/src/sparseml/core/recipe/recipe.py index b4b34f57fcc..41014995e75 100644 --- a/src/sparseml/core/recipe/recipe.py +++ b/src/sparseml/core/recipe/recipe.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Any, List, Tuple +from typing import Any, Dict, List, Tuple + from pydantic import BaseModel, Field, root_validator -from sparseml.core.recipe.args import RecipeArgs -from sparseml.core.recipe.stage import RecipeStage -from sparseml.core.recipe.metadata import RecipeMetaData from sparseml.core.framework import Framework from sparseml.core.modifier import StageModifiers +from sparseml.core.recipe.args import RecipeArgs +from sparseml.core.recipe.metadata import RecipeMetaData +from sparseml.core.recipe.stage import RecipeStage __all__ = ["Recipe"] @@ -32,10 +33,11 @@ class Recipe(BaseModel): metadata: RecipeMetaData = None _args_evaluated: RecipeArgs = None - def evaluate(self): - self._args_evaluated = self.args.evaluate() + def evaluate(self, args: Dict[str, Any] = None, shift: int = None): + args = self.args.combine(args) + self._args_evaluated = args.evaluate() for stage in self.stages: - stage.evaluate(self.args) + stage.evaluate(self._args_evaluated, shift) def create_modifiers(self, framework: Framework) -> List[StageModifiers]: self.evaluate() diff --git a/src/sparseml/core/recipe/stage.py b/src/sparseml/core/recipe/stage.py index 5bd9f003add..fa918132827 100644 --- a/src/sparseml/core/recipe/stage.py +++ b/src/sparseml/core/recipe/stage.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Any, List +from typing import Any, Dict, List + from pydantic import BaseModel, Field, root_validator -from sparseml.core.recipe.args import RecipeArgs -from sparseml.core.recipe.modifier import RecipeModifier from sparseml.core.framework import Framework from sparseml.core.modifier import StageModifiers +from sparseml.core.recipe.args import RecipeArgs +from sparseml.core.recipe.modifier import RecipeModifier __all__ = ["RecipeStage"] @@ -31,11 +32,11 @@ class RecipeStage(BaseModel): modifiers: List[RecipeModifier] = Field(default_factory=list) _args_evaluated: RecipeArgs = None - def evaluate(self, parent_args: RecipeArgs): + def evaluate(self, parent_args: RecipeArgs = None, shift: int = None): merged_args = self.args.combine(parent_args) self._args_evaluated = merged_args.evaluate() for modifier in self.modifiers: - modifier.evaluate(merged_args) + modifier.evaluate(self._args_evaluated, shift) def create_modifiers( self, framework: Framework, parent_args: RecipeArgs = None diff --git a/src/sparseml/core/session.py b/src/sparseml/core/session.py index 840cc347a17..d157cb1e714 100644 --- a/src/sparseml/core/session.py +++ b/src/sparseml/core/session.py @@ -14,22 +14,26 @@ import threading from contextlib import contextmanager -from typing import Callable, Any, Union, List, Dict, Tuple from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Union -from sparseml.core.state import State -from sparseml.core.event import EventType, Event -from sparseml.core.recipe import Recipe +from sparseml.core.event import ( + CallbacksEventLifecycle, + EventType, + WrappedOptimEventLifecycle, +) from sparseml.core.framework import Framework from sparseml.core.modifier import StageModifiers +from sparseml.core.recipe import Recipe +from sparseml.core.state import ModifiedState, State __all__ = [ "SparseSession", "create_session", "active_session", - "apply_structure", - "init", + "pre_initialize_structure", + "initialize", "finalize", "apply", "callbacks", @@ -49,6 +53,10 @@ class SparseSession: def __init__(self): self._state: State = State() self._modifiers: List[StageModifiers] = [] + self._initialized_structure = False + self._initialized = False + self._finalized = False + self._event_called = False @property def state(self) -> State: @@ -58,8 +66,21 @@ def state(self) -> State: def modifiers(self) -> List[StageModifiers]: return self._modifiers - def last_event(self) -> Event: - return self._state.last_event + @property + def initialized_structure(self) -> bool: + return self._initialized_structure + + @property + def initialized(self) -> bool: + return self._initialized + + @property + def finalized(self) -> bool: + return self._finalized + + @property + def event_called(self) -> bool: + return self._event_called def pre_initialize_structure( self, @@ -67,18 +88,27 @@ def pre_initialize_structure( recipe: Union[Recipe, List[Recipe]], framework: Framework = None, **kwargs, - ) -> Any: + ) -> ModifiedState: self.state.update_framework(framework) self.state.update_model(model) self.state.update_recipe(recipe) self._check_compile_recipe() + modifier_data = [] - if self._modifiers: - for modifier in self._modifiers: - modifier.pre_initialize_structure(state=self.state, **kwargs) + for modifier in self._modifiers: + data = modifier.pre_initialize_structure(state=self.state, **kwargs) + if data: + modifier_data.append(data) + + self._initialized_structure = True - return self.state.model.model + return ModifiedState( + model=self.state.model.model, + optimizer=None, + loss=None, + modifier_data=modifier_data, + ) def initialize( self, @@ -98,7 +128,13 @@ def initialize( steps_per_epoch: int = None, batches_per_step: int = None, **kwargs, - ) -> Tuple[Any, Any]: + ) -> ModifiedState: + if self.event_called: + raise ValueError("Cannot initialize after invoking an event") + + if self.finalized: + raise ValueError("Cannot initialize after finalizing") + self.state.update_framework(framework) self.state.update_recipe(recipe, recipe_stage, recipe_args) self.state.update_model(model) @@ -107,44 +143,129 @@ def initialize( self.state.update_start(start, steps_per_epoch, batches_per_step) self._check_compile_recipe() + modifier_data = [] if self._modifiers: for modifier in self._modifiers: - modifier.initialize(state=self.state, **kwargs) + data = modifier.initialize(state=self.state, **kwargs) + if data: + modifier_data.append(data) + + self._initialized = True + + return ModifiedState( + model=self.state.model.model, + optimizer=self.state.optimizer.optimizer, + loss=self.state.loss.loss, + modifier_data=modifier_data, + ) - model_return = None - optim_return = None + def finalize(self, **kwargs) -> ModifiedState: + if not self.initialized: + raise ValueError("Cannot finalize before initializing") - if model: - model_return = self.state.model.model - if optimizer: - optim_return = self.state.optimizer.optimizer + if self.finalized: + raise ValueError("Cannot finalize more than once") - return model_return, optim_return + modifier_data = [] - def finalize(self, **kwargs): - pass + for modifier in self._modifiers: + data = modifier.finalize(state=self.state, **kwargs) + if data: + modifier_data.append(data) + + self._finalized = True + + return ModifiedState( + model=self.state.model.model, + optimizer=self.state.optimizer.optimizer, + loss=self.state.loss.loss, + modifier_data=modifier_data, + ) def apply(self, **kwargs): self.initialize(**kwargs) - self.finalize(**kwargs) - def apply_structure( - self, model: Any, recipe: Union[Recipe, List[Recipe]], **kwargs - ): - pass + return self.finalize(**kwargs) + + def event( + self, event_type: EventType, batch_data: Any = None, loss: Any = None, **kwargs + ) -> ModifiedState: + if not self.initialized: + raise ValueError("Cannot invoke event before initializing") + + if self.finalized: + raise ValueError("Cannot invoke event after finalizing") + + if event_type in [EventType.PRE_INIT, EventType.INITIALIZE, EventType.FINALIZE]: + raise ValueError( + f"Cannot invoke {event_type} event. " + f"Use the corresponding method instead." + ) - def event(self, event_type: EventType, **kwargs): - pass + if event_type == EventType.LOSS_CALCULATED and loss is None: + raise ValueError("Loss must be provided for loss calculated event") + + if self.state.event_lifecycle is None: + if event_type == EventType.BATCH_START: + # utilizing callbacks pathway, ensure optim is not wrapped + if self.state.optim_wrapped: + raise ValueError( + "Cannot use batch callbacks with wrapped optimizer, " + "set attach_optim_callbacks to False when initializing " + ) + self.state.event_lifecycle = CallbacksEventLifecycle(event_type) + elif self.state.optim_wrapped: + # utilizing wrapped optimizer for callbacks + self.state.event_lifecycle = WrappedOptimEventLifecycle(event_type) + else: + raise ValueError( + "First event must be batch_start or " + "attach_optim_callbacks must be True" + ) + + event = None + modifier_data = [] + for event in self.state.event_lifecycle.events_from_type(event_type): + for modifier in self._modifiers: + data = modifier.update_event( + state=self.state, + event=event, + batch_data=batch_data, + loss=loss, + **kwargs, + ) + if data: + modifier_data.append(data) + + assert event is not None, f"No events generated for event type {event_type}" + self.state.last_event = event + self._event_called = True + + return ModifiedState( + model=self.state.model.model, + optimizer=self.state.optimizer.optimizer, + loss=self.state.loss.loss, + modifier_data=modifier_data, + ) def reset(self): if self._state: del self._state self._state = State() - if self._recipe_modifier: - del self._recipe_modifier - self._recipe_modifier = None + if self._modifiers: + if self.initialized and not self.finalized: + for modifier in self._modifiers: + modifier.finalize(self.state) + + del self._modifiers + + self._modifiers = [] + self._initialized_structure = False + self._initialized = False + self._finalized = False + self._event_called = False def _check_compile_recipe(self): if not self.state.should_recompile_recipe(): @@ -153,7 +274,7 @@ def _check_compile_recipe(self): # clear out the modifiers to reinitialize from newly compiled recipe if self._modifiers: for modifier in self._modifiers: - if modifier.initialized: + if modifier._initialized: modifier.finalize(self.state) del self._modifiers @@ -185,26 +306,89 @@ def active_session() -> SparseSession: return getattr(_local_storage, "session", _global_session) -def apply_structure(**kwargs): - active_session().apply_structure(**kwargs) - - -def init(**kwargs): - active_session().initialize(**kwargs) - - -def finalize(**kwargs): - active_session().finalize(**kwargs) - - -def apply(**kwargs): - init(**kwargs) - finalize(**kwargs) +def pre_initialize_structure(**kwargs): + active_session().pre_initialize_structure(**kwargs) + + +def initialize( + framework: Framework = None, + recipe: Union[str, List[str], Recipe, List[Recipe]] = None, + recipe_stage: str = None, + recipe_args: Dict[str, Any] = None, + model: Any = None, + optimizer: Any = None, + attach_optim_callbacks: bool = True, + train_data: Any = None, + val_data: Any = None, + test_data: Any = None, + calib_data: Any = None, + copy_data: bool = True, + start: float = None, + steps_per_epoch: int = None, + batches_per_step: int = None, + **kwargs, +) -> ModifiedState: + return active_session().initialize( + framework=framework, + recipe=recipe, + recipe_stage=recipe_stage, + recipe_args=recipe_args, + model=model, + optimizer=optimizer, + attach_optim_callbacks=attach_optim_callbacks, + train_data=train_data, + val_data=val_data, + test_data=test_data, + calib_data=calib_data, + copy_data=copy_data, + start=start, + steps_per_epoch=steps_per_epoch, + batches_per_step=batches_per_step, + **kwargs, + ) + + +def finalize(**kwargs) -> ModifiedState: + return active_session().finalize(**kwargs) + + +def apply( + framework: Framework = None, + recipe: Union[str, List[str], Recipe, List[Recipe]] = None, + recipe_stage: str = None, + recipe_args: Dict[str, Any] = None, + model: Any = None, + train_data: Any = None, + val_data: Any = None, + test_data: Any = None, + calib_data: Any = None, + copy_data: bool = True, + start: float = None, + steps_per_epoch: int = None, + batches_per_step: int = None, + **kwargs, +) -> ModifiedState: + return active_session().apply( + framework=framework, + recipe=recipe, + recipe_stage=recipe_stage, + recipe_args=recipe_args, + model=model, + train_data=train_data, + val_data=val_data, + test_data=test_data, + calib_data=calib_data, + copy_data=copy_data, + start=start, + steps_per_epoch=steps_per_epoch, + batches_per_step=batches_per_step, + **kwargs, + ) class LifecycleCallbacks: @classmethod - def event(cls, event_type: EventType, **kwargs) -> Any: + def event(cls, event_type: EventType, **kwargs) -> ModifiedState: if event_type in [EventType.PRE_INIT, EventType.INITIALIZE, EventType.FINALIZE]: raise ValueError( f"Cannot invoke {event_type} event. " @@ -214,20 +398,24 @@ def event(cls, event_type: EventType, **kwargs) -> Any: return active_session().event(event_type, **kwargs) @classmethod - def batch_start(cls, **kwargs) -> Any: - return cls.event(EventType.BATCH_START, **kwargs) + def batch_start(cls, batch_data: Any = None, **kwargs) -> ModifiedState: + return cls.event(EventType.BATCH_START, batch_data=batch_data, **kwargs) @classmethod - def batch_end(cls, **kwargs) -> Any: - return cls.event(EventType.BATCH_END, **kwargs) + def loss_calculated(cls, loss: Any = None, **kwargs) -> ModifiedState: + return cls.event(EventType.LOSS_CALCULATED, loss=loss, **kwargs) @classmethod - def optim_stepped(cls, **kwargs) -> Any: + def optim_pre_step(cls, **kwargs) -> ModifiedState: + return cls.event(EventType.OPTIM_PRE_STEP, **kwargs) + + @classmethod + def optim_stepped(cls, **kwargs) -> ModifiedState: return cls.event(EventType.OPTIM_POST_STEP, **kwargs) @classmethod - def loss_calculated(cls, **kwargs) -> Any: - return cls.event(EventType.LOSS_CALCULATED, **kwargs) + def batch_end(cls, **kwargs) -> ModifiedState: + return cls.event(EventType.BATCH_END, **kwargs) callbacks = LifecycleCallbacks diff --git a/src/sparseml/core/state.py b/src/sparseml/core/state.py index 7ffca5c69f6..b0929d29ef2 100644 --- a/src/sparseml/core/state.py +++ b/src/sparseml/core/state.py @@ -13,18 +13,19 @@ # limitations under the License. from dataclasses import dataclass -from typing import List +from typing import Any, Dict, List, Tuple, Union + from pydantic import Field -from sparseml.core.event import Event from sparseml.core.data import ModifiableData +from sparseml.core.event import Event, EventLifecycle +from sparseml.core.framework import Framework from sparseml.core.model import ModifiableModel from sparseml.core.optimizer import ModifiableOptimizer from sparseml.core.recipe import Recipe -from sparseml.core.framework import Framework -__all__ = ["State", "Data", "Hardware"] +__all__ = ["State", "Data", "Hardware", "ModifiedState"] @dataclass @@ -50,11 +51,70 @@ class Hardware: @dataclass class State: compiled_recipe: Recipe = None - recipes: List[Recipe] = Field(default_factory=list) + recipes: List[Tuple[Recipe, str, Dict[str, Any]]] = Field(default_factory=list) loggers = Field(default_factory=list) framework: Framework = None model: ModifiableModel = None optimizer: ModifiableOptimizer = None + optim_wrapped: bool = None + loss = None + batch_data = None data = Data() hardware = Hardware() - last_event: Event = Event() + event_lifecycle: EventLifecycle = None + last_event: Event = None + + def update_framework(self, framework: Framework): + self.framework = framework if framework else Framework.pytorch + + def update_recipe( + self, + recipe: Union[str, List[str], Recipe, List[Recipe]] = None, + recipe_stage: str = None, + recipe_args: Dict[str, Any] = None, + ): + pass + + def update_model(self, model: Any): + pass + + def update_optimizer(self, optimizer: Any, attach_callbacks: bool = True): + pass + + def update_data( + self, + train_data: Any = None, + val_data: Any = None, + test_data: Any = None, + calib_data: Any = None, + copy_data: bool = True, + ): + pass + + def update_start( + self, + start: float = None, + steps_per_epoch: int = None, + batches_per_step: int = None, + ): + pass + + def should_recompile_recipe(self) -> bool: + pass + + def recompile_recipe(self) -> Recipe: + pass + + +@dataclass +class ModifiedState: + model: Any = None + optimizer: Any = None + loss: Any = None + modifier_data: List[Dict[str, Any]] = None + + def __init__(self, model, optimizer, loss, modifier_data): + self.model = model + self.optimizer = optimizer + self.loss = loss + self.modifier_data = modifier_data From 42aef3db736fbcda92e21cf466193937d46db1a1 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Wed, 6 Sep 2023 09:22:56 -0400 Subject: [PATCH 03/27] add in recipe helper functions for merging, loading, and running callbacks --- src/sparseml/core/event.py | 88 ++++++++++++++---------- src/sparseml/core/modifier/base.py | 20 ++++-- src/sparseml/core/modifier/modifier.py | 16 ++++- src/sparseml/core/modifier/stage.py | 16 +++++ src/sparseml/core/recipe/base.py | 43 ++++++++++++ src/sparseml/core/recipe/modifier.py | 21 ++++-- src/sparseml/core/recipe/recipe.py | 95 ++++++++++++++++++++++++-- src/sparseml/core/recipe/stage.py | 18 ++++- src/sparseml/core/session.py | 82 ++++++++++++++-------- src/sparseml/core/state.py | 71 ++++++++++++++++--- 10 files changed, 378 insertions(+), 92 deletions(-) create mode 100644 src/sparseml/core/recipe/base.py diff --git a/src/sparseml/core/event.py b/src/sparseml/core/event.py index d2e9ce1d3ca..9875d0bf022 100644 --- a/src/sparseml/core/event.py +++ b/src/sparseml/core/event.py @@ -13,6 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod +from copy import deepcopy from dataclasses import dataclass from enum import Enum from typing import List @@ -67,7 +68,6 @@ def order(self) -> int: class Event: type_: EventType = None - epoch_based: bool = None steps_per_epoch: int = None batches_per_step: int = None invocations_per_step: int = None @@ -75,6 +75,10 @@ class Event: global_step: int = 0 global_batch: int = 0 + @property + def epoch_based(self) -> bool: + return self.steps_per_epoch is not None + @property def epoch(self) -> int: return self.global_step // self.steps_per_epoch @@ -97,6 +101,7 @@ def epoch_batch(self) -> int: return self.global_batch % batches_per_epoch + @property def current_index(self) -> float: if not self.epoch_based: return self.global_step @@ -106,15 +111,26 @@ def current_index(self) -> float: return self.epoch_full - def new_instance(self, **kwargs) -> "Event": - instance = Event( - type_=self.type_, - epoch_based=self.epoch_based, - steps_per_epoch=self.steps_per_epoch, - batches_per_step=self.batches_per_step, - global_step=self.global_step, - global_batch=self.global_batch, + @current_index.setter + def current_index(self, value: float): + if not self.epoch_based: + self.global_step = int(value) + self.global_batch = ( + self.global_step + if self.batches_per_step is None or self.batches_per_step < 2 + else self.global_step * self.batches_per_step + ) + return + + self.global_step = int(value * self.steps_per_epoch) + self.global_batch = ( + self.global_step + if self.batches_per_step is None or self.batches_per_step < 2 + else self.global_step * self.batches_per_step ) + + def new_instance(self, **kwargs) -> "Event": + instance = deepcopy(self) for key, value in kwargs.items(): setattr(instance, key, value) @@ -123,13 +139,16 @@ def new_instance(self, **kwargs) -> "Event": class EventLifecycle(ABC, Event): type_first: EventType = None - batches_step_counter: int = 0 - steps_epoch_counter: int = 0 step_count: int = 0 batch_count: int = 0 - def __init__(self, type_first: EventType): + def __init__(self, type_first: EventType, start: Event): self.type_first = type_first + self.steps_per_epoch = start.steps_per_epoch + self.batches_per_step = start.batches_per_step + self.invocations_per_step = start.invocations_per_step + self.global_step = start.global_step + self.global_batch = start.global_batch def events_from_type(self, type_: EventType) -> List[Event]: if type_ == EventType.BATCH_START: @@ -149,35 +168,15 @@ def events_from_type(self, type_: EventType) -> List[Event]: raise ValueError(f"invalid event type {type_}") - @abstractmethod - def batch_start_events(self) -> List[Event]: - raise NotImplementedError() - - @abstractmethod - def loss_calculated_events(self) -> List[Event]: - raise NotImplementedError() - - @abstractmethod - def optim_pre_step_events(self) -> List[Event]: - raise NotImplementedError() - - @abstractmethod - def optim_post_step_events(self) -> List[Event]: - raise NotImplementedError() - - @abstractmethod - def batch_end_events(self) -> List[Event]: - raise NotImplementedError() - def check_step_batches_count(self, increment: bool) -> bool: if self.batches_per_step is None or self.batches_per_step < 2: return True - compare_batch = self.batches_step_counter + 1 + compare_batch = self.batch_count + 1 at_step = compare_batch % self.batches_per_step == 0 if increment: - self.batches_step_counter = compare_batch if not at_step else 0 + self.batch_count = compare_batch if not at_step else 0 return at_step @@ -193,8 +192,25 @@ def check_step_invocations_count(self, increment: bool) -> bool: return at_step - def reset_step_count(self): - self.step_count = 0 + @abstractmethod + def batch_start_events(self) -> List[Event]: + raise NotImplementedError() + + @abstractmethod + def loss_calculated_events(self) -> List[Event]: + raise NotImplementedError() + + @abstractmethod + def optim_pre_step_events(self) -> List[Event]: + raise NotImplementedError() + + @abstractmethod + def optim_post_step_events(self) -> List[Event]: + raise NotImplementedError() + + @abstractmethod + def batch_end_events(self) -> List[Event]: + raise NotImplementedError() class WrappedOptimEventLifecycle(EventLifecycle): diff --git a/src/sparseml/core/modifier/base.py b/src/sparseml/core/modifier/base.py index d9341dbc3ed..7eb5fd3f667 100644 --- a/src/sparseml/core/modifier/base.py +++ b/src/sparseml/core/modifier/base.py @@ -26,18 +26,30 @@ class ModifierInterface(ABC): def __init__(self, **kwargs): pass + @abstractmethod + def check_initialized(self): + raise NotImplementedError() + + @abstractmethod + def calculate_start(self) -> float: + raise NotImplementedError() + + @abstractmethod + def calculate_end(self) -> float: + raise NotImplementedError() + @abstractmethod def pre_initialize_structure(self, state: State, **kwargs): - pass + raise NotImplementedError() @abstractmethod def initialize(self, state: State, **kwargs): - pass + raise NotImplementedError() @abstractmethod def finalize(self, state: State, **kwargs): - pass + raise NotImplementedError() @abstractmethod def update_event(self, state: State, event: Event, **kwargs): - pass + raise NotImplementedError() diff --git a/src/sparseml/core/modifier/modifier.py b/src/sparseml/core/modifier/modifier.py index 343b502190f..4d89b059496 100644 --- a/src/sparseml/core/modifier/modifier.py +++ b/src/sparseml/core/modifier/modifier.py @@ -30,7 +30,7 @@ class Modifier(ModifierInterface, MultiFrameworkObject, BaseModel): index: int = None group: str = None - start: float + start: float = None end: Optional[float] = None update: Optional[float] = None @@ -40,6 +40,16 @@ class Modifier(ModifierInterface, MultiFrameworkObject, BaseModel): _started: bool = False _ended: bool = False + def check_initialized(self): + if not self._initialized: + raise RuntimeError("modifier has not been initialized") + + def calculate_start(self) -> float: + return self.start if self.start is not None else -1 + + def calculate_end(self) -> float: + return self.end if self.end is not None else -1 + def pre_initialize_structure(self, state: State, **kwargs): self.on_initialize_structure(state, **kwargs) self._initialized_structure = True @@ -113,12 +123,12 @@ def update_event(self, state: State, event: Event, **kwargs): self.on_update(state, event, **kwargs) def should_start(self, event: Event): - current = event.current_index() + current = event.current_index return self.start <= current and (self.end is None or current < self.end) def should_end(self, event: Event): - current = event.current_index() + current = event.current_index return self.end is not None and current >= self.end diff --git a/src/sparseml/core/modifier/stage.py b/src/sparseml/core/modifier/stage.py index 11fcc6550bb..caa2e81c5f6 100644 --- a/src/sparseml/core/modifier/stage.py +++ b/src/sparseml/core/modifier/stage.py @@ -31,6 +31,22 @@ class StageModifiers(ModifierInterface, BaseModel): _initialized: bool = False _finalized: bool = False + def check_initialized(self): + for modifier in self.modifiers: + modifier.check_initialized() + + def calculate_start(self) -> float: + return min( + mod.calculate_start() + for mod in self.modifiers + if mod.calculate_start() >= 0 + ) + + def calculate_end(self) -> float: + return max( + mod.calculate_end() for mod in self.modifiers if mod.calculate_end() >= 0 + ) + def pre_initialize_structure(self, state: State, **kwargs): for modifier in self.modifiers: modifier.pre_initialize_structure(state, **kwargs) diff --git a/src/sparseml/core/recipe/base.py b/src/sparseml/core/recipe/base.py new file mode 100644 index 00000000000..7eeb8e4539d --- /dev/null +++ b/src/sparseml/core/recipe/base.py @@ -0,0 +1,43 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Any, Dict + +from pydantic import BaseModel, root_validator + +from sparseml.core.framework import Framework +from sparseml.core.modifier import Modifier, ModifierFactory +from sparseml.core.recipe.args import RecipeArgs + + +__all__ = ["RecipeBase"] + + +class RecipeBase(BaseModel, ABC): + @abstractmethod + def calculate_start(self) -> int: + raise NotImplementedError() + + @abstractmethod + def calculate_end(self) -> int: + raise NotImplementedError() + + @abstractmethod + def evaluate(self, args: RecipeArgs = None, shift: int = None): + raise NotImplementedError() + + @abstractmethod + def create_modifier(self, framework: Framework) -> Any: + raise NotImplementedError() diff --git a/src/sparseml/core/recipe/modifier.py b/src/sparseml/core/recipe/modifier.py index 4d16f1b3100..23e6ac98887 100644 --- a/src/sparseml/core/recipe/modifier.py +++ b/src/sparseml/core/recipe/modifier.py @@ -14,27 +14,40 @@ from typing import Any, Dict -from pydantic import BaseModel, root_validator +from pydantic import root_validator from sparseml.core.framework import Framework from sparseml.core.modifier import Modifier, ModifierFactory from sparseml.core.recipe.args import RecipeArgs +from sparseml.core.recipe.base import RecipeBase __all__ = ["RecipeModifier"] -class RecipeModifier(BaseModel): +class RecipeModifier(RecipeBase): type: str group: str = None args: Dict[str, Any] = None _args_evaluated: Dict[str, Any] = None - def evaluate(self, parent_args: RecipeArgs = None, shift: int = None): + def calculate_start(self) -> int: + if not self._args_evaluated: + raise ValueError("args must be evaluated before calculating start") + + return self._args_evaluated.get("start", -1) + + def calculate_end(self) -> int: + if not self._args_evaluated: + raise ValueError("args must be evaluated before calculating start") + + return self._args_evaluated.get("end", -1) + + def evaluate(self, args: RecipeArgs = None, shift: int = None): if not self.args: raise ValueError("args must be set before evaluating") - comb_args = parent_args or RecipeArgs() + comb_args = args or RecipeArgs() self._args_evaluated = comb_args.evaluate_ext(self.args) if shift is not None and "start" in self._args_evaluated: diff --git a/src/sparseml/core/recipe/recipe.py b/src/sparseml/core/recipe/recipe.py index 41014995e75..fa5c4d86014 100644 --- a/src/sparseml/core/recipe/recipe.py +++ b/src/sparseml/core/recipe/recipe.py @@ -12,13 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Tuple +import json +import os +from typing import Any, Dict, List, Tuple, Union -from pydantic import BaseModel, Field, root_validator +import yaml +from pydantic import Field, root_validator from sparseml.core.framework import Framework from sparseml.core.modifier import StageModifiers from sparseml.core.recipe.args import RecipeArgs +from sparseml.core.recipe.base import RecipeBase from sparseml.core.recipe.metadata import RecipeMetaData from sparseml.core.recipe.stage import RecipeStage @@ -26,21 +30,100 @@ __all__ = ["Recipe"] -class Recipe(BaseModel): +class Recipe(RecipeBase): + @staticmethod + def create_instance(path: str) -> "Recipe": + if not os.path.isfile(path): + # not a local file, load from SparseZoo + raise NotImplementedError() + + with open(path, "r") as file: + content = file.read() + + if path.lower().endswith(".json"): + obj = json.loads(content) + elif path.lower().endswith(".yaml") or path.lower().endswith(".yml"): + obj = yaml.safe_load(content) + else: + try: + obj = json.loads(content) + except json.JSONDecodeError: + try: + obj = yaml.safe_load(content) + except yaml.YAMLError: + raise ValueError(f"Could not parse recipe from path {path}") + + return Recipe.parse_obj(obj) + + @staticmethod + def simplify_recipe( + recipe: "Recipe", stages: List[str], args: Dict[str, Any], shift: int = None + ) -> "Recipe": + simplified = Recipe() + simplified.version = recipe.version + simplified.args = recipe.args + simplified.stages = [ + stage + for stage in recipe.stages + if ((not stages or "default" in stages) and not stage.exclude_default) + or stage.group in stages + ] + simplified.evaluate(args=args, shift=shift) + + return simplified + + @staticmethod + def simplify_combine_recipes( + recipes: List[Union["Recipe", Tuple["Recipe", str, Dict[str, Any]]]] + ) -> "Recipe": + simplified = Recipe() + + for recipe_tuple in recipes: + recipe = ( + recipe_tuple[0] if isinstance(recipe_tuple, tuple) else recipe_tuple + ) + stages = ( + recipe_tuple[1].split(",") if isinstance(recipe_tuple, tuple) else None + ) + args = recipe_tuple[2] if isinstance(recipe_tuple, tuple) else None + recipe_simple = Recipe.simplify_recipe( + recipe=recipe, + stages=stages, + args=args, + shift=simplified.calculate_end(), + ) + simplified.version = recipe_simple.version + simplified.stages.extend(recipe_simple.stages) + + return simplified + version: str = None args: RecipeArgs = None stages: List[RecipeStage] = Field(default_factory=list) metadata: RecipeMetaData = None _args_evaluated: RecipeArgs = None + def calculate_start(self) -> int: + return min( + stage.calculate_start() + for stage in self.stages + if stage.calculate_start() >= 0 + ) + + def calculate_end(self) -> int: + return max( + stage.calculate_end() for stage in self.stages if stage.calculate_end() >= 0 + ) + def evaluate(self, args: Dict[str, Any] = None, shift: int = None): - args = self.args.combine(args) + args = self.args.combine(args) if self.args else RecipeArgs(**(args or {})) self._args_evaluated = args.evaluate() for stage in self.stages: stage.evaluate(self._args_evaluated, shift) def create_modifiers(self, framework: Framework) -> List[StageModifiers]: - self.evaluate() + if self._args_evaluated is None: + self.evaluate() modifiers = [] for index, stage in enumerate(self.stages): @@ -49,7 +132,7 @@ def create_modifiers(self, framework: Framework) -> List[StageModifiers]: stage_modifiers.group = stage.group modifiers.append(stage_modifiers) - return stage_modifiers + return modifiers @root_validator(pre=True) def remap_stages(cls, values: Dict[str, Any]) -> Dict[str, Any]: diff --git a/src/sparseml/core/recipe/stage.py b/src/sparseml/core/recipe/stage.py index fa918132827..54effcfc042 100644 --- a/src/sparseml/core/recipe/stage.py +++ b/src/sparseml/core/recipe/stage.py @@ -14,24 +14,38 @@ from typing import Any, Dict, List -from pydantic import BaseModel, Field, root_validator +from pydantic import Field, root_validator from sparseml.core.framework import Framework from sparseml.core.modifier import StageModifiers from sparseml.core.recipe.args import RecipeArgs +from sparseml.core.recipe.base import RecipeBase from sparseml.core.recipe.modifier import RecipeModifier __all__ = ["RecipeStage"] -class RecipeStage(BaseModel): +class RecipeStage(RecipeBase): group: str = None args: RecipeArgs = None enabled: bool = True modifiers: List[RecipeModifier] = Field(default_factory=list) + exclude_default: bool = False _args_evaluated: RecipeArgs = None + def calculate_start(self) -> int: + return min( + mod.calculate_start() + for mod in self.modifiers + if mod.calculate_start() >= 0 + ) + + def calculate_end(self) -> int: + return max( + mod.calculate_end() for mod in self.modifiers if mod.calculate_end() >= 0 + ) + def evaluate(self, parent_args: RecipeArgs = None, shift: int = None): merged_args = self.args.combine(parent_args) self._args_evaluated = merged_args.evaluate() diff --git a/src/sparseml/core/session.py b/src/sparseml/core/session.py index d157cb1e714..14f2c7c85c5 100644 --- a/src/sparseml/core/session.py +++ b/src/sparseml/core/session.py @@ -206,23 +206,7 @@ def event( if event_type == EventType.LOSS_CALCULATED and loss is None: raise ValueError("Loss must be provided for loss calculated event") - if self.state.event_lifecycle is None: - if event_type == EventType.BATCH_START: - # utilizing callbacks pathway, ensure optim is not wrapped - if self.state.optim_wrapped: - raise ValueError( - "Cannot use batch callbacks with wrapped optimizer, " - "set attach_optim_callbacks to False when initializing " - ) - self.state.event_lifecycle = CallbacksEventLifecycle(event_type) - elif self.state.optim_wrapped: - # utilizing wrapped optimizer for callbacks - self.state.event_lifecycle = WrappedOptimEventLifecycle(event_type) - else: - raise ValueError( - "First event must be batch_start or " - "attach_optim_callbacks must be True" - ) + self._check_setup_lifecycle(event_type) event = None modifier_data = [] @@ -268,20 +252,62 @@ def reset(self): self._event_called = False def _check_compile_recipe(self): - if not self.state.should_recompile_recipe(): + if not self.state.recipe_changed and self._modifiers is not None: + # recipe hasn't changed and modifiers set, no need to recompile return - # clear out the modifiers to reinitialize from newly compiled recipe - if self._modifiers: - for modifier in self._modifiers: - if modifier._initialized: - modifier.finalize(self.state) - del self._modifiers + if self.state.recipes is None: + # no recipes currently, return + return - self.state.recompile_recipe() - self._modifiers = self.state.compiled_recipe.create_modifiers( - self.state.framework - ) + if self.state.recipe_changed: + self.state.recompile_recipe() + + if self._modifiers: + # clear out the modifiers to reinitialize from newly compiled recipe + for modifier in self._modifiers: + if modifier._initialized: + modifier.finalize(self.state) + del self._modifiers + + if self.state.recipe_modifier_ready: + self._modifiers = self.state.compiled_recipe.create_modifiers( + self.state.framework + ) + + def _check_setup_lifecycle(self, event_type: EventType): + if self.state.event_lifecycle is not None: + return + + # first event call, setup lifecycle and make sure everything is initialized + if not self.state.recipe_modifier_ready: + raise ValueError( + "Cannot invoke event before recipe, model, and start are set" + ) + + for modifier in self._modifiers: + modifier.check_initialized() + + if event_type == EventType.BATCH_START: + # utilizing callbacks pathway, ensure optim is not wrapped + if self.state.optim_wrapped: + raise ValueError( + "Cannot use batch callbacks with wrapped optimizer, " + "set attach_optim_callbacks to False when initializing " + ) + self.state.event_lifecycle = CallbacksEventLifecycle( + event_type, self.state.start_event + ) + elif self.state.optim_wrapped: + # utilizing wrapped optimizer for callbacks + self.state.event_lifecycle = WrappedOptimEventLifecycle( + event_type, self.state.start_event + ) + else: + raise ValueError( + "First event must be batch_start or " + "attach_optim_callbacks must be True" + ) _global_session = SparseSession() diff --git a/src/sparseml/core/state.py b/src/sparseml/core/state.py index b0929d29ef2..2cd9ca79657 100644 --- a/src/sparseml/core/state.py +++ b/src/sparseml/core/state.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy from dataclasses import dataclass from typing import Any, Dict, List, Tuple, Union @@ -62,7 +63,21 @@ class State: data = Data() hardware = Hardware() event_lifecycle: EventLifecycle = None + start_event: Event = None last_event: Event = None + _recipe_changed: bool = False + + @property + def recipe_changed(self) -> bool: + return self._recipe_changed + + @property + def recipe_modifier_ready(self) -> bool: + return ( + self.compiled_recipe is not None + and self.model is not None + and self.start_event is not None + ) def update_framework(self, framework: Framework): self.framework = framework if framework else Framework.pytorch @@ -73,13 +88,31 @@ def update_recipe( recipe_stage: str = None, recipe_args: Dict[str, Any] = None, ): - pass + if not isinstance(recipe, list): + recipe = [recipe] + + for rec in recipe: + if isinstance(rec, str): + rec = Recipe.create_instance(rec) + + self.recipes.append((rec, recipe_stage, recipe_args)) + + self._recipe_changed = True def update_model(self, model: Any): - pass + if self.framework is None: + raise RuntimeError("framework must be set before updating model") + + self.model = ModifiableModel(framework=self.framework, model=model) def update_optimizer(self, optimizer: Any, attach_callbacks: bool = True): - pass + if self.framework is None: + raise RuntimeError("framework must be set before updating optimizer") + + self.optim_wrapped = attach_callbacks + self.optimizer = ModifiableOptimizer( + framework=self.framework, optimizer=optimizer + ) def update_data( self, @@ -89,7 +122,22 @@ def update_data( calib_data: Any = None, copy_data: bool = True, ): - pass + if self.framework is None: + raise RuntimeError("framework must be set before updating data") + + self.data = ModifiableData(framework=self.framework) + + if train_data is not None: + self.data.train = train_data if not copy_data else deepcopy(train_data) + + if val_data is not None: + self.data.val = val_data if not copy_data else deepcopy(val_data) + + if test_data is not None: + self.data.test = test_data if not copy_data else deepcopy(test_data) + + if calib_data is not None: + self.data.calib = calib_data if not copy_data else deepcopy(calib_data) def update_start( self, @@ -97,13 +145,18 @@ def update_start( steps_per_epoch: int = None, batches_per_step: int = None, ): - pass + self.start_event = Event() + self.start_event.steps_per_epoch = steps_per_epoch + self.start_event.batches_per_step = batches_per_step + self.start_event.current_index = start if start is not None else 0 + + def recompile_recipe(self): + self._recipe_changed = False - def should_recompile_recipe(self) -> bool: - pass + if not self.recipes: + raise RuntimeError("No recipes to compile") - def recompile_recipe(self) -> Recipe: - pass + self.compiled_recipe = Recipe.simplify_combine_recipes(self.recipes) @dataclass From 668278435ceb49b126752d894822b930e2c0d077 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Wed, 6 Sep 2023 09:43:06 -0400 Subject: [PATCH 04/27] minor fixes for new framework --- src/sparseml/core/recipe/recipe.py | 2 +- src/sparseml/core/session.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sparseml/core/recipe/recipe.py b/src/sparseml/core/recipe/recipe.py index fa5c4d86014..668e00764b2 100644 --- a/src/sparseml/core/recipe/recipe.py +++ b/src/sparseml/core/recipe/recipe.py @@ -121,7 +121,7 @@ def evaluate(self, args: Dict[str, Any] = None, shift: int = None): for stage in self.stages: stage.evaluate(self._args_evaluated, shift) - def create_modifiers(self, framework: Framework) -> List[StageModifiers]: + def create_modifier(self, framework: Framework) -> List[StageModifiers]: if self._args_evaluated is None: self.evaluate() modifiers = [] diff --git a/src/sparseml/core/session.py b/src/sparseml/core/session.py index 14f2c7c85c5..d61dd4c4876 100644 --- a/src/sparseml/core/session.py +++ b/src/sparseml/core/session.py @@ -271,7 +271,7 @@ def _check_compile_recipe(self): del self._modifiers if self.state.recipe_modifier_ready: - self._modifiers = self.state.compiled_recipe.create_modifiers( + self._modifiers = self.state.compiled_recipe.create_modifier( self.state.framework ) From 5b0f190af0a64e3462d04b79d92972d9dc9f343e Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Thu, 7 Sep 2023 09:56:07 -0400 Subject: [PATCH 05/27] add constant pruning modifier --- src/sparseml/core/session.py | 2 +- .../modifiers/distillation/__init__.py | 13 ++ .../modifiers/distillation/output/__init__.py | 13 ++ src/sparseml/modifiers/pruning/__init__.py | 13 ++ .../modifiers/pruning/constant/__init__.py | 13 ++ .../modifiers/pruning/constant/base.py | 25 +++ .../modifiers/pruning/constant/pytorch.py | 159 ++++++++++++++++++ .../modifiers/pruning/magnitude/__init__.py | 13 ++ 8 files changed, 250 insertions(+), 1 deletion(-) create mode 100644 src/sparseml/modifiers/distillation/__init__.py create mode 100644 src/sparseml/modifiers/distillation/output/__init__.py create mode 100644 src/sparseml/modifiers/pruning/__init__.py create mode 100644 src/sparseml/modifiers/pruning/constant/__init__.py create mode 100644 src/sparseml/modifiers/pruning/constant/base.py create mode 100644 src/sparseml/modifiers/pruning/constant/pytorch.py create mode 100644 src/sparseml/modifiers/pruning/magnitude/__init__.py diff --git a/src/sparseml/core/session.py b/src/sparseml/core/session.py index d61dd4c4876..5e21bb0f250 100644 --- a/src/sparseml/core/session.py +++ b/src/sparseml/core/session.py @@ -436,7 +436,7 @@ def optim_pre_step(cls, **kwargs) -> ModifiedState: return cls.event(EventType.OPTIM_PRE_STEP, **kwargs) @classmethod - def optim_stepped(cls, **kwargs) -> ModifiedState: + def optim_post_step(cls, **kwargs) -> ModifiedState: return cls.event(EventType.OPTIM_POST_STEP, **kwargs) @classmethod diff --git a/src/sparseml/modifiers/distillation/__init__.py b/src/sparseml/modifiers/distillation/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/src/sparseml/modifiers/distillation/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/modifiers/distillation/output/__init__.py b/src/sparseml/modifiers/distillation/output/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/src/sparseml/modifiers/distillation/output/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/modifiers/pruning/__init__.py b/src/sparseml/modifiers/pruning/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/src/sparseml/modifiers/pruning/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/modifiers/pruning/constant/__init__.py b/src/sparseml/modifiers/pruning/constant/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/src/sparseml/modifiers/pruning/constant/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/modifiers/pruning/constant/base.py b/src/sparseml/modifiers/pruning/constant/base.py new file mode 100644 index 00000000000..f18a9e42dc6 --- /dev/null +++ b/src/sparseml/modifiers/pruning/constant/base.py @@ -0,0 +1,25 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +from sparseml.core import Modifier, State + + +class ConstantPruningModifier(Modifier): + targets: Union[str, List[str]] + _epsilon: float = 10e-9 + + def on_initialize_structure(self, state: State, **kwargs): + pass # nothing needed for this modifier diff --git a/src/sparseml/modifiers/pruning/constant/pytorch.py b/src/sparseml/modifiers/pruning/constant/pytorch.py new file mode 100644 index 00000000000..f376e19fc10 --- /dev/null +++ b/src/sparseml/modifiers/pruning/constant/pytorch.py @@ -0,0 +1,159 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Tuple + +import torch +from torch.nn import Module, Parameter + +from sparseml.core import Event, State, EventType +from sparseml.modifiers.pruning.constant.base import ConstantPruningModifier + + +class ConstantPruningModifierPyTorch(ConstantPruningModifier): + _layers_params: Dict[str, Tuple[Module, str, Parameter]] = None + _forward_hooks = None + _backward_hooks = None + + _save_masks: bool = False + _use_hooks: bool = False + _hooks_set: bool = False + + def on_initialize(self, state: State, event: Event, **kwargs) -> bool: + if "save_masks" in kwargs: + self._save_masks = kwargs["save_masks"] + if "use_hooks" in kwargs: + self._use_hooks = kwargs["use_hooks"] + + if not state.model or not state.start_event: + return False + + self._layers_params = state.model.get_layers_params(self.targets) + self._create_masks() + self._check_create_hooks() + + return True + + def on_finalize(self, state: State, event: Event, **kwargs) -> bool: + self._check_remove_masks() + self._check_remove_hooks() + + return True + + def on_start(self, state: State, event: Event, **kwargs): + self._populate_masks() + + def on_update(self, state: State, event: Event, **kwargs): + if self._use_hooks: + # hooks are used to update, so nothing to do here + return + + if event.type_ == EventType.OPTIM_PRE_STEP: + # zero out the gradients for the pruned params + self._apply_mask_gradients() + elif event.type_ == EventType.OPTIM_POST_STEP: + # apply the masks to the pruned params + self._apply_mask_params() + + def on_end(self, state: State, event: Event, **kwargs): + self._check_remove_hooks() + + def _param_mask_name(self, param_name: str) -> str: + return f"{param_name}_mask" + + def _create_masks(self): + for name, (layer, param_name, param) in self._layers_params.items(): + # check if a mask is already applied to the layer + try: + layer.get_buffer(self._param_mask_name(param_name)) + except AttributeError: + # add the mask buffer to the layer + layer.register_buffer( + self._param_mask_name(param_name), + torch.ones_like(param.data, dtype=torch.bool), + persistent=self._save_masks, + ) + + def _populate_masks(self): + for name, (layer, param_name, param) in self._layers_params.items(): + layer.get_buffer(self._param_mask_name(param_name)).fill_( + param.data.abs() < self._epsilon + ) + + def _apply_mask_params(self): + for name, (layer, param_name, param) in self._layers_params.items(): + mask = layer.get_buffer(self._param_mask_name(param_name)) + param.data = param.data * mask + + def _apply_mask_gradients(self): + for name, (layer, param_name, param) in self._layers_params.items(): + if param.grad is not None: + mask = layer.get_buffer(self._param_mask_name(param_name)) + param.grad = param.grad * mask + + def _check_remove_masks(self): + if self._save_masks: + return + + for name, (layer, param_name, param) in self._layers_params.items(): + try: + layer.unregister_buffer(self._param_mask_name(param_name)) + except AttributeError: + pass + + def _check_create_hooks(self): + if not self._use_hooks or self._hooks_set: + return + + def _register_hooks(layer, param_name, param): + mask_name = self._param_mask_name(param_name) + + def _forward_hook_fn(module, input, output): + mask = module.get_buffer(mask_name) + param.data = param.data * mask + + return output + + def _backward_hook_fn(module, grad_input, grad_output): + mask = module.get_buffer(mask_name) + if grad_input[0] is not None: + grad_input[0] *= mask + return grad_input + + forward_hook = layer.register_forward_hook(_forward_hook_fn) + backward_hook = layer.register_backward_hook(_backward_hook_fn) + + return forward_hook, backward_hook + + self._forward_hooks = [] + self._backward_hooks = [] + + for name, (layer, param_name, param) in self._layers_params.items(): + forward, backward = _register_hooks(layer, param_name, param) + self._forward_hooks.append(forward) + self._backward_hooks.append(backward) + + self._hooks_set = True + + def _check_remove_hooks(self): + if self._hooks_set: + return + + for forward, backward in zip(self._forward_hooks, self._backward_hooks): + forward.remove() + backward.remove() + + self._forward_hooks = None + self._backward_hooks = None + self._hooks_set = False diff --git a/src/sparseml/modifiers/pruning/magnitude/__init__.py b/src/sparseml/modifiers/pruning/magnitude/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/src/sparseml/modifiers/pruning/magnitude/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. From b8452a55a6f8e5b1f2ecfacdfa7925b639d8ce7d Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Sat, 9 Sep 2023 09:42:34 -0400 Subject: [PATCH 06/27] add magntitude pruning modifier --- src/sparseml/core/data/base.py | 6 +- src/sparseml/core/model/__init__.py | 2 +- src/sparseml/core/model/base.py | 21 +- src/sparseml/core/modifier/modifier.py | 14 +- src/sparseml/core/optimizer/__init__.py | 2 +- src/sparseml/core/optimizer/base.py | 6 +- src/sparseml/modifiers/__init__.py | 3 + src/sparseml/modifiers/pruning/__init__.py | 3 + .../modifiers/pruning/constant/__init__.py | 2 + .../modifiers/pruning/constant/base.py | 3 + .../modifiers/pruning/constant/pytorch.py | 138 +++---------- src/sparseml/modifiers/pruning/helpers.py | 190 ++++++++++++++++++ .../modifiers/pruning/magnitude/__init__.py | 2 + .../modifiers/pruning/magnitude/base.py | 34 ++++ .../modifiers/pruning/magnitude/pytorch.py | 129 ++++++++++++ .../utils/pytorch/pruning/__init__.py | 16 ++ .../utils/pytorch/pruning/layer_mask.py | 177 ++++++++++++++++ src/sparseml/utils/pytorch/pruning/mask.py | 167 +++++++++++++++ 18 files changed, 785 insertions(+), 130 deletions(-) create mode 100644 src/sparseml/modifiers/pruning/helpers.py create mode 100644 src/sparseml/modifiers/pruning/magnitude/base.py create mode 100644 src/sparseml/modifiers/pruning/magnitude/pytorch.py create mode 100644 src/sparseml/utils/pytorch/pruning/__init__.py create mode 100644 src/sparseml/utils/pytorch/pruning/layer_mask.py create mode 100644 src/sparseml/utils/pytorch/pruning/mask.py diff --git a/src/sparseml/core/data/base.py b/src/sparseml/core/data/base.py index 7665d85330e..3d15e48a777 100644 --- a/src/sparseml/core/data/base.py +++ b/src/sparseml/core/data/base.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from typing import Generic, TypeVar -from pydantic import BaseModel - from sparseml.core.framework import MultiFrameworkObject @@ -24,7 +23,8 @@ DT = TypeVar("DT") # Dataset Type -class ModifiableData(Generic[DT], MultiFrameworkObject, BaseModel): +@dataclass +class ModifiableData(Generic[DT], MultiFrameworkObject): data: DT = None num_samples: int = None diff --git a/src/sparseml/core/model/__init__.py b/src/sparseml/core/model/__init__.py index 7df43946035..87930811c41 100644 --- a/src/sparseml/core/model/__init__.py +++ b/src/sparseml/core/model/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import ModifiableModel +from .base import * diff --git a/src/sparseml/core/model/base.py b/src/sparseml/core/model/base.py index db455b18412..0a0fdb4084d 100644 --- a/src/sparseml/core/model/base.py +++ b/src/sparseml/core/model/base.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from typing import Dict, Generic, List, TypeVar, Union -from pydantic import BaseModel - from sparseml.core.framework import MultiFrameworkObject -__all__ = ["ModifiableModel"] +__all__ = ["ModifiableModel", "ModelParameterizedLayer"] MT = TypeVar("MT") @@ -27,9 +26,23 @@ PT = TypeVar("PT") -class ModifiableModel(Generic[MT, LT, PT], MultiFrameworkObject, BaseModel): +@dataclass +class ModelParameterizedLayer(Generic[LT, PT]): + layer_name: str + layer: LT + param_name: str + param: PT + + +@dataclass +class ModifiableModel(Generic[MT, LT, PT], MultiFrameworkObject): model: MT = None + def get_layers_params( + self, targets: Union[str, List[str]] + ) -> Dict[str, ModelParameterizedLayer[LT, PT]]: + raise NotImplementedError() + def get_layers(self, targets: Union[str, List[str]]) -> Dict[str, LT]: raise NotImplementedError() diff --git a/src/sparseml/core/modifier/modifier.py b/src/sparseml/core/modifier/modifier.py index 4d89b059496..750fb58a7e2 100644 --- a/src/sparseml/core/modifier/modifier.py +++ b/src/sparseml/core/modifier/modifier.py @@ -13,7 +13,6 @@ # limitations under the License. -from abc import abstractmethod from typing import Optional from pydantic import BaseModel @@ -61,6 +60,9 @@ def initialize(self, state: State, **kwargs): if self._finalized: raise RuntimeError("cannot initialize a finalized modifier") + if state.start_event is None: + return + initialized = self.on_initialize(**kwargs) if not isinstance(initialized, bool): @@ -71,6 +73,10 @@ def initialize(self, state: State, **kwargs): self._initialized = initialized + if self.should_start(state.start_event): + self.on_start(state, state.start_event, **kwargs) + self._started = True + def finalize(self, state: State, **kwargs): if self._finalized: return @@ -132,26 +138,20 @@ def should_end(self, event: Event): return self.end is not None and current >= self.end - @abstractmethod def on_initialize_structure(self, state: State, **kwargs): raise NotImplementedError() - @abstractmethod def on_initialize(self, state: State, event: Event, **kwargs) -> bool: raise NotImplementedError() - @abstractmethod def on_finalize(self, state: State, event: Event, **kwargs) -> bool: raise NotImplementedError() - @abstractmethod def on_start(self, state: State, event: Event, **kwargs): raise NotImplementedError() - @abstractmethod def on_update(self, state: State, event: Event, **kwargs): raise NotImplementedError() - @abstractmethod def on_end(self, state: State, event: Event, **kwargs): raise NotImplementedError() diff --git a/src/sparseml/core/optimizer/__init__.py b/src/sparseml/core/optimizer/__init__.py index 6ded41b5440..87930811c41 100644 --- a/src/sparseml/core/optimizer/__init__.py +++ b/src/sparseml/core/optimizer/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import ModifiableOptimizer +from .base import * diff --git a/src/sparseml/core/optimizer/base.py b/src/sparseml/core/optimizer/base.py index bb95135c1f9..09b46208557 100644 --- a/src/sparseml/core/optimizer/base.py +++ b/src/sparseml/core/optimizer/base.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from typing import Any, Generic, List, TypeVar, Union -from pydantic import BaseModel - from sparseml.core.framework import MultiFrameworkObject @@ -26,7 +25,8 @@ PGT = TypeVar("PGT") -class ModifiableOptimizer(Generic[OT, PGT], MultiFrameworkObject, BaseModel): +@dataclass +class ModifiableOptimizer(Generic[OT, PGT], MultiFrameworkObject): optimizer: OT = None def get_param_groups(self) -> List[PGT]: diff --git a/src/sparseml/modifiers/__init__.py b/src/sparseml/modifiers/__init__.py index 0c44f887a47..737bb4ed07e 100644 --- a/src/sparseml/modifiers/__init__.py +++ b/src/sparseml/modifiers/__init__.py @@ -11,3 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .distillation import * +from .pruning import * diff --git a/src/sparseml/modifiers/pruning/__init__.py b/src/sparseml/modifiers/pruning/__init__.py index 0c44f887a47..522fd000e3f 100644 --- a/src/sparseml/modifiers/pruning/__init__.py +++ b/src/sparseml/modifiers/pruning/__init__.py @@ -11,3 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .constant import * +from .magnitude import * diff --git a/src/sparseml/modifiers/pruning/constant/__init__.py b/src/sparseml/modifiers/pruning/constant/__init__.py index 0c44f887a47..03ee625d7d4 100644 --- a/src/sparseml/modifiers/pruning/constant/__init__.py +++ b/src/sparseml/modifiers/pruning/constant/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .base import ConstantPruningModifier diff --git a/src/sparseml/modifiers/pruning/constant/base.py b/src/sparseml/modifiers/pruning/constant/base.py index f18a9e42dc6..da29d3ec7a0 100644 --- a/src/sparseml/modifiers/pruning/constant/base.py +++ b/src/sparseml/modifiers/pruning/constant/base.py @@ -17,6 +17,9 @@ from sparseml.core import Modifier, State +__all__ = ["ConstantPruningModifier"] + + class ConstantPruningModifier(Modifier): targets: Union[str, List[str]] _epsilon: float = 10e-9 diff --git a/src/sparseml/modifiers/pruning/constant/pytorch.py b/src/sparseml/modifiers/pruning/constant/pytorch.py index f376e19fc10..2fd3e94c22c 100644 --- a/src/sparseml/modifiers/pruning/constant/pytorch.py +++ b/src/sparseml/modifiers/pruning/constant/pytorch.py @@ -12,23 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Tuple +from typing import Dict -import torch -from torch.nn import Module, Parameter - -from sparseml.core import Event, State, EventType +from sparseml.core import Event, EventType, ModelParameterizedLayer, State from sparseml.modifiers.pruning.constant.base import ConstantPruningModifier +from sparseml.utils.pytorch.pruning import LayerParamMasking -class ConstantPruningModifierPyTorch(ConstantPruningModifier): - _layers_params: Dict[str, Tuple[Module, str, Parameter]] = None - _forward_hooks = None - _backward_hooks = None - +class ConstantPruningModifierPyTorch(ConstantPruningModifier, LayerParamMasking): + _parameterized_layers: Dict[str, ModelParameterizedLayer] = None _save_masks: bool = False _use_hooks: bool = False - _hooks_set: bool = False def on_initialize(self, state: State, event: Event, **kwargs) -> bool: if "save_masks" in kwargs: @@ -39,20 +33,31 @@ def on_initialize(self, state: State, event: Event, **kwargs) -> bool: if not state.model or not state.start_event: return False - self._layers_params = state.model.get_layers_params(self.targets) - self._create_masks() - self._check_create_hooks() + self._parameterized_layers = state.model.get_layers_params(self.targets) + + for layer_param_name, parameterized_layer in self._parameterized_layers.items(): + self.add_mask( + layer_param_name, + parameterized_layer, + persistent=self._save_masks, + add_hooks=self._use_hooks, + ) return True def on_finalize(self, state: State, event: Event, **kwargs) -> bool: - self._check_remove_masks() - self._check_remove_hooks() + for layer_param_name, _ in self._parameterized_layers.items(): + self.remove_mask(layer_param_name) return True def on_start(self, state: State, event: Event, **kwargs): - self._populate_masks() + for layer_param_name, parameterized_layer in self._parameterized_layers.items(): + self.update_mask( + layer_param_name, parameterized_layer.param.data.abs() < self._epsilon + ) + + self.enable_masks() def on_update(self, state: State, event: Event, **kwargs): if self._use_hooks: @@ -60,100 +65,11 @@ def on_update(self, state: State, event: Event, **kwargs): return if event.type_ == EventType.OPTIM_PRE_STEP: - # zero out the gradients for the pruned params - self._apply_mask_gradients() + for layer_param_name, _ in self._parameterized_layers.items(): + self.apply_mask_gradient(layer_param_name) elif event.type_ == EventType.OPTIM_POST_STEP: - # apply the masks to the pruned params - self._apply_mask_params() + for layer_param_name, _ in self._parameterized_layers.items(): + self.apply_mask_weight(layer_param_name) def on_end(self, state: State, event: Event, **kwargs): - self._check_remove_hooks() - - def _param_mask_name(self, param_name: str) -> str: - return f"{param_name}_mask" - - def _create_masks(self): - for name, (layer, param_name, param) in self._layers_params.items(): - # check if a mask is already applied to the layer - try: - layer.get_buffer(self._param_mask_name(param_name)) - except AttributeError: - # add the mask buffer to the layer - layer.register_buffer( - self._param_mask_name(param_name), - torch.ones_like(param.data, dtype=torch.bool), - persistent=self._save_masks, - ) - - def _populate_masks(self): - for name, (layer, param_name, param) in self._layers_params.items(): - layer.get_buffer(self._param_mask_name(param_name)).fill_( - param.data.abs() < self._epsilon - ) - - def _apply_mask_params(self): - for name, (layer, param_name, param) in self._layers_params.items(): - mask = layer.get_buffer(self._param_mask_name(param_name)) - param.data = param.data * mask - - def _apply_mask_gradients(self): - for name, (layer, param_name, param) in self._layers_params.items(): - if param.grad is not None: - mask = layer.get_buffer(self._param_mask_name(param_name)) - param.grad = param.grad * mask - - def _check_remove_masks(self): - if self._save_masks: - return - - for name, (layer, param_name, param) in self._layers_params.items(): - try: - layer.unregister_buffer(self._param_mask_name(param_name)) - except AttributeError: - pass - - def _check_create_hooks(self): - if not self._use_hooks or self._hooks_set: - return - - def _register_hooks(layer, param_name, param): - mask_name = self._param_mask_name(param_name) - - def _forward_hook_fn(module, input, output): - mask = module.get_buffer(mask_name) - param.data = param.data * mask - - return output - - def _backward_hook_fn(module, grad_input, grad_output): - mask = module.get_buffer(mask_name) - if grad_input[0] is not None: - grad_input[0] *= mask - return grad_input - - forward_hook = layer.register_forward_hook(_forward_hook_fn) - backward_hook = layer.register_backward_hook(_backward_hook_fn) - - return forward_hook, backward_hook - - self._forward_hooks = [] - self._backward_hooks = [] - - for name, (layer, param_name, param) in self._layers_params.items(): - forward, backward = _register_hooks(layer, param_name, param) - self._forward_hooks.append(forward) - self._backward_hooks.append(backward) - - self._hooks_set = True - - def _check_remove_hooks(self): - if self._hooks_set: - return - - for forward, backward in zip(self._forward_hooks, self._backward_hooks): - forward.remove() - backward.remove() - - self._forward_hooks = None - self._backward_hooks = None - self._hooks_set = False + self.disable_masks() diff --git a/src/sparseml/modifiers/pruning/helpers.py b/src/sparseml/modifiers/pruning/helpers.py new file mode 100644 index 00000000000..d29ea2cefc8 --- /dev/null +++ b/src/sparseml/modifiers/pruning/helpers.py @@ -0,0 +1,190 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import re +from dataclasses import dataclass +from typing import Any, Callable, Dict + +import numpy as np + +from sparseml.core import Event, State + + +__all__ = [ + "PruningCreateSettings", + "SchedulerCalculationType", + "CreateSchedulerType", + "PruningSchedulerFactory", + "create_custom_scheduler", + "linear_scheduler", + "cubic_scheduler", + "polynomial_decay_scheduler", + "polynomial_scheduler", + "multi_step_scheduler", +] + + +@dataclass +class PruningCreateSettings: + start: float + end: float + update: float + init_sparsity: float + final_sparsity: float + args: Dict[str, Any] + + +SchedulerCalculationType = Callable[[Event, State], float] +CreateSchedulerType = Callable[[PruningCreateSettings], SchedulerCalculationType] + + +class PruningSchedulerFactory: + registry = {} # type: Dict[str, CreateSchedulerType] + + @staticmethod + def register(name: str, func: CreateSchedulerType): + PruningSchedulerFactory.registry[name] = func + + @staticmethod + def register_decorator(name: str): + def inner(func: CreateSchedulerType): + PruningSchedulerFactory.registry[name] = func + return func + + return inner + + @staticmethod + def create_scheduler( + scheduler_type: str, settings: PruningCreateSettings + ) -> SchedulerCalculationType: + if scheduler_type in PruningSchedulerFactory.registry: + return PruningSchedulerFactory.registry[scheduler_type](settings) + elif scheduler_type.startswith("calc(") and scheduler_type.endswith(")"): + return create_custom_scheduler(scheduler_type, settings) + else: + raise ValueError(f"Unknown scheduler type: {scheduler_type}") + + +def create_custom_scheduler( + scheduler_type: str, settings: PruningCreateSettings +) -> SchedulerCalculationType: + pattern = re.compile(r"calc\(([^()]*)\)") + match = pattern.search(scheduler_type) + + if not match: + raise ValueError(f"invalid calc string {scheduler_type}") + + inner_expr = match.group(1) + + def _schedule(event: Event, state: State): + return eval( + inner_expr, + {"math": math}, + { + "start": settings.start, + "end": settings.end, + "update": settings.update, + "init_sparsity": settings.init_sparsity, + "final_sparsity": settings.final_sparsity, + **(settings.args if settings.args else {}), + "index": event.current_index, + }, + ) + + return _schedule + + +@PruningSchedulerFactory.register_decorator("linear") +def linear_scheduler(settings: PruningCreateSettings) -> SchedulerCalculationType: + def _schedule(event: Event, state: State) -> float: + per_complete = (event.current_index - settings.start) / ( + settings.end - settings.start + ) + + return ( + settings.init_sparsity + + (settings.final_sparsity - settings.init_sparsity) * per_complete + ) + + return _schedule + + +@PruningSchedulerFactory.register_decorator("cubic") +def cubic_scheduler(settings: PruningCreateSettings) -> SchedulerCalculationType: + settings.args = {"exponent": 3} + + return polynomial_decay_scheduler(settings) + + +@PruningSchedulerFactory.register_decorator("polynomial_decay") +def polynomial_decay_scheduler( + settings: PruningCreateSettings, +) -> SchedulerCalculationType: + args = settings.args if settings.args else {} + exponent = args.get("exponent", 2) + + def _schedule(event: Event, state: State) -> float: + per_complete = (event.current_index - settings.start) / ( + settings.end - settings.start + ) + + if exponent % 2 == 0: + scaled_complete = -1 * np.exp(per_complete - 1) + 1 + else: + scaled_complete = np.exp(per_complete - 1) - 1 + + return ( + settings.init_sparsity + + (settings.final_sparsity - settings.init_sparsity) * scaled_complete + ) + + return _schedule + + +@PruningSchedulerFactory.register_decorator("polynomial") +def polynomial_scheduler(settings: PruningCreateSettings) -> SchedulerCalculationType: + args = settings.args if settings.args else {} + exponent = args.get("exponent", 2) + + def _schedule(event: Event, state: State) -> float: + per_complete = (event.current_index - settings.start) / ( + settings.end - settings.start + ) + scaled_complete = per_complete**exponent + + return ( + settings.init_sparsity + + (settings.final_sparsity - settings.init_sparsity) * scaled_complete + ) + + return _schedule + + +@PruningSchedulerFactory.register_decorator("multi_step") +def multi_step_scheduler(settings: PruningCreateSettings) -> SchedulerCalculationType: + args = settings.args if settings.args else {} + steps = args.get("steps", []) + steps = sorted(steps, key=lambda x: x[0]) + + def _schedule(event: Event, state: State) -> float: + current_sparsity = settings.init_sparsity + + for (index, sparsity) in steps: + if event.current_index >= index: + current_sparsity = sparsity + + return current_sparsity + + return _schedule diff --git a/src/sparseml/modifiers/pruning/magnitude/__init__.py b/src/sparseml/modifiers/pruning/magnitude/__init__.py index 0c44f887a47..78cd427840b 100644 --- a/src/sparseml/modifiers/pruning/magnitude/__init__.py +++ b/src/sparseml/modifiers/pruning/magnitude/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .base import MagnitudePruningModifier diff --git a/src/sparseml/modifiers/pruning/magnitude/base.py b/src/sparseml/modifiers/pruning/magnitude/base.py new file mode 100644 index 00000000000..802de89049b --- /dev/null +++ b/src/sparseml/modifiers/pruning/magnitude/base.py @@ -0,0 +1,34 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Union + +from sparseml.core import Modifier, State + + +__all__ = ["MagnitudePruningModifier"] + + +class MagnitudePruningModifier(Modifier): + targets: Union[str, List[str]] + init_sparsity: float + final_sparsity: float + update_scheduler: str = "cubic" + scheduler_args: Dict[str, Any] = {} + mask_structure: str = "unstructured" + leave_enabled: bool = False + apply_globally: bool = False + + def on_initialize_structure(self, state: State, **kwargs): + pass # nothing needed for this modifier diff --git a/src/sparseml/modifiers/pruning/magnitude/pytorch.py b/src/sparseml/modifiers/pruning/magnitude/pytorch.py new file mode 100644 index 00000000000..e8d6ac3bd0a --- /dev/null +++ b/src/sparseml/modifiers/pruning/magnitude/pytorch.py @@ -0,0 +1,129 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +from sparseml.core import Event, EventType, ModelParameterizedLayer, State +from sparseml.modifiers.pruning.helpers import ( + PruningCreateSettings, + PruningSchedulerFactory, + SchedulerCalculationType, +) +from sparseml.modifiers.pruning.magnitude.base import MagnitudePruningModifier +from sparseml.utils.pytorch.pruning import ( + LayerParamMasking, + MaskCreatorType, + PruningMaskCreatorArgs, + PruningMaskFactory, +) + + +class MagnitudePruningModifierPyTorch(MagnitudePruningModifier, LayerParamMasking): + _parameterized_layers: Dict[str, ModelParameterizedLayer] = None + _save_masks: bool = False + _use_hooks: bool = False + _scheduler_function: SchedulerCalculationType = None + _mask_creator_function: MaskCreatorType = None + _current_sparsity: float = None + + def on_initialize(self, state: State, event: Event, **kwargs) -> bool: + if "save_masks" in kwargs: + self._save_masks = kwargs["save_masks"] + if "use_hooks" in kwargs: + self._use_hooks = kwargs["use_hooks"] + + if not state.model or not state.start_event: + return False + + self._scheduler_function = PruningSchedulerFactory.create_scheduler( + self.update_scheduler, + PruningCreateSettings( + self.start, + self.end, + self.update, + self.init_sparsity, + self.final_sparsity, + self.scheduler_args, + ), + ) + self._mask_creator_function = PruningMaskFactory.create_mask_creator( + self.mask_structure + ) + + self._parameterized_layers = state.model.get_layers_params(self.targets) + + for layer_param_name, parameterized_layer in self._parameterized_layers.items(): + self.add_mask( + layer_param_name, + parameterized_layer, + persistent=self._save_masks, + add_hooks=self._use_hooks, + ) + + return True + + def on_finalize(self, state: State, event: Event, **kwargs) -> bool: + for layer_param_name, _ in self._parameterized_layers.items(): + self.remove_mask(layer_param_name) + + return True + + def on_start(self, state: State, event: Event, **kwargs): + sparsity = self._scheduler_function(event, state) + self._current_sparsity = sparsity + + for layer_param_name, parameterized_layer in self._parameterized_layers.items(): + mask = self._mask_creator_function( + PruningMaskCreatorArgs( + parameter=parameterized_layer.param, + sparsity=sparsity, + scores=parameterized_layer.param.data.abs(), + ) + ) + self.update_mask(layer_param_name, mask) + + self.enable_masks() + + def on_update(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.BATCH_START: + sparsity = self._scheduler_function(event, state) + if sparsity != self._current_sparsity: + self._current_sparsity = sparsity + + for ( + layer_param_name, + parameterized_layer, + ) in self._parameterized_layers.items(): + mask = self._mask_creator_function( + PruningMaskCreatorArgs( + parameter=parameterized_layer.param, + sparsity=sparsity, + scores=parameterized_layer.param.data.abs(), + ) + ) + self.update_mask(layer_param_name, mask) + + if self._use_hooks: + # hooks are used to update, so nothing to do here + return + + if event.type_ == EventType.OPTIM_PRE_STEP: + for layer_param_name, _ in self._parameterized_layers.items(): + self.apply_mask_gradient(layer_param_name) + elif event.type_ == EventType.OPTIM_POST_STEP: + for layer_param_name, _ in self._parameterized_layers.items(): + self.apply_mask_weight(layer_param_name) + + def on_end(self, state: State, event: Event, **kwargs): + self.disable_masks() diff --git a/src/sparseml/utils/pytorch/pruning/__init__.py b/src/sparseml/utils/pytorch/pruning/__init__.py new file mode 100644 index 00000000000..c89c8da17de --- /dev/null +++ b/src/sparseml/utils/pytorch/pruning/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .layer_mask import * +from .mask import * diff --git a/src/sparseml/utils/pytorch/pruning/layer_mask.py b/src/sparseml/utils/pytorch/pruning/layer_mask.py new file mode 100644 index 00000000000..a8c2b16bb21 --- /dev/null +++ b/src/sparseml/utils/pytorch/pruning/layer_mask.py @@ -0,0 +1,177 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict + +import torch +from torch.nn import Module, Parameter +from torch.utils.hooks import RemovableHandle + +from sparseml.core import ModelParameterizedLayer + + +__all__ = ["LayerParamMasking"] + + +def param_mask_name(param_name: str) -> str: + return f"{param_name}_mask" + + +def setup_mask_for_param(param: Parameter, mask: torch.Tensor) -> torch.Tensor: + if mask is None: + raise ValueError("Mask cannot be None") + + if mask.shape != param.data.shape: + raise ValueError( + f"Mask shape {mask.shape} does not match " f"param shape {param.data.shape}" + ) + + if mask.dtype != torch.bool: + raise ValueError(f"Mask must be a boolean tensor") + + return param.data.new_tensor(mask, dtype=torch.bool) + + +@dataclass +class ParameterizedLayerMaskSettings: + persistent: bool = False + use_hooks: bool = False + + +class LayerParamMasking: + def __init__(self): + self._mask_settings: Dict[str, ParameterizedLayerMaskSettings] = {} + self._masked_layer_params: Dict[ + str, ModelParameterizedLayer[Module, Parameter] + ] = {} + self._forward_hooks: Dict[str, RemovableHandle] = {} + self._backward_hooks: Dict[str, RemovableHandle] = {} + self._enabled = False + + def add_mask( + self, + layer_param_name: str, + parameterized_layer: ModelParameterizedLayer[Module, Parameter], + init_mask: torch.Tensor = None, + persistent: bool = False, + add_hooks: bool = False, + ): + if layer_param_name in self._masked_layer_params: + raise ValueError(f"Layer param {layer_param_name} already has a mask") + + mask_name = param_mask_name(parameterized_layer.param_name) + + try: + parameterized_layer.layer.get_buffer(mask_name) + except AttributeError: + # add the mask buffer to the layer + parameterized_layer.layer.register_buffer( + mask_name, + torch.ones_like(parameterized_layer.param.data, dtype=torch.bool), + persistent=persistent, + ) + + if init_mask is not None: + parameterized_layer.layer.get_buffer(mask_name).fill_( + setup_mask_for_param(parameterized_layer.param, init_mask) + ) + + self._masked_layer_params[layer_param_name] = parameterized_layer + self._mask_settings[layer_param_name] = ParameterizedLayerMaskSettings( + persistent=persistent, use_hooks=add_hooks + ) + + if add_hooks: + + def _forward_hook_fn(module, input, output): + if not self._enabled: + return output + + mask = module.get_buffer(mask_name) + parameterized_layer.param.data = parameterized_layer.param.data * mask + + return output + + def _backward_hook_fn(gradients): + if not self._enabled: + return + + mask = parameterized_layer.layer.get_buffer(mask_name) + if gradients[0] is not None: + gradients[0] *= mask + + return gradients + + self._forward_hooks[ + layer_param_name + ] = parameterized_layer.layer.register_forward_hook(_forward_hook_fn) + self._backward_hooks[ + layer_param_name + ] = parameterized_layer.param.register_hook(_backward_hook_fn) + + def update_mask( + self, + layer_param_name: str, + mask: torch.Tensor, + ): + parameterized_layer = self._masked_layer_params[layer_param_name] + mask_name = param_mask_name(parameterized_layer.param_name) + mask_tensor = parameterized_layer.layer.get_buffer(mask_name) + mask_tensor.fill_(setup_mask_for_param(parameterized_layer.param, mask)) + + def remove_mask(self, layer_param_name: str): + mask_settings = self._mask_settings[layer_param_name] + parameterized_layer = self._masked_layer_params[layer_param_name] + + if mask_settings.persistent: + parameterized_layer.layer.unregister_buffer( + param_mask_name(parameterized_layer.param_name) + ) + + del self._masked_layer_params[layer_param_name] + del self._mask_settings[layer_param_name] + + if mask_settings.use_hooks: + self._forward_hooks[layer_param_name].remove() + self._backward_hooks[layer_param_name].remove() + + del self._forward_hooks[layer_param_name] + del self._backward_hooks[layer_param_name] + + def apply_mask_weight(self, layer_param_name: str): + if not self._enabled: + return + + parameterized_layer = self._masked_layer_params[layer_param_name] + mask_name = param_mask_name(parameterized_layer.param_name) + mask = parameterized_layer.layer.get_buffer(mask_name) + parameterized_layer.param.data = parameterized_layer.param.data * mask + + def apply_mask_gradient(self, layer_param_name: str): + if not self._enabled: + return + + parameterized_layer = self._masked_layer_params[layer_param_name] + mask_name = param_mask_name(parameterized_layer.param_name) + mask = parameterized_layer.layer.get_buffer(mask_name) + + if parameterized_layer.param.grad is not None: + parameterized_layer.param.grad = parameterized_layer.param.grad * mask + + def enable_masks(self): + self._enabled = True + + def disable_masks(self): + self._enabled = False diff --git a/src/sparseml/utils/pytorch/pruning/mask.py b/src/sparseml/utils/pytorch/pruning/mask.py new file mode 100644 index 00000000000..77bdc006189 --- /dev/null +++ b/src/sparseml/utils/pytorch/pruning/mask.py @@ -0,0 +1,167 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from dataclasses import dataclass +from typing import Callable, Optional + +import torch +from torch import Tensor +from torch.nn.parameter import Parameter + + +__all__ = [ + "PruningMaskCreatorArgs", + "MaskCreatorType", + "CreateMaskCreatorType", + "PruningMaskFactory", + "unstructured_pruning", + "channel_pruning", + "filter_pruning", + "block_pruning", +] + + +@dataclass +class PruningMaskCreatorArgs: + parameter: Parameter + sparsity: float + scores: Tensor + prev_mask: Optional[Tensor] = None + + +MaskCreatorType = Callable[[PruningMaskCreatorArgs], Tensor] +CreateMaskCreatorType = Callable[[str], MaskCreatorType] + + +class PruningMaskFactory: + registry = {} + + @staticmethod + def register(name: str, func: CreateMaskCreatorType): + PruningMaskFactory.registry[name] = func + + @staticmethod + def register_decorator(name: str): + def inner(func: CreateMaskCreatorType): + PruningMaskFactory.registry[name] = func + return func + + return inner + + @staticmethod + def create_mask_creator(mask_structure: str, **kwargs) -> MaskCreatorType: + for pattern, creator in PruningMaskFactory.registry.items(): + if pattern == mask_structure: + return creator(mask_structure=mask_structure, **kwargs) + + try: + if re.match(pattern, mask_structure): + return creator(mask_structure=mask_structure, **kwargs) + except Exception: + pass + + raise ValueError(f"Invalid mask_structure: {mask_structure}") + + +@PruningMaskFactory.register_decorator("unstructured") +def unstructured_pruning(mask_structure: str): + if mask_structure != "unstructured": + raise ValueError(f"Invalid mask_structure: {mask_structure}") + + def _create_mask(args: PruningMaskCreatorArgs) -> Tensor: + prune_elements = int(args.sparsity * args.scores.numel()) + mask = ( + args.prev_mask + if args.prev_mask is not None + else torch.ones_like(args.parameter.data, dtype=torch.bool) + ) + + if prune_elements > 0: + threshold, _ = torch.topk( + args.scores.view(-1), prune_elements, largest=False + ) + mask = (args.scores > threshold[-1]).to(dtype=torch.bool) + else: + mask = torch.ones_like(mask, dtype=torch.bool) + + return mask + + return _create_mask + + +@PruningMaskFactory.register_decorator("channel") +def channel_pruning(mask_structure: str, aggregate: str = "sum"): + if mask_structure != "channel": + raise ValueError(f"Invalid mask_structure: {mask_structure}") + + def _aggregate(tensor, method="sum"): + return getattr(tensor, method)(dim=(1, 2, 3)) + + def _create_mask(args: PruningMaskCreatorArgs) -> Tensor: + prune_channels = int(args.sparsity * args.scores.size(0)) + aggregated_scores = _aggregate(args.scores, aggregate) + _, top_indices = torch.topk(aggregated_scores, prune_channels, largest=False) + mask = torch.ones_like(args.scores, dtype=torch.bool) + mask[top_indices, :, :, :] = 0 + return mask + + return _create_mask + + +@PruningMaskFactory.register_decorator("filter") +def filter_pruning(mask_structure: str, aggregate: str = "sum"): + if mask_structure != "filter": + raise ValueError(f"Invalid mask_structure: {mask_structure}") + + def _aggregate(tensor, method="sum"): + return getattr(tensor, method)(dim=(0, 2, 3)) + + def _create_mask(args: PruningMaskCreatorArgs) -> Tensor: + prune_filters = int(args.sparsity * args.scores.size(1)) + aggregated_scores = _aggregate(args.scores, aggregate) + _, top_indices = torch.topk(aggregated_scores, prune_filters, largest=False) + mask = torch.ones_like(args.scores, dtype=torch.bool) + mask[:, top_indices, :, :] = 0 + return mask + + return _create_mask + + +@PruningMaskFactory.register_decorator("^block_.*") +def block_pruning(mask_structure: str, aggregate: str = "sum"): + pattern = re.compile(r"^block_(.*)") + match = pattern.search(mask_structure) + + if not match: + raise ValueError(f"invalid block mask type {mask_structure}") + + block_dims = list(map(int, match.group(1).split(","))) + + def _aggregate_block(block, method="sum"): + return getattr(block, method)() + + def _create_mask(args: PruningMaskCreatorArgs) -> Tensor: + block_view = args.scores + for dim, size in enumerate(block_dims): + block_view = block_view.unfold(dimension=dim, size=size, step=size) + block_sums = _aggregate_block(block_view, aggregate) + prune_blocks = int(args.sparsity * block_sums.numel()) + threshold, _ = torch.topk(block_sums.view(-1), prune_blocks, largest=False) + mask = (block_sums > threshold[-1]).float().unsqueeze(-1) + for size in block_dims: + mask = mask.repeat_interleave(size, dim=-1) + return mask.to(dtype=torch.bool) + + return _create_mask From f04ca6f32ff743ceb371bd055509837c4e28c43c Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Sun, 10 Sep 2023 17:52:47 -0400 Subject: [PATCH 07/27] knowledge distillation implementation --- src/sparseml/core/event.py | 15 +- src/sparseml/core/session.py | 6 + src/sparseml/core/state.py | 7 + .../modifiers/distillation/output/base.py | 33 ++ .../modifiers/distillation/output/pytorch.py | 149 +++++++ .../modifiers/distillation/utils/__init__.py | 13 + .../distillation/utils/pytorch/__init__.py | 16 + .../distillation/utils/pytorch/kd_factory.py | 401 ++++++++++++++++++ .../distillation/utils/pytorch/kd_wrapper.py | 120 ++++++ .../modifiers/experimental/__init__.py | 13 + .../modifiers/pruning/constant/pytorch.py | 2 +- .../modifiers/pruning/magnitude/pytorch.py | 14 +- .../modifiers/pruning/utils/__init__.py | 13 + .../pruning/utils/pytorch/__init__.py | 16 + .../pruning/utils/pytorch}/layer_mask.py | 0 .../pruning/utils/pytorch/mask_factory.py} | 0 16 files changed, 808 insertions(+), 10 deletions(-) create mode 100644 src/sparseml/modifiers/distillation/output/base.py create mode 100644 src/sparseml/modifiers/distillation/output/pytorch.py create mode 100644 src/sparseml/modifiers/distillation/utils/__init__.py create mode 100644 src/sparseml/modifiers/distillation/utils/pytorch/__init__.py create mode 100644 src/sparseml/modifiers/distillation/utils/pytorch/kd_factory.py create mode 100644 src/sparseml/modifiers/distillation/utils/pytorch/kd_wrapper.py create mode 100644 src/sparseml/modifiers/experimental/__init__.py create mode 100644 src/sparseml/modifiers/pruning/utils/__init__.py create mode 100644 src/sparseml/modifiers/pruning/utils/pytorch/__init__.py rename src/sparseml/{utils/pytorch/pruning => modifiers/pruning/utils/pytorch}/layer_mask.py (100%) rename src/sparseml/{utils/pytorch/pruning/mask.py => modifiers/pruning/utils/pytorch/mask_factory.py} (100%) diff --git a/src/sparseml/core/event.py b/src/sparseml/core/event.py index 9875d0bf022..afa21372c51 100644 --- a/src/sparseml/core/event.py +++ b/src/sparseml/core/event.py @@ -16,7 +16,7 @@ from copy import deepcopy from dataclasses import dataclass from enum import Enum -from typing import List +from typing import List, Optional __all__ = [ @@ -129,6 +129,19 @@ def current_index(self, value: float): else self.global_step * self.batches_per_step ) + def should_update( + self, start: Optional[float], end: Optional[float], update: float + ): + current = self.current_index + + if start is not None and current < start: + return False + + if end is not None and current > end: + return False + + return update is None or update <= 0.0 or current % update < 1e-10 + def new_instance(self, **kwargs) -> "Event": instance = deepcopy(self) for key, value in kwargs.items(): diff --git a/src/sparseml/core/session.py b/src/sparseml/core/session.py index 5e21bb0f250..505cab7060d 100644 --- a/src/sparseml/core/session.py +++ b/src/sparseml/core/session.py @@ -117,6 +117,7 @@ def initialize( recipe_stage: str = None, recipe_args: Dict[str, Any] = None, model: Any = None, + teacher_model: Any = None, optimizer: Any = None, attach_optim_callbacks: bool = True, train_data: Any = None, @@ -138,6 +139,7 @@ def initialize( self.state.update_framework(framework) self.state.update_recipe(recipe, recipe_stage, recipe_args) self.state.update_model(model) + self.state.update_teacher_model(teacher_model) self.state.update_optimizer(optimizer, attach_optim_callbacks) self.state.update_data(train_data, val_data, test_data, calib_data, copy_data) self.state.update_start(start, steps_per_epoch, batches_per_step) @@ -342,6 +344,7 @@ def initialize( recipe_stage: str = None, recipe_args: Dict[str, Any] = None, model: Any = None, + teacher_model: Any = None, optimizer: Any = None, attach_optim_callbacks: bool = True, train_data: Any = None, @@ -360,6 +363,7 @@ def initialize( recipe_stage=recipe_stage, recipe_args=recipe_args, model=model, + teacher_model=teacher_model, optimizer=optimizer, attach_optim_callbacks=attach_optim_callbacks, train_data=train_data, @@ -384,6 +388,7 @@ def apply( recipe_stage: str = None, recipe_args: Dict[str, Any] = None, model: Any = None, + teacher_model: Any = None, train_data: Any = None, val_data: Any = None, test_data: Any = None, @@ -400,6 +405,7 @@ def apply( recipe_stage=recipe_stage, recipe_args=recipe_args, model=model, + teacher_model=teacher_model, train_data=train_data, val_data=val_data, test_data=test_data, diff --git a/src/sparseml/core/state.py b/src/sparseml/core/state.py index 2cd9ca79657..a7c4f8c15e4 100644 --- a/src/sparseml/core/state.py +++ b/src/sparseml/core/state.py @@ -56,6 +56,7 @@ class State: loggers = Field(default_factory=list) framework: Framework = None model: ModifiableModel = None + teacher_model: ModifiableModel = None optimizer: ModifiableOptimizer = None optim_wrapped: bool = None loss = None @@ -105,6 +106,12 @@ def update_model(self, model: Any): self.model = ModifiableModel(framework=self.framework, model=model) + def update_teacher_model(self, model: Any): + if self.framework is None: + raise RuntimeError("framework must be set before updating model") + + self.teacher_model = ModifiableModel(framework=self.framework, model=model) + def update_optimizer(self, optimizer: Any, attach_callbacks: bool = True): if self.framework is None: raise RuntimeError("framework must be set before updating optimizer") diff --git a/src/sparseml/modifiers/distillation/output/base.py b/src/sparseml/modifiers/distillation/output/base.py new file mode 100644 index 00000000000..14770a02811 --- /dev/null +++ b/src/sparseml/modifiers/distillation/output/base.py @@ -0,0 +1,33 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Tuple, Union + +from sparseml.core import Modifier, State + + +__all__ = ["OutputDistillationModifier"] + + +class OutputDistillationModifier(Modifier): + targets: Union[str, List[Union[str, Tuple[str, str]]]] + projection: str = None + projection_args: Dict[str, Any] = None + transforms: Union[str, List[str]] = "softmax" + transforms_args: Union[Dict[str, Any], List[Dict[str, Any]]] = None + comparison: str = "kl_divergence" + comparison_args: Dict[str, Any] = None + + def on_initialize_structure(self, state: State, **kwargs): + pass # nothing needed for this modifier diff --git a/src/sparseml/modifiers/distillation/output/pytorch.py b/src/sparseml/modifiers/distillation/output/pytorch.py new file mode 100644 index 00000000000..ca96486a87d --- /dev/null +++ b/src/sparseml/modifiers/distillation/output/pytorch.py @@ -0,0 +1,149 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Tuple, Union + +import torch +from torch.nn import Module + +from sparseml.core import Event, EventType, State +from sparseml.modifiers.distillation.output.base import OutputDistillationModifier +from sparseml.modifiers.distillation.utils.pytorch import KDFactory, KDModuleWrapper + + +__all__ = ["OutputDistillationModifierPyTorch"] + + +class OutputDistillationModifierPyTorch(OutputDistillationModifier): + _wrappers: Dict[str, KDModuleWrapper] = None + + def on_initialize(self, state: State, event: Event, **kwargs) -> bool: + if ( + state.framework is None + or state.model is None + or state.teacher_model is None + ): + return False + + self._wrappers = {} + + for target in ( + self.targets if isinstance(self.targets, list) else [self.targets] + ): + if isinstance(target, tuple): + model_target, teacher_target = target + else: + model_target, teacher_target = target, target + + model_layers = state.model.get_layers(model_target) + teacher_layers = state.teacher_model.get_layers(teacher_target) + + if len(model_layers) < 1: + raise ValueError(f"no model layers found for target {target}") + + if len(model_layers) != len(teacher_layers): + raise ValueError( + f"model and teacher model layers for target {target} do not match" + ) + + for (key, student_layer), teacher_layer in zip( + model_layers.items(), teacher_layers.values() + ): + wrapper = self._create_wrapper(student_layer, teacher_layer, state) + state.model.set_layer(key, wrapper) + self._wrappers[key] = wrapper + + return True + + def on_finalize(self, state: State, event: Event, **kwargs) -> bool: + for key, wrapper in self._wrappers.items(): + state.model.set_layer(key, wrapper.student_layer) + del wrapper + + return True + + def on_start(self, state: State, event: Event, **kwargs): + for wrapper in self._wrappers.values(): + wrapper.kd_enabled = True + + def on_update(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.LOSS_CALCULATED and event.should_update( + self.start, self.end, self.update + ): + comparisons = [ + wrapper.kd_last_comparison for wrapper in self._wrappers.values() + ] + state.loss = state.loss + torch.Stack(comparisons).mean() + + def on_end(self, state: State, event: Event, **kwargs): + for wrapper in self._wrappers.values(): + wrapper.kd_enabled = False + + def _create_wrapper( + self, student_layer: Module, teacher_layer: Module, state: State + ) -> KDModuleWrapper: + projections = ( + KDFactory.create_projection( + self.projection, + student_layer, + teacher_layer, + state, + **(self.projection_args or {}), + ) + if self.projection + else None + ) + comparison = KDFactory.create_comparison( + self.comparison, + student_layer, + teacher_layer, + state, + **(self.comparison_args or {}), + ) + + transforms = [] + if self.transforms: + tmp_transforms = ( + self.transforms + if isinstance(self.transforms, list) + else [self.transforms] + ) + tmp_transform_args = [ + args + for args in ( + self.transforms_args + if isinstance(self.transforms_args, list) + else [self.transforms_args if self.transforms_args else {}] + ) + for _ in range(len(tmp_transforms)) + ] + + for transform, transform_args in zip(tmp_transforms, tmp_transform_args): + transforms.append( + KDFactory.create_transform( + transform, + student_layer, + teacher_layer, + state, + **transform_args, + ) + ) + + return KDModuleWrapper( + student_layer=student_layer, + teacher_layer=teacher_layer, + projections=projections, + transforms=transforms, + comparison=comparison, + ) diff --git a/src/sparseml/modifiers/distillation/utils/__init__.py b/src/sparseml/modifiers/distillation/utils/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/src/sparseml/modifiers/distillation/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/modifiers/distillation/utils/pytorch/__init__.py b/src/sparseml/modifiers/distillation/utils/pytorch/__init__.py new file mode 100644 index 00000000000..4fb62d86716 --- /dev/null +++ b/src/sparseml/modifiers/distillation/utils/pytorch/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .kd_factory import * +from .kd_wrapper import * diff --git a/src/sparseml/modifiers/distillation/utils/pytorch/kd_factory.py b/src/sparseml/modifiers/distillation/utils/pytorch/kd_factory.py new file mode 100644 index 00000000000..23a321b02ad --- /dev/null +++ b/src/sparseml/modifiers/distillation/utils/pytorch/kd_factory.py @@ -0,0 +1,401 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Callable, Dict, List, Sequence, Tuple, Union + +import torch +import torch.nn.functional as TF +from torch import Tensor +from torch.nn import Module + +from sparseml.core import State + + +__all__ = [ + "TensorOrCollectionType", + "ProjectionFuncType", + "CreateProjectionFuncType", + "TransformFuncType", + "CreateTransformFuncType", + "ComparisonFuncType", + "CreateComparisonFuncType", + "KDFactory", + "recursive_apply", + "recursive_combine", + "identity_transform", + "softmax_transform", + "log_softmax_transform", + "normalize_transform", + "l1_comparison", + "l2_comparison", + "inner_product_comparison", + "cosine_similarity_comparison", + "kl_divergence_comparison", + "cross_entropy_comparison", +] + + +TensorOrCollectionType = Union[Tensor, Sequence[Tensor], Dict[str, Tensor]] +ProjectionFuncType = Callable[ + [TensorOrCollectionType, TensorOrCollectionType], TensorOrCollectionType +] +CreateProjectionFuncType = Callable[ + [str, Module, Module, State], Tuple[ProjectionFuncType, ProjectionFuncType] +] +TransformFuncType = Callable[[TensorOrCollectionType], TensorOrCollectionType] +CreateTransformFuncType = Callable[[str, Module, Module, State], TransformFuncType] +ComparisonFuncType = Callable[ + [TensorOrCollectionType, TensorOrCollectionType], TensorOrCollectionType +] +CreateComparisonFuncType = Callable[[str, Module, Module, State], ComparisonFuncType] + + +class KDFactory: + registry_projections: Dict[str, CreateProjectionFuncType] = {} + registry_transforms: Dict[str, CreateTransformFuncType] = {} + registry_comparisons: Dict[str, CreateComparisonFuncType] = {} + + @staticmethod + def register_projection(name: str, func: CreateProjectionFuncType): + KDFactory.registry_projections[name] = func + + @staticmethod + def register_projection_decorator(name: str): + def inner(func: CreateProjectionFuncType): + KDFactory.registry_projections[name] = func + return func + + return inner + + @staticmethod + def create_projection( + name: str, student_layer: Module, teacher_layer: Module, state: State, **kwargs + ) -> Tuple[ProjectionFuncType, ProjectionFuncType]: + for pattern, creator in KDFactory.registry_projections: + match = pattern == name + + if not match: + try: + match = re.match(pattern, name) + except Exception: + pass + + if match: + return creator( + name=name, + student_layer=student_layer, + teacher_layer=teacher_layer, + state=state, + **kwargs, + ) + + raise ValueError(f"Invalid projection name: {name}") + + @staticmethod + def register_transform(name: str, func: CreateTransformFuncType): + KDFactory.registry_transforms[name] = func + + @staticmethod + def register_transform_decorator(name: str): + def inner(func: CreateTransformFuncType): + KDFactory.registry_transforms[name] = func + return func + + return inner + + @staticmethod + def create_transform( + name: str, + student_layer: Module, + teacher_layer: Module, + state: State, + **kwargs, + ) -> TransformFuncType: + + for pattern, creator in KDFactory.registry_transforms.items(): + match = pattern == name + + if not match: + try: + match = re.match(pattern, name) + except Exception: + pass + + if match: + return creator( + name=name, + student_layer=student_layer, + teacher_layer=teacher_layer, + state=state, + **kwargs, + ) + + raise ValueError(f"Invalid transform name: {name}") + + @staticmethod + def register_comparison(name: str, func): + KDFactory.registry_comparisons[name] = func + + @staticmethod + def register_comparison_decorator(name: str): + def inner(func): + KDFactory.registry_comparisons[name] = func + return func + + return inner + + @staticmethod + def create_comparison( + name: str, student_layer: Module, teacher_layer: Module, state: State, **kwargs + ) -> ComparisonFuncType: + for pattern, creator in KDFactory.registry_comparisons.items(): + match = pattern == name + + if not match: + try: + match = re.match(pattern, name) + except Exception: + pass + + if match: + return creator( + name=name, + student_layer=student_layer, + teacher_layer=teacher_layer, + state=state, + **kwargs, + ) + + raise ValueError(f"Invalid comparison name: {name}") + + +def recursive_apply( + val: TensorOrCollectionType, + func: Callable[[Tensor], Tensor], +) -> TensorOrCollectionType: + if isinstance(val, Tensor): + return func(val) + + if isinstance(val, Sequence): + return [recursive_apply(item, func) for item in val] + + if isinstance(val, dict): + return {key: recursive_apply(item, func) for key, item in val.items()} + + raise ValueError(f"Unsupported type for recursive_apply: {type(val)}") + + +def recursive_combine( + val_one: TensorOrCollectionType, + val_two: TensorOrCollectionType, + func: Callable[[Tensor, Tensor], Tensor], +): + if isinstance(val_one, Tensor): + return func(val_one, val_two) + + if isinstance(val_one, Sequence): + return [ + recursive_combine(item_one, item_two, func) + for item_one, item_two in zip(val_one, val_two) + ] + + if isinstance(val_one, dict): + return { + key: recursive_combine(val_one[key], val_two[key], func) + for key in val_one.keys() + } + + raise ValueError(f"Unsupported type for recursive_combine: {type(val_one)}") + + +@KDFactory.register_transform_decorator("identity") +def identity_transform(name: str, **kwargs): + if name != "identity": + raise ValueError(f"Invalid transform name: {name}") + + def _create_transform(val: TensorOrCollectionType) -> TensorOrCollectionType: + return val + + return _create_transform + + +@KDFactory.register_transform_decorator("softmax") +def softmax_transform(name: str, temperature: float = 1.0, dim: int = -1, **kwargs): + if name != "softmax": + raise ValueError(f"Invalid transform name: {name}") + + def _softmax(val: Tensor) -> Tensor: + val = val / temperature + + return torch.softmax(val, dim=dim) + + def _create_transform(val: TensorOrCollectionType) -> TensorOrCollectionType: + return recursive_apply(val, _softmax) + + return _create_transform + + +@KDFactory.register_transform_decorator("log_softmax") +def log_softmax_transform(name: str, temperature: float = 1.0, dim: int = -1, **kwargs): + if name != "log_softmax": + raise ValueError(f"Invalid transform name: {name}") + + def _log_softmax(val: Tensor) -> Tensor: + val = val / temperature + + return torch.log_softmax(val, dim=dim) + + def _create_transform(val: TensorOrCollectionType) -> TensorOrCollectionType: + return recursive_apply(val, _log_softmax) + + return _create_transform + + +@KDFactory.register_transform_decorator("normalize") +def normalize_transform( + name: str, + p: float = 1, + dim: int = -1, + eps: float = 1e-12, + mean: bool = False, + std: bool = False, + **kwargs, +): + if name != "normalize": + raise ValueError(f"Invalid transform name: {name}") + + def _normalize(val: Tensor) -> Tensor: + out = TF.normalize(val, p=p, dim=dim, eps=eps) + + if mean: + out = out - out.mean(dim=dim, keepdim=True) + + if std: + out = out / out.std(dim=dim, keepdim=True) + + return out + + def _create_transform(val: TensorOrCollectionType) -> TensorOrCollectionType: + return recursive_apply(val, _normalize) + + return _create_transform + + +@KDFactory.register_comparison_decorator("l1_distance") +def l1_comparison(name: str, dim: int = -1, **kwargs): + if name != "l1_distance": + raise ValueError(f"Invalid comparison name: {name}") + + def _l1(val_one: Tensor, val_two: Tensor) -> Tensor: + return torch.sum(torch.abs(val_one - val_two), dim=dim) + + def _create_comparison( + val_one: TensorOrCollectionType, val_two: TensorOrCollectionType + ) -> TensorOrCollectionType: + return recursive_combine(val_one, val_two, _l1) + + return _create_comparison + + +@KDFactory.register_comparison_decorator("l2_distance") +def l2_comparison(name: str, dim: int = -1, **kwargs): + if name != "l2_distance": + raise ValueError(f"Invalid comparison name: {name}") + + def _l2(val_one: Tensor, val_two: Tensor) -> Tensor: + return torch.sum((val_one - val_two) ** 2, dim=dim) + + def _create_comparison( + val_one: TensorOrCollectionType, val_two: TensorOrCollectionType + ) -> TensorOrCollectionType: + return recursive_combine(val_one, val_two, _l2) + + return _create_comparison + + +@KDFactory.register_comparison_decorator("inner_product") +def inner_product_comparison(name: str, dim: int = -1, **kwargs): + if name != "inner_product": + raise ValueError(f"Invalid comparison name: {name}") + + def _inner_product(val_one: Tensor, val_two: Tensor) -> Tensor: + return torch.sum(val_one * val_two, dim=dim) + + def _create_comparison( + val_one: TensorOrCollectionType, val_two: TensorOrCollectionType + ) -> TensorOrCollectionType: + return recursive_combine(val_one, val_two, _inner_product) + + return _create_comparison + + +@KDFactory.register_comparison_decorator("cosine_similarity") +def cosine_similarity_comparison(name: str, dim: int = -1, **kwargs): + if name != "cosine_similarity": + raise ValueError(f"Invalid comparison name: {name}") + + def _cosine_similarity(val_one: Tensor, val_two: Tensor) -> Tensor: + return torch.sum(val_one * val_two, dim=dim) / ( + torch.norm(val_one, dim=dim) * torch.norm(val_two, dim=dim) + ) + + def _create_comparison( + val_one: TensorOrCollectionType, val_two: TensorOrCollectionType + ) -> TensorOrCollectionType: + return recursive_combine(val_one, val_two, _cosine_similarity) + + return _create_comparison + + +@KDFactory.register_comparison_decorator("kl_divergence") +def kl_divergence_comparison( + name: str, dim: int = -1, temperature: float = 1.0, **kwargs +): + if name != "kl_divergence": + raise ValueError(f"Invalid comparison name: {name}") + + def _kl_divergence(val_one: Tensor, val_two: Tensor) -> Tensor: + val_one = val_one / temperature + val_two = val_two / temperature + + return torch.sum(val_one * torch.log(val_one / val_two), dim=dim) + + def _create_comparison( + val_one: TensorOrCollectionType, val_two: TensorOrCollectionType + ) -> TensorOrCollectionType: + return recursive_combine(val_one, val_two, _kl_divergence) + + return _create_comparison + + +@KDFactory.register_comparison_decorator("cross_entropy") +def cross_entropy_comparison( + name: str, temperature: float = 1.0, reduction: str = "none", **kwargs +): + if name != "cross_entropy": + raise ValueError(f"Invalid projection name: {name}") + + def _cross_entropy(val_one: Tensor, val_two: Tensor) -> Tensor: + val_one = val_one / temperature + val_two = val_two / temperature + + return TF.cross_entropy(val_one, val_two, reduction=reduction) + + def _create_projection( + val_one: TensorOrCollectionType, val_two: TensorOrCollectionType + ) -> TensorOrCollectionType: + return recursive_combine(val_one, val_two, _cross_entropy) + + return _create_projection diff --git a/src/sparseml/modifiers/distillation/utils/pytorch/kd_wrapper.py b/src/sparseml/modifiers/distillation/utils/pytorch/kd_wrapper.py new file mode 100644 index 00000000000..17ae6aa144e --- /dev/null +++ b/src/sparseml/modifiers/distillation/utils/pytorch/kd_wrapper.py @@ -0,0 +1,120 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Sequence, Tuple + +import torch +from torch.nn import Module + +from sparseml.modifiers.distillation.utils.pytorch.kd_factory import ( + ComparisonFuncType, + ProjectionFuncType, + TransformFuncType, + recursive_apply, +) + + +__all__ = ["KDModuleWrapper"] + + +class KDModuleWrapper(Module): + def __init__( + self, + student_layer: Module, + teacher_layer: Module, + projections: Optional[Tuple[ProjectionFuncType, ProjectionFuncType]], + transforms: Optional[List[TransformFuncType]], + comparison: ComparisonFuncType, + ): + super(KDModuleWrapper, self).__init__() + + self.kd_student_layer = student_layer + self.kd_teacher_layer = teacher_layer + self.kd_student_projection = projections[0] if projections is not None else None + self.kd_teacher_projection = projections[1] if projections is not None else None + self.kd_transforms = transforms + self.kd_comparison = comparison + self.kd_enabled = False + self.kd_last_comparison = None + self._init_called = True # make sure this is last property to be set + + # def __getattr__(self, name): + # if name.startswith("kd_"): + # return getattr(self, name) + # + # return getattr(self.student_layer, name) + # + # def __setattr__(self, name, value): + # if "_init_called" not in self.__dict__: + # super().__setattr__(name, value) + # elif name.startswith("kd_"): + # super().__setattr__(name, value) + # elif hasattr(self.student_layer, name): + # setattr(self.student_layer, name, value) + # else: + # super().__setattr__(name, value) + + def forward(self, *args, **kwargs): + if not self.kd_enabled: + return self.kd_student_layer(*args, **kwargs) + + org_output = self.kd_student_layer(*args, **kwargs) + student_output = org_output + + with torch.no_grad(): + teacher_output = self.kd_teacher_layer(*args, **kwargs) + + if self.kd_student_projections is not None: + for projection in self.kd_student_projections: + student_output = projection(student_output, teacher_output) + + if self.kd_teacher_projections is not None: + for projection in self.kd_teacher_projections: + teacher_output = projection(teacher_output, student_output) + + if self.kd_student_transforms is not None: + for transform in self.kd_student_transforms: + student_output = transform(student_output) + + if self.kd_teacher_transforms is not None: + for transform in self.kd_teacher_transforms: + teacher_output = transform(teacher_output) + + comp = self.kd_comparison(student_output, teacher_output) + comp = recursive_apply(comp, lambda x: x.mean()) + comp = ( + comp + if isinstance(comp, float) + else torch.stack( + comp if isinstance(comp, Sequence) else list(comp.values()) + ).sum() + ) + + self.kd_last_comparison = comp + + return org_output + + def state_dict(self, destination=None, prefix="", keep_vars=False, **kwargs): + return self.student_layer.state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars, **kwargs + ) + + def load_state_dict(self, state_dict, strict=True): + return self.student_layer.load_state_dict(state_dict, strict=strict) + + def _named_members(self, get_members_fn, prefix="", recurse=True): + for name, module in self.student_layer._named_members( + get_members_fn, prefix=prefix, recurse=recurse + ): + yield name, module diff --git a/src/sparseml/modifiers/experimental/__init__.py b/src/sparseml/modifiers/experimental/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/src/sparseml/modifiers/experimental/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/modifiers/pruning/constant/pytorch.py b/src/sparseml/modifiers/pruning/constant/pytorch.py index 2fd3e94c22c..d3ed5840097 100644 --- a/src/sparseml/modifiers/pruning/constant/pytorch.py +++ b/src/sparseml/modifiers/pruning/constant/pytorch.py @@ -16,7 +16,7 @@ from sparseml.core import Event, EventType, ModelParameterizedLayer, State from sparseml.modifiers.pruning.constant.base import ConstantPruningModifier -from sparseml.utils.pytorch.pruning import LayerParamMasking +from sparseml.modifiers.pruning.utils.pytorch import LayerParamMasking class ConstantPruningModifierPyTorch(ConstantPruningModifier, LayerParamMasking): diff --git a/src/sparseml/modifiers/pruning/magnitude/pytorch.py b/src/sparseml/modifiers/pruning/magnitude/pytorch.py index e8d6ac3bd0a..bc61d3267ec 100644 --- a/src/sparseml/modifiers/pruning/magnitude/pytorch.py +++ b/src/sparseml/modifiers/pruning/magnitude/pytorch.py @@ -21,7 +21,7 @@ SchedulerCalculationType, ) from sparseml.modifiers.pruning.magnitude.base import MagnitudePruningModifier -from sparseml.utils.pytorch.pruning import ( +from sparseml.modifiers.pruning.utils.pytorch import ( LayerParamMasking, MaskCreatorType, PruningMaskCreatorArgs, @@ -38,6 +38,9 @@ class MagnitudePruningModifierPyTorch(MagnitudePruningModifier, LayerParamMaskin _current_sparsity: float = None def on_initialize(self, state: State, event: Event, **kwargs) -> bool: + if self.apply_globally: + raise NotImplementedError("global pruning not implemented yet for PyTorch") + if "save_masks" in kwargs: self._save_masks = kwargs["save_masks"] if "use_hooks" in kwargs: @@ -113,15 +116,10 @@ def on_update(self, state: State, event: Event, **kwargs): ) ) self.update_mask(layer_param_name, mask) - - if self._use_hooks: - # hooks are used to update, so nothing to do here - return - - if event.type_ == EventType.OPTIM_PRE_STEP: + elif event.type_ == EventType.OPTIM_PRE_STEP and not self._use_hooks: for layer_param_name, _ in self._parameterized_layers.items(): self.apply_mask_gradient(layer_param_name) - elif event.type_ == EventType.OPTIM_POST_STEP: + elif event.type_ == EventType.OPTIM_POST_STEP and not self._use_hooks: for layer_param_name, _ in self._parameterized_layers.items(): self.apply_mask_weight(layer_param_name) diff --git a/src/sparseml/modifiers/pruning/utils/__init__.py b/src/sparseml/modifiers/pruning/utils/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/src/sparseml/modifiers/pruning/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/modifiers/pruning/utils/pytorch/__init__.py b/src/sparseml/modifiers/pruning/utils/pytorch/__init__.py new file mode 100644 index 00000000000..a7bb161fee9 --- /dev/null +++ b/src/sparseml/modifiers/pruning/utils/pytorch/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .layer_mask import * +from .mask_factory import * diff --git a/src/sparseml/utils/pytorch/pruning/layer_mask.py b/src/sparseml/modifiers/pruning/utils/pytorch/layer_mask.py similarity index 100% rename from src/sparseml/utils/pytorch/pruning/layer_mask.py rename to src/sparseml/modifiers/pruning/utils/pytorch/layer_mask.py diff --git a/src/sparseml/utils/pytorch/pruning/mask.py b/src/sparseml/modifiers/pruning/utils/pytorch/mask_factory.py similarity index 100% rename from src/sparseml/utils/pytorch/pruning/mask.py rename to src/sparseml/modifiers/pruning/utils/pytorch/mask_factory.py From c745492070112dca3092417d5fbe58c5124ea7e8 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Thu, 14 Sep 2023 10:31:52 -0400 Subject: [PATCH 08/27] fix import errors and multiframework inits --- src/sparseml/core/__init__.py | 10 +--------- src/sparseml/core/data/__init__.py | 2 +- src/sparseml/core/framework.py | 2 +- src/sparseml/core/model/__init__.py | 2 +- src/sparseml/core/model/base.py | 3 +++ src/sparseml/core/model/pytorch.py | 4 ++++ src/sparseml/core/modifier/__init__.py | 7 ++----- src/sparseml/core/modifier/base.py | 12 ++++-------- src/sparseml/core/modifier/factory.py | 5 +---- src/sparseml/core/modifier/modifier.py | 22 ++++++++++------------ src/sparseml/core/modifier/stage.py | 16 +++++++++------- src/sparseml/core/optimizer/__init__.py | 2 +- src/sparseml/core/optimizer/base.py | 3 +++ src/sparseml/core/optimizer/pytorch.py | 3 +++ src/sparseml/core/recipe/__init__.py | 6 +----- src/sparseml/core/recipe/base.py | 1 - src/sparseml/core/recipe/modifier.py | 4 ++-- src/sparseml/core/session.py | 14 ++++++-------- src/sparseml/core/state.py | 3 +++ 19 files changed, 56 insertions(+), 65 deletions(-) diff --git a/src/sparseml/core/__init__.py b/src/sparseml/core/__init__.py index fc3f40c71ff..bc26ded926c 100644 --- a/src/sparseml/core/__init__.py +++ b/src/sparseml/core/__init__.py @@ -12,12 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .data import * -from .event import * -from .framework import * -from .model import * -from .modifier import * -from .optimizer import * -from .recipe import * -from .session import * -from .state import * +from .session import * \ No newline at end of file diff --git a/src/sparseml/core/data/__init__.py b/src/sparseml/core/data/__init__.py index 87930811c41..1101a7fa8ea 100644 --- a/src/sparseml/core/data/__init__.py +++ b/src/sparseml/core/data/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import * +from .base import ModifiableData diff --git a/src/sparseml/core/framework.py b/src/sparseml/core/framework.py index d4f6ddcaebd..c7a67fe85cd 100644 --- a/src/sparseml/core/framework.py +++ b/src/sparseml/core/framework.py @@ -77,7 +77,7 @@ def __new__( if cls is MultiFrameworkObject: raise TypeError("MultiFrameworkObject cannot be instantiated directly") - instance = super(MultiFrameworkObject, cls).__new__(cls, **kwargs) + instance = super(MultiFrameworkObject, cls).__new__(cls) package = instance.__class__.__module__.rsplit(".", 1)[0] class_name = instance.__class__.__name__ diff --git a/src/sparseml/core/model/__init__.py b/src/sparseml/core/model/__init__.py index 87930811c41..7a2c12e5d45 100644 --- a/src/sparseml/core/model/__init__.py +++ b/src/sparseml/core/model/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import * +from .base import ModifiableModel, ModelParameterizedLayer diff --git a/src/sparseml/core/model/base.py b/src/sparseml/core/model/base.py index 0a0fdb4084d..96f2fb789d6 100644 --- a/src/sparseml/core/model/base.py +++ b/src/sparseml/core/model/base.py @@ -38,6 +38,9 @@ class ModelParameterizedLayer(Generic[LT, PT]): class ModifiableModel(Generic[MT, LT, PT], MultiFrameworkObject): model: MT = None + def __init__(self, framework=None, model=None): + self.model = model + def get_layers_params( self, targets: Union[str, List[str]] ) -> Dict[str, ModelParameterizedLayer[LT, PT]]: diff --git a/src/sparseml/core/model/pytorch.py b/src/sparseml/core/model/pytorch.py index fde46468f8e..b394f3a753f 100644 --- a/src/sparseml/core/model/pytorch.py +++ b/src/sparseml/core/model/pytorch.py @@ -31,6 +31,10 @@ class ModifiableModelPyTorch(ModifiableModel[Module, Module, Parameter]): + + def __init__(self, framework=None, model=None): + super().__init__(framework=framework, model=model) + def get_layers(self, targets: Union[str, List[str]]) -> Dict[str, Module]: return get_layers(targets, self.model) diff --git a/src/sparseml/core/modifier/__init__.py b/src/sparseml/core/modifier/__init__.py index 6405fb2b97d..af25a2b2db2 100644 --- a/src/sparseml/core/modifier/__init__.py +++ b/src/sparseml/core/modifier/__init__.py @@ -12,8 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import * -from .factory import * -from .modifier import * -from .recipe import * -from .stage import * +from .stage import StageModifiers +from .factory import ModifierFactory \ No newline at end of file diff --git a/src/sparseml/core/modifier/base.py b/src/sparseml/core/modifier/base.py index 7eb5fd3f667..cbdf2a002b6 100644 --- a/src/sparseml/core/modifier/base.py +++ b/src/sparseml/core/modifier/base.py @@ -15,10 +15,6 @@ from abc import ABC, abstractmethod -from sparseml.core.event import Event -from sparseml.core.state import State - - __all__ = ["ModifierInterface"] @@ -39,17 +35,17 @@ def calculate_end(self) -> float: raise NotImplementedError() @abstractmethod - def pre_initialize_structure(self, state: State, **kwargs): + def pre_initialize_structure(self, state: "State", **kwargs): raise NotImplementedError() @abstractmethod - def initialize(self, state: State, **kwargs): + def initialize(self, state: "State", **kwargs): raise NotImplementedError() @abstractmethod - def finalize(self, state: State, **kwargs): + def finalize(self, state: "State", **kwargs): raise NotImplementedError() @abstractmethod - def update_event(self, state: State, event: Event, **kwargs): + def update_event(self, state: "State", event: "Event", **kwargs): raise NotImplementedError() diff --git a/src/sparseml/core/modifier/factory.py b/src/sparseml/core/modifier/factory.py index 018e2089f58..704021fd087 100644 --- a/src/sparseml/core/modifier/factory.py +++ b/src/sparseml/core/modifier/factory.py @@ -12,9 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from sparseml.core.framework import Framework -from sparseml.core.modifier.modifier import Modifier - __all__ = ["ModifierFactory"] @@ -25,5 +22,5 @@ def refresh(): raise NotImplementedError() @staticmethod - def create(type_: str, framework: Framework, **kwargs) -> Modifier: + def create(type_: str, framework: "Framework", **kwargs) -> "Modifier": raise NotImplementedError() diff --git a/src/sparseml/core/modifier/modifier.py b/src/sparseml/core/modifier/modifier.py index 750fb58a7e2..6c1c6d92df6 100644 --- a/src/sparseml/core/modifier/modifier.py +++ b/src/sparseml/core/modifier/modifier.py @@ -20,8 +20,6 @@ from sparseml.core.event import Event, EventType from sparseml.core.framework import MultiFrameworkObject from sparseml.core.modifier.base import ModifierInterface -from sparseml.core.state import State - __all__ = ["Modifier"] @@ -49,11 +47,11 @@ def calculate_start(self) -> float: def calculate_end(self) -> float: return self.end if self.end is not None else -1 - def pre_initialize_structure(self, state: State, **kwargs): + def pre_initialize_structure(self, state: "State", **kwargs): self.on_initialize_structure(state, **kwargs) self._initialized_structure = True - def initialize(self, state: State, **kwargs): + def initialize(self, state: "State", **kwargs): if self._initialized: return @@ -77,7 +75,7 @@ def initialize(self, state: State, **kwargs): self.on_start(state, state.start_event, **kwargs) self._started = True - def finalize(self, state: State, **kwargs): + def finalize(self, state: "State", **kwargs): if self._finalized: return @@ -94,7 +92,7 @@ def finalize(self, state: State, **kwargs): self._finalized = finalized - def update_event(self, state: State, event: Event, **kwargs): + def update_event(self, state: "State", event: Event, **kwargs): if not self._initialized: raise RuntimeError("cannot update an uninitialized modifier") @@ -138,20 +136,20 @@ def should_end(self, event: Event): return self.end is not None and current >= self.end - def on_initialize_structure(self, state: State, **kwargs): + def on_initialize_structure(self, state: "State", **kwargs): raise NotImplementedError() - def on_initialize(self, state: State, event: Event, **kwargs) -> bool: + def on_initialize(self, state: "State", event: Event, **kwargs) -> bool: raise NotImplementedError() - def on_finalize(self, state: State, event: Event, **kwargs) -> bool: + def on_finalize(self, state: "State", event: Event, **kwargs) -> bool: raise NotImplementedError() - def on_start(self, state: State, event: Event, **kwargs): + def on_start(self, state: "State", event: Event, **kwargs): raise NotImplementedError() - def on_update(self, state: State, event: Event, **kwargs): + def on_update(self, state: "State", event: Event, **kwargs): raise NotImplementedError() - def on_end(self, state: State, event: Event, **kwargs): + def on_end(self, state: "State", event: Event, **kwargs): raise NotImplementedError() diff --git a/src/sparseml/core/modifier/stage.py b/src/sparseml/core/modifier/stage.py index caa2e81c5f6..c27c4d2db52 100644 --- a/src/sparseml/core/modifier/stage.py +++ b/src/sparseml/core/modifier/stage.py @@ -18,12 +18,14 @@ from pydantic import BaseModel, Field from sparseml.core.modifier.base import ModifierInterface -from sparseml.core.modifier.modifier import Modifier -from sparseml.core.state import Event, State + +__all__ = [ + "StageModifier" +] class StageModifiers(ModifierInterface, BaseModel): - modifiers: List[Modifier] = Field(default_factory=list) + modifiers: List["Modifier"] = Field(default_factory=list) index: int = None group: str = None @@ -47,21 +49,21 @@ def calculate_end(self) -> float: mod.calculate_end() for mod in self.modifiers if mod.calculate_end() >= 0 ) - def pre_initialize_structure(self, state: State, **kwargs): + def pre_initialize_structure(self, state: "State", **kwargs): for modifier in self.modifiers: modifier.pre_initialize_structure(state, **kwargs) self._initialized_structure = True - def initialize(self, state: State, **kwargs): + def initialize(self, state: "State", **kwargs): for modifier in self.modifiers: modifier.initialize(state, **kwargs) self._initialized = True - def finalize(self, state: State, **kwargs): + def finalize(self, state: "State", **kwargs): for modifier in self.modifiers: modifier.finalize(state, **kwargs) self._finalized = True - def update_event(self, state: State, event: Event, **kwargs): + def update_event(self, state: "State", event: "Event", **kwargs): for modifier in self.modifiers: modifier.update_event(state, event, **kwargs) diff --git a/src/sparseml/core/optimizer/__init__.py b/src/sparseml/core/optimizer/__init__.py index 87930811c41..07e2638ee13 100644 --- a/src/sparseml/core/optimizer/__init__.py +++ b/src/sparseml/core/optimizer/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import * +from .base import ModifiableOptimizer \ No newline at end of file diff --git a/src/sparseml/core/optimizer/base.py b/src/sparseml/core/optimizer/base.py index 09b46208557..0058bf43bf2 100644 --- a/src/sparseml/core/optimizer/base.py +++ b/src/sparseml/core/optimizer/base.py @@ -28,6 +28,9 @@ @dataclass class ModifiableOptimizer(Generic[OT, PGT], MultiFrameworkObject): optimizer: OT = None + + def __init__(self, optimizer=None, attach_optim_callbacks=False, framework=None): + self.optimizer = optimizer def get_param_groups(self) -> List[PGT]: raise NotImplementedError() diff --git a/src/sparseml/core/optimizer/pytorch.py b/src/sparseml/core/optimizer/pytorch.py index 502c5f8766c..15d7d71b857 100644 --- a/src/sparseml/core/optimizer/pytorch.py +++ b/src/sparseml/core/optimizer/pytorch.py @@ -23,6 +23,9 @@ class ModifiableOptimizerPyTorch(ModifiableOptimizer[Optimizer, Dict[str, Any]]): + def __init__(self, optimizer=None, attach_optim_callbacks=False, framework=None): + super().__init__(optimizer=optimizer, attach_optim_callbacks=attach_optim_callbacks, framework=framework) + def get_param_groups(self) -> List[Dict[str, Any]]: return self.optimizer.param_groups diff --git a/src/sparseml/core/recipe/__init__.py b/src/sparseml/core/recipe/__init__.py index 9bf403c2829..09223c3bd12 100644 --- a/src/sparseml/core/recipe/__init__.py +++ b/src/sparseml/core/recipe/__init__.py @@ -12,8 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .args import * -from .metadata import * -from .modifier import * -from .recipe import * -from .stage import * +from .recipe import Recipe \ No newline at end of file diff --git a/src/sparseml/core/recipe/base.py b/src/sparseml/core/recipe/base.py index 7eeb8e4539d..b781504406c 100644 --- a/src/sparseml/core/recipe/base.py +++ b/src/sparseml/core/recipe/base.py @@ -18,7 +18,6 @@ from pydantic import BaseModel, root_validator from sparseml.core.framework import Framework -from sparseml.core.modifier import Modifier, ModifierFactory from sparseml.core.recipe.args import RecipeArgs diff --git a/src/sparseml/core/recipe/modifier.py b/src/sparseml/core/recipe/modifier.py index 23e6ac98887..43d27224ce6 100644 --- a/src/sparseml/core/recipe/modifier.py +++ b/src/sparseml/core/recipe/modifier.py @@ -17,7 +17,7 @@ from pydantic import root_validator from sparseml.core.framework import Framework -from sparseml.core.modifier import Modifier, ModifierFactory +from sparseml.core.modifier import ModifierFactory from sparseml.core.recipe.args import RecipeArgs from sparseml.core.recipe.base import RecipeBase @@ -56,7 +56,7 @@ def evaluate(self, args: RecipeArgs = None, shift: int = None): if shift is not None and "end" in self._args_evaluated: self._args_evaluated["end"] += shift - def create_modifier(self, framework: Framework) -> Modifier: + def create_modifier(self, framework: Framework) -> "Modifier": return ModifierFactory.create(self.type, framework, **self._args_evaluated) @root_validator(pre=True) diff --git a/src/sparseml/core/session.py b/src/sparseml/core/session.py index 505cab7060d..1ba946b8c70 100644 --- a/src/sparseml/core/session.py +++ b/src/sparseml/core/session.py @@ -23,8 +23,6 @@ WrappedOptimEventLifecycle, ) from sparseml.core.framework import Framework -from sparseml.core.modifier import StageModifiers -from sparseml.core.recipe import Recipe from sparseml.core.state import ModifiedState, State @@ -52,7 +50,7 @@ class _CallbackContainer: class SparseSession: def __init__(self): self._state: State = State() - self._modifiers: List[StageModifiers] = [] + self._modifiers: List["StageModifiers"] = [] self._initialized_structure = False self._initialized = False self._finalized = False @@ -63,7 +61,7 @@ def state(self) -> State: return self._state @property - def modifiers(self) -> List[StageModifiers]: + def modifiers(self) -> List["StageModifiers"]: return self._modifiers @property @@ -85,7 +83,7 @@ def event_called(self) -> bool: def pre_initialize_structure( self, model: Any, - recipe: Union[Recipe, List[Recipe]], + recipe: Union["Recipe", List["Recipe"]], framework: Framework = None, **kwargs, ) -> ModifiedState: @@ -113,7 +111,7 @@ def pre_initialize_structure( def initialize( self, framework: Framework = None, - recipe: Union[str, List[str], Recipe, List[Recipe]] = None, + recipe: Union[str, List[str], "Recipe", List["Recipe"]] = None, recipe_stage: str = None, recipe_args: Dict[str, Any] = None, model: Any = None, @@ -340,7 +338,7 @@ def pre_initialize_structure(**kwargs): def initialize( framework: Framework = None, - recipe: Union[str, List[str], Recipe, List[Recipe]] = None, + recipe: Union[str, List[str], "Recipe", List["Recipe"]] = None, recipe_stage: str = None, recipe_args: Dict[str, Any] = None, model: Any = None, @@ -384,7 +382,7 @@ def finalize(**kwargs) -> ModifiedState: def apply( framework: Framework = None, - recipe: Union[str, List[str], Recipe, List[Recipe]] = None, + recipe: Union[str, List[str], "Recipe", List["Recipe"]] = None, recipe_stage: str = None, recipe_args: Dict[str, Any] = None, model: Any = None, diff --git a/src/sparseml/core/state.py b/src/sparseml/core/state.py index a7c4f8c15e4..12d5f793aee 100644 --- a/src/sparseml/core/state.py +++ b/src/sparseml/core/state.py @@ -89,6 +89,9 @@ def update_recipe( recipe_stage: str = None, recipe_args: Dict[str, Any] = None, ): + if recipe is None: + return + if not isinstance(recipe, list): recipe = [recipe] From bc73e1505f59093a3002af6ee98f8850a84ae39f Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Thu, 14 Sep 2023 10:31:52 -0400 Subject: [PATCH 09/27] fix import errors and multiframework inits --- src/sparseml/core/__init__.py | 10 +--------- src/sparseml/core/data/__init__.py | 2 +- src/sparseml/core/framework.py | 2 +- src/sparseml/core/model/__init__.py | 2 +- src/sparseml/core/model/base.py | 3 +++ src/sparseml/core/model/pytorch.py | 4 ++++ src/sparseml/core/modifier/__init__.py | 7 ++----- src/sparseml/core/modifier/base.py | 12 ++++-------- src/sparseml/core/modifier/factory.py | 5 +---- src/sparseml/core/modifier/modifier.py | 22 ++++++++++------------ src/sparseml/core/modifier/stage.py | 16 +++++++++------- src/sparseml/core/optimizer/__init__.py | 2 +- src/sparseml/core/optimizer/base.py | 3 +++ src/sparseml/core/optimizer/pytorch.py | 3 +++ src/sparseml/core/recipe/__init__.py | 6 +----- src/sparseml/core/recipe/base.py | 1 - src/sparseml/core/recipe/modifier.py | 4 ++-- src/sparseml/core/session.py | 14 ++++++-------- src/sparseml/core/state.py | 3 +++ 19 files changed, 56 insertions(+), 65 deletions(-) diff --git a/src/sparseml/core/__init__.py b/src/sparseml/core/__init__.py index fc3f40c71ff..bc26ded926c 100644 --- a/src/sparseml/core/__init__.py +++ b/src/sparseml/core/__init__.py @@ -12,12 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .data import * -from .event import * -from .framework import * -from .model import * -from .modifier import * -from .optimizer import * -from .recipe import * -from .session import * -from .state import * +from .session import * \ No newline at end of file diff --git a/src/sparseml/core/data/__init__.py b/src/sparseml/core/data/__init__.py index 87930811c41..1101a7fa8ea 100644 --- a/src/sparseml/core/data/__init__.py +++ b/src/sparseml/core/data/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import * +from .base import ModifiableData diff --git a/src/sparseml/core/framework.py b/src/sparseml/core/framework.py index d4f6ddcaebd..c7a67fe85cd 100644 --- a/src/sparseml/core/framework.py +++ b/src/sparseml/core/framework.py @@ -77,7 +77,7 @@ def __new__( if cls is MultiFrameworkObject: raise TypeError("MultiFrameworkObject cannot be instantiated directly") - instance = super(MultiFrameworkObject, cls).__new__(cls, **kwargs) + instance = super(MultiFrameworkObject, cls).__new__(cls) package = instance.__class__.__module__.rsplit(".", 1)[0] class_name = instance.__class__.__name__ diff --git a/src/sparseml/core/model/__init__.py b/src/sparseml/core/model/__init__.py index 87930811c41..7a2c12e5d45 100644 --- a/src/sparseml/core/model/__init__.py +++ b/src/sparseml/core/model/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import * +from .base import ModifiableModel, ModelParameterizedLayer diff --git a/src/sparseml/core/model/base.py b/src/sparseml/core/model/base.py index 0a0fdb4084d..96f2fb789d6 100644 --- a/src/sparseml/core/model/base.py +++ b/src/sparseml/core/model/base.py @@ -38,6 +38,9 @@ class ModelParameterizedLayer(Generic[LT, PT]): class ModifiableModel(Generic[MT, LT, PT], MultiFrameworkObject): model: MT = None + def __init__(self, framework=None, model=None): + self.model = model + def get_layers_params( self, targets: Union[str, List[str]] ) -> Dict[str, ModelParameterizedLayer[LT, PT]]: diff --git a/src/sparseml/core/model/pytorch.py b/src/sparseml/core/model/pytorch.py index fde46468f8e..b394f3a753f 100644 --- a/src/sparseml/core/model/pytorch.py +++ b/src/sparseml/core/model/pytorch.py @@ -31,6 +31,10 @@ class ModifiableModelPyTorch(ModifiableModel[Module, Module, Parameter]): + + def __init__(self, framework=None, model=None): + super().__init__(framework=framework, model=model) + def get_layers(self, targets: Union[str, List[str]]) -> Dict[str, Module]: return get_layers(targets, self.model) diff --git a/src/sparseml/core/modifier/__init__.py b/src/sparseml/core/modifier/__init__.py index 6405fb2b97d..af25a2b2db2 100644 --- a/src/sparseml/core/modifier/__init__.py +++ b/src/sparseml/core/modifier/__init__.py @@ -12,8 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import * -from .factory import * -from .modifier import * -from .recipe import * -from .stage import * +from .stage import StageModifiers +from .factory import ModifierFactory \ No newline at end of file diff --git a/src/sparseml/core/modifier/base.py b/src/sparseml/core/modifier/base.py index 7eb5fd3f667..cbdf2a002b6 100644 --- a/src/sparseml/core/modifier/base.py +++ b/src/sparseml/core/modifier/base.py @@ -15,10 +15,6 @@ from abc import ABC, abstractmethod -from sparseml.core.event import Event -from sparseml.core.state import State - - __all__ = ["ModifierInterface"] @@ -39,17 +35,17 @@ def calculate_end(self) -> float: raise NotImplementedError() @abstractmethod - def pre_initialize_structure(self, state: State, **kwargs): + def pre_initialize_structure(self, state: "State", **kwargs): raise NotImplementedError() @abstractmethod - def initialize(self, state: State, **kwargs): + def initialize(self, state: "State", **kwargs): raise NotImplementedError() @abstractmethod - def finalize(self, state: State, **kwargs): + def finalize(self, state: "State", **kwargs): raise NotImplementedError() @abstractmethod - def update_event(self, state: State, event: Event, **kwargs): + def update_event(self, state: "State", event: "Event", **kwargs): raise NotImplementedError() diff --git a/src/sparseml/core/modifier/factory.py b/src/sparseml/core/modifier/factory.py index 018e2089f58..704021fd087 100644 --- a/src/sparseml/core/modifier/factory.py +++ b/src/sparseml/core/modifier/factory.py @@ -12,9 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from sparseml.core.framework import Framework -from sparseml.core.modifier.modifier import Modifier - __all__ = ["ModifierFactory"] @@ -25,5 +22,5 @@ def refresh(): raise NotImplementedError() @staticmethod - def create(type_: str, framework: Framework, **kwargs) -> Modifier: + def create(type_: str, framework: "Framework", **kwargs) -> "Modifier": raise NotImplementedError() diff --git a/src/sparseml/core/modifier/modifier.py b/src/sparseml/core/modifier/modifier.py index 750fb58a7e2..6c1c6d92df6 100644 --- a/src/sparseml/core/modifier/modifier.py +++ b/src/sparseml/core/modifier/modifier.py @@ -20,8 +20,6 @@ from sparseml.core.event import Event, EventType from sparseml.core.framework import MultiFrameworkObject from sparseml.core.modifier.base import ModifierInterface -from sparseml.core.state import State - __all__ = ["Modifier"] @@ -49,11 +47,11 @@ def calculate_start(self) -> float: def calculate_end(self) -> float: return self.end if self.end is not None else -1 - def pre_initialize_structure(self, state: State, **kwargs): + def pre_initialize_structure(self, state: "State", **kwargs): self.on_initialize_structure(state, **kwargs) self._initialized_structure = True - def initialize(self, state: State, **kwargs): + def initialize(self, state: "State", **kwargs): if self._initialized: return @@ -77,7 +75,7 @@ def initialize(self, state: State, **kwargs): self.on_start(state, state.start_event, **kwargs) self._started = True - def finalize(self, state: State, **kwargs): + def finalize(self, state: "State", **kwargs): if self._finalized: return @@ -94,7 +92,7 @@ def finalize(self, state: State, **kwargs): self._finalized = finalized - def update_event(self, state: State, event: Event, **kwargs): + def update_event(self, state: "State", event: Event, **kwargs): if not self._initialized: raise RuntimeError("cannot update an uninitialized modifier") @@ -138,20 +136,20 @@ def should_end(self, event: Event): return self.end is not None and current >= self.end - def on_initialize_structure(self, state: State, **kwargs): + def on_initialize_structure(self, state: "State", **kwargs): raise NotImplementedError() - def on_initialize(self, state: State, event: Event, **kwargs) -> bool: + def on_initialize(self, state: "State", event: Event, **kwargs) -> bool: raise NotImplementedError() - def on_finalize(self, state: State, event: Event, **kwargs) -> bool: + def on_finalize(self, state: "State", event: Event, **kwargs) -> bool: raise NotImplementedError() - def on_start(self, state: State, event: Event, **kwargs): + def on_start(self, state: "State", event: Event, **kwargs): raise NotImplementedError() - def on_update(self, state: State, event: Event, **kwargs): + def on_update(self, state: "State", event: Event, **kwargs): raise NotImplementedError() - def on_end(self, state: State, event: Event, **kwargs): + def on_end(self, state: "State", event: Event, **kwargs): raise NotImplementedError() diff --git a/src/sparseml/core/modifier/stage.py b/src/sparseml/core/modifier/stage.py index caa2e81c5f6..c27c4d2db52 100644 --- a/src/sparseml/core/modifier/stage.py +++ b/src/sparseml/core/modifier/stage.py @@ -18,12 +18,14 @@ from pydantic import BaseModel, Field from sparseml.core.modifier.base import ModifierInterface -from sparseml.core.modifier.modifier import Modifier -from sparseml.core.state import Event, State + +__all__ = [ + "StageModifier" +] class StageModifiers(ModifierInterface, BaseModel): - modifiers: List[Modifier] = Field(default_factory=list) + modifiers: List["Modifier"] = Field(default_factory=list) index: int = None group: str = None @@ -47,21 +49,21 @@ def calculate_end(self) -> float: mod.calculate_end() for mod in self.modifiers if mod.calculate_end() >= 0 ) - def pre_initialize_structure(self, state: State, **kwargs): + def pre_initialize_structure(self, state: "State", **kwargs): for modifier in self.modifiers: modifier.pre_initialize_structure(state, **kwargs) self._initialized_structure = True - def initialize(self, state: State, **kwargs): + def initialize(self, state: "State", **kwargs): for modifier in self.modifiers: modifier.initialize(state, **kwargs) self._initialized = True - def finalize(self, state: State, **kwargs): + def finalize(self, state: "State", **kwargs): for modifier in self.modifiers: modifier.finalize(state, **kwargs) self._finalized = True - def update_event(self, state: State, event: Event, **kwargs): + def update_event(self, state: "State", event: "Event", **kwargs): for modifier in self.modifiers: modifier.update_event(state, event, **kwargs) diff --git a/src/sparseml/core/optimizer/__init__.py b/src/sparseml/core/optimizer/__init__.py index 87930811c41..07e2638ee13 100644 --- a/src/sparseml/core/optimizer/__init__.py +++ b/src/sparseml/core/optimizer/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import * +from .base import ModifiableOptimizer \ No newline at end of file diff --git a/src/sparseml/core/optimizer/base.py b/src/sparseml/core/optimizer/base.py index 09b46208557..0058bf43bf2 100644 --- a/src/sparseml/core/optimizer/base.py +++ b/src/sparseml/core/optimizer/base.py @@ -28,6 +28,9 @@ @dataclass class ModifiableOptimizer(Generic[OT, PGT], MultiFrameworkObject): optimizer: OT = None + + def __init__(self, optimizer=None, attach_optim_callbacks=False, framework=None): + self.optimizer = optimizer def get_param_groups(self) -> List[PGT]: raise NotImplementedError() diff --git a/src/sparseml/core/optimizer/pytorch.py b/src/sparseml/core/optimizer/pytorch.py index 502c5f8766c..15d7d71b857 100644 --- a/src/sparseml/core/optimizer/pytorch.py +++ b/src/sparseml/core/optimizer/pytorch.py @@ -23,6 +23,9 @@ class ModifiableOptimizerPyTorch(ModifiableOptimizer[Optimizer, Dict[str, Any]]): + def __init__(self, optimizer=None, attach_optim_callbacks=False, framework=None): + super().__init__(optimizer=optimizer, attach_optim_callbacks=attach_optim_callbacks, framework=framework) + def get_param_groups(self) -> List[Dict[str, Any]]: return self.optimizer.param_groups diff --git a/src/sparseml/core/recipe/__init__.py b/src/sparseml/core/recipe/__init__.py index 9bf403c2829..09223c3bd12 100644 --- a/src/sparseml/core/recipe/__init__.py +++ b/src/sparseml/core/recipe/__init__.py @@ -12,8 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .args import * -from .metadata import * -from .modifier import * -from .recipe import * -from .stage import * +from .recipe import Recipe \ No newline at end of file diff --git a/src/sparseml/core/recipe/base.py b/src/sparseml/core/recipe/base.py index 7eeb8e4539d..b781504406c 100644 --- a/src/sparseml/core/recipe/base.py +++ b/src/sparseml/core/recipe/base.py @@ -18,7 +18,6 @@ from pydantic import BaseModel, root_validator from sparseml.core.framework import Framework -from sparseml.core.modifier import Modifier, ModifierFactory from sparseml.core.recipe.args import RecipeArgs diff --git a/src/sparseml/core/recipe/modifier.py b/src/sparseml/core/recipe/modifier.py index 23e6ac98887..43d27224ce6 100644 --- a/src/sparseml/core/recipe/modifier.py +++ b/src/sparseml/core/recipe/modifier.py @@ -17,7 +17,7 @@ from pydantic import root_validator from sparseml.core.framework import Framework -from sparseml.core.modifier import Modifier, ModifierFactory +from sparseml.core.modifier import ModifierFactory from sparseml.core.recipe.args import RecipeArgs from sparseml.core.recipe.base import RecipeBase @@ -56,7 +56,7 @@ def evaluate(self, args: RecipeArgs = None, shift: int = None): if shift is not None and "end" in self._args_evaluated: self._args_evaluated["end"] += shift - def create_modifier(self, framework: Framework) -> Modifier: + def create_modifier(self, framework: Framework) -> "Modifier": return ModifierFactory.create(self.type, framework, **self._args_evaluated) @root_validator(pre=True) diff --git a/src/sparseml/core/session.py b/src/sparseml/core/session.py index 505cab7060d..1ba946b8c70 100644 --- a/src/sparseml/core/session.py +++ b/src/sparseml/core/session.py @@ -23,8 +23,6 @@ WrappedOptimEventLifecycle, ) from sparseml.core.framework import Framework -from sparseml.core.modifier import StageModifiers -from sparseml.core.recipe import Recipe from sparseml.core.state import ModifiedState, State @@ -52,7 +50,7 @@ class _CallbackContainer: class SparseSession: def __init__(self): self._state: State = State() - self._modifiers: List[StageModifiers] = [] + self._modifiers: List["StageModifiers"] = [] self._initialized_structure = False self._initialized = False self._finalized = False @@ -63,7 +61,7 @@ def state(self) -> State: return self._state @property - def modifiers(self) -> List[StageModifiers]: + def modifiers(self) -> List["StageModifiers"]: return self._modifiers @property @@ -85,7 +83,7 @@ def event_called(self) -> bool: def pre_initialize_structure( self, model: Any, - recipe: Union[Recipe, List[Recipe]], + recipe: Union["Recipe", List["Recipe"]], framework: Framework = None, **kwargs, ) -> ModifiedState: @@ -113,7 +111,7 @@ def pre_initialize_structure( def initialize( self, framework: Framework = None, - recipe: Union[str, List[str], Recipe, List[Recipe]] = None, + recipe: Union[str, List[str], "Recipe", List["Recipe"]] = None, recipe_stage: str = None, recipe_args: Dict[str, Any] = None, model: Any = None, @@ -340,7 +338,7 @@ def pre_initialize_structure(**kwargs): def initialize( framework: Framework = None, - recipe: Union[str, List[str], Recipe, List[Recipe]] = None, + recipe: Union[str, List[str], "Recipe", List["Recipe"]] = None, recipe_stage: str = None, recipe_args: Dict[str, Any] = None, model: Any = None, @@ -384,7 +382,7 @@ def finalize(**kwargs) -> ModifiedState: def apply( framework: Framework = None, - recipe: Union[str, List[str], Recipe, List[Recipe]] = None, + recipe: Union[str, List[str], "Recipe", List["Recipe"]] = None, recipe_stage: str = None, recipe_args: Dict[str, Any] = None, model: Any = None, diff --git a/src/sparseml/core/state.py b/src/sparseml/core/state.py index a7c4f8c15e4..12d5f793aee 100644 --- a/src/sparseml/core/state.py +++ b/src/sparseml/core/state.py @@ -89,6 +89,9 @@ def update_recipe( recipe_stage: str = None, recipe_args: Dict[str, Any] = None, ): + if recipe is None: + return + if not isinstance(recipe, list): recipe = [recipe] From 5438e057e4895f19cb9828be16f4258013656bf6 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Thu, 14 Sep 2023 15:14:42 -0400 Subject: [PATCH 10/27] initialization --- src/sparseml/core/data/pytorch.py | 2 +- src/sparseml/core/session.py | 6 +++--- src/sparseml/core/state.py | 14 ++++++++------ 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/sparseml/core/data/pytorch.py b/src/sparseml/core/data/pytorch.py index 65e28718dd3..bbc890f1d7f 100644 --- a/src/sparseml/core/data/pytorch.py +++ b/src/sparseml/core/data/pytorch.py @@ -125,7 +125,7 @@ def merge_batches(batches): class ModifiableDataPyTorch(ModifiableData[DynamicBatchSizeDataLoader]): - def __init__(self, data_loader: DataLoader): + def __init__(self, data_loader: DataLoader, framework=None): super().__init__() self.data = DynamicBatchSizeDataLoader(data_loader) diff --git a/src/sparseml/core/session.py b/src/sparseml/core/session.py index 1ba946b8c70..20498e3c827 100644 --- a/src/sparseml/core/session.py +++ b/src/sparseml/core/session.py @@ -156,7 +156,7 @@ def initialize( return ModifiedState( model=self.state.model.model, optimizer=self.state.optimizer.optimizer, - loss=self.state.loss.loss, + loss=self.state.loss, modifier_data=modifier_data, ) @@ -179,7 +179,7 @@ def finalize(self, **kwargs) -> ModifiedState: return ModifiedState( model=self.state.model.model, optimizer=self.state.optimizer.optimizer, - loss=self.state.loss.loss, + loss=self.state.loss, modifier_data=modifier_data, ) @@ -229,7 +229,7 @@ def event( return ModifiedState( model=self.state.model.model, optimizer=self.state.optimizer.optimizer, - loss=self.state.loss.loss, + loss=self.state.loss, modifier_data=modifier_data, ) diff --git a/src/sparseml/core/state.py b/src/sparseml/core/state.py index 12d5f793aee..1e4fb257c24 100644 --- a/src/sparseml/core/state.py +++ b/src/sparseml/core/state.py @@ -135,19 +135,21 @@ def update_data( if self.framework is None: raise RuntimeError("framework must be set before updating data") - self.data = ModifiableData(framework=self.framework) - if train_data is not None: - self.data.train = train_data if not copy_data else deepcopy(train_data) + train_loader = train_data if not copy_data else deepcopy(train_data) + self.train_data = ModifiableData(framework=self.framework, data_loader=train_loader) if val_data is not None: - self.data.val = val_data if not copy_data else deepcopy(val_data) + val_loader = val_data if not copy_data else deepcopy(val_data) + self.val_data = ModifiableData(framework=self.framework, data_loader=val_loader) if test_data is not None: - self.data.test = test_data if not copy_data else deepcopy(test_data) + test_loader = test_data if not copy_data else deepcopy(test_data) + self.test_data = ModifiableData(framework=self.framework, data_loader=test_loader) if calib_data is not None: - self.data.calib = calib_data if not copy_data else deepcopy(calib_data) + calib_loader = calib_data if not copy_data else deepcopy(calib_data) + self.calib_data = ModifiableData(framework=self.framework, data_loader=calib_loader) def update_start( self, From 996c533452df0a27baf3dfa0a31d5e3b909a6c86 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Fri, 15 Sep 2023 12:10:34 -0400 Subject: [PATCH 11/27] RecipeModifiers working --- src/sparseml/core/modifier/factory.py | 4 + src/sparseml/core/modifier/stage.py | 2 +- src/sparseml/core/recipe/modifier.py | 29 +- src/sparseml/core/recipe/recipe.py | 22 +- src/sparseml/core/recipe/stage.py | 17 +- src/sparseml/core/state.py | 8 +- test_e2e.ipynb | 460 ++++++++++++++++++++++++++ test_e2e_recipe.yaml | 28 ++ 8 files changed, 538 insertions(+), 32 deletions(-) create mode 100644 test_e2e.ipynb create mode 100644 test_e2e_recipe.yaml diff --git a/src/sparseml/core/modifier/factory.py b/src/sparseml/core/modifier/factory.py index 704021fd087..388810e56a9 100644 --- a/src/sparseml/core/modifier/factory.py +++ b/src/sparseml/core/modifier/factory.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect +#import sparseml.modifiers as modifiers __all__ = ["ModifierFactory"] @@ -24,3 +26,5 @@ def refresh(): @staticmethod def create(type_: str, framework: "Framework", **kwargs) -> "Modifier": raise NotImplementedError() + #for name, obj in inspect.getmembers(modifiers): + # print(name, obj) diff --git a/src/sparseml/core/modifier/stage.py b/src/sparseml/core/modifier/stage.py index c27c4d2db52..0bf0476eac5 100644 --- a/src/sparseml/core/modifier/stage.py +++ b/src/sparseml/core/modifier/stage.py @@ -20,7 +20,7 @@ from sparseml.core.modifier.base import ModifierInterface __all__ = [ - "StageModifier" + "StageModifiers" ] diff --git a/src/sparseml/core/recipe/modifier.py b/src/sparseml/core/recipe/modifier.py index 43d27224ce6..5c6d4c60f1b 100644 --- a/src/sparseml/core/recipe/modifier.py +++ b/src/sparseml/core/recipe/modifier.py @@ -29,42 +29,45 @@ class RecipeModifier(RecipeBase): type: str group: str = None args: Dict[str, Any] = None - _args_evaluated: Dict[str, Any] = None + args_evaluated: Dict[str, Any] = None def calculate_start(self) -> int: - if not self._args_evaluated: + if not self.args_evaluated: raise ValueError("args must be evaluated before calculating start") - return self._args_evaluated.get("start", -1) + return self.args_evaluated.get("start", -1) def calculate_end(self) -> int: - if not self._args_evaluated: - raise ValueError("args must be evaluated before calculating start") + if not self.args_evaluated: + raise ValueError("args must be evaluated before calculating end") - return self._args_evaluated.get("end", -1) + return self.args_evaluated.get("end", -1) def evaluate(self, args: RecipeArgs = None, shift: int = None): if not self.args: raise ValueError("args must be set before evaluating") comb_args = args or RecipeArgs() - self._args_evaluated = comb_args.evaluate_ext(self.args) + self.args_evaluated = comb_args.evaluate_ext(self.args) - if shift is not None and "start" in self._args_evaluated: - self._args_evaluated["start"] += shift + if shift is not None and "start" in self.args_evaluated: + self.args_evaluated["start"] += shift - if shift is not None and "end" in self._args_evaluated: - self._args_evaluated["end"] += shift + if shift is not None and "end" in self.args_evaluated: + self.args_evaluated["end"] += shift def create_modifier(self, framework: Framework) -> "Modifier": - return ModifierFactory.create(self.type, framework, **self._args_evaluated) + return ModifierFactory.create(self.type, framework, **self.args_evaluated) @root_validator(pre=True) def extract_modifier_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: + modifier = {"group": values.pop("group")} assert len(values) == 1, "multiple key pairs found for modifier" modifier_type, args = list(values.items())[0] - return {"type": modifier_type, "args": args} + modifier["type"] = modifier_type + modifier["args"] = args + return modifier def dict(self, *args, **kwargs) -> Dict[str, Any]: return {self.type: self.args} diff --git a/src/sparseml/core/recipe/recipe.py b/src/sparseml/core/recipe/recipe.py index 668e00764b2..d6e5a579222 100644 --- a/src/sparseml/core/recipe/recipe.py +++ b/src/sparseml/core/recipe/recipe.py @@ -20,7 +20,6 @@ from pydantic import Field, root_validator from sparseml.core.framework import Framework -from sparseml.core.modifier import StageModifiers from sparseml.core.recipe.args import RecipeArgs from sparseml.core.recipe.base import RecipeBase from sparseml.core.recipe.metadata import RecipeMetaData @@ -76,6 +75,10 @@ def simplify_recipe( def simplify_combine_recipes( recipes: List[Union["Recipe", Tuple["Recipe", str, Dict[str, Any]]]] ) -> "Recipe": + + if len(recipes) == 1: + return recipes[0] + simplified = Recipe() for recipe_tuple in recipes: @@ -101,7 +104,7 @@ def simplify_combine_recipes( args: RecipeArgs = None stages: List[RecipeStage] = Field(default_factory=list) metadata: RecipeMetaData = None - _args_evaluated: RecipeArgs = None + args_evaluated: RecipeArgs = None def calculate_start(self) -> int: return min( @@ -117,17 +120,17 @@ def calculate_end(self) -> int: def evaluate(self, args: Dict[str, Any] = None, shift: int = None): args = self.args.combine(args) if self.args else RecipeArgs(**(args or {})) - self._args_evaluated = args.evaluate() + self.args_evaluated = args.evaluate() for stage in self.stages: - stage.evaluate(self._args_evaluated, shift) + stage.evaluate(self.args_evaluated, shift) - def create_modifier(self, framework: Framework) -> List[StageModifiers]: - if self._args_evaluated is None: + def create_modifier(self, framework: Framework) -> List["StageModifiers"]: + if self.args_evaluated is None: self.evaluate() modifiers = [] for index, stage in enumerate(self.stages): - stage_modifiers = stage.create_modifiers(framework) + stage_modifiers = stage.create_modifier(framework) stage_modifiers.index = index stage_modifiers.group = stage.group modifiers.append(stage_modifiers) @@ -136,15 +139,14 @@ def create_modifier(self, framework: Framework) -> List[StageModifiers]: @root_validator(pre=True) def remap_stages(cls, values: Dict[str, Any]) -> Dict[str, Any]: - modifiers = RecipeStage._combine_modifiers(values) - stages = [{"modifiers": modifiers, "group": "default"}] if modifiers else [] + stages = [] add_stages, remove_keys = Recipe._combine_stages(values) stages.extend(add_stages) for key in remove_keys: del values[key] - values["stages"] = Recipe._combine_stages(values) + values["stages"] = stages return values diff --git a/src/sparseml/core/recipe/stage.py b/src/sparseml/core/recipe/stage.py index 54effcfc042..e08976d5d16 100644 --- a/src/sparseml/core/recipe/stage.py +++ b/src/sparseml/core/recipe/stage.py @@ -32,7 +32,7 @@ class RecipeStage(RecipeBase): enabled: bool = True modifiers: List[RecipeModifier] = Field(default_factory=list) exclude_default: bool = False - _args_evaluated: RecipeArgs = None + args_evaluated: RecipeArgs = None def calculate_start(self) -> int: return min( @@ -47,12 +47,14 @@ def calculate_end(self) -> int: ) def evaluate(self, parent_args: RecipeArgs = None, shift: int = None): + if self.args is None: + self.args = RecipeArgs({}) merged_args = self.args.combine(parent_args) - self._args_evaluated = merged_args.evaluate() + self.args_evaluated = merged_args.evaluate() for modifier in self.modifiers: - modifier.evaluate(self._args_evaluated, shift) + modifier.evaluate(self.args_evaluated, shift) - def create_modifiers( + def create_modifier( self, framework: Framework, parent_args: RecipeArgs = None ) -> StageModifiers: if parent_args is not None: @@ -99,16 +101,19 @@ def dict(self, *args, **kwargs) -> Dict[str, Any]: @staticmethod def _combine_modifiers(values: Dict[str, Any]) -> List[Dict[str, Any]]: modifiers = [] + keys = [] for key, value in list(values.items()): if key.endswith("_modifiers") or key == "modifiers": + keys.append(key) group = ( key.rsplit("_modifiers", 1)[0] if key.endswith("_modifiers") else "default" ) - for modifier in value: + for mod_key, mod_value in value.items(): + modifier = {mod_key: mod_value} modifier["group"] = group modifiers.append(modifier) - return modifiers + return modifiers, keys diff --git a/src/sparseml/core/state.py b/src/sparseml/core/state.py index 1e4fb257c24..d9a3f671236 100644 --- a/src/sparseml/core/state.py +++ b/src/sparseml/core/state.py @@ -98,8 +98,12 @@ def update_recipe( for rec in recipe: if isinstance(rec, str): rec = Recipe.create_instance(rec) - - self.recipes.append((rec, recipe_stage, recipe_args)) + if not isinstance(self.recipes, List): + self.recipes = [] + if recipe_stage is None and recipe_args is None: + self.recipes.append(rec) + else: + self.recipes.append((rec, recipe_stage, recipe_args)) self._recipe_changed = True diff --git a/test_e2e.ipynb b/test_e2e.ipynb new file mode 100644 index 00000000000..9941d4d775d --- /dev/null +++ b/test_e2e.ipynb @@ -0,0 +1,460 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import sparseml.core.session as sml\n", + "from sparseml.core.framework import Framework\n", + "import torchvision\n", + "from torchvision import transforms\n", + "import torch\n", + "from torch.utils.data import DataLoader\n", + "import datasets\n", + "import os\n", + "from torch.optim import Adam" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "sml.create_session()\n", + "session = sml.active_session()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "NUM_LABELS = 3\n", + "model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT)\n", + "model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, NUM_LABELS)\n", + "optimizer = Adam(model.parameters(), lr=8e-3)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "#beans_dataset = datasets.load_dataset(\"beans\")\n", + "#print(beans_dataset[\"train\"][0][\"image_file_path\"])\n", + "#print(beans_dataset[\"validation\"][0][\"image_file_path\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "train_path = \"/home/sadkins/.cache/huggingface/datasets/downloads/extracted/dbf92bfb2c3766fb3083a51374ad94d8a3690f53cdf0f9113a231c2351c9ff33/train\"\n", + "val_path = \"/home/sadkins/.cache/huggingface/datasets/downloads/extracted/510ede718de2aeaa2f9d88b0d81d88c449beeb7d074ea594bdf25a0e6a9d51d0/validation\"" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "NUM_LABELS = 3\n", + "BATCH_SIZE = 32\n", + "\n", + "# imagenet transforms\n", + "imagenet_transform = transforms.Compose([\n", + " transforms.Resize(size=256, interpolation=transforms.InterpolationMode.BILINEAR, max_size=None, antialias=None),\n", + " transforms.CenterCrop(size=(224, 224)),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", + "])\n", + "\n", + "# datasets\n", + "train_dataset = torchvision.datasets.ImageFolder(\n", + " root=train_path,\n", + " transform=imagenet_transform\n", + ")\n", + "\n", + "val_dataset = torchvision.datasets.ImageFolder(\n", + " root=val_path,\n", + " transform=imagenet_transform\n", + ")\n", + "\n", + "# dataloaders\n", + "train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=16)\n", + "val_loader = DataLoader(val_dataset, BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=16)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "recipe = \"test_e2e_recipe.yaml\"" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ModifiedState(model=MobileNetV2(\n", + " (features): Sequential(\n", + " (0): Conv2dNormActivation(\n", + " (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (1): InvertedResidual(\n", + " (conv): Sequential(\n", + " (0): Conv2dNormActivation(\n", + " (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)\n", + " (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (2): InvertedResidual(\n", + " (conv): Sequential(\n", + " (0): Conv2dNormActivation(\n", + " (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (1): Conv2dNormActivation(\n", + " (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=96, bias=False)\n", + " (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (2): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (3): InvertedResidual(\n", + " (conv): Sequential(\n", + " (0): Conv2dNormActivation(\n", + " (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (1): Conv2dNormActivation(\n", + " (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144, bias=False)\n", + " (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (2): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (4): InvertedResidual(\n", + " (conv): Sequential(\n", + " (0): Conv2dNormActivation(\n", + " (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (1): Conv2dNormActivation(\n", + " (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=144, bias=False)\n", + " (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (2): Conv2d(144, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (5): InvertedResidual(\n", + " (conv): Sequential(\n", + " (0): Conv2dNormActivation(\n", + " (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (1): Conv2dNormActivation(\n", + " (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)\n", + " (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (2): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (6): InvertedResidual(\n", + " (conv): Sequential(\n", + " (0): Conv2dNormActivation(\n", + " (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (1): Conv2dNormActivation(\n", + " (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)\n", + " (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (2): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (7): InvertedResidual(\n", + " (conv): Sequential(\n", + " (0): Conv2dNormActivation(\n", + " (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (1): Conv2dNormActivation(\n", + " (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=192, bias=False)\n", + " (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (2): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (8): InvertedResidual(\n", + " (conv): Sequential(\n", + " (0): Conv2dNormActivation(\n", + " (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (1): Conv2dNormActivation(\n", + " (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)\n", + " (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (9): InvertedResidual(\n", + " (conv): Sequential(\n", + " (0): Conv2dNormActivation(\n", + " (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (1): Conv2dNormActivation(\n", + " (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)\n", + " (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (10): InvertedResidual(\n", + " (conv): Sequential(\n", + " (0): Conv2dNormActivation(\n", + " (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (1): Conv2dNormActivation(\n", + " (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)\n", + " (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (11): InvertedResidual(\n", + " (conv): Sequential(\n", + " (0): Conv2dNormActivation(\n", + " (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (1): Conv2dNormActivation(\n", + " (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)\n", + " (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (2): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (12): InvertedResidual(\n", + " (conv): Sequential(\n", + " (0): Conv2dNormActivation(\n", + " (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (1): Conv2dNormActivation(\n", + " (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)\n", + " (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (2): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (13): InvertedResidual(\n", + " (conv): Sequential(\n", + " (0): Conv2dNormActivation(\n", + " (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (1): Conv2dNormActivation(\n", + " (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)\n", + " (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (2): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (14): InvertedResidual(\n", + " (conv): Sequential(\n", + " (0): Conv2dNormActivation(\n", + " (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (1): Conv2dNormActivation(\n", + " (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=576, bias=False)\n", + " (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (2): Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (15): InvertedResidual(\n", + " (conv): Sequential(\n", + " (0): Conv2dNormActivation(\n", + " (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (1): Conv2dNormActivation(\n", + " (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n", + " (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (2): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (16): InvertedResidual(\n", + " (conv): Sequential(\n", + " (0): Conv2dNormActivation(\n", + " (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (1): Conv2dNormActivation(\n", + " (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n", + " (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (2): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (17): InvertedResidual(\n", + " (conv): Sequential(\n", + " (0): Conv2dNormActivation(\n", + " (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (1): Conv2dNormActivation(\n", + " (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)\n", + " (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " (2): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (18): Conv2dNormActivation(\n", + " (0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU6(inplace=True)\n", + " )\n", + " )\n", + " (classifier): Sequential(\n", + " (0): Dropout(p=0.2, inplace=False)\n", + " (1): Linear(in_features=1280, out_features=3, bias=True)\n", + " )\n", + "), optimizer=None, loss=None, modifier_data=[])" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "session.pre_initialize_structure(\n", + " framework=Framework.pytorch,\n", + " recipe=recipe,\n", + " model=model\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "session.initialize(\n", + " framework=Framework.pytorch,\n", + " recipe=recipe,\n", + " model=model,\n", + " teacher_model=None,\n", + " optimizer=optimizer,\n", + " train_data=train_loader,\n", + " val_data=val_loader\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test_e2e_recipe.yaml b/test_e2e_recipe.yaml new file mode 100644 index 00000000000..558990df149 --- /dev/null +++ b/test_e2e_recipe.yaml @@ -0,0 +1,28 @@ +test_stage: + pruning_modifiers: + MagnitudePruningModifier: + init_sparsity: 0.0 + final_sparsity: 0.5 + start_epoch: 1.0 + end_epoch: 10.0 + update_frequency: 0.5 + params: + - 'features.0.0.weight' + - 'features.18.0.weight' + - 're:features.*.conv.*.weight' + - 're:features.*.conv.*.*.weight' + leave_enabled: True +test2_stage: + pruning_modifiers: + MagnitudePruningModifier: + init_sparsity: 0.0 + final_sparsity: 0.5 + start_epoch: 1.0 + end_epoch: 10.0 + update_frequency: 0.5 + params: + - 'features.0.0.weight' + - 'features.18.0.weight' + - 're:features.*.conv.*.weight' + - 're:features.*.conv.*.*.weight' + leave_enabled: True \ No newline at end of file From 9635acb05cd8ad9ec82a11f31dd485d6a3f318aa Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Sun, 17 Sep 2023 10:11:50 -0400 Subject: [PATCH 12/27] fix import errors --- src/sparseml/core/__init__.py | 12 +- src/sparseml/core/data/base.py | 2 +- src/sparseml/core/event.py | 261 +--------------------- src/sparseml/core/framework.py | 49 +---- src/sparseml/core/framework_object.py | 68 ++++++ src/sparseml/core/lifecycle/__init__.py | 16 ++ src/sparseml/core/lifecycle/event.py | 280 ++++++++++++++++++++++++ src/sparseml/core/lifecycle/session.py | 203 +++++++++++++++++ src/sparseml/core/model/__init__.py | 2 +- src/sparseml/core/model/base.py | 2 +- src/sparseml/core/model/pytorch.py | 1 - src/sparseml/core/modifier/__init__.py | 6 +- src/sparseml/core/modifier/base.py | 27 ++- src/sparseml/core/modifier/modifier.py | 38 +++- src/sparseml/core/modifier/stage.py | 35 +-- src/sparseml/core/optimizer/__init__.py | 2 +- src/sparseml/core/optimizer/base.py | 4 +- src/sparseml/core/optimizer/pytorch.py | 6 +- src/sparseml/core/recipe/__init__.py | 8 +- src/sparseml/core/recipe/base.py | 4 +- src/sparseml/core/recipe/container.py | 85 +++++++ src/sparseml/core/recipe/modifier.py | 2 +- src/sparseml/core/recipe/recipe.py | 138 ++++++++---- src/sparseml/core/recipe/stage.py | 89 +++++--- src/sparseml/core/session.py | 268 +++++------------------ src/sparseml/core/state.py | 138 ++++-------- 26 files changed, 1021 insertions(+), 725 deletions(-) create mode 100644 src/sparseml/core/framework_object.py create mode 100644 src/sparseml/core/lifecycle/__init__.py create mode 100644 src/sparseml/core/lifecycle/event.py create mode 100644 src/sparseml/core/lifecycle/session.py create mode 100644 src/sparseml/core/recipe/container.py diff --git a/src/sparseml/core/__init__.py b/src/sparseml/core/__init__.py index bc26ded926c..3d69db68239 100644 --- a/src/sparseml/core/__init__.py +++ b/src/sparseml/core/__init__.py @@ -12,4 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .session import * \ No newline at end of file +from .data import * +from .event import * +from .framework import * +from .framework_object import * +from .lifecycle import * +from .model import * +from .modifier import * +from .optimizer import * +from .recipe import * +from .session import * +from .state import * diff --git a/src/sparseml/core/data/base.py b/src/sparseml/core/data/base.py index 3d15e48a777..e8994668635 100644 --- a/src/sparseml/core/data/base.py +++ b/src/sparseml/core/data/base.py @@ -15,7 +15,7 @@ from dataclasses import dataclass from typing import Generic, TypeVar -from sparseml.core.framework import MultiFrameworkObject +from sparseml.core.framework_object import MultiFrameworkObject __all__ = ["ModifiableData"] diff --git a/src/sparseml/core/event.py b/src/sparseml/core/event.py index afa21372c51..9d4046914cc 100644 --- a/src/sparseml/core/event.py +++ b/src/sparseml/core/event.py @@ -12,19 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, abstractmethod from copy import deepcopy from dataclasses import dataclass from enum import Enum -from typing import List, Optional +from typing import Optional __all__ = [ "EventType", "Event", - "EventLifecycle", - "WrappedOptimEventLifecycle", - "CallbacksEventLifecycle", ] @@ -148,258 +144,3 @@ def new_instance(self, **kwargs) -> "Event": setattr(instance, key, value) return instance - - -class EventLifecycle(ABC, Event): - type_first: EventType = None - step_count: int = 0 - batch_count: int = 0 - - def __init__(self, type_first: EventType, start: Event): - self.type_first = type_first - self.steps_per_epoch = start.steps_per_epoch - self.batches_per_step = start.batches_per_step - self.invocations_per_step = start.invocations_per_step - self.global_step = start.global_step - self.global_batch = start.global_batch - - def events_from_type(self, type_: EventType) -> List[Event]: - if type_ == EventType.BATCH_START: - return self.batch_start_events() - - if type_ == EventType.LOSS_CALCULATED: - return self.loss_calculated_events() - - if type_ == EventType.OPTIM_PRE_STEP: - return self.optim_pre_step_events() - - if type_ == EventType.OPTIM_POST_STEP: - return self.optim_post_step_events() - - if type_ == EventType.BATCH_END: - return self.batch_end_events() - - raise ValueError(f"invalid event type {type_}") - - def check_step_batches_count(self, increment: bool) -> bool: - if self.batches_per_step is None or self.batches_per_step < 2: - return True - - compare_batch = self.batch_count + 1 - at_step = compare_batch % self.batches_per_step == 0 - - if increment: - self.batch_count = compare_batch if not at_step else 0 - - return at_step - - def check_step_invocations_count(self, increment: bool) -> bool: - if self.invocations_per_step is None or self.invocations_per_step < 2: - return True - - compare_step = self.step_count + 1 - at_step = compare_step % self.invocations_per_step == 0 - - if increment: - self.step_count = compare_step if not at_step else 0 - - return at_step - - @abstractmethod - def batch_start_events(self) -> List[Event]: - raise NotImplementedError() - - @abstractmethod - def loss_calculated_events(self) -> List[Event]: - raise NotImplementedError() - - @abstractmethod - def optim_pre_step_events(self) -> List[Event]: - raise NotImplementedError() - - @abstractmethod - def optim_post_step_events(self) -> List[Event]: - raise NotImplementedError() - - @abstractmethod - def batch_end_events(self) -> List[Event]: - raise NotImplementedError() - - -class WrappedOptimEventLifecycle(EventLifecycle): - """ - Optimizer is wrapped and no batch or optim callbacks - - batch_start: must not be invoked, auto triggered - from loss calculated if that is called, otherwise from pre_step - - loss_calculated: must be called before batch_end and optim_pre_step - - batch_end: must not be invoked, auto triggered from optim_post_step - - optim_pre_step: must be called before optim_post_step - - optim_post_step: must be called only once after optim_pre_step - """ - - def batch_start_events(self) -> List[Event]: - raise ValueError("batch start should not be invoked when only wrapped optim") - - def loss_calculated_events(self) -> List[Event]: - if self.type_first != EventType.LOSS_CALCULATED: - raise ValueError("loss calculated must be called first for wrapped optim") - - if ( - self.type_ != EventType.OPTIM_POST_STEP - and self.type_ != EventType.LOSS_CALCULATED - ): - raise ValueError( - "loss calculated must be called after batch end or optim post step" - ) - - self.type_ = EventType.LOSS_CALCULATED - self.global_batch += 1 - - if not self.check_step_batches_count(increment=True): - # step won't be called, so batch end must be called - return [ - self.new_instance(type_=EventType.BATCH_START), - self.new_instance(type_=EventType.LOSS_CALCULATED), - self.new_instance(type_=EventType.BATCH_END), - ] - else: - # batch end handled by optim step - return [ - self.new_instance(type_=EventType.BATCH_START), - self.new_instance(type_=EventType.LOSS_CALCULATED), - ] - - def optim_pre_step_events(self) -> List[Event]: - if ( - self.type_first == EventType.OPTIM_PRE_STEP - and self.type_ is not None - and self.type_ != EventType.OPTIM_POST_STEP - ): - raise ValueError("optim pre step must be called after optim post step") - - if ( - self.type_first == EventType.LOSS_CALCULATED - and self.type_ != EventType.LOSS_CALCULATED - ): - raise ValueError("optim pre step must be called after loss calculated") - - self.type_ = EventType.OPTIM_PRE_STEP - - if self.type_first == EventType.OPTIM_PRE_STEP: - self.global_batch += ( - 1 - if self.batches_per_step is None or self.batches_per_step < 2 - else self.batches_per_step - ) - batch_start_events = [self.new_instance(type_=EventType.BATCH_START)] - else: - batch_start_events = [] - - if not self.check_step_invocations_count(increment=False): - return batch_start_events - - return batch_start_events + [ - self.new_instance(type_=EventType.OPTIM_PRE_STEP), - ] - - def optim_post_step_events(self) -> List[Event]: - if self.type_ != EventType.OPTIM_PRE_STEP: - raise ValueError("optim post step must be called after optim pre step") - - self.type_ = EventType.OPTIM_POST_STEP - - if not self.check_step_invocations_count(increment=True): - return [ - self.new_instance(type_=EventType.BATCH_END), - ] - - self.global_step += 1 - - return [ - self.new_instance(type_=EventType.OPTIM_POST_STEP), - self.new_instance(type_=EventType.BATCH_END), - ] - - def batch_end_events(self) -> List[Event]: - raise ValueError("batch end should not be invoked when only wrapped optim") - - -class CallbacksEventLifecycle(EventLifecycle): - """ - Optimizer is not wrapped, callbacks must be used - - batch_start: must be called first - - loss_calculated: must be called before batch_end and optim_post_step - - batch_end: must be called before next batch start - - optim_pre_step: must be invoked before optim_post_step - - optim_post_step: must be called only once after optim_pre_step - """ - - def batch_start_events(self) -> List[Event]: - if self.type_first != EventType.BATCH_START: - raise ValueError("batch start must be called first for callbacks") - - if self.type_ is not None and self.type_ != EventType.BATCH_END: - raise ValueError("batch start must be called after batch end") - - self.type_ = EventType.BATCH_START - self.global_batch += 1 - - return [self.new_instance(type_=EventType.BATCH_START)] - - def loss_calculated_events(self) -> List[Event]: - if self.type_ != EventType.BATCH_START: - raise ValueError("loss calculated must be called after batch start") - - self.type_ = EventType.LOSS_CALCULATED - - return [self.new_instance(type_=EventType.LOSS_CALCULATED)] - - def optim_pre_step_events(self) -> List[Event]: - if ( - self.type_ != EventType.BATCH_START - and self.type_ != EventType.LOSS_CALCULATED - ): - raise ValueError( - "optim pre step must be called after batch start or loss calculated" - ) - - self.type_ = EventType.OPTIM_PRE_STEP - - if not self.check_step_invocations_count(increment=False): - return [] - - return [ - self.new_instance(type_=EventType.OPTIM_PRE_STEP), - ] - - def optim_post_step_events(self) -> List[Event]: - if self.type_ != EventType.OPTIM_PRE_STEP: - raise ValueError("optim post step must be called after optim pre step") - - self.type_ = EventType.OPTIM_POST_STEP - - if not self.check_step_invocations_count(increment=True): - return [] - - self.global_step += 1 - - return [ - self.new_instance(type_=EventType.OPTIM_POST_STEP), - ] - - def batch_end_events(self) -> List[Event]: - if ( - self.type_ != EventType.OPTIM_POST_STEP - and self.type_ != EventType.LOSS_CALCULATED - and self.type_ != EventType.BATCH_START - ): - raise ValueError( - "batch end must be called after optim post step or " - "loss calculated or batch start" - ) - - self.type_ = EventType.BATCH_END - - return [ - self.new_instance(type_=EventType.BATCH_END), - ] diff --git a/src/sparseml/core/framework.py b/src/sparseml/core/framework.py index c7a67fe85cd..4b7130161bf 100644 --- a/src/sparseml/core/framework.py +++ b/src/sparseml/core/framework.py @@ -17,7 +17,7 @@ from enum import Enum -__all__ = ["Framework", "MultiFrameworkObject"] +__all__ = ["Framework"] class Framework(Enum): @@ -65,50 +65,3 @@ def formatted(self) -> str: def class_name(self) -> str: return self.formatted() if self != self.general else "" - - -class MultiFrameworkObject: - def __new__( - cls, - framework: Framework = None, - enable_experimental: bool = False, - **kwargs, - ): - if cls is MultiFrameworkObject: - raise TypeError("MultiFrameworkObject cannot be instantiated directly") - - instance = super(MultiFrameworkObject, cls).__new__(cls) - - package = instance.__class__.__module__.rsplit(".", 1)[0] - class_name = instance.__class__.__name__ - - if framework is None or framework == Framework.general: - return instance - - if enable_experimental: - # check under the experimental package first - try: - return MultiFrameworkObject.load_framework_class( - f"{package}.experimental.{str(framework)}", - f"{class_name}{framework.class_name()}", - )(**kwargs) - except ImportError: - pass - - # next check under the main package for the framework version - try: - return MultiFrameworkObject.load_framework_class( - f"{package}.{str(framework)}", f"{class_name}{framework.class_name()}" - )(**kwargs) - except ImportError: - pass - - # fall back on the class that was requested and - # fail later if it doesn't support that framework - return instance - - @staticmethod - def load_framework_class(package: str, class_name: str): - module = importlib.import_module(package) - - return getattr(module, class_name) diff --git a/src/sparseml/core/framework_object.py b/src/sparseml/core/framework_object.py new file mode 100644 index 00000000000..9fa4c9f08e6 --- /dev/null +++ b/src/sparseml/core/framework_object.py @@ -0,0 +1,68 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import importlib + +from sparseml.core.framework import Framework + + +__all__ = ["MultiFrameworkObject"] + + +class MultiFrameworkObject: + def __new__( + cls, + framework: Framework = None, + enable_experimental: bool = False, + **kwargs, + ): + if cls is MultiFrameworkObject: + raise TypeError("MultiFrameworkObject cannot be instantiated directly") + + instance = super(MultiFrameworkObject, cls).__new__(cls, **kwargs) + + package = instance.__class__.__module__.rsplit(".", 1)[0] + class_name = instance.__class__.__name__ + + if framework is None or framework == Framework.general: + return instance + + if enable_experimental: + # check under the experimental package first + try: + return MultiFrameworkObject.load_framework_class( + f"{package}.experimental.{str(framework)}", + f"{class_name}{framework.class_name()}", + )(**kwargs) + except ImportError: + pass + + # next check under the main package for the framework version + try: + return MultiFrameworkObject.load_framework_class( + f"{package}.{str(framework)}", f"{class_name}{framework.class_name()}" + )(**kwargs) + except ImportError: + pass + + # fall back on the class that was requested and + # fail later if it doesn't support that framework + return instance + + @staticmethod + def load_framework_class(package: str, class_name: str): + module = importlib.import_module(package) + + return getattr(module, class_name) diff --git a/src/sparseml/core/lifecycle/__init__.py b/src/sparseml/core/lifecycle/__init__.py new file mode 100644 index 00000000000..581cb06e687 --- /dev/null +++ b/src/sparseml/core/lifecycle/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .event import * +from .session import * diff --git a/src/sparseml/core/lifecycle/event.py b/src/sparseml/core/lifecycle/event.py new file mode 100644 index 00000000000..40937c8ff2f --- /dev/null +++ b/src/sparseml/core/lifecycle/event.py @@ -0,0 +1,280 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import List + +from sparseml.core.event import Event, EventType + + +__all__ = [ + "EventLifecycle", + "WrappedOptimEventLifecycle", + "CallbacksEventLifecycle", +] + + +class EventLifecycle(ABC, Event): + type_first: EventType = None + step_count: int = 0 + batch_count: int = 0 + + def __init__(self, type_first: EventType, start: Event): + self.type_first = type_first + self.steps_per_epoch = start.steps_per_epoch + self.batches_per_step = start.batches_per_step + self.invocations_per_step = start.invocations_per_step + self.global_step = start.global_step + self.global_batch = start.global_batch + + def events_from_type(self, type_: EventType) -> List[Event]: + if type_ == EventType.BATCH_START: + return self.batch_start_events() + + if type_ == EventType.LOSS_CALCULATED: + return self.loss_calculated_events() + + if type_ == EventType.OPTIM_PRE_STEP: + return self.optim_pre_step_events() + + if type_ == EventType.OPTIM_POST_STEP: + return self.optim_post_step_events() + + if type_ == EventType.BATCH_END: + return self.batch_end_events() + + raise ValueError(f"invalid event type {type_}") + + def check_step_batches_count(self, increment: bool) -> bool: + if self.batches_per_step is None or self.batches_per_step < 2: + return True + + compare_batch = self.batch_count + 1 + at_step = compare_batch % self.batches_per_step == 0 + + if increment: + self.batch_count = compare_batch if not at_step else 0 + + return at_step + + def check_step_invocations_count(self, increment: bool) -> bool: + if self.invocations_per_step is None or self.invocations_per_step < 2: + return True + + compare_step = self.step_count + 1 + at_step = compare_step % self.invocations_per_step == 0 + + if increment: + self.step_count = compare_step if not at_step else 0 + + return at_step + + @abstractmethod + def batch_start_events(self) -> List[Event]: + raise NotImplementedError() + + @abstractmethod + def loss_calculated_events(self) -> List[Event]: + raise NotImplementedError() + + @abstractmethod + def optim_pre_step_events(self) -> List[Event]: + raise NotImplementedError() + + @abstractmethod + def optim_post_step_events(self) -> List[Event]: + raise NotImplementedError() + + @abstractmethod + def batch_end_events(self) -> List[Event]: + raise NotImplementedError() + + +class WrappedOptimEventLifecycle(EventLifecycle): + """ + Optimizer is wrapped and no batch or optim callbacks + - batch_start: must not be invoked, auto triggered + from loss calculated if that is called, otherwise from pre_step + - loss_calculated: must be called before batch_end and optim_pre_step + - batch_end: must not be invoked, auto triggered from optim_post_step + - optim_pre_step: must be called before optim_post_step + - optim_post_step: must be called only once after optim_pre_step + """ + + def batch_start_events(self) -> List[Event]: + raise ValueError("batch start should not be invoked when only wrapped optim") + + def loss_calculated_events(self) -> List[Event]: + if self.type_first != EventType.LOSS_CALCULATED: + raise ValueError("loss calculated must be called first for wrapped optim") + + if ( + self.type_ != EventType.OPTIM_POST_STEP + and self.type_ != EventType.LOSS_CALCULATED + ): + raise ValueError( + "loss calculated must be called after batch end or optim post step" + ) + + self.type_ = EventType.LOSS_CALCULATED + self.global_batch += 1 + + if not self.check_step_batches_count(increment=True): + # step won't be called, so batch end must be called + return [ + self.new_instance(type_=EventType.BATCH_START), + self.new_instance(type_=EventType.LOSS_CALCULATED), + self.new_instance(type_=EventType.BATCH_END), + ] + else: + # batch end handled by optim step + return [ + self.new_instance(type_=EventType.BATCH_START), + self.new_instance(type_=EventType.LOSS_CALCULATED), + ] + + def optim_pre_step_events(self) -> List[Event]: + if ( + self.type_first == EventType.OPTIM_PRE_STEP + and self.type_ is not None + and self.type_ != EventType.OPTIM_POST_STEP + ): + raise ValueError("optim pre step must be called after optim post step") + + if ( + self.type_first == EventType.LOSS_CALCULATED + and self.type_ != EventType.LOSS_CALCULATED + ): + raise ValueError("optim pre step must be called after loss calculated") + + self.type_ = EventType.OPTIM_PRE_STEP + + if self.type_first == EventType.OPTIM_PRE_STEP: + self.global_batch += ( + 1 + if self.batches_per_step is None or self.batches_per_step < 2 + else self.batches_per_step + ) + batch_start_events = [self.new_instance(type_=EventType.BATCH_START)] + else: + batch_start_events = [] + + if not self.check_step_invocations_count(increment=False): + return batch_start_events + + return batch_start_events + [ + self.new_instance(type_=EventType.OPTIM_PRE_STEP), + ] + + def optim_post_step_events(self) -> List[Event]: + if self.type_ != EventType.OPTIM_PRE_STEP: + raise ValueError("optim post step must be called after optim pre step") + + self.type_ = EventType.OPTIM_POST_STEP + + if not self.check_step_invocations_count(increment=True): + return [ + self.new_instance(type_=EventType.BATCH_END), + ] + + self.global_step += 1 + + return [ + self.new_instance(type_=EventType.OPTIM_POST_STEP), + self.new_instance(type_=EventType.BATCH_END), + ] + + def batch_end_events(self) -> List[Event]: + raise ValueError("batch end should not be invoked when only wrapped optim") + + +class CallbacksEventLifecycle(EventLifecycle): + """ + Optimizer is not wrapped, callbacks must be used + - batch_start: must be called first + - loss_calculated: must be called before batch_end and optim_post_step + - batch_end: must be called before next batch start + - optim_pre_step: must be invoked before optim_post_step + - optim_post_step: must be called only once after optim_pre_step + """ + + def batch_start_events(self) -> List[Event]: + if self.type_first != EventType.BATCH_START: + raise ValueError("batch start must be called first for callbacks") + + if self.type_ is not None and self.type_ != EventType.BATCH_END: + raise ValueError("batch start must be called after batch end") + + self.type_ = EventType.BATCH_START + self.global_batch += 1 + + return [self.new_instance(type_=EventType.BATCH_START)] + + def loss_calculated_events(self) -> List[Event]: + if self.type_ != EventType.BATCH_START: + raise ValueError("loss calculated must be called after batch start") + + self.type_ = EventType.LOSS_CALCULATED + + return [self.new_instance(type_=EventType.LOSS_CALCULATED)] + + def optim_pre_step_events(self) -> List[Event]: + if ( + self.type_ != EventType.BATCH_START + and self.type_ != EventType.LOSS_CALCULATED + ): + raise ValueError( + "optim pre step must be called after batch start or loss calculated" + ) + + self.type_ = EventType.OPTIM_PRE_STEP + + if not self.check_step_invocations_count(increment=False): + return [] + + return [ + self.new_instance(type_=EventType.OPTIM_PRE_STEP), + ] + + def optim_post_step_events(self) -> List[Event]: + if self.type_ != EventType.OPTIM_PRE_STEP: + raise ValueError("optim post step must be called after optim pre step") + + self.type_ = EventType.OPTIM_POST_STEP + + if not self.check_step_invocations_count(increment=True): + return [] + + self.global_step += 1 + + return [ + self.new_instance(type_=EventType.OPTIM_POST_STEP), + ] + + def batch_end_events(self) -> List[Event]: + if ( + self.type_ != EventType.OPTIM_POST_STEP + and self.type_ != EventType.LOSS_CALCULATED + and self.type_ != EventType.BATCH_START + ): + raise ValueError( + "batch end must be called after optim post step or " + "loss calculated or batch start" + ) + + self.type_ = EventType.BATCH_END + + return [ + self.new_instance(type_=EventType.BATCH_END), + ] diff --git a/src/sparseml/core/lifecycle/session.py b/src/sparseml/core/lifecycle/session.py new file mode 100644 index 00000000000..b7ff8c6b4ae --- /dev/null +++ b/src/sparseml/core/lifecycle/session.py @@ -0,0 +1,203 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, List, Optional + +from sparseml.core.event import EventType +from sparseml.core.framework import Framework +from sparseml.core.lifecycle.event import ( + CallbacksEventLifecycle, + EventLifecycle, + WrappedOptimEventLifecycle, +) +from sparseml.core.modifier import ModifierInterface +from sparseml.core.recipe import RecipeContainer +from sparseml.core.state import State + + +__all__ = [ + "SparsificationLifecycle", +] + + +@dataclass +class SparsificationLifecycle: + state: Optional[State] = None + recipe_container: RecipeContainer = RecipeContainer() + modifiers: List[ModifierInterface] = field(default_factory=list) + event_lifecycle: Optional[EventLifecycle] = None + + initialized_structure: bool = False + initialized: bool = False + finalized: bool = False + event_called: bool = False + + def reset(self): + for mod in self.modifiers: + if not mod.initialized or mod.finalized: + continue + + try: + mod.finalize(self.state) + except Exception: + pass + + self.state = None + self.recipe_container = RecipeContainer() + self.modifiers = [] + self.event_lifecycle = None + + self.initialized_structure = False + self.initialized = False + self.finalized = False + self.event_called = False + + def pre_initialize_structure( + self, framework: Framework = None, **kwargs + ) -> List[Any]: + self._check_create_state(framework=framework) + extras = self.state.update(**kwargs) + extras = self.recipe_container.update(**extras) + + self._check_compile_recipe() + mod_data = [] + for mod in self.modifiers: + data = mod.pre_initialize_structure(state=self.state, **extras) + if data is not None: + mod_data.append(data) + + self.initialized_structure = True + + return mod_data + + def initialize(self, framework: Framework = None, **kwargs) -> List[Any]: + self._check_create_state(framework=framework) + extras = self.state.update(**kwargs) + extras = self.recipe_container.update(**extras) + + self._check_compile_recipe() + mod_data = [] + for mod in self.modifiers: + data = mod.initialize(state=self.state, **extras) + if data is not None: + mod_data.append(data) + + self.initialized = True + + return mod_data + + def finalize(self, **kwargs) -> List[Any]: + if not self.initialized: + raise ValueError("Cannot finalize before initializing") + + if self.finalized: + raise ValueError("Cannot finalize more than once") + + mod_data = [] + for mod in self.modifiers: + data = mod.finalize(state=self.state, **kwargs) + if data is not None: + mod_data.append(data) + + self.finalized = True + + return mod_data + + def event(self, event_type: EventType, **kwargs) -> List[Any]: + if not self.initialized: + raise ValueError("Cannot invoke event before initializing") + + if self.finalized: + raise ValueError("Cannot invoke event after finalizing") + + if event_type in [EventType.PRE_INIT, EventType.INITIALIZE, EventType.FINALIZE]: + raise ValueError( + f"Cannot invoke {event_type} event. " + f"Use the corresponding method instead." + ) + + if event_type == EventType.LOSS_CALCULATED and ( + "loss" not in kwargs or kwargs["loss"] is None + ): + raise ValueError("Loss must be provided for loss calculated event") + + self._check_setup_event_lifecycle(event_type) + + event = None + mod_data = [] + for event in self.event_lifecycle.events_from_type(event_type): + if self.state.start_event is None: + self.state.start_event = event + + for mod in self.modifiers: + data = mod.update_event(state=self.state, event=event, **kwargs) + if data is not None: + mod_data.append(data) + + assert ( + event is not None + ), f"Event lifecycle did not return an event for {event_type}" + self.state.last_event = event + self.event_called = True + + return mod_data + + def _check_create_state(self, framework: Framework): + if self.state is not None: + return + + if framework is None: + raise ValueError("framework must be provided to create state") + + self.state = State(framework=framework) + + def _check_compile_recipe(self): + if self.recipe_container.check_compile_recipe(): + self.modifiers = self.recipe_container.compiled_recipe.create_modifier( + self.state.framework + ) + + def _check_setup_event_lifecycle(self, event_type: EventType): + if self.event_lifecycle is not None: + return + + if ( + self.state is None + or self.state.model is None + or self.state.start_event + or self.recipe_container.compiled_recipe is None + ): + raise ValueError( + "Cannot invoke event before recipe, model, and start are set" + ) + + if not self.state.sparsification_ready: + raise ValueError( + "Cannot invoke event before recipe, model, and start are set" + ) + + for mod in self.modifiers: + mod.check_initialized() + + if event_type == EventType.BATCH_START: + self.event_lifecycle = WrappedOptimEventLifecycle( + type_first=EventType.BATCH_START, start=self.state.start_event + ) + elif event_type == EventType.LOSS_CALCULATED: + self.event_lifecycle = CallbacksEventLifecycle( + type_first=EventType.LOSS_CALCULATED, start=self.state.start_event + ) + else: + raise ValueError(f"invalid event type {event_type}") diff --git a/src/sparseml/core/model/__init__.py b/src/sparseml/core/model/__init__.py index 7a2c12e5d45..81ade568d8d 100644 --- a/src/sparseml/core/model/__init__.py +++ b/src/sparseml/core/model/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import ModifiableModel, ModelParameterizedLayer +from .base import ModelParameterizedLayer, ModifiableModel diff --git a/src/sparseml/core/model/base.py b/src/sparseml/core/model/base.py index 96f2fb789d6..a08c8129c72 100644 --- a/src/sparseml/core/model/base.py +++ b/src/sparseml/core/model/base.py @@ -15,7 +15,7 @@ from dataclasses import dataclass from typing import Dict, Generic, List, TypeVar, Union -from sparseml.core.framework import MultiFrameworkObject +from sparseml.core.framework_object import MultiFrameworkObject __all__ = ["ModifiableModel", "ModelParameterizedLayer"] diff --git a/src/sparseml/core/model/pytorch.py b/src/sparseml/core/model/pytorch.py index b394f3a753f..c2d12065faf 100644 --- a/src/sparseml/core/model/pytorch.py +++ b/src/sparseml/core/model/pytorch.py @@ -31,7 +31,6 @@ class ModifiableModelPyTorch(ModifiableModel[Module, Module, Parameter]): - def __init__(self, framework=None, model=None): super().__init__(framework=framework, model=model) diff --git a/src/sparseml/core/modifier/__init__.py b/src/sparseml/core/modifier/__init__.py index af25a2b2db2..2da941711e8 100644 --- a/src/sparseml/core/modifier/__init__.py +++ b/src/sparseml/core/modifier/__init__.py @@ -12,5 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .stage import StageModifiers -from .factory import ModifierFactory \ No newline at end of file +from .base import * +from .factory import * +from .modifier import * +from .stage import * diff --git a/src/sparseml/core/modifier/base.py b/src/sparseml/core/modifier/base.py index cbdf2a002b6..693bbbdbb42 100644 --- a/src/sparseml/core/modifier/base.py +++ b/src/sparseml/core/modifier/base.py @@ -15,6 +15,10 @@ from abc import ABC, abstractmethod +from sparseml.core.event import Event +from sparseml.core.state import State + + __all__ = ["ModifierInterface"] @@ -22,6 +26,21 @@ class ModifierInterface(ABC): def __init__(self, **kwargs): pass + @property + @abstractmethod + def initialized_structure(self) -> bool: + raise NotImplementedError() + + @property + @abstractmethod + def initialized(self) -> bool: + raise NotImplementedError() + + @property + @abstractmethod + def finalized(self) -> bool: + raise NotImplementedError() + @abstractmethod def check_initialized(self): raise NotImplementedError() @@ -35,17 +54,17 @@ def calculate_end(self) -> float: raise NotImplementedError() @abstractmethod - def pre_initialize_structure(self, state: "State", **kwargs): + def pre_initialize_structure(self, state: State, **kwargs): raise NotImplementedError() @abstractmethod - def initialize(self, state: "State", **kwargs): + def initialize(self, state: State, **kwargs): raise NotImplementedError() @abstractmethod - def finalize(self, state: "State", **kwargs): + def finalize(self, state: State, **kwargs): raise NotImplementedError() @abstractmethod - def update_event(self, state: "State", event: "Event", **kwargs): + def update_event(self, state: State, event: Event, **kwargs): raise NotImplementedError() diff --git a/src/sparseml/core/modifier/modifier.py b/src/sparseml/core/modifier/modifier.py index 6c1c6d92df6..2d4d80550bb 100644 --- a/src/sparseml/core/modifier/modifier.py +++ b/src/sparseml/core/modifier/modifier.py @@ -18,13 +18,15 @@ from pydantic import BaseModel from sparseml.core.event import Event, EventType -from sparseml.core.framework import MultiFrameworkObject +from sparseml.core.framework_object import MultiFrameworkObject from sparseml.core.modifier.base import ModifierInterface +from sparseml.core.state import State + __all__ = ["Modifier"] -class Modifier(ModifierInterface, MultiFrameworkObject, BaseModel): +class Modifier(BaseModel, ModifierInterface, MultiFrameworkObject): index: int = None group: str = None start: float = None @@ -37,6 +39,18 @@ class Modifier(ModifierInterface, MultiFrameworkObject, BaseModel): _started: bool = False _ended: bool = False + @property + def initialized_structure(self) -> bool: + return self._initialized_structure + + @property + def initialized(self) -> bool: + return self._initialized + + @property + def finalized(self) -> bool: + return self._finalized + def check_initialized(self): if not self._initialized: raise RuntimeError("modifier has not been initialized") @@ -47,11 +61,11 @@ def calculate_start(self) -> float: def calculate_end(self) -> float: return self.end if self.end is not None else -1 - def pre_initialize_structure(self, state: "State", **kwargs): + def pre_initialize_structure(self, state: State, **kwargs): self.on_initialize_structure(state, **kwargs) self._initialized_structure = True - def initialize(self, state: "State", **kwargs): + def initialize(self, state: State, **kwargs): if self._initialized: return @@ -75,7 +89,7 @@ def initialize(self, state: "State", **kwargs): self.on_start(state, state.start_event, **kwargs) self._started = True - def finalize(self, state: "State", **kwargs): + def finalize(self, state: State, **kwargs): if self._finalized: return @@ -92,7 +106,7 @@ def finalize(self, state: "State", **kwargs): self._finalized = finalized - def update_event(self, state: "State", event: Event, **kwargs): + def update_event(self, state: State, event: Event, **kwargs): if not self._initialized: raise RuntimeError("cannot update an uninitialized modifier") @@ -136,20 +150,20 @@ def should_end(self, event: Event): return self.end is not None and current >= self.end - def on_initialize_structure(self, state: "State", **kwargs): + def on_initialize_structure(self, state: State, **kwargs): raise NotImplementedError() - def on_initialize(self, state: "State", event: Event, **kwargs) -> bool: + def on_initialize(self, state: State, event: Event, **kwargs) -> bool: raise NotImplementedError() - def on_finalize(self, state: "State", event: Event, **kwargs) -> bool: + def on_finalize(self, state: State, event: Event, **kwargs) -> bool: raise NotImplementedError() - def on_start(self, state: "State", event: Event, **kwargs): + def on_start(self, state: State, event: Event, **kwargs): raise NotImplementedError() - def on_update(self, state: "State", event: Event, **kwargs): + def on_update(self, state: State, event: Event, **kwargs): raise NotImplementedError() - def on_end(self, state: "State", event: Event, **kwargs): + def on_end(self, state: State, event: Event, **kwargs): raise NotImplementedError() diff --git a/src/sparseml/core/modifier/stage.py b/src/sparseml/core/modifier/stage.py index c27c4d2db52..1b3de491504 100644 --- a/src/sparseml/core/modifier/stage.py +++ b/src/sparseml/core/modifier/stage.py @@ -17,21 +17,31 @@ from pydantic import BaseModel, Field +from sparseml.core.event import Event from sparseml.core.modifier.base import ModifierInterface +from sparseml.core.modifier.modifier import Modifier +from sparseml.core.state import State -__all__ = [ - "StageModifier" -] + +__all__ = ["StageModifiers"] class StageModifiers(ModifierInterface, BaseModel): - modifiers: List["Modifier"] = Field(default_factory=list) + modifiers: List[Modifier] = Field(default_factory=list) index: int = None group: str = None - _initialized_structure: bool = False - _initialized: bool = False - _finalized: bool = False + @property + def initialized_structure(self) -> bool: + return any(mod.initialized_structure for mod in self.modifiers) + + @property + def initialized(self) -> bool: + return all(mod.initialized for mod in self.modifiers) + + @property + def finalized(self) -> bool: + return all(mod.finalized for mod in self.modifiers) def check_initialized(self): for modifier in self.modifiers: @@ -49,21 +59,18 @@ def calculate_end(self) -> float: mod.calculate_end() for mod in self.modifiers if mod.calculate_end() >= 0 ) - def pre_initialize_structure(self, state: "State", **kwargs): + def pre_initialize_structure(self, state: State, **kwargs): for modifier in self.modifiers: modifier.pre_initialize_structure(state, **kwargs) - self._initialized_structure = True - def initialize(self, state: "State", **kwargs): + def initialize(self, state: State, **kwargs): for modifier in self.modifiers: modifier.initialize(state, **kwargs) - self._initialized = True - def finalize(self, state: "State", **kwargs): + def finalize(self, state: State, **kwargs): for modifier in self.modifiers: modifier.finalize(state, **kwargs) - self._finalized = True - def update_event(self, state: "State", event: "Event", **kwargs): + def update_event(self, state: State, event: Event, **kwargs): for modifier in self.modifiers: modifier.update_event(state, event, **kwargs) diff --git a/src/sparseml/core/optimizer/__init__.py b/src/sparseml/core/optimizer/__init__.py index 07e2638ee13..6ded41b5440 100644 --- a/src/sparseml/core/optimizer/__init__.py +++ b/src/sparseml/core/optimizer/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import ModifiableOptimizer \ No newline at end of file +from .base import ModifiableOptimizer diff --git a/src/sparseml/core/optimizer/base.py b/src/sparseml/core/optimizer/base.py index 0058bf43bf2..41a4238ec1e 100644 --- a/src/sparseml/core/optimizer/base.py +++ b/src/sparseml/core/optimizer/base.py @@ -15,7 +15,7 @@ from dataclasses import dataclass from typing import Any, Generic, List, TypeVar, Union -from sparseml.core.framework import MultiFrameworkObject +from sparseml.core.framework_object import MultiFrameworkObject __all__ = ["ModifiableOptimizer"] @@ -28,7 +28,7 @@ @dataclass class ModifiableOptimizer(Generic[OT, PGT], MultiFrameworkObject): optimizer: OT = None - + def __init__(self, optimizer=None, attach_optim_callbacks=False, framework=None): self.optimizer = optimizer diff --git a/src/sparseml/core/optimizer/pytorch.py b/src/sparseml/core/optimizer/pytorch.py index 15d7d71b857..a06c9656a54 100644 --- a/src/sparseml/core/optimizer/pytorch.py +++ b/src/sparseml/core/optimizer/pytorch.py @@ -24,7 +24,11 @@ class ModifiableOptimizerPyTorch(ModifiableOptimizer[Optimizer, Dict[str, Any]]): def __init__(self, optimizer=None, attach_optim_callbacks=False, framework=None): - super().__init__(optimizer=optimizer, attach_optim_callbacks=attach_optim_callbacks, framework=framework) + super().__init__( + optimizer=optimizer, + attach_optim_callbacks=attach_optim_callbacks, + framework=framework, + ) def get_param_groups(self) -> List[Dict[str, Any]]: return self.optimizer.param_groups diff --git a/src/sparseml/core/recipe/__init__.py b/src/sparseml/core/recipe/__init__.py index 09223c3bd12..979d6f9d28a 100644 --- a/src/sparseml/core/recipe/__init__.py +++ b/src/sparseml/core/recipe/__init__.py @@ -12,4 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .recipe import Recipe \ No newline at end of file +from .args import * +from .base import * +from .container import * +from .metadata import * +from .modifier import * +from .recipe import * +from .stage import * diff --git a/src/sparseml/core/recipe/base.py b/src/sparseml/core/recipe/base.py index b781504406c..bbc028d7ef1 100644 --- a/src/sparseml/core/recipe/base.py +++ b/src/sparseml/core/recipe/base.py @@ -13,9 +13,9 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Dict +from typing import Any -from pydantic import BaseModel, root_validator +from pydantic import BaseModel from sparseml.core.framework import Framework from sparseml.core.recipe.args import RecipeArgs diff --git a/src/sparseml/core/recipe/container.py b/src/sparseml/core/recipe/container.py new file mode 100644 index 00000000000..d3721f277c8 --- /dev/null +++ b/src/sparseml/core/recipe/container.py @@ -0,0 +1,85 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Union + +from sparseml.core.recipe.recipe import Recipe, RecipeTuple + + +__all__ = ["RecipeContainer"] + + +@dataclass +class RecipeContainer: + compiled_recipe: Optional[Recipe] = None + recipes: List[RecipeTuple] = field(default_factory=list) + + def update( + self, + recipe: Union[str, List[str], Recipe, List[Recipe]] = None, + recipe_stage: Union[str, List[str]] = None, + recipe_args: Union[Dict[str, Any], List[Dict[str, Any]]] = None, + **kwargs, + ) -> Dict: + if recipe is not None: + self.compiled_recipe = None + + if not isinstance(recipe, list): + recipe = [recipe] + + if recipe_stage is None: + recipe_stage = [None] * len(recipe) + elif not isinstance(recipe_stage, list): + recipe_stage = [recipe_stage] * len(recipe) + + if recipe_args is None: + recipe_args = [{}] * len(recipe) + elif not isinstance(recipe_args, list): + recipe_args = [recipe_args] * len(recipe) + + if len(recipe) != len(recipe_stage) or len(recipe) != len(recipe_args): + raise ValueError( + "recipe, recipe_stage, and recipe_args must be the same length" + ) + + for rec, stage, args in zip(recipe, recipe_stage, recipe_args): + if isinstance(rec, str): + rec = Recipe.create_instance(rec) + self.recipes.append(RecipeTuple(rec, stage, args)) + + return kwargs + + def check_compile_recipe(self) -> bool: + if self.compiled_recipe is None and self.recipes: + self.compiled_recipe = Recipe.simplify_combine_recipes(self.recipes) + + return True + + return False diff --git a/src/sparseml/core/recipe/modifier.py b/src/sparseml/core/recipe/modifier.py index 43d27224ce6..108206e61cb 100644 --- a/src/sparseml/core/recipe/modifier.py +++ b/src/sparseml/core/recipe/modifier.py @@ -17,7 +17,7 @@ from pydantic import root_validator from sparseml.core.framework import Framework -from sparseml.core.modifier import ModifierFactory +from sparseml.core.modifier import Modifier, ModifierFactory from sparseml.core.recipe.args import RecipeArgs from sparseml.core.recipe.base import RecipeBase diff --git a/src/sparseml/core/recipe/recipe.py b/src/sparseml/core/recipe/recipe.py index 668e00764b2..9c2d4a7bee2 100644 --- a/src/sparseml/core/recipe/recipe.py +++ b/src/sparseml/core/recipe/recipe.py @@ -14,6 +14,7 @@ import json import os +from dataclasses import dataclass from typing import Any, Dict, List, Tuple, Union import yaml @@ -27,7 +28,7 @@ from sparseml.core.recipe.stage import RecipeStage -__all__ = ["Recipe"] +__all__ = ["Recipe", "RecipeTuple"] class Recipe(RecipeBase): @@ -57,8 +58,11 @@ def create_instance(path: str) -> "Recipe": @staticmethod def simplify_recipe( - recipe: "Recipe", stages: List[str], args: Dict[str, Any], shift: int = None + recipe: Union["Recipe", "RecipeTuple"], shift: int = None ) -> "Recipe": + stages = recipe.target_stages if isinstance(recipe, RecipeTuple) else [] + args = recipe.override_args if isinstance(recipe, RecipeTuple) else {} + simplified = Recipe() simplified.version = recipe.version simplified.args = recipe.args @@ -74,28 +78,19 @@ def simplify_recipe( @staticmethod def simplify_combine_recipes( - recipes: List[Union["Recipe", Tuple["Recipe", str, Dict[str, Any]]]] + recipes: List[Union["Recipe", "RecipeTuple"]] ) -> "Recipe": - simplified = Recipe() + combined = Recipe() - for recipe_tuple in recipes: - recipe = ( - recipe_tuple[0] if isinstance(recipe_tuple, tuple) else recipe_tuple - ) - stages = ( - recipe_tuple[1].split(",") if isinstance(recipe_tuple, tuple) else None - ) - args = recipe_tuple[2] if isinstance(recipe_tuple, tuple) else None - recipe_simple = Recipe.simplify_recipe( + for recipe in recipes: + simplified = Recipe.simplify_recipe( recipe=recipe, - stages=stages, - args=args, - shift=simplified.calculate_end(), + shift=combined.calculate_end(), ) - simplified.version = recipe_simple.version - simplified.stages.extend(recipe_simple.stages) + combined.version = simplified.version + combined.stages.extend(simplified.stages) - return simplified + return combined version: str = None args: RecipeArgs = None @@ -127,7 +122,7 @@ def create_modifier(self, framework: Framework) -> List[StageModifiers]: modifiers = [] for index, stage in enumerate(self.stages): - stage_modifiers = stage.create_modifiers(framework) + stage_modifiers = stage.create_modifier(framework) stage_modifiers.index = index stage_modifiers.group = stage.group modifiers.append(stage_modifiers) @@ -136,41 +131,98 @@ def create_modifier(self, framework: Framework) -> List[StageModifiers]: @root_validator(pre=True) def remap_stages(cls, values: Dict[str, Any]) -> Dict[str, Any]: - modifiers = RecipeStage._combine_modifiers(values) - stages = [{"modifiers": modifiers, "group": "default"}] if modifiers else [] - add_stages, remove_keys = Recipe._combine_stages(values) - stages.extend(add_stages) + stages = [] - for key in remove_keys: - del values[key] + modifiers = RecipeStage.extract_dict_modifiers(values) + if modifiers: + default_stage = {"modifiers": modifiers, "group": "default"} + stages.append(default_stage) - values["stages"] = Recipe._combine_stages(values) + extracted = Recipe.extract_dict_stages(values) + stages.extend(extracted) + values["stages"] = stages return values + @staticmethod + def extract_dict_stages(values: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + Accepted stage formats: + - stages: + first_stage: + modifiers: ... + second_stage: + modifiers: ... + + - first_stage: + modifiers: ... + - second_stage: + modifiers: ... + + Accepted modifier formats default stage: + - modifiers: + - ModifierTypeOne + ... + - ModifierTypeTwo + ... + + - first_modifiers: + - ModifierTypeOne + ... + - ModifierTypeTwo + ... + """ + + stages = [] + remove_keys = [] + + default_modifiers = RecipeStage.extract_dict_modifiers(values) + if default_modifiers: + default_stage = {"modifiers": default_modifiers, "group": "default"} + stages.append(default_stage) + + if "stages" in values and values["stages"]: + assert isinstance( + values["stages"], dict + ), f"stages must be a dict, given {values['stages']}" + remove_keys.append("stages") + + for key, value in values["stages"].items(): + assert isinstance(value, dict), f"stage must be a dict, given {value}" + value["group"] = key + stages.append(value) + + for key, value in list(values.items()): + if key.endswith("_stage"): + remove_keys.append(key) + value["group"] = key.rsplit("_stage", 1)[0] + stages.append(value) + + for key in remove_keys: + del values[key] + + return stages + def dict(self, *args, **kwargs) -> Dict[str, Any]: dict_ = super().dict(*args, **kwargs) + stages = {} for stage in dict_["stages"]: - name = f"{stage['group']}_stage" + name = stage["group"] del stage["group"] - dict_[name] = stage["args"] - del dict_["stages"] + if name not in stages: + stages[name] = [] - return dict_ + stages[name].append(stage) - @staticmethod - def _combine_stages( - values: Dict[str, Any] - ) -> Tuple[List[Dict[str, Any]], List[str]]: - stages = [] - keys = [] + dict_["stages"] = stages + + return dict_ - for key, value in list(values.items()): - if key.endswith("_stage"): - keys.append(key) - value["group"] = key.rsplit("_stage", 1)[0] - stages.append(value) - return stages, keys +@dataclass +class RecipeTuple: + recipe: Recipe + target_stages: Union[str, List[str]] + override_args: Dict[str, Any] diff --git a/src/sparseml/core/recipe/stage.py b/src/sparseml/core/recipe/stage.py index 54effcfc042..37f8a5cf9f0 100644 --- a/src/sparseml/core/recipe/stage.py +++ b/src/sparseml/core/recipe/stage.py @@ -52,7 +52,7 @@ def evaluate(self, parent_args: RecipeArgs = None, shift: int = None): for modifier in self.modifiers: modifier.evaluate(self._args_evaluated, shift) - def create_modifiers( + def create_modifier( self, framework: Framework, parent_args: RecipeArgs = None ) -> StageModifiers: if parent_args is not None: @@ -68,47 +68,74 @@ def create_modifiers( @root_validator(pre=True) def remap_modifiers(cls, values: Dict[str, Any]) -> Dict[str, Any]: - modifiers = [] - add_modifiers, remove_keys = RecipeStage._combine_modifiers(values) - modifiers.extend(add_modifiers) - for key in remove_keys: - del values[key] + modifiers = RecipeStage.extract_dict_modifiers(values) values["modifiers"] = modifiers return values - def dict(self, *args, **kwargs) -> Dict[str, Any]: - dict_ = super().dict(*args, **kwargs) - modifier_groups = dict() - - for modifier in dict_["modifiers"]: - group = modifier["group"] - del modifier["group"] - if group not in modifier_groups: - modifier_groups[group] = [] - modifier_groups[group].append(modifier) - - for group, modifiers in modifier_groups.items(): - name = f"{group}_modifiers" if group != "default" else "modifiers" - dict_[name] = modifiers + @staticmethod + def extract_dict_modifiers(values: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + Accepted formats: + - modifiers: + - ModifierTypeOne + ... + - ModifierTypeTwo + ... + + - first_modifiers: + - ModifierTypeOne + ... + - ModifierTypeTwo + ... + """ - del dict_["modifiers"] + modifiers = [] + remove_keys = [] - return dict_ + if "modifiers" in values and values["modifiers"]: + assert isinstance( + values["modifiers"], list + ), f"modifiers must be a list, given {values['modifiers']}" + remove_keys.append("modifiers") - @staticmethod - def _combine_modifiers(values: Dict[str, Any]) -> List[Dict[str, Any]]: - modifiers = [] + for modifier in values["stages"]: + assert isinstance( + modifier, dict + ), f"stage must be a dict, given {modifier}" + modifier["group"] = "default" + modifiers.append(modifier) for key, value in list(values.items()): - if key.endswith("_modifiers") or key == "modifiers": - group = ( - key.rsplit("_modifiers", 1)[0] - if key.endswith("_modifiers") - else "default" - ) + if key.endswith("_modifiers"): + assert isinstance( + value, list + ), f"modifier must be a list, given {value}" + remove_keys.append(key) + group = key.rsplit("_modifiers", 1)[0] for modifier in value: + assert isinstance( + modifier, dict + ), f"modifier must be a dict, given {modifier}" modifier["group"] = group modifiers.append(modifier) + for key in remove_keys: + del values[key] + return modifiers + + def dict(self, *args, **kwargs) -> Dict[str, Any]: + dict_ = super().dict(*args, **kwargs) + modifiers = {} + + for modifier in dict_["modifiers"]: + group = modifier["group"] + del modifier["group"] + if group not in modifiers: + modifiers[group] = [] + modifiers[group].append(modifier) + + dict_["modifiers"] = modifiers + + return dict_ diff --git a/src/sparseml/core/session.py b/src/sparseml/core/session.py index 1ba946b8c70..9ab56ec0fe1 100644 --- a/src/sparseml/core/session.py +++ b/src/sparseml/core/session.py @@ -17,12 +17,10 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, List, Union -from sparseml.core.event import ( - CallbacksEventLifecycle, - EventType, - WrappedOptimEventLifecycle, -) +from sparseml.core.event import EventType from sparseml.core.framework import Framework +from sparseml.core.lifecycle import SparsificationLifecycle +from sparseml.core.recipe import Recipe from sparseml.core.state import ModifiedState, State @@ -49,69 +47,45 @@ class _CallbackContainer: class SparseSession: def __init__(self): - self._state: State = State() - self._modifiers: List["StageModifiers"] = [] - self._initialized_structure = False - self._initialized = False - self._finalized = False - self._event_called = False + self._lifecycle = SparsificationLifecycle() @property - def state(self) -> State: - return self._state - - @property - def modifiers(self) -> List["StageModifiers"]: - return self._modifiers + def lifecycle(self) -> SparsificationLifecycle: + return self._lifecycle @property - def initialized_structure(self) -> bool: - return self._initialized_structure - - @property - def initialized(self) -> bool: - return self._initialized - - @property - def finalized(self) -> bool: - return self._finalized - - @property - def event_called(self) -> bool: - return self._event_called + def state(self) -> State: + return self._lifecycle.state def pre_initialize_structure( self, model: Any, - recipe: Union["Recipe", List["Recipe"]], + recipe: Union[str, List[str], Recipe, List[Recipe]] = None, + recipe_stage: Union[str, List[str]] = None, + recipe_args: Union[Dict[str, Any], List[Dict[str, Any]]] = None, framework: Framework = None, **kwargs, ) -> ModifiedState: - self.state.update_framework(framework) - self.state.update_model(model) - self.state.update_recipe(recipe) - - self._check_compile_recipe() - modifier_data = [] - - for modifier in self._modifiers: - data = modifier.pre_initialize_structure(state=self.state, **kwargs) - if data: - modifier_data.append(data) - - self._initialized_structure = True + mod_data = self._lifecycle.pre_initialize_structure( + model=model, + recipe=recipe, + recipe_stage=recipe_stage, + recipe_args=recipe_args, + framework=framework, + **kwargs, + ) return ModifiedState( - model=self.state.model.model, + model=self.state.model.model if self.state.model else None, optimizer=None, loss=None, - modifier_data=modifier_data, + modifier_data=mod_data, ) def initialize( self, framework: Framework = None, - recipe: Union[str, List[str], "Recipe", List["Recipe"]] = None, + recipe: Union[str, List[str], Recipe, List[Recipe]] = None, recipe_stage: str = None, recipe_args: Dict[str, Any] = None, model: Any = None, @@ -128,59 +102,41 @@ def initialize( batches_per_step: int = None, **kwargs, ) -> ModifiedState: - if self.event_called: - raise ValueError("Cannot initialize after invoking an event") - - if self.finalized: - raise ValueError("Cannot initialize after finalizing") - - self.state.update_framework(framework) - self.state.update_recipe(recipe, recipe_stage, recipe_args) - self.state.update_model(model) - self.state.update_teacher_model(teacher_model) - self.state.update_optimizer(optimizer, attach_optim_callbacks) - self.state.update_data(train_data, val_data, test_data, calib_data, copy_data) - self.state.update_start(start, steps_per_epoch, batches_per_step) - - self._check_compile_recipe() - modifier_data = [] - - if self._modifiers: - for modifier in self._modifiers: - data = modifier.initialize(state=self.state, **kwargs) - if data: - modifier_data.append(data) - - self._initialized = True + mod_data = self._lifecycle.initialize( + framework=framework, + recipe=recipe, + recipe_stage=recipe_stage, + recipe_args=recipe_args, + model=model, + teacher_model=teacher_model, + optimizer=optimizer, + attach_optim_callbacks=attach_optim_callbacks, + train_data=train_data, + val_data=val_data, + test_data=test_data, + calib_data=calib_data, + copy_data=copy_data, + start=start, + steps_per_epoch=steps_per_epoch, + batches_per_step=batches_per_step, + **kwargs, + ) return ModifiedState( - model=self.state.model.model, - optimizer=self.state.optimizer.optimizer, - loss=self.state.loss.loss, - modifier_data=modifier_data, + model=self.state.model.model if self.state.model else None, + optimizer=self.state.optimizer.optimizer if self.state.optimizer else None, + loss=self.state.loss.loss if self.state.loss else None, + modifier_data=mod_data, ) def finalize(self, **kwargs) -> ModifiedState: - if not self.initialized: - raise ValueError("Cannot finalize before initializing") - - if self.finalized: - raise ValueError("Cannot finalize more than once") - - modifier_data = [] - - for modifier in self._modifiers: - data = modifier.finalize(state=self.state, **kwargs) - if data: - modifier_data.append(data) - - self._finalized = True + mod_data = self._lifecycle.finalize(**kwargs) return ModifiedState( - model=self.state.model.model, - optimizer=self.state.optimizer.optimizer, - loss=self.state.loss.loss, - modifier_data=modifier_data, + model=self.state.model.model if self.state.model else None, + optimizer=self.state.optimizer.optimizer if self.state.optimizer else None, + loss=self.state.loss.loss if self.state.loss else None, + modifier_data=mod_data, ) def apply(self, **kwargs): @@ -191,123 +147,19 @@ def apply(self, **kwargs): def event( self, event_type: EventType, batch_data: Any = None, loss: Any = None, **kwargs ) -> ModifiedState: - if not self.initialized: - raise ValueError("Cannot invoke event before initializing") - - if self.finalized: - raise ValueError("Cannot invoke event after finalizing") - - if event_type in [EventType.PRE_INIT, EventType.INITIALIZE, EventType.FINALIZE]: - raise ValueError( - f"Cannot invoke {event_type} event. " - f"Use the corresponding method instead." - ) - - if event_type == EventType.LOSS_CALCULATED and loss is None: - raise ValueError("Loss must be provided for loss calculated event") - - self._check_setup_lifecycle(event_type) - - event = None - modifier_data = [] - for event in self.state.event_lifecycle.events_from_type(event_type): - for modifier in self._modifiers: - data = modifier.update_event( - state=self.state, - event=event, - batch_data=batch_data, - loss=loss, - **kwargs, - ) - if data: - modifier_data.append(data) - - assert event is not None, f"No events generated for event type {event_type}" - self.state.last_event = event - self._event_called = True + mod_data = self._lifecycle.event( + event_type=event_type, batch_data=batch_data, loss=loss, **kwargs + ) return ModifiedState( - model=self.state.model.model, - optimizer=self.state.optimizer.optimizer, - loss=self.state.loss.loss, - modifier_data=modifier_data, + model=self.state.model.model if self.state.model else None, + optimizer=self.state.optimizer.optimizer if self.state.optimizer else None, + loss=self.state.loss.loss if self.state.loss else None, + modifier_data=mod_data, ) def reset(self): - if self._state: - del self._state - self._state = State() - - if self._modifiers: - if self.initialized and not self.finalized: - for modifier in self._modifiers: - modifier.finalize(self.state) - - del self._modifiers - - self._modifiers = [] - self._initialized_structure = False - self._initialized = False - self._finalized = False - self._event_called = False - - def _check_compile_recipe(self): - if not self.state.recipe_changed and self._modifiers is not None: - # recipe hasn't changed and modifiers set, no need to recompile - return - - if self.state.recipes is None: - # no recipes currently, return - return - - if self.state.recipe_changed: - self.state.recompile_recipe() - - if self._modifiers: - # clear out the modifiers to reinitialize from newly compiled recipe - for modifier in self._modifiers: - if modifier._initialized: - modifier.finalize(self.state) - del self._modifiers - - if self.state.recipe_modifier_ready: - self._modifiers = self.state.compiled_recipe.create_modifier( - self.state.framework - ) - - def _check_setup_lifecycle(self, event_type: EventType): - if self.state.event_lifecycle is not None: - return - - # first event call, setup lifecycle and make sure everything is initialized - if not self.state.recipe_modifier_ready: - raise ValueError( - "Cannot invoke event before recipe, model, and start are set" - ) - - for modifier in self._modifiers: - modifier.check_initialized() - - if event_type == EventType.BATCH_START: - # utilizing callbacks pathway, ensure optim is not wrapped - if self.state.optim_wrapped: - raise ValueError( - "Cannot use batch callbacks with wrapped optimizer, " - "set attach_optim_callbacks to False when initializing " - ) - self.state.event_lifecycle = CallbacksEventLifecycle( - event_type, self.state.start_event - ) - elif self.state.optim_wrapped: - # utilizing wrapped optimizer for callbacks - self.state.event_lifecycle = WrappedOptimEventLifecycle( - event_type, self.state.start_event - ) - else: - raise ValueError( - "First event must be batch_start or " - "attach_optim_callbacks must be True" - ) + self._lifecycle.reset() _global_session = SparseSession() @@ -338,7 +190,7 @@ def pre_initialize_structure(**kwargs): def initialize( framework: Framework = None, - recipe: Union[str, List[str], "Recipe", List["Recipe"]] = None, + recipe: Union[str, List[str], Recipe, List[Recipe]] = None, recipe_stage: str = None, recipe_args: Dict[str, Any] = None, model: Any = None, @@ -382,7 +234,7 @@ def finalize(**kwargs) -> ModifiedState: def apply( framework: Framework = None, - recipe: Union[str, List[str], "Recipe", List["Recipe"]] = None, + recipe: Union[str, List[str], Recipe, List[Recipe]] = None, recipe_stage: str = None, recipe_args: Dict[str, Any] = None, model: Any = None, diff --git a/src/sparseml/core/state.py b/src/sparseml/core/state.py index 12d5f793aee..51db318282f 100644 --- a/src/sparseml/core/state.py +++ b/src/sparseml/core/state.py @@ -14,16 +14,15 @@ from copy import deepcopy from dataclasses import dataclass -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List from pydantic import Field from sparseml.core.data import ModifiableData -from sparseml.core.event import Event, EventLifecycle +from sparseml.core.event import Event from sparseml.core.framework import Framework from sparseml.core.model import ModifiableModel from sparseml.core.optimizer import ModifiableOptimizer -from sparseml.core.recipe import Recipe __all__ = ["State", "Data", "Hardware", "ModifiedState"] @@ -51,122 +50,81 @@ class Hardware: @dataclass class State: - compiled_recipe: Recipe = None - recipes: List[Tuple[Recipe, str, Dict[str, Any]]] = Field(default_factory=list) - loggers = Field(default_factory=list) - framework: Framework = None + framework: Framework model: ModifiableModel = None teacher_model: ModifiableModel = None optimizer: ModifiableOptimizer = None optim_wrapped: bool = None - loss = None - batch_data = None + loss: Any = None + batch_data: Any = None data = Data() hardware = Hardware() - event_lifecycle: EventLifecycle = None start_event: Event = None last_event: Event = None - _recipe_changed: bool = False - - @property - def recipe_changed(self) -> bool: - return self._recipe_changed + loggers = Field(default_factory=list) @property - def recipe_modifier_ready(self) -> bool: + def sparsification_ready(self) -> bool: return ( - self.compiled_recipe is not None - and self.model is not None - and self.start_event is not None - ) - - def update_framework(self, framework: Framework): - self.framework = framework if framework else Framework.pytorch - - def update_recipe( - self, - recipe: Union[str, List[str], Recipe, List[Recipe]] = None, - recipe_stage: str = None, - recipe_args: Dict[str, Any] = None, - ): - if recipe is None: - return - - if not isinstance(recipe, list): - recipe = [recipe] - - for rec in recipe: - if isinstance(rec, str): - rec = Recipe.create_instance(rec) - - self.recipes.append((rec, recipe_stage, recipe_args)) - - self._recipe_changed = True - - def update_model(self, model: Any): - if self.framework is None: - raise RuntimeError("framework must be set before updating model") - - self.model = ModifiableModel(framework=self.framework, model=model) - - def update_teacher_model(self, model: Any): - if self.framework is None: - raise RuntimeError("framework must be set before updating model") - - self.teacher_model = ModifiableModel(framework=self.framework, model=model) - - def update_optimizer(self, optimizer: Any, attach_callbacks: bool = True): - if self.framework is None: - raise RuntimeError("framework must be set before updating optimizer") - - self.optim_wrapped = attach_callbacks - self.optimizer = ModifiableOptimizer( - framework=self.framework, optimizer=optimizer + self.model is not None + and self.optimizer is not None + and self.loss is not None + and self.batch_data is not None ) - def update_data( + def update( self, + model: Any = None, + teacher_model: Any = None, + optimizer: Any = None, + attach_optim_callbacks: bool = True, train_data: Any = None, val_data: Any = None, test_data: Any = None, calib_data: Any = None, copy_data: bool = True, - ): - if self.framework is None: - raise RuntimeError("framework must be set before updating data") - - self.data = ModifiableData(framework=self.framework) + start: float = None, + steps_per_epoch: int = None, + batches_per_step: int = None, + **kwargs, + ) -> Dict: + if model is not None: + self.model = ModifiableModel(framework=self.framework, model=model) + if teacher_model is not None: + self.teacher_model = ModifiableModel( + framework=self.framework, model=teacher_model + ) + if optimizer is not None: + self.optim_wrapped = attach_optim_callbacks + self.optimizer = ModifiableOptimizer( + framework=self.framework, optimizer=optimizer + ) if train_data is not None: self.data.train = train_data if not copy_data else deepcopy(train_data) - if val_data is not None: self.data.val = val_data if not copy_data else deepcopy(val_data) - if test_data is not None: self.data.test = test_data if not copy_data else deepcopy(test_data) - if calib_data is not None: self.data.calib = calib_data if not copy_data else deepcopy(calib_data) - def update_start( - self, - start: float = None, - steps_per_epoch: int = None, - batches_per_step: int = None, - ): - self.start_event = Event() - self.start_event.steps_per_epoch = steps_per_epoch - self.start_event.batches_per_step = batches_per_step - self.start_event.current_index = start if start is not None else 0 - - def recompile_recipe(self): - self._recipe_changed = False - - if not self.recipes: - raise RuntimeError("No recipes to compile") - - self.compiled_recipe = Recipe.simplify_combine_recipes(self.recipes) + if ( + start is not None + or steps_per_epoch is not None + or batches_per_step is not None + ): + if self.start_event is None: + self.start_event = Event() + + if start is not None: + self.start_event.current_index = start + if steps_per_epoch is not None: + self.start_event.steps_per_epoch = steps_per_epoch + if batches_per_step is not None: + self.start_event.batches_per_step = batches_per_step + + return kwargs @dataclass From 7ecd5c613e31526cedae09e1ea2041916d313a88 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 19 Sep 2023 09:35:09 -0400 Subject: [PATCH 13/27] modifiers loading in stages --- src/sparseml/core/__init__.py | 6 +++- src/sparseml/core/framework.py | 8 ++++- src/sparseml/core/modifier/__init__.py | 3 +- src/sparseml/core/modifier/base.py | 3 -- src/sparseml/core/modifier/factory.py | 8 ++--- src/sparseml/core/modifier/modifier.py | 42 +++++++++++++------------- src/sparseml/core/modifier/stage.py | 12 ++++---- 7 files changed, 45 insertions(+), 37 deletions(-) diff --git a/src/sparseml/core/__init__.py b/src/sparseml/core/__init__.py index bc26ded926c..d687e5180f9 100644 --- a/src/sparseml/core/__init__.py +++ b/src/sparseml/core/__init__.py @@ -12,4 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .session import * \ No newline at end of file +from .session import * +from .modifier import * +from .state import * +from .event import * +from .model import * \ No newline at end of file diff --git a/src/sparseml/core/framework.py b/src/sparseml/core/framework.py index c7a67fe85cd..e85a2438973 100644 --- a/src/sparseml/core/framework.py +++ b/src/sparseml/core/framework.py @@ -15,6 +15,7 @@ import importlib from enum import Enum +from pydantic import ValidationError __all__ = ["Framework", "MultiFrameworkObject"] @@ -100,8 +101,13 @@ def __new__( return MultiFrameworkObject.load_framework_class( f"{package}.{str(framework)}", f"{class_name}{framework.class_name()}" )(**kwargs) - except ImportError: + except ImportError as e: + print(e) pass + except ValidationError as e: + print(e) + print(e.errors()) + # fall back on the class that was requested and # fail later if it doesn't support that framework diff --git a/src/sparseml/core/modifier/__init__.py b/src/sparseml/core/modifier/__init__.py index af25a2b2db2..a21bf075123 100644 --- a/src/sparseml/core/modifier/__init__.py +++ b/src/sparseml/core/modifier/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. from .stage import StageModifiers -from .factory import ModifierFactory \ No newline at end of file +from .factory import ModifierFactory +from .modifier import Modifier \ No newline at end of file diff --git a/src/sparseml/core/modifier/base.py b/src/sparseml/core/modifier/base.py index cbdf2a002b6..147db8ca557 100644 --- a/src/sparseml/core/modifier/base.py +++ b/src/sparseml/core/modifier/base.py @@ -19,9 +19,6 @@ class ModifierInterface(ABC): - def __init__(self, **kwargs): - pass - @abstractmethod def check_initialized(self): raise NotImplementedError() diff --git a/src/sparseml/core/modifier/factory.py b/src/sparseml/core/modifier/factory.py index 388810e56a9..4a5e3442024 100644 --- a/src/sparseml/core/modifier/factory.py +++ b/src/sparseml/core/modifier/factory.py @@ -13,7 +13,6 @@ # limitations under the License. import inspect -#import sparseml.modifiers as modifiers __all__ = ["ModifierFactory"] @@ -25,6 +24,7 @@ def refresh(): @staticmethod def create(type_: str, framework: "Framework", **kwargs) -> "Modifier": - raise NotImplementedError() - #for name, obj in inspect.getmembers(modifiers): - # print(name, obj) + import sparseml.modifiers as modifiers + for name, obj in inspect.getmembers(modifiers): + if name == type_: + return obj(framework=framework, **kwargs) diff --git a/src/sparseml/core/modifier/modifier.py b/src/sparseml/core/modifier/modifier.py index 6c1c6d92df6..61022de3d9a 100644 --- a/src/sparseml/core/modifier/modifier.py +++ b/src/sparseml/core/modifier/modifier.py @@ -31,14 +31,14 @@ class Modifier(ModifierInterface, MultiFrameworkObject, BaseModel): end: Optional[float] = None update: Optional[float] = None - _initialized_structure: bool = False - _initialized: bool = False - _finalized: bool = False - _started: bool = False - _ended: bool = False + initialized_structure_: bool = False + initialized: bool = False + finalized_: bool = False + started_: bool = False + ended_: bool = False def check_initialized(self): - if not self._initialized: + if not self.initialized: raise RuntimeError("modifier has not been initialized") def calculate_start(self) -> float: @@ -49,13 +49,13 @@ def calculate_end(self) -> float: def pre_initialize_structure(self, state: "State", **kwargs): self.on_initialize_structure(state, **kwargs) - self._initialized_structure = True + self.initialized_structure_ = True def initialize(self, state: "State", **kwargs): - if self._initialized: + if self.initialized: return - if self._finalized: + if self.finalized_: raise RuntimeError("cannot initialize a finalized modifier") if state.start_event is None: @@ -69,17 +69,17 @@ def initialize(self, state: "State", **kwargs): "True for success, False for not initialized" ) - self._initialized = initialized + self.initialized = initialized if self.should_start(state.start_event): self.on_start(state, state.start_event, **kwargs) - self._started = True + self.started_ = True def finalize(self, state: "State", **kwargs): - if self._finalized: + if self.finalized_: return - if not self._initialized: + if not self.initialized: raise RuntimeError("cannot finalize an uninitialized modifier") finalized = self.on_finalize(**kwargs) @@ -90,23 +90,23 @@ def finalize(self, state: "State", **kwargs): "True for success, False for not finalized" ) - self._finalized = finalized + self.finalized_ = finalized def update_event(self, state: "State", event: Event, **kwargs): - if not self._initialized: + if not self.initialized: raise RuntimeError("cannot update an uninitialized modifier") - if self._finalized: + if self.finalized_: raise RuntimeError("cannot update a finalized modifier") # handle starting the modifier if needed if ( event.type_ == EventType.BATCH_START - and not self._started + and not self.started_ and self.should_start(event) ): self.on_start(state, event, **kwargs) - self._started = True + self.started_ = True self.on_update(state, event, **kwargs) return @@ -114,16 +114,16 @@ def update_event(self, state: "State", event: Event, **kwargs): # handle ending the modifier if needed if ( event.type_ == EventType.BATCH_END - and not self._ended + and not self.ended_ and self.should_end(event) ): self.on_end(state, event, **kwargs) - self._ended = True + self.ended_ = True self.on_update(state, event, **kwargs) return - if self._started and not self._ended: + if self.started_ and not self.ended_: self.on_update(state, event, **kwargs) def should_start(self, event: Event): diff --git a/src/sparseml/core/modifier/stage.py b/src/sparseml/core/modifier/stage.py index 0bf0476eac5..23ac3e6cf61 100644 --- a/src/sparseml/core/modifier/stage.py +++ b/src/sparseml/core/modifier/stage.py @@ -29,9 +29,9 @@ class StageModifiers(ModifierInterface, BaseModel): index: int = None group: str = None - _initialized_structure: bool = False - _initialized: bool = False - _finalized: bool = False + initialized_structure_: bool = False + initialized_: bool = False + finalized_: bool = False def check_initialized(self): for modifier in self.modifiers: @@ -52,17 +52,17 @@ def calculate_end(self) -> float: def pre_initialize_structure(self, state: "State", **kwargs): for modifier in self.modifiers: modifier.pre_initialize_structure(state, **kwargs) - self._initialized_structure = True + self.initialized_structure_ = True def initialize(self, state: "State", **kwargs): for modifier in self.modifiers: modifier.initialize(state, **kwargs) - self._initialized = True + self.initialized_ = True def finalize(self, state: "State", **kwargs): for modifier in self.modifiers: modifier.finalize(state, **kwargs) - self._finalized = True + self.finalized_ = True def update_event(self, state: "State", event: "Event", **kwargs): for modifier in self.modifiers: From 3e2954e667cdd278c84e11ec86d6f330d28984f8 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 19 Sep 2023 09:35:30 -0400 Subject: [PATCH 14/27] adding test files --- test_e2e.ipynb | 359 ++++++------------------------------------- test_e2e.py | 58 +++++++ test_e2e_recipe.yaml | 4 +- 3 files changed, 109 insertions(+), 312 deletions(-) create mode 100644 test_e2e.py diff --git a/test_e2e.ipynb b/test_e2e.ipynb index 9941d4d775d..22f4984336d 100644 --- a/test_e2e.ipynb +++ b/test_e2e.ipynb @@ -43,11 +43,41 @@ "cell_type": "code", "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Found cached dataset beans (/home/sadkins/.cache/huggingface/datasets/beans/default/0.0.0/90c755fb6db1c0ccdad02e897a37969dbf070bed3755d4391e269ff70642d791)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6e77f370860f472394ab3547decb1fb7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00 1\u001b[0m session\u001b[39m.\u001b[39;49minitialize(\n\u001b[1;32m 2\u001b[0m framework\u001b[39m=\u001b[39;49mFramework\u001b[39m.\u001b[39;49mpytorch,\n\u001b[1;32m 3\u001b[0m recipe\u001b[39m=\u001b[39;49mrecipe,\n\u001b[1;32m 4\u001b[0m model\u001b[39m=\u001b[39;49mmodel,\n\u001b[1;32m 5\u001b[0m teacher_model\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m 6\u001b[0m optimizer\u001b[39m=\u001b[39;49moptimizer,\n\u001b[1;32m 7\u001b[0m train_data\u001b[39m=\u001b[39;49mtrain_loader,\n\u001b[1;32m 8\u001b[0m val_data\u001b[39m=\u001b[39;49mval_loader\n\u001b[1;32m 9\u001b[0m )\n", + "File \u001b[0;32m~/sparseml/src/sparseml/core/session.py:145\u001b[0m, in \u001b[0;36mSparseSession.initialize\u001b[0;34m(self, framework, recipe, recipe_stage, recipe_args, model, teacher_model, optimizer, attach_optim_callbacks, train_data, val_data, test_data, calib_data, copy_data, start, steps_per_epoch, batches_per_step, **kwargs)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mupdate_data(train_data, val_data, test_data, calib_data, copy_data)\n\u001b[1;32m 143\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mupdate_start(start, steps_per_epoch, batches_per_step)\n\u001b[0;32m--> 145\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_check_compile_recipe()\n\u001b[1;32m 146\u001b[0m modifier_data \u001b[39m=\u001b[39m []\n\u001b[1;32m 148\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_modifiers:\n", + "File \u001b[0;32m~/sparseml/src/sparseml/core/session.py:274\u001b[0m, in \u001b[0;36mSparseSession._check_compile_recipe\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 271\u001b[0m \u001b[39mdel\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_modifiers\n\u001b[1;32m 273\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mrecipe_modifier_ready:\n\u001b[0;32m--> 274\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_modifiers \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstate\u001b[39m.\u001b[39;49mcompiled_recipe\u001b[39m.\u001b[39;49mcreate_modifier(\n\u001b[1;32m 275\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstate\u001b[39m.\u001b[39;49mframework\n\u001b[1;32m 276\u001b[0m )\n", + "File \u001b[0;32m~/sparseml/src/sparseml/core/recipe/recipe.py:133\u001b[0m, in \u001b[0;36mRecipe.create_modifier\u001b[0;34m(self, framework)\u001b[0m\n\u001b[1;32m 130\u001b[0m modifiers \u001b[39m=\u001b[39m []\n\u001b[1;32m 132\u001b[0m \u001b[39mfor\u001b[39;00m index, stage \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstages):\n\u001b[0;32m--> 133\u001b[0m stage_modifiers \u001b[39m=\u001b[39m stage\u001b[39m.\u001b[39;49mcreate_modifier(framework)\n\u001b[1;32m 134\u001b[0m stage_modifiers\u001b[39m.\u001b[39mindex \u001b[39m=\u001b[39m index\n\u001b[1;32m 135\u001b[0m stage_modifiers\u001b[39m.\u001b[39mgroup \u001b[39m=\u001b[39m stage\u001b[39m.\u001b[39mgroup\n", + "File \u001b[0;32m~/sparseml/src/sparseml/core/recipe/stage.py:66\u001b[0m, in \u001b[0;36mRecipeStage.create_modifier\u001b[0;34m(self, framework, parent_args)\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[39mfor\u001b[39;00m index, modifier \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodifiers):\n\u001b[1;32m 65\u001b[0m modifier \u001b[39m=\u001b[39m modifier\u001b[39m.\u001b[39mcreate_modifier(framework)\n\u001b[0;32m---> 66\u001b[0m modifier\u001b[39m.\u001b[39;49mgroup \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mgroup\n\u001b[1;32m 67\u001b[0m modifier\u001b[39m.\u001b[39mindex \u001b[39m=\u001b[39m index\n\u001b[1;32m 69\u001b[0m \u001b[39mreturn\u001b[39;00m stage_modifiers\n", + "File \u001b[0;32m~/sparseml/.venv/lib/python3.8/site-packages/pydantic/main.py:405\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.__setattr__\u001b[0;34m()\u001b[0m\n", + "\u001b[0;31mAttributeError\u001b[0m: __fields_set__" + ] } ], - "source": [ - "session.pre_initialize_structure(\n", - " framework=Framework.pytorch,\n", - " recipe=recipe,\n", - " model=model\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "session.initialize(\n", " framework=Framework.pytorch,\n", diff --git a/test_e2e.py b/test_e2e.py new file mode 100644 index 00000000000..3a6cdd2249e --- /dev/null +++ b/test_e2e.py @@ -0,0 +1,58 @@ +import sparseml.core.session as sml +from sparseml.core.framework import Framework +import torchvision +from torchvision import transforms +import torch +from torch.utils.data import DataLoader +import datasets +import os +from torch.optim import Adam + +sml.create_session() +session = sml.active_session() + +NUM_LABELS = 3 +model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT) +model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, NUM_LABELS) +optimizer = Adam(model.parameters(), lr=8e-3) + +train_path = "/home/sadkins/.cache/huggingface/datasets/downloads/extracted/dbf92bfb2c3766fb3083a51374ad94d8a3690f53cdf0f9113a231c2351c9ff33/train" +val_path = "/home/sadkins/.cache/huggingface/datasets/downloads/extracted/510ede718de2aeaa2f9d88b0d81d88c449beeb7d074ea594bdf25a0e6a9d51d0/validation" + +NUM_LABELS = 3 +BATCH_SIZE = 32 + +# imagenet transforms +imagenet_transform = transforms.Compose([ + transforms.Resize(size=256, interpolation=transforms.InterpolationMode.BILINEAR, max_size=None, antialias=None), + transforms.CenterCrop(size=(224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +]) + +# datasets +train_dataset = torchvision.datasets.ImageFolder( + root=train_path, + transform=imagenet_transform +) + +val_dataset = torchvision.datasets.ImageFolder( + root=val_path, + transform=imagenet_transform +) + +# dataloaders +train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=16) +val_loader = DataLoader(val_dataset, BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=16) + +recipe = "test_e2e_recipe.yaml" + +session.initialize( + framework=Framework.pytorch, + recipe=recipe, + model=model, + teacher_model=None, + optimizer=optimizer, + train_data=train_loader, + val_data=val_loader +) \ No newline at end of file diff --git a/test_e2e_recipe.yaml b/test_e2e_recipe.yaml index 558990df149..9618859f751 100644 --- a/test_e2e_recipe.yaml +++ b/test_e2e_recipe.yaml @@ -6,7 +6,7 @@ test_stage: start_epoch: 1.0 end_epoch: 10.0 update_frequency: 0.5 - params: + targets: - 'features.0.0.weight' - 'features.18.0.weight' - 're:features.*.conv.*.weight' @@ -20,7 +20,7 @@ test2_stage: start_epoch: 1.0 end_epoch: 10.0 update_frequency: 0.5 - params: + targets: - 'features.0.0.weight' - 'features.18.0.weight' - 're:features.*.conv.*.weight' From 6b83b027e434c54258118e740b5a44d712a22e89 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Tue, 19 Sep 2023 13:42:52 -0400 Subject: [PATCH 15/27] modifier factory implementation --- src/sparseml/core/__init__.py | 1 + src/sparseml/core/factory.py | 133 ++++++++++++++++++ src/sparseml/core/modifier/__init__.py | 1 - src/sparseml/core/modifier/factory.py | 26 ---- src/sparseml/core/recipe/modifier.py | 11 +- .../modifiers/distillation/__init__.py | 2 + .../modifiers/distillation/output/__init__.py | 2 + 7 files changed, 147 insertions(+), 29 deletions(-) create mode 100644 src/sparseml/core/factory.py delete mode 100644 src/sparseml/core/modifier/factory.py diff --git a/src/sparseml/core/__init__.py b/src/sparseml/core/__init__.py index 3d69db68239..6e9d92e657c 100644 --- a/src/sparseml/core/__init__.py +++ b/src/sparseml/core/__init__.py @@ -14,6 +14,7 @@ from .data import * from .event import * +from .factory import * from .framework import * from .framework_object import * from .lifecycle import * diff --git a/src/sparseml/core/factory.py b/src/sparseml/core/factory.py new file mode 100644 index 00000000000..1881592f55e --- /dev/null +++ b/src/sparseml/core/factory.py @@ -0,0 +1,133 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import importlib +import pkgutil +from typing import Dict, Type + +from sparseml.core.framework import Framework +from sparseml.core.modifier import Modifier + + +__all__ = ["ModifierFactory"] + + +class ModifierFactory: + _MAIN_PACKAGE_PATH = "sparseml.modifiers" + _EXPERIMENTAL_PACKAGE_PATH = "sparseml.modifiers.experimental" + + _loaded: bool = False + _main_registry: Dict[str, Type[Modifier]] = {} + _experimental_registry: Dict[str, Type[Modifier]] = {} + _registered_registry: Dict[str, Type[Modifier]] = {} + _errors: Dict[str, Exception] = {} + + @staticmethod + def refresh(): + ModifierFactory._main_registry = ModifierFactory.load_from_package( + ModifierFactory._MAIN_PACKAGE_PATH + ) + ModifierFactory._experimental_registry = ModifierFactory.load_from_package( + ModifierFactory._EXPERIMENTAL_PACKAGE_PATH + ) + ModifierFactory._loaded = True + + @staticmethod + def load_from_package(package_path: str) -> Dict[str, Type[Modifier]]: + loaded = {} + main_package = importlib.import_module(package_path) + + for importer, modname, is_pkg in pkgutil.walk_packages( + main_package.__path__, package_path + "." + ): + try: + module = importlib.import_module(modname) + + for attribute_name in dir(module): + if not attribute_name.endswith("Modifier"): + continue + + try: + if attribute_name in loaded: + raise ValueError( + f"Attribute {attribute_name} already registered" + ) + + attr = getattr(module, attribute_name) + + if not isinstance(attr, type): + raise ValueError( + f"Attribute {attribute_name} is not a type" + ) + + if not issubclass(attr, Modifier): + raise ValueError( + f"Attribute {attribute_name} is not a Modifier" + ) + + loaded[attribute_name] = attr + except Exception as err: + # TODO: log import error + ModifierFactory._errors[attribute_name] = err + except Exception as module_err: + # TODO: log import error + print(module_err) + + return loaded + + @staticmethod + def create( + type_: str, + framework: Framework, + allow_registered: bool, + allow_experimental: bool, + **kwargs, + ) -> Modifier: + if type_ in ModifierFactory._errors: + raise ModifierFactory._errors[type_] + + if type_ in ModifierFactory._registered_registry: + if allow_registered: + return ModifierFactory._registered_registry[type_]( + framework=framework, **kwargs + ) + else: + # TODO: log warning that modifier was skipped + pass + + if type_ in ModifierFactory._experimental_registry: + if allow_experimental: + return ModifierFactory._experimental_registry[type_]( + framework=framework, **kwargs + ) + else: + # TODO: log warning that modifier was skipped + pass + + if type_ in ModifierFactory._main_registry: + return ModifierFactory._main_registry[type_](framework=framework, **kwargs) + + raise ValueError(f"No modifier of type '{type_}' found.") + + @staticmethod + def register(type_: str, modifier_class: Type[Modifier]): + if not issubclass(modifier_class, Modifier): + raise ValueError( + f"The provided class does not subclass the Modifier base class." + ) + if not isinstance(modifier_class, type): + raise ValueError(f"The provided class is not a type.") + + ModifierFactory._registered_registry[type_] = modifier_class diff --git a/src/sparseml/core/modifier/__init__.py b/src/sparseml/core/modifier/__init__.py index 2da941711e8..b205e585dbf 100644 --- a/src/sparseml/core/modifier/__init__.py +++ b/src/sparseml/core/modifier/__init__.py @@ -13,6 +13,5 @@ # limitations under the License. from .base import * -from .factory import * from .modifier import * from .stage import * diff --git a/src/sparseml/core/modifier/factory.py b/src/sparseml/core/modifier/factory.py deleted file mode 100644 index 704021fd087..00000000000 --- a/src/sparseml/core/modifier/factory.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -__all__ = ["ModifierFactory"] - - -class ModifierFactory: - @staticmethod - def refresh(): - raise NotImplementedError() - - @staticmethod - def create(type_: str, framework: "Framework", **kwargs) -> "Modifier": - raise NotImplementedError() diff --git a/src/sparseml/core/recipe/modifier.py b/src/sparseml/core/recipe/modifier.py index 108206e61cb..02fd9616a9c 100644 --- a/src/sparseml/core/recipe/modifier.py +++ b/src/sparseml/core/recipe/modifier.py @@ -16,8 +16,9 @@ from pydantic import root_validator +from sparseml.core.factory import ModifierFactory from sparseml.core.framework import Framework -from sparseml.core.modifier import Modifier, ModifierFactory +from sparseml.core.modifier import Modifier from sparseml.core.recipe.args import RecipeArgs from sparseml.core.recipe.base import RecipeBase @@ -57,7 +58,13 @@ def evaluate(self, args: RecipeArgs = None, shift: int = None): self._args_evaluated["end"] += shift def create_modifier(self, framework: Framework) -> "Modifier": - return ModifierFactory.create(self.type, framework, **self._args_evaluated) + return ModifierFactory.create( + self.type, + framework=framework, + allow_registered=True, + allow_experimental=True, + **self._args_evaluated, + ) @root_validator(pre=True) def extract_modifier_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: diff --git a/src/sparseml/modifiers/distillation/__init__.py b/src/sparseml/modifiers/distillation/__init__.py index 0c44f887a47..6a7699b2f53 100644 --- a/src/sparseml/modifiers/distillation/__init__.py +++ b/src/sparseml/modifiers/distillation/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .output import * diff --git a/src/sparseml/modifiers/distillation/output/__init__.py b/src/sparseml/modifiers/distillation/output/__init__.py index 0c44f887a47..87930811c41 100644 --- a/src/sparseml/modifiers/distillation/output/__init__.py +++ b/src/sparseml/modifiers/distillation/output/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .base import * From e85772934cf45abb23f35ce938746436cb04959e Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 19 Sep 2023 17:00:50 -0400 Subject: [PATCH 16/27] running example, but sparsity not working correctly --- src/sparseml/core/__init__.py | 1 + src/sparseml/core/lifecycle/session.py | 12 +- src/sparseml/core/model/pytorch.py | 6 +- src/sparseml/core/modifier/modifier.py | 2 +- src/sparseml/core/recipe/container.py | 5 - src/sparseml/core/recipe/modifier.py | 2 +- src/sparseml/core/recipe/recipe.py | 17 +-- src/sparseml/core/recipe/stage.py | 1 + src/sparseml/core/state.py | 7 +- .../modifiers/distillation/output/pytorch.py | 4 +- .../modifiers/pruning/constant/pytorch.py | 14 +-- .../modifiers/pruning/magnitude/pytorch.py | 40 +++---- .../pruning/utils/pytorch/layer_mask.py | 33 +++--- src/sparseml/utils/pytorch/module.py | 19 ++++ test_e2e.ipynb | 105 ++++++++++++++---- test_e2e.py | 62 ++++++++++- test_e2e_recipe.yaml | 75 +++++++++---- 17 files changed, 294 insertions(+), 111 deletions(-) diff --git a/src/sparseml/core/__init__.py b/src/sparseml/core/__init__.py index aafe1c21ca4..3fd99ea93d6 100644 --- a/src/sparseml/core/__init__.py +++ b/src/sparseml/core/__init__.py @@ -21,3 +21,4 @@ from .modifier import * from .optimizer import * from .recipe import * +from .state import * diff --git a/src/sparseml/core/lifecycle/session.py b/src/sparseml/core/lifecycle/session.py index ec5a4fd60fb..c7229d23ca6 100644 --- a/src/sparseml/core/lifecycle/session.py +++ b/src/sparseml/core/lifecycle/session.py @@ -176,7 +176,7 @@ def _check_setup_event_lifecycle(self, event_type: EventType): if ( self.state is None or self.state.model is None - or self.state.start_event + or self.state.start_event is None or self.recipe_container.compiled_recipe is None ): raise ValueError( @@ -192,12 +192,20 @@ def _check_setup_event_lifecycle(self, event_type: EventType): mod.check_initialized() if event_type == EventType.BATCH_START: - self.event_lifecycle = WrappedOptimEventLifecycle( + self.event_lifecycle = CallbacksEventLifecycle( type_first=EventType.BATCH_START, start=self.state.start_event ) elif event_type == EventType.LOSS_CALCULATED: self.event_lifecycle = CallbacksEventLifecycle( type_first=EventType.LOSS_CALCULATED, start=self.state.start_event ) + elif event_type == EventType.OPTIM_PRE_STEP: + self.event_lifecycle = CallbacksEventLifecycle( + type_first=EventType.OPTIM_PRE_STEP, start=self.state.start_event + ) + elif event_type == EventType.OPTIM_POST_STEP: + self.event_lifecycle = CallbacksEventLifecycle( + type_first=EventType.OPTIM_POST_STEP, start=self.state.start_event + ) else: raise ValueError(f"invalid event type {event_type}") diff --git a/src/sparseml/core/model/pytorch.py b/src/sparseml/core/model/pytorch.py index c2d12065faf..ca17612b8b2 100644 --- a/src/sparseml/core/model/pytorch.py +++ b/src/sparseml/core/model/pytorch.py @@ -16,7 +16,7 @@ from torch.nn import Module, Parameter -from sparseml.core.model.base import ModifiableModel +from sparseml.core.model.base import ModelParameterizedLayer, ModifiableModel from sparseml.utils.pytorch import ( get_layer, get_layers, @@ -24,6 +24,7 @@ get_params, set_layer, set_param, + get_layers_params ) @@ -34,6 +35,9 @@ class ModifiableModelPyTorch(ModifiableModel[Module, Module, Parameter]): def __init__(self, framework=None, model=None): super().__init__(framework=framework, model=model) + def get_layers_params(self, targets: Union[str, List[str]]) -> Dict[str, ModelParameterizedLayer[Module, Parameter]]: + return get_layers_params(targets, self.model) + def get_layers(self, targets: Union[str, List[str]]) -> Dict[str, Module]: return get_layers(targets, self.model) diff --git a/src/sparseml/core/modifier/modifier.py b/src/sparseml/core/modifier/modifier.py index 9f04e2b1b5c..5bdecb54b03 100644 --- a/src/sparseml/core/modifier/modifier.py +++ b/src/sparseml/core/modifier/modifier.py @@ -73,7 +73,7 @@ def initialize(self, state: "State", **kwargs): if state.start_event is None: return - initialized = self.on_initialize(**kwargs) + initialized = self.on_initialize(state=state, event=state.start_event, **kwargs) if not isinstance(initialized, bool): raise ValueError( diff --git a/src/sparseml/core/recipe/container.py b/src/sparseml/core/recipe/container.py index 59b645aa272..fc828796bc6 100644 --- a/src/sparseml/core/recipe/container.py +++ b/src/sparseml/core/recipe/container.py @@ -53,7 +53,6 @@ def update( if not isinstance(recipe, list): recipe = [recipe] - if recipe_stage is None: recipe_stage = [None] * len(recipe) elif not isinstance(recipe_stage, list): @@ -72,10 +71,6 @@ def update( for rec, stage, args in zip(recipe, recipe_stage, recipe_args): if isinstance(rec, str): rec = Recipe.create_instance(rec) - if isinstance(stage, str): - stage = [stage] - elif stage is None: - stage = [] self.recipes.append(RecipeTuple(rec, stage, args)) return kwargs diff --git a/src/sparseml/core/recipe/modifier.py b/src/sparseml/core/recipe/modifier.py index 1d13d7a19c2..5c6d4c60f1b 100644 --- a/src/sparseml/core/recipe/modifier.py +++ b/src/sparseml/core/recipe/modifier.py @@ -57,7 +57,7 @@ def evaluate(self, args: RecipeArgs = None, shift: int = None): self.args_evaluated["end"] += shift def create_modifier(self, framework: Framework) -> "Modifier": - return ModifierFactory.create(self.type, framework, **self._args_evaluated) + return ModifierFactory.create(self.type, framework, **self.args_evaluated) @root_validator(pre=True) def extract_modifier_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: diff --git a/src/sparseml/core/recipe/recipe.py b/src/sparseml/core/recipe/recipe.py index d83441baf61..49bcbc17c31 100644 --- a/src/sparseml/core/recipe/recipe.py +++ b/src/sparseml/core/recipe/recipe.py @@ -59,19 +59,22 @@ def create_instance(path: str) -> "Recipe": def simplify_recipe( recipe: Union["Recipe", "RecipeTuple"], shift: int = None ) -> "Recipe": - stages = recipe.target_stages if isinstance(recipe, RecipeTuple) else [] + stages = [] + if isinstance(recipe, RecipeTuple): + stage_names = recipe.target_stages + if stage_names is None: + stages = recipe.recipe.stages + else: + for stage in recipe.recipe.stages: + if stage.group in stage_names: + stages.append(stage) args = recipe.override_args if isinstance(recipe, RecipeTuple) else {} version = recipe.version if isinstance(recipe, Recipe) else None simplified = Recipe() simplified.version = version simplified.args = args - simplified.stages = [ - stage - for stage in stages - if ((not stages or "default" in stages) and not stage.exclude_default) - or stage.group in stages - ] + simplified.stages = stages simplified.evaluate(args=args, shift=shift) return simplified diff --git a/src/sparseml/core/recipe/stage.py b/src/sparseml/core/recipe/stage.py index 2a36ee8e712..e03cc86cbe4 100644 --- a/src/sparseml/core/recipe/stage.py +++ b/src/sparseml/core/recipe/stage.py @@ -65,6 +65,7 @@ def create_modifier( modifier = modifier.create_modifier(framework) modifier.group = self.group modifier.index = index + stage_modifiers.modifiers.append(modifier) return stage_modifiers diff --git a/src/sparseml/core/state.py b/src/sparseml/core/state.py index 3ab2d848337..a8bf937bf0a 100644 --- a/src/sparseml/core/state.py +++ b/src/sparseml/core/state.py @@ -23,6 +23,7 @@ from sparseml.core.framework import Framework from sparseml.core.model import ModifiableModel from sparseml.core.optimizer import ModifiableOptimizer +from sparseml.core.event import EventType __all__ = ["State", "Data", "Hardware", "ModifiedState"] @@ -68,8 +69,8 @@ def sparsification_ready(self) -> bool: return ( self.model is not None and self.optimizer is not None - and self.loss is not None - and self.batch_data is not None + #and self.loss is not None + #and self.batch_data is not None ) def update( @@ -116,7 +117,7 @@ def update( or batches_per_step is not None ): if self.start_event is None: - self.start_event = Event() + self.start_event = Event(type_=EventType.BATCH_START) if start is not None: self.start_event.current_index = start diff --git a/src/sparseml/modifiers/distillation/output/pytorch.py b/src/sparseml/modifiers/distillation/output/pytorch.py index ca96486a87d..a75079a78b7 100644 --- a/src/sparseml/modifiers/distillation/output/pytorch.py +++ b/src/sparseml/modifiers/distillation/output/pytorch.py @@ -75,7 +75,7 @@ def on_finalize(self, state: State, event: Event, **kwargs) -> bool: def on_start(self, state: State, event: Event, **kwargs): for wrapper in self._wrappers.values(): - wrapper.kd_enabled = True + wrapper.kdenabled_ = True def on_update(self, state: State, event: Event, **kwargs): if event.type_ == EventType.LOSS_CALCULATED and event.should_update( @@ -88,7 +88,7 @@ def on_update(self, state: State, event: Event, **kwargs): def on_end(self, state: State, event: Event, **kwargs): for wrapper in self._wrappers.values(): - wrapper.kd_enabled = False + wrapper.kdenabled_ = False def _create_wrapper( self, student_layer: Module, teacher_layer: Module, state: State diff --git a/src/sparseml/modifiers/pruning/constant/pytorch.py b/src/sparseml/modifiers/pruning/constant/pytorch.py index d3ed5840097..b6e2639c96c 100644 --- a/src/sparseml/modifiers/pruning/constant/pytorch.py +++ b/src/sparseml/modifiers/pruning/constant/pytorch.py @@ -20,7 +20,7 @@ class ConstantPruningModifierPyTorch(ConstantPruningModifier, LayerParamMasking): - _parameterized_layers: Dict[str, ModelParameterizedLayer] = None + parameterized_layers_: Dict[str, ModelParameterizedLayer] = None _save_masks: bool = False _use_hooks: bool = False @@ -33,9 +33,9 @@ def on_initialize(self, state: State, event: Event, **kwargs) -> bool: if not state.model or not state.start_event: return False - self._parameterized_layers = state.model.get_layers_params(self.targets) + self.parameterized_layers_ = state.model.get_layers_params(self.targets) - for layer_param_name, parameterized_layer in self._parameterized_layers.items(): + for layer_param_name, parameterized_layer in self.parameterized_layers_.items(): self.add_mask( layer_param_name, parameterized_layer, @@ -46,13 +46,13 @@ def on_initialize(self, state: State, event: Event, **kwargs) -> bool: return True def on_finalize(self, state: State, event: Event, **kwargs) -> bool: - for layer_param_name, _ in self._parameterized_layers.items(): + for layer_param_name, _ in self.parameterized_layers_.items(): self.remove_mask(layer_param_name) return True def on_start(self, state: State, event: Event, **kwargs): - for layer_param_name, parameterized_layer in self._parameterized_layers.items(): + for layer_param_name, parameterized_layer in self.parameterized_layers_.items(): self.update_mask( layer_param_name, parameterized_layer.param.data.abs() < self._epsilon ) @@ -65,10 +65,10 @@ def on_update(self, state: State, event: Event, **kwargs): return if event.type_ == EventType.OPTIM_PRE_STEP: - for layer_param_name, _ in self._parameterized_layers.items(): + for layer_param_name, _ in self.parameterized_layers_.items(): self.apply_mask_gradient(layer_param_name) elif event.type_ == EventType.OPTIM_POST_STEP: - for layer_param_name, _ in self._parameterized_layers.items(): + for layer_param_name, _ in self.parameterized_layers_.items(): self.apply_mask_weight(layer_param_name) def on_end(self, state: State, event: Event, **kwargs): diff --git a/src/sparseml/modifiers/pruning/magnitude/pytorch.py b/src/sparseml/modifiers/pruning/magnitude/pytorch.py index bc61d3267ec..01cebba0a78 100644 --- a/src/sparseml/modifiers/pruning/magnitude/pytorch.py +++ b/src/sparseml/modifiers/pruning/magnitude/pytorch.py @@ -30,12 +30,12 @@ class MagnitudePruningModifierPyTorch(MagnitudePruningModifier, LayerParamMasking): - _parameterized_layers: Dict[str, ModelParameterizedLayer] = None + parameterized_layers_: Dict[str, ModelParameterizedLayer] = None _save_masks: bool = False _use_hooks: bool = False - _scheduler_function: SchedulerCalculationType = None - _mask_creator_function: MaskCreatorType = None - _current_sparsity: float = None + scheduler_function_: SchedulerCalculationType = None + mask_creator_function_: MaskCreatorType = None + current_sparsity_: float = None def on_initialize(self, state: State, event: Event, **kwargs) -> bool: if self.apply_globally: @@ -49,7 +49,7 @@ def on_initialize(self, state: State, event: Event, **kwargs) -> bool: if not state.model or not state.start_event: return False - self._scheduler_function = PruningSchedulerFactory.create_scheduler( + self.scheduler_function_ = PruningSchedulerFactory.create_scheduler( self.update_scheduler, PruningCreateSettings( self.start, @@ -60,13 +60,13 @@ def on_initialize(self, state: State, event: Event, **kwargs) -> bool: self.scheduler_args, ), ) - self._mask_creator_function = PruningMaskFactory.create_mask_creator( + self.mask_creator_function_ = PruningMaskFactory.create_mask_creator( self.mask_structure ) - self._parameterized_layers = state.model.get_layers_params(self.targets) + self.parameterized_layers_ = state.model.get_layers_params(self.targets) - for layer_param_name, parameterized_layer in self._parameterized_layers.items(): + for layer_param_name, parameterized_layer in self.parameterized_layers_.items(): self.add_mask( layer_param_name, parameterized_layer, @@ -77,17 +77,17 @@ def on_initialize(self, state: State, event: Event, **kwargs) -> bool: return True def on_finalize(self, state: State, event: Event, **kwargs) -> bool: - for layer_param_name, _ in self._parameterized_layers.items(): + for layer_param_name, _ in self.parameterized_layers_.items(): self.remove_mask(layer_param_name) return True def on_start(self, state: State, event: Event, **kwargs): - sparsity = self._scheduler_function(event, state) - self._current_sparsity = sparsity + sparsity = self.scheduler_function_(event, state) + self.current_sparsity_ = sparsity - for layer_param_name, parameterized_layer in self._parameterized_layers.items(): - mask = self._mask_creator_function( + for layer_param_name, parameterized_layer in self.parameterized_layers_.items(): + mask = self.mask_creator_function_( PruningMaskCreatorArgs( parameter=parameterized_layer.param, sparsity=sparsity, @@ -100,15 +100,15 @@ def on_start(self, state: State, event: Event, **kwargs): def on_update(self, state: State, event: Event, **kwargs): if event.type_ == EventType.BATCH_START: - sparsity = self._scheduler_function(event, state) - if sparsity != self._current_sparsity: - self._current_sparsity = sparsity + sparsity = self.scheduler_function_(event, state) + if sparsity != self.current_sparsity_: + self.current_sparsity_ = sparsity for ( layer_param_name, parameterized_layer, - ) in self._parameterized_layers.items(): - mask = self._mask_creator_function( + ) in self.parameterized_layers_.items(): + mask = self.mask_creator_function_( PruningMaskCreatorArgs( parameter=parameterized_layer.param, sparsity=sparsity, @@ -117,10 +117,10 @@ def on_update(self, state: State, event: Event, **kwargs): ) self.update_mask(layer_param_name, mask) elif event.type_ == EventType.OPTIM_PRE_STEP and not self._use_hooks: - for layer_param_name, _ in self._parameterized_layers.items(): + for layer_param_name, _ in self.parameterized_layers_.items(): self.apply_mask_gradient(layer_param_name) elif event.type_ == EventType.OPTIM_POST_STEP and not self._use_hooks: - for layer_param_name, _ in self._parameterized_layers.items(): + for layer_param_name, _ in self.parameterized_layers_.items(): self.apply_mask_weight(layer_param_name) def on_end(self, state: State, event: Event, **kwargs): diff --git a/src/sparseml/modifiers/pruning/utils/pytorch/layer_mask.py b/src/sparseml/modifiers/pruning/utils/pytorch/layer_mask.py index a8c2b16bb21..9c1a8eefe39 100644 --- a/src/sparseml/modifiers/pruning/utils/pytorch/layer_mask.py +++ b/src/sparseml/modifiers/pruning/utils/pytorch/layer_mask.py @@ -20,13 +20,15 @@ from torch.utils.hooks import RemovableHandle from sparseml.core import ModelParameterizedLayer +from pydantic import BaseModel __all__ = ["LayerParamMasking"] def param_mask_name(param_name: str) -> str: - return f"{param_name}_mask" + valid_name = param_name.replace(".", "_") + return f"{valid_name}_mask" def setup_mask_for_param(param: Parameter, mask: torch.Tensor) -> torch.Tensor: @@ -50,15 +52,12 @@ class ParameterizedLayerMaskSettings: use_hooks: bool = False -class LayerParamMasking: - def __init__(self): - self._mask_settings: Dict[str, ParameterizedLayerMaskSettings] = {} - self._masked_layer_params: Dict[ - str, ModelParameterizedLayer[Module, Parameter] - ] = {} - self._forward_hooks: Dict[str, RemovableHandle] = {} - self._backward_hooks: Dict[str, RemovableHandle] = {} - self._enabled = False +class LayerParamMasking(BaseModel): + _mask_settings: Dict[str, ParameterizedLayerMaskSettings] = {} + _masked_layer_params: Dict[str, ModelParameterizedLayer[Module, Parameter]] = {} + _forward_hooks: Dict[str, RemovableHandle] = {} + _backward_hooks: Dict[str, RemovableHandle] = {} + enabled_: bool = False def add_mask( self, @@ -96,7 +95,7 @@ def add_mask( if add_hooks: def _forward_hook_fn(module, input, output): - if not self._enabled: + if not self.enabled_: return output mask = module.get_buffer(mask_name) @@ -105,7 +104,7 @@ def _forward_hook_fn(module, input, output): return output def _backward_hook_fn(gradients): - if not self._enabled: + if not self.enabled_: return mask = parameterized_layer.layer.get_buffer(mask_name) @@ -129,7 +128,7 @@ def update_mask( parameterized_layer = self._masked_layer_params[layer_param_name] mask_name = param_mask_name(parameterized_layer.param_name) mask_tensor = parameterized_layer.layer.get_buffer(mask_name) - mask_tensor.fill_(setup_mask_for_param(parameterized_layer.param, mask)) + mask_tensor[:] = setup_mask_for_param(parameterized_layer.param, mask) def remove_mask(self, layer_param_name: str): mask_settings = self._mask_settings[layer_param_name] @@ -151,7 +150,7 @@ def remove_mask(self, layer_param_name: str): del self._backward_hooks[layer_param_name] def apply_mask_weight(self, layer_param_name: str): - if not self._enabled: + if not self.enabled_: return parameterized_layer = self._masked_layer_params[layer_param_name] @@ -160,7 +159,7 @@ def apply_mask_weight(self, layer_param_name: str): parameterized_layer.param.data = parameterized_layer.param.data * mask def apply_mask_gradient(self, layer_param_name: str): - if not self._enabled: + if not self.enabled_: return parameterized_layer = self._masked_layer_params[layer_param_name] @@ -171,7 +170,7 @@ def apply_mask_gradient(self, layer_param_name: str): parameterized_layer.param.grad = parameterized_layer.param.grad * mask def enable_masks(self): - self._enabled = True + self.enabled_ = True def disable_masks(self): - self._enabled = False + self.enabled_ = False diff --git a/src/sparseml/utils/pytorch/module.py b/src/sparseml/utils/pytorch/module.py index 53f17655ecd..fa18a5a8736 100644 --- a/src/sparseml/utils/pytorch/module.py +++ b/src/sparseml/utils/pytorch/module.py @@ -23,6 +23,7 @@ from packaging import version from torch.nn import Linear, Module, Parameter from torch.nn.modules.conv import _ConvNd +from sparseml.core.model.base import ModelParameterizedLayer try: @@ -63,6 +64,7 @@ "get_terminal_layers", "get_prunable_layers", "get_quantizable_layers", + "get_layers_params" ] @@ -236,3 +238,20 @@ def get_quantizable_layers(module: Module) -> Dict[str, Module]: quantizable[name] = layer return quantizable + +def get_layers_params(targets: Union[str, List[str]], module: Module) -> Dict[str, ModelParameterizedLayer[Parameter, Module]]: + params = get_params(targets, module) + layers = get_layers(targets, module) + + parameterized_layers = {} + for name, param in params.items(): + param_layer = ModelParameterizedLayer( + layer_name=name, + layer=layers[name], + param_name=name, + param=param + ) + parameterized_layers[name] = param_layer + + return parameterized_layers + \ No newline at end of file diff --git a/test_e2e.ipynb b/test_e2e.ipynb index 22f4984336d..9df55c0976f 100644 --- a/test_e2e.ipynb +++ b/test_e2e.ipynb @@ -14,7 +14,10 @@ "from torch.utils.data import DataLoader\n", "import datasets\n", "import os\n", - "from torch.optim import Adam" + "from torch.optim import Adam\n", + "from tqdm.auto import tqdm\n", + "from torch.nn import CrossEntropyLoss\n", + "from sparseml.core.event import EventType" ] }, { @@ -36,7 +39,8 @@ "NUM_LABELS = 3\n", "model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT)\n", "model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, NUM_LABELS)\n", - "optimizer = Adam(model.parameters(), lr=8e-3)" + "optimizer = Adam(model.parameters(), lr=8e-3)\n", + "criterion = CrossEntropyLoss()" ] }, { @@ -54,7 +58,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "6e77f370860f472394ab3547decb1fb7", + "model_id": "335893b467c8472b9cafb1717d7a7cdb", "version_major": 2, "version_minor": 0 }, @@ -136,34 +140,81 @@ "cell_type": "code", "execution_count": 8, "metadata": {}, + "outputs": [], + "source": [ + "session_data = session.initialize(\n", + " framework=Framework.pytorch,\n", + " recipe=recipe,\n", + " model=model,\n", + " teacher_model=None,\n", + " optimizer=optimizer,\n", + " train_data=train_loader,\n", + " val_data=val_loader,\n", + " start=0.0,\n", + " steps_per_epoch= len(train_loader) / BATCH_SIZE # number of times steps in called per epoch (total_data / batch_size in normal cases)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, "outputs": [ { - "ename": "AttributeError", - "evalue": "__fields_set__", + "ename": "ValueError", + "evalue": "batch start must be called first for callbacks", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[8], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m session\u001b[39m.\u001b[39;49minitialize(\n\u001b[1;32m 2\u001b[0m framework\u001b[39m=\u001b[39;49mFramework\u001b[39m.\u001b[39;49mpytorch,\n\u001b[1;32m 3\u001b[0m recipe\u001b[39m=\u001b[39;49mrecipe,\n\u001b[1;32m 4\u001b[0m model\u001b[39m=\u001b[39;49mmodel,\n\u001b[1;32m 5\u001b[0m teacher_model\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m 6\u001b[0m optimizer\u001b[39m=\u001b[39;49moptimizer,\n\u001b[1;32m 7\u001b[0m train_data\u001b[39m=\u001b[39;49mtrain_loader,\n\u001b[1;32m 8\u001b[0m val_data\u001b[39m=\u001b[39;49mval_loader\n\u001b[1;32m 9\u001b[0m )\n", - "File \u001b[0;32m~/sparseml/src/sparseml/core/session.py:145\u001b[0m, in \u001b[0;36mSparseSession.initialize\u001b[0;34m(self, framework, recipe, recipe_stage, recipe_args, model, teacher_model, optimizer, attach_optim_callbacks, train_data, val_data, test_data, calib_data, copy_data, start, steps_per_epoch, batches_per_step, **kwargs)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mupdate_data(train_data, val_data, test_data, calib_data, copy_data)\n\u001b[1;32m 143\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mupdate_start(start, steps_per_epoch, batches_per_step)\n\u001b[0;32m--> 145\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_check_compile_recipe()\n\u001b[1;32m 146\u001b[0m modifier_data \u001b[39m=\u001b[39m []\n\u001b[1;32m 148\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_modifiers:\n", - "File \u001b[0;32m~/sparseml/src/sparseml/core/session.py:274\u001b[0m, in \u001b[0;36mSparseSession._check_compile_recipe\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 271\u001b[0m \u001b[39mdel\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_modifiers\n\u001b[1;32m 273\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mrecipe_modifier_ready:\n\u001b[0;32m--> 274\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_modifiers \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstate\u001b[39m.\u001b[39;49mcompiled_recipe\u001b[39m.\u001b[39;49mcreate_modifier(\n\u001b[1;32m 275\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstate\u001b[39m.\u001b[39;49mframework\n\u001b[1;32m 276\u001b[0m )\n", - "File \u001b[0;32m~/sparseml/src/sparseml/core/recipe/recipe.py:133\u001b[0m, in \u001b[0;36mRecipe.create_modifier\u001b[0;34m(self, framework)\u001b[0m\n\u001b[1;32m 130\u001b[0m modifiers \u001b[39m=\u001b[39m []\n\u001b[1;32m 132\u001b[0m \u001b[39mfor\u001b[39;00m index, stage \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstages):\n\u001b[0;32m--> 133\u001b[0m stage_modifiers \u001b[39m=\u001b[39m stage\u001b[39m.\u001b[39;49mcreate_modifier(framework)\n\u001b[1;32m 134\u001b[0m stage_modifiers\u001b[39m.\u001b[39mindex \u001b[39m=\u001b[39m index\n\u001b[1;32m 135\u001b[0m stage_modifiers\u001b[39m.\u001b[39mgroup \u001b[39m=\u001b[39m stage\u001b[39m.\u001b[39mgroup\n", - "File \u001b[0;32m~/sparseml/src/sparseml/core/recipe/stage.py:66\u001b[0m, in \u001b[0;36mRecipeStage.create_modifier\u001b[0;34m(self, framework, parent_args)\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[39mfor\u001b[39;00m index, modifier \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodifiers):\n\u001b[1;32m 65\u001b[0m modifier \u001b[39m=\u001b[39m modifier\u001b[39m.\u001b[39mcreate_modifier(framework)\n\u001b[0;32m---> 66\u001b[0m modifier\u001b[39m.\u001b[39;49mgroup \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mgroup\n\u001b[1;32m 67\u001b[0m modifier\u001b[39m.\u001b[39mindex \u001b[39m=\u001b[39m index\n\u001b[1;32m 69\u001b[0m \u001b[39mreturn\u001b[39;00m stage_modifiers\n", - "File \u001b[0;32m~/sparseml/.venv/lib/python3.8/site-packages/pydantic/main.py:405\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.__setattr__\u001b[0;34m()\u001b[0m\n", - "\u001b[0;31mAttributeError\u001b[0m: __fields_set__" + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[10], line 16\u001b[0m\n\u001b[1;32m 14\u001b[0m inputs \u001b[39m=\u001b[39m inputs\u001b[39m.\u001b[39mto(device)\n\u001b[1;32m 15\u001b[0m labels \u001b[39m=\u001b[39m labels\u001b[39m.\u001b[39mto(device)\n\u001b[0;32m---> 16\u001b[0m session\u001b[39m.\u001b[39;49mevent(event_type\u001b[39m=\u001b[39;49mEventType\u001b[39m.\u001b[39;49mBATCH_START, batch_data\u001b[39m=\u001b[39;49m(\u001b[39minput\u001b[39;49m, labels))\n\u001b[1;32m 17\u001b[0m session\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39moptimizer\u001b[39m.\u001b[39moptimizer\u001b[39m.\u001b[39mzero_grad()\n\u001b[1;32m 19\u001b[0m outputs \u001b[39m=\u001b[39m session\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mmodel\u001b[39m.\u001b[39mmodel(inputs)\n", + "File \u001b[0;32m~/sparseml/src/sparseml/core/session.py:150\u001b[0m, in \u001b[0;36mSparseSession.event\u001b[0;34m(self, event_type, batch_data, loss, **kwargs)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mevent\u001b[39m(\n\u001b[1;32m 148\u001b[0m \u001b[39mself\u001b[39m, event_type: EventType, batch_data: Any \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m, loss: Any \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs\n\u001b[1;32m 149\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m ModifiedState:\n\u001b[0;32m--> 150\u001b[0m mod_data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_lifecycle\u001b[39m.\u001b[39;49mevent(\n\u001b[1;32m 151\u001b[0m event_type\u001b[39m=\u001b[39;49mevent_type, batch_data\u001b[39m=\u001b[39;49mbatch_data, loss\u001b[39m=\u001b[39;49mloss, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs\n\u001b[1;32m 152\u001b[0m )\n\u001b[1;32m 154\u001b[0m \u001b[39mreturn\u001b[39;00m ModifiedState(\n\u001b[1;32m 155\u001b[0m model\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mmodel\u001b[39m.\u001b[39mmodel \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mmodel \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 156\u001b[0m optimizer\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39moptimizer\u001b[39m.\u001b[39moptimizer \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39moptimizer \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 157\u001b[0m loss\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mloss\u001b[39m.\u001b[39mloss \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mloss \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 158\u001b[0m modifier_data\u001b[39m=\u001b[39mmod_data,\n\u001b[1;32m 159\u001b[0m )\n", + "File \u001b[0;32m~/sparseml/src/sparseml/core/lifecycle/session.py:140\u001b[0m, in \u001b[0;36mSparsificationLifecycle.event\u001b[0;34m(self, event_type, **kwargs)\u001b[0m\n\u001b[1;32m 138\u001b[0m event \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 139\u001b[0m mod_data \u001b[39m=\u001b[39m []\n\u001b[0;32m--> 140\u001b[0m \u001b[39mfor\u001b[39;00m event \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mevent_lifecycle\u001b[39m.\u001b[39;49mevents_from_type(event_type):\n\u001b[1;32m 141\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mstart_event \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 142\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mstart_event \u001b[39m=\u001b[39m event\n", + "File \u001b[0;32m~/sparseml/src/sparseml/core/lifecycle/event.py:43\u001b[0m, in \u001b[0;36mEventLifecycle.events_from_type\u001b[0;34m(self, type_)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mevents_from_type\u001b[39m(\u001b[39mself\u001b[39m, type_: EventType) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m List[Event]:\n\u001b[1;32m 42\u001b[0m \u001b[39mif\u001b[39;00m type_ \u001b[39m==\u001b[39m EventType\u001b[39m.\u001b[39mBATCH_START:\n\u001b[0;32m---> 43\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbatch_start_events()\n\u001b[1;32m 45\u001b[0m \u001b[39mif\u001b[39;00m type_ \u001b[39m==\u001b[39m EventType\u001b[39m.\u001b[39mLOSS_CALCULATED:\n\u001b[1;32m 46\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mloss_calculated_events()\n", + "File \u001b[0;32m~/sparseml/src/sparseml/core/lifecycle/event.py:214\u001b[0m, in \u001b[0;36mCallbacksEventLifecycle.batch_start_events\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mbatch_start_events\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m List[Event]:\n\u001b[1;32m 213\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtype_first \u001b[39m!=\u001b[39m EventType\u001b[39m.\u001b[39mBATCH_START:\n\u001b[0;32m--> 214\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mbatch start must be called first for callbacks\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 216\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtype_ \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtype_ \u001b[39m!=\u001b[39m EventType\u001b[39m.\u001b[39mBATCH_END:\n\u001b[1;32m 217\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mbatch start must be called after batch end\u001b[39m\u001b[39m\"\u001b[39m)\n", + "\u001b[0;31mValueError\u001b[0m: batch start must be called first for callbacks" ] } ], "source": [ - "session.initialize(\n", - " framework=Framework.pytorch,\n", - " recipe=recipe,\n", - " model=model,\n", - " teacher_model=None,\n", - " optimizer=optimizer,\n", - " train_data=train_loader,\n", - " val_data=val_loader\n", - ")" + "running_loss = 0.0\n", + "total_correct = 0\n", + "total_predictions = 0\n", + "\n", + "NUM_EPOCHS = 15\n", + "device = \"cuda:0\"\n", + "\n", + "\n", + "session.state.model.model.to(device)\n", + "\n", + "# loop through batches\n", + "for epoch in range(NUM_EPOCHS):\n", + " for step, (inputs, labels) in enumerate(session.state.data.train):\n", + " inputs = inputs.to(device)\n", + " labels = labels.to(device)\n", + " session.event(event_type=EventType.BATCH_START, batch_data=(input, labels))\n", + " session.state.optimizer.optimizer.zero_grad()\n", + "\n", + " outputs = session.state.model.model(inputs)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " session.event(event_type=EventType.LOSS_CALCULATED, loss=loss)\n", + "\n", + " session.event(event_type=EventType.OPTIM_PRE_STEP)\n", + " session.state.optimizer.optimizer.step()\n", + " session.event(event_type=EventType.OPTIM_POST_STEP)\n", + "\n", + " running_loss += loss.item()\n", + "\n", + " predictions = outputs.argmax(dim=1)\n", + " total_correct += torch.sum(predictions == labels).item()\n", + " total_predictions += inputs.size(0)\n", + "\n", + " #session.event(event_type=EventType.BATCH_END)\n", + "\n", + " loss = running_loss / (step + 1.0)\n", + " accuracy = total_correct / total_predictions\n", + " print(\"Epoch: {} Loss: {} Accuracy: {}\".format(epoch, loss, accuracy))\n" ] }, { @@ -171,6 +222,16 @@ "execution_count": null, "metadata": {}, "outputs": [], + "source": [ + "from sparseml.pytorch.utils import get_prunable_layers, tensor_sparsity\n", + "\n", + "for (name, layer) in get_prunable_layers(session.state.model.model):\n", + " print(f\"{name}.weight: {tensor_sparsity(layer.weight).item():.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, "source": [] } ], diff --git a/test_e2e.py b/test_e2e.py index 3a6cdd2249e..f1347bac8e3 100644 --- a/test_e2e.py +++ b/test_e2e.py @@ -7,6 +7,9 @@ import datasets import os from torch.optim import Adam +from torch.nn import CrossEntropyLoss +from sparseml.core.event import EventType +from sparseml.pytorch.utils import get_prunable_layers, tensor_sparsity sml.create_session() session = sml.active_session() @@ -46,13 +49,66 @@ val_loader = DataLoader(val_dataset, BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=16) recipe = "test_e2e_recipe.yaml" +criterion = CrossEntropyLoss() -session.initialize( + + +#this doubles the stages +#session.pre_initialize_structure( +# model=model, +# recipe=recipe, +# framework=Framework.pytorch +#) + +session_data = session.initialize( framework=Framework.pytorch, recipe=recipe, model=model, teacher_model=None, optimizer=optimizer, train_data=train_loader, - val_data=val_loader -) \ No newline at end of file + val_data=val_loader, + start=0.0, + steps_per_epoch= len(train_loader) # number of times steps in called per epoch (total_data / batch_size in normal cases) +) + +running_loss = 0.0 +total_correct = 0 +total_predictions = 0 + +NUM_EPOCHS = 15 +device = "cuda:0" + +session.state.model.model.to(device) + +# loop through batches +for epoch in range(NUM_EPOCHS): + for step, (inputs, labels) in enumerate(session.state.data.train): + inputs = inputs.to(device) + labels = labels.to(device) + session.event(event_type=EventType.BATCH_START, batch_data=(input, labels)) + session.state.optimizer.optimizer.zero_grad() + + outputs = session.state.model.model(inputs) + loss = criterion(outputs, labels) + loss.backward() + session.event(event_type=EventType.LOSS_CALCULATED, loss=loss) + + session.event(event_type=EventType.OPTIM_PRE_STEP) + session.state.optimizer.optimizer.step() + session.event(event_type=EventType.OPTIM_POST_STEP) + + running_loss += loss.item() + + predictions = outputs.argmax(dim=1) + total_correct += torch.sum(predictions == labels).item() + total_predictions += inputs.size(0) + + session.event(event_type=EventType.BATCH_END) + + loss = running_loss / (step + 1.0) + accuracy = total_correct / total_predictions + print("Epoch: {} Loss: {} Accuracy: {}".format(epoch, loss, accuracy)) + +for (name, layer) in get_prunable_layers(session.state.model.model): + print(f"{name}.weight: {tensor_sparsity(layer.weight).item():.4f}") \ No newline at end of file diff --git a/test_e2e_recipe.yaml b/test_e2e_recipe.yaml index 9618859f751..cf14247d0c7 100644 --- a/test_e2e_recipe.yaml +++ b/test_e2e_recipe.yaml @@ -3,26 +3,61 @@ test_stage: MagnitudePruningModifier: init_sparsity: 0.0 final_sparsity: 0.5 - start_epoch: 1.0 - end_epoch: 10.0 + start: 5.0 + end: 10.0 update_frequency: 0.5 targets: - - 'features.0.0.weight' - - 'features.18.0.weight' - - 're:features.*.conv.*.weight' - - 're:features.*.conv.*.*.weight' + - features.0.0.weight + - features.1.conv.0.0.weight + - features.1.conv.1.weight + - features.2.conv.0.0.weight + - features.2.conv.1.0.weight + - features.2.conv.2.weight + - features.3.conv.0.0.weight + - features.3.conv.1.0.weight + - features.3.conv.2.weight + - features.4.conv.0.0.weight + - features.4.conv.1.0.weight + - features.4.conv.2.weight + - features.5.conv.0.0.weight + - features.5.conv.1.0.weight + - features.5.conv.2.weight + - features.6.conv.0.0.weight + - features.6.conv.1.0.weight + - features.6.conv.2.weight + - features.7.conv.0.0.weight + - features.7.conv.1.0.weight + - features.7.conv.2.weight + - features.8.conv.0.0.weight + - features.8.conv.1.0.weight + - features.8.conv.2.weight + - features.9.conv.0.0.weight + - features.9.conv.1.0.weight + - features.9.conv.2.weight + - features.10.conv.0.0.weight + - features.10.conv.1.0.weight + - features.10.conv.2.weight + - features.11.conv.0.0.weight + - features.11.conv.1.0.weight + - features.11.conv.2.weight + - features.12.conv.0.0.weight + - features.12.conv.1.0.weight + - features.12.conv.2.weight + - features.13.conv.0.0.weight + - features.13.conv.1.0.weight + - features.13.conv.2.weight + - features.14.conv.0.0.weight + - features.14.conv.1.0.weight + - features.14.conv.2.weight + - features.15.conv.0.0.weight + - features.15.conv.1.0.weight + - features.15.conv.2.weight + - features.16.conv.0.0.weight + - features.16.conv.1.0.weight + - features.16.conv.2.weight + - features.17.conv.0.0.weight + - features.17.conv.1.0.weight + - features.17.conv.2.weight + - features.18.0.weight + - classifier.1.weight leave_enabled: True -test2_stage: - pruning_modifiers: - MagnitudePruningModifier: - init_sparsity: 0.0 - final_sparsity: 0.5 - start_epoch: 1.0 - end_epoch: 10.0 - update_frequency: 0.5 - targets: - - 'features.0.0.weight' - - 'features.18.0.weight' - - 're:features.*.conv.*.weight' - - 're:features.*.conv.*.*.weight' - leave_enabled: True \ No newline at end of file From bc5798d80f262b5106935d5511187692ef50392e Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 20 Sep 2023 12:11:49 -0400 Subject: [PATCH 17/27] fix polynomial scheduler, leave masks enabled on end --- src/sparseml/modifiers/pruning/helpers.py | 5 +---- .../modifiers/pruning/magnitude/pytorch.py | 3 ++- test_e2e.py | 22 +++++++++---------- test_e2e_recipe.yaml | 4 ++-- 4 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/sparseml/modifiers/pruning/helpers.py b/src/sparseml/modifiers/pruning/helpers.py index d29ea2cefc8..d1703dc8d3b 100644 --- a/src/sparseml/modifiers/pruning/helpers.py +++ b/src/sparseml/modifiers/pruning/helpers.py @@ -140,10 +140,7 @@ def _schedule(event: Event, state: State) -> float: settings.end - settings.start ) - if exponent % 2 == 0: - scaled_complete = -1 * np.exp(per_complete - 1) + 1 - else: - scaled_complete = np.exp(per_complete - 1) - 1 + scaled_complete = pow(per_complete - 1, exponent) + 1 return ( settings.init_sparsity diff --git a/src/sparseml/modifiers/pruning/magnitude/pytorch.py b/src/sparseml/modifiers/pruning/magnitude/pytorch.py index 01cebba0a78..ec27b5fdac4 100644 --- a/src/sparseml/modifiers/pruning/magnitude/pytorch.py +++ b/src/sparseml/modifiers/pruning/magnitude/pytorch.py @@ -124,4 +124,5 @@ def on_update(self, state: State, event: Event, **kwargs): self.apply_mask_weight(layer_param_name) def on_end(self, state: State, event: Event, **kwargs): - self.disable_masks() + if not self.leave_enabled: + self.disable_masks() diff --git a/test_e2e.py b/test_e2e.py index f1347bac8e3..d714c1d549a 100644 --- a/test_e2e.py +++ b/test_e2e.py @@ -15,15 +15,17 @@ session = sml.active_session() NUM_LABELS = 3 +device = "cuda:0" +BATCH_SIZE = 32 + model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT) model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, NUM_LABELS) +model.to(device) optimizer = Adam(model.parameters(), lr=8e-3) train_path = "/home/sadkins/.cache/huggingface/datasets/downloads/extracted/dbf92bfb2c3766fb3083a51374ad94d8a3690f53cdf0f9113a231c2351c9ff33/train" val_path = "/home/sadkins/.cache/huggingface/datasets/downloads/extracted/510ede718de2aeaa2f9d88b0d81d88c449beeb7d074ea594bdf25a0e6a9d51d0/validation" -NUM_LABELS = 3 -BATCH_SIZE = 32 # imagenet transforms imagenet_transform = transforms.Compose([ @@ -72,22 +74,18 @@ steps_per_epoch= len(train_loader) # number of times steps in called per epoch (total_data / batch_size in normal cases) ) -running_loss = 0.0 -total_correct = 0 -total_predictions = 0 - -NUM_EPOCHS = 15 -device = "cuda:0" - -session.state.model.model.to(device) +NUM_EPOCHS = 2 # loop through batches for epoch in range(NUM_EPOCHS): + running_loss = 0.0 + total_correct = 0 + total_predictions = 0 for step, (inputs, labels) in enumerate(session.state.data.train): inputs = inputs.to(device) labels = labels.to(device) - session.event(event_type=EventType.BATCH_START, batch_data=(input, labels)) session.state.optimizer.optimizer.zero_grad() + session.event(event_type=EventType.BATCH_START, batch_data=(input, labels)) outputs = session.state.model.model(inputs) loss = criterion(outputs, labels) @@ -110,5 +108,7 @@ accuracy = total_correct / total_predictions print("Epoch: {} Loss: {} Accuracy: {}".format(epoch, loss, accuracy)) +#session.finalize() + for (name, layer) in get_prunable_layers(session.state.model.model): print(f"{name}.weight: {tensor_sparsity(layer.weight).item():.4f}") \ No newline at end of file diff --git a/test_e2e_recipe.yaml b/test_e2e_recipe.yaml index cf14247d0c7..5a844dc2e48 100644 --- a/test_e2e_recipe.yaml +++ b/test_e2e_recipe.yaml @@ -3,8 +3,8 @@ test_stage: MagnitudePruningModifier: init_sparsity: 0.0 final_sparsity: 0.5 - start: 5.0 - end: 10.0 + start: 0.0 + end: 1.0 update_frequency: 0.5 targets: - features.0.0.weight From a35581ddc11dde4c08f5a308b5af2c32fadce889 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 20 Sep 2023 12:17:07 -0400 Subject: [PATCH 18/27] remove e2e files --- test_e2e.ipynb | 260 ------------------------------------------- test_e2e.py | 114 ------------------- test_e2e_recipe.yaml | 63 ----------- 3 files changed, 437 deletions(-) delete mode 100644 test_e2e.ipynb delete mode 100644 test_e2e.py delete mode 100644 test_e2e_recipe.yaml diff --git a/test_e2e.ipynb b/test_e2e.ipynb deleted file mode 100644 index 9df55c0976f..00000000000 --- a/test_e2e.ipynb +++ /dev/null @@ -1,260 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import sparseml.core.session as sml\n", - "from sparseml.core.framework import Framework\n", - "import torchvision\n", - "from torchvision import transforms\n", - "import torch\n", - "from torch.utils.data import DataLoader\n", - "import datasets\n", - "import os\n", - "from torch.optim import Adam\n", - "from tqdm.auto import tqdm\n", - "from torch.nn import CrossEntropyLoss\n", - "from sparseml.core.event import EventType" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "sml.create_session()\n", - "session = sml.active_session()" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "NUM_LABELS = 3\n", - "model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT)\n", - "model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, NUM_LABELS)\n", - "optimizer = Adam(model.parameters(), lr=8e-3)\n", - "criterion = CrossEntropyLoss()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Found cached dataset beans (/home/sadkins/.cache/huggingface/datasets/beans/default/0.0.0/90c755fb6db1c0ccdad02e897a37969dbf070bed3755d4391e269ff70642d791)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "335893b467c8472b9cafb1717d7a7cdb", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/3 [00:00 16\u001b[0m session\u001b[39m.\u001b[39;49mevent(event_type\u001b[39m=\u001b[39;49mEventType\u001b[39m.\u001b[39;49mBATCH_START, batch_data\u001b[39m=\u001b[39;49m(\u001b[39minput\u001b[39;49m, labels))\n\u001b[1;32m 17\u001b[0m session\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39moptimizer\u001b[39m.\u001b[39moptimizer\u001b[39m.\u001b[39mzero_grad()\n\u001b[1;32m 19\u001b[0m outputs \u001b[39m=\u001b[39m session\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mmodel\u001b[39m.\u001b[39mmodel(inputs)\n", - "File \u001b[0;32m~/sparseml/src/sparseml/core/session.py:150\u001b[0m, in \u001b[0;36mSparseSession.event\u001b[0;34m(self, event_type, batch_data, loss, **kwargs)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mevent\u001b[39m(\n\u001b[1;32m 148\u001b[0m \u001b[39mself\u001b[39m, event_type: EventType, batch_data: Any \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m, loss: Any \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs\n\u001b[1;32m 149\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m ModifiedState:\n\u001b[0;32m--> 150\u001b[0m mod_data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_lifecycle\u001b[39m.\u001b[39;49mevent(\n\u001b[1;32m 151\u001b[0m event_type\u001b[39m=\u001b[39;49mevent_type, batch_data\u001b[39m=\u001b[39;49mbatch_data, loss\u001b[39m=\u001b[39;49mloss, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs\n\u001b[1;32m 152\u001b[0m )\n\u001b[1;32m 154\u001b[0m \u001b[39mreturn\u001b[39;00m ModifiedState(\n\u001b[1;32m 155\u001b[0m model\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mmodel\u001b[39m.\u001b[39mmodel \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mmodel \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 156\u001b[0m optimizer\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39moptimizer\u001b[39m.\u001b[39moptimizer \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39moptimizer \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 157\u001b[0m loss\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mloss\u001b[39m.\u001b[39mloss \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mloss \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 158\u001b[0m modifier_data\u001b[39m=\u001b[39mmod_data,\n\u001b[1;32m 159\u001b[0m )\n", - "File \u001b[0;32m~/sparseml/src/sparseml/core/lifecycle/session.py:140\u001b[0m, in \u001b[0;36mSparsificationLifecycle.event\u001b[0;34m(self, event_type, **kwargs)\u001b[0m\n\u001b[1;32m 138\u001b[0m event \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 139\u001b[0m mod_data \u001b[39m=\u001b[39m []\n\u001b[0;32m--> 140\u001b[0m \u001b[39mfor\u001b[39;00m event \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mevent_lifecycle\u001b[39m.\u001b[39;49mevents_from_type(event_type):\n\u001b[1;32m 141\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mstart_event \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 142\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mstart_event \u001b[39m=\u001b[39m event\n", - "File \u001b[0;32m~/sparseml/src/sparseml/core/lifecycle/event.py:43\u001b[0m, in \u001b[0;36mEventLifecycle.events_from_type\u001b[0;34m(self, type_)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mevents_from_type\u001b[39m(\u001b[39mself\u001b[39m, type_: EventType) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m List[Event]:\n\u001b[1;32m 42\u001b[0m \u001b[39mif\u001b[39;00m type_ \u001b[39m==\u001b[39m EventType\u001b[39m.\u001b[39mBATCH_START:\n\u001b[0;32m---> 43\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbatch_start_events()\n\u001b[1;32m 45\u001b[0m \u001b[39mif\u001b[39;00m type_ \u001b[39m==\u001b[39m EventType\u001b[39m.\u001b[39mLOSS_CALCULATED:\n\u001b[1;32m 46\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mloss_calculated_events()\n", - "File \u001b[0;32m~/sparseml/src/sparseml/core/lifecycle/event.py:214\u001b[0m, in \u001b[0;36mCallbacksEventLifecycle.batch_start_events\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mbatch_start_events\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m List[Event]:\n\u001b[1;32m 213\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtype_first \u001b[39m!=\u001b[39m EventType\u001b[39m.\u001b[39mBATCH_START:\n\u001b[0;32m--> 214\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mbatch start must be called first for callbacks\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 216\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtype_ \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtype_ \u001b[39m!=\u001b[39m EventType\u001b[39m.\u001b[39mBATCH_END:\n\u001b[1;32m 217\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mbatch start must be called after batch end\u001b[39m\u001b[39m\"\u001b[39m)\n", - "\u001b[0;31mValueError\u001b[0m: batch start must be called first for callbacks" - ] - } - ], - "source": [ - "running_loss = 0.0\n", - "total_correct = 0\n", - "total_predictions = 0\n", - "\n", - "NUM_EPOCHS = 15\n", - "device = \"cuda:0\"\n", - "\n", - "\n", - "session.state.model.model.to(device)\n", - "\n", - "# loop through batches\n", - "for epoch in range(NUM_EPOCHS):\n", - " for step, (inputs, labels) in enumerate(session.state.data.train):\n", - " inputs = inputs.to(device)\n", - " labels = labels.to(device)\n", - " session.event(event_type=EventType.BATCH_START, batch_data=(input, labels))\n", - " session.state.optimizer.optimizer.zero_grad()\n", - "\n", - " outputs = session.state.model.model(inputs)\n", - " loss = criterion(outputs, labels)\n", - " loss.backward()\n", - " session.event(event_type=EventType.LOSS_CALCULATED, loss=loss)\n", - "\n", - " session.event(event_type=EventType.OPTIM_PRE_STEP)\n", - " session.state.optimizer.optimizer.step()\n", - " session.event(event_type=EventType.OPTIM_POST_STEP)\n", - "\n", - " running_loss += loss.item()\n", - "\n", - " predictions = outputs.argmax(dim=1)\n", - " total_correct += torch.sum(predictions == labels).item()\n", - " total_predictions += inputs.size(0)\n", - "\n", - " #session.event(event_type=EventType.BATCH_END)\n", - "\n", - " loss = running_loss / (step + 1.0)\n", - " accuracy = total_correct / total_predictions\n", - " print(\"Epoch: {} Loss: {} Accuracy: {}\".format(epoch, loss, accuracy))\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from sparseml.pytorch.utils import get_prunable_layers, tensor_sparsity\n", - "\n", - "for (name, layer) in get_prunable_layers(session.state.model.model):\n", - " print(f\"{name}.weight: {tensor_sparsity(layer.weight).item():.4f}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.10" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/test_e2e.py b/test_e2e.py deleted file mode 100644 index d714c1d549a..00000000000 --- a/test_e2e.py +++ /dev/null @@ -1,114 +0,0 @@ -import sparseml.core.session as sml -from sparseml.core.framework import Framework -import torchvision -from torchvision import transforms -import torch -from torch.utils.data import DataLoader -import datasets -import os -from torch.optim import Adam -from torch.nn import CrossEntropyLoss -from sparseml.core.event import EventType -from sparseml.pytorch.utils import get_prunable_layers, tensor_sparsity - -sml.create_session() -session = sml.active_session() - -NUM_LABELS = 3 -device = "cuda:0" -BATCH_SIZE = 32 - -model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT) -model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, NUM_LABELS) -model.to(device) -optimizer = Adam(model.parameters(), lr=8e-3) - -train_path = "/home/sadkins/.cache/huggingface/datasets/downloads/extracted/dbf92bfb2c3766fb3083a51374ad94d8a3690f53cdf0f9113a231c2351c9ff33/train" -val_path = "/home/sadkins/.cache/huggingface/datasets/downloads/extracted/510ede718de2aeaa2f9d88b0d81d88c449beeb7d074ea594bdf25a0e6a9d51d0/validation" - - -# imagenet transforms -imagenet_transform = transforms.Compose([ - transforms.Resize(size=256, interpolation=transforms.InterpolationMode.BILINEAR, max_size=None, antialias=None), - transforms.CenterCrop(size=(224, 224)), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) -]) - -# datasets -train_dataset = torchvision.datasets.ImageFolder( - root=train_path, - transform=imagenet_transform -) - -val_dataset = torchvision.datasets.ImageFolder( - root=val_path, - transform=imagenet_transform -) - -# dataloaders -train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=16) -val_loader = DataLoader(val_dataset, BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=16) - -recipe = "test_e2e_recipe.yaml" -criterion = CrossEntropyLoss() - - - -#this doubles the stages -#session.pre_initialize_structure( -# model=model, -# recipe=recipe, -# framework=Framework.pytorch -#) - -session_data = session.initialize( - framework=Framework.pytorch, - recipe=recipe, - model=model, - teacher_model=None, - optimizer=optimizer, - train_data=train_loader, - val_data=val_loader, - start=0.0, - steps_per_epoch= len(train_loader) # number of times steps in called per epoch (total_data / batch_size in normal cases) -) - -NUM_EPOCHS = 2 - -# loop through batches -for epoch in range(NUM_EPOCHS): - running_loss = 0.0 - total_correct = 0 - total_predictions = 0 - for step, (inputs, labels) in enumerate(session.state.data.train): - inputs = inputs.to(device) - labels = labels.to(device) - session.state.optimizer.optimizer.zero_grad() - session.event(event_type=EventType.BATCH_START, batch_data=(input, labels)) - - outputs = session.state.model.model(inputs) - loss = criterion(outputs, labels) - loss.backward() - session.event(event_type=EventType.LOSS_CALCULATED, loss=loss) - - session.event(event_type=EventType.OPTIM_PRE_STEP) - session.state.optimizer.optimizer.step() - session.event(event_type=EventType.OPTIM_POST_STEP) - - running_loss += loss.item() - - predictions = outputs.argmax(dim=1) - total_correct += torch.sum(predictions == labels).item() - total_predictions += inputs.size(0) - - session.event(event_type=EventType.BATCH_END) - - loss = running_loss / (step + 1.0) - accuracy = total_correct / total_predictions - print("Epoch: {} Loss: {} Accuracy: {}".format(epoch, loss, accuracy)) - -#session.finalize() - -for (name, layer) in get_prunable_layers(session.state.model.model): - print(f"{name}.weight: {tensor_sparsity(layer.weight).item():.4f}") \ No newline at end of file diff --git a/test_e2e_recipe.yaml b/test_e2e_recipe.yaml deleted file mode 100644 index 5a844dc2e48..00000000000 --- a/test_e2e_recipe.yaml +++ /dev/null @@ -1,63 +0,0 @@ -test_stage: - pruning_modifiers: - MagnitudePruningModifier: - init_sparsity: 0.0 - final_sparsity: 0.5 - start: 0.0 - end: 1.0 - update_frequency: 0.5 - targets: - - features.0.0.weight - - features.1.conv.0.0.weight - - features.1.conv.1.weight - - features.2.conv.0.0.weight - - features.2.conv.1.0.weight - - features.2.conv.2.weight - - features.3.conv.0.0.weight - - features.3.conv.1.0.weight - - features.3.conv.2.weight - - features.4.conv.0.0.weight - - features.4.conv.1.0.weight - - features.4.conv.2.weight - - features.5.conv.0.0.weight - - features.5.conv.1.0.weight - - features.5.conv.2.weight - - features.6.conv.0.0.weight - - features.6.conv.1.0.weight - - features.6.conv.2.weight - - features.7.conv.0.0.weight - - features.7.conv.1.0.weight - - features.7.conv.2.weight - - features.8.conv.0.0.weight - - features.8.conv.1.0.weight - - features.8.conv.2.weight - - features.9.conv.0.0.weight - - features.9.conv.1.0.weight - - features.9.conv.2.weight - - features.10.conv.0.0.weight - - features.10.conv.1.0.weight - - features.10.conv.2.weight - - features.11.conv.0.0.weight - - features.11.conv.1.0.weight - - features.11.conv.2.weight - - features.12.conv.0.0.weight - - features.12.conv.1.0.weight - - features.12.conv.2.weight - - features.13.conv.0.0.weight - - features.13.conv.1.0.weight - - features.13.conv.2.weight - - features.14.conv.0.0.weight - - features.14.conv.1.0.weight - - features.14.conv.2.weight - - features.15.conv.0.0.weight - - features.15.conv.1.0.weight - - features.15.conv.2.weight - - features.16.conv.0.0.weight - - features.16.conv.1.0.weight - - features.16.conv.2.weight - - features.17.conv.0.0.weight - - features.17.conv.1.0.weight - - features.17.conv.2.weight - - features.18.0.weight - - classifier.1.weight - leave_enabled: True From 71869be8bca1adcc33743599c5802c7b3e564eaf Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Wed, 20 Sep 2023 12:21:34 -0400 Subject: [PATCH 19/27] add on_event for modifier lifecycle and add initial integration for torchvision --- src/sparseml/core/modifier/modifier.py | 5 +++++ .../integrations/torchvision/__init__.py | 0 .../integrations/torchvision/data/__init__.py | 0 .../integrations/torchvision/evaluator.py | 0 .../torchvision/metrics/__init__.py | 0 .../integrations/torchvision/model/__init__.py | 0 .../integrations/torchvision/optim/__init__.py | 0 .../integrations/torchvision/trainer.py | 0 .../modifiers/pruning/magnitude/pytorch.py | 17 +++++++++++++---- 9 files changed, 18 insertions(+), 4 deletions(-) create mode 100644 src/sparseml/integrations/torchvision/__init__.py create mode 100644 src/sparseml/integrations/torchvision/data/__init__.py create mode 100644 src/sparseml/integrations/torchvision/evaluator.py create mode 100644 src/sparseml/integrations/torchvision/metrics/__init__.py create mode 100644 src/sparseml/integrations/torchvision/model/__init__.py create mode 100644 src/sparseml/integrations/torchvision/optim/__init__.py create mode 100644 src/sparseml/integrations/torchvision/trainer.py diff --git a/src/sparseml/core/modifier/modifier.py b/src/sparseml/core/modifier/modifier.py index 5bdecb54b03..5d39d275ab3 100644 --- a/src/sparseml/core/modifier/modifier.py +++ b/src/sparseml/core/modifier/modifier.py @@ -111,6 +111,8 @@ def update_event(self, state: "State", event: Event, **kwargs): if self.finalized_: raise RuntimeError("cannot update a finalized modifier") + self.on_event(state, event, **kwargs) + # handle starting the modifier if needed if ( event.type_ == EventType.BATCH_START @@ -165,3 +167,6 @@ def on_update(self, state: "State", event: Event, **kwargs): def on_end(self, state: "State", event: Event, **kwargs): raise NotImplementedError() + + def on_event(self, state: State, event: Event, **kwargs): + pass diff --git a/src/sparseml/integrations/torchvision/__init__.py b/src/sparseml/integrations/torchvision/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/sparseml/integrations/torchvision/data/__init__.py b/src/sparseml/integrations/torchvision/data/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/sparseml/integrations/torchvision/evaluator.py b/src/sparseml/integrations/torchvision/evaluator.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/sparseml/integrations/torchvision/metrics/__init__.py b/src/sparseml/integrations/torchvision/metrics/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/sparseml/integrations/torchvision/model/__init__.py b/src/sparseml/integrations/torchvision/model/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/sparseml/integrations/torchvision/optim/__init__.py b/src/sparseml/integrations/torchvision/optim/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/sparseml/integrations/torchvision/trainer.py b/src/sparseml/integrations/torchvision/trainer.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/sparseml/modifiers/pruning/magnitude/pytorch.py b/src/sparseml/modifiers/pruning/magnitude/pytorch.py index ec27b5fdac4..acbbfc9fd20 100644 --- a/src/sparseml/modifiers/pruning/magnitude/pytorch.py +++ b/src/sparseml/modifiers/pruning/magnitude/pytorch.py @@ -116,13 +116,22 @@ def on_update(self, state: State, event: Event, **kwargs): ) ) self.update_mask(layer_param_name, mask) - elif event.type_ == EventType.OPTIM_PRE_STEP and not self._use_hooks: + else: + self._update_masks(event) + + def on_end(self, state: State, event: Event, **kwargs): + if not self.leave_enabled: + self.disable_masks() + + def on_event(self, state: State, event: Event, **kwargs): + if event.current_index >= self.end and self.leave_enabled: + self._update_masks(event) + + def _update_masks(self, event: Event): + if event.type_ == EventType.OPTIM_PRE_STEP and not self._use_hooks: for layer_param_name, _ in self.parameterized_layers_.items(): self.apply_mask_gradient(layer_param_name) elif event.type_ == EventType.OPTIM_POST_STEP and not self._use_hooks: for layer_param_name, _ in self.parameterized_layers_.items(): self.apply_mask_weight(layer_param_name) - def on_end(self, state: State, event: Event, **kwargs): - if not self.leave_enabled: - self.disable_masks() From 2d04ea02536cc43407144d980afb2bf83de20284 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 20 Sep 2023 12:33:18 -0400 Subject: [PATCH 20/27] leave_enabled fixes --- src/sparseml/core/modifier/modifier.py | 2 +- src/sparseml/modifiers/pruning/magnitude/base.py | 2 +- src/sparseml/modifiers/pruning/magnitude/pytorch.py | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/sparseml/core/modifier/modifier.py b/src/sparseml/core/modifier/modifier.py index 5d39d275ab3..a5917d4f00b 100644 --- a/src/sparseml/core/modifier/modifier.py +++ b/src/sparseml/core/modifier/modifier.py @@ -168,5 +168,5 @@ def on_update(self, state: "State", event: Event, **kwargs): def on_end(self, state: "State", event: Event, **kwargs): raise NotImplementedError() - def on_event(self, state: State, event: Event, **kwargs): + def on_event(self, state: "State", event: Event, **kwargs): pass diff --git a/src/sparseml/modifiers/pruning/magnitude/base.py b/src/sparseml/modifiers/pruning/magnitude/base.py index 802de89049b..6a7c5e5c2b3 100644 --- a/src/sparseml/modifiers/pruning/magnitude/base.py +++ b/src/sparseml/modifiers/pruning/magnitude/base.py @@ -27,7 +27,7 @@ class MagnitudePruningModifier(Modifier): update_scheduler: str = "cubic" scheduler_args: Dict[str, Any] = {} mask_structure: str = "unstructured" - leave_enabled: bool = False + leave_enabled: bool = True apply_globally: bool = False def on_initialize_structure(self, state: State, **kwargs): diff --git a/src/sparseml/modifiers/pruning/magnitude/pytorch.py b/src/sparseml/modifiers/pruning/magnitude/pytorch.py index acbbfc9fd20..be8d11a2a35 100644 --- a/src/sparseml/modifiers/pruning/magnitude/pytorch.py +++ b/src/sparseml/modifiers/pruning/magnitude/pytorch.py @@ -77,8 +77,9 @@ def on_initialize(self, state: State, event: Event, **kwargs) -> bool: return True def on_finalize(self, state: State, event: Event, **kwargs) -> bool: - for layer_param_name, _ in self.parameterized_layers_.items(): - self.remove_mask(layer_param_name) + if not self.leave_enabled: + for layer_param_name, _ in self.parameterized_layers_.items(): + self.remove_mask(layer_param_name) return True From 7b182e46594f28af74dfa232203504edac5ddb7e Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 20 Sep 2023 15:00:09 -0400 Subject: [PATCH 21/27] fixing evals and finalization --- src/sparseml/core/modifier/modifier.py | 8 ++++---- src/sparseml/core/recipe/args.py | 8 +++++--- src/sparseml/core/recipe/recipe.py | 7 ++++--- src/sparseml/core/state.py | 2 +- src/sparseml/modifiers/distillation/output/pytorch.py | 4 ++-- src/sparseml/modifiers/pruning/constant/pytorch.py | 4 ++-- src/sparseml/modifiers/pruning/magnitude/pytorch.py | 4 ++-- 7 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/sparseml/core/modifier/modifier.py b/src/sparseml/core/modifier/modifier.py index a5917d4f00b..82ee423d87a 100644 --- a/src/sparseml/core/modifier/modifier.py +++ b/src/sparseml/core/modifier/modifier.py @@ -73,7 +73,7 @@ def initialize(self, state: "State", **kwargs): if state.start_event is None: return - initialized = self.on_initialize(state=state, event=state.start_event, **kwargs) + initialized = self.on_initialize(state=state, **kwargs) if not isinstance(initialized, bool): raise ValueError( @@ -94,7 +94,7 @@ def finalize(self, state: "State", **kwargs): if not self.initialized_: raise RuntimeError("cannot finalize an uninitialized modifier") - finalized = self.on_finalize(**kwargs) + finalized = self.on_finalize(state=state, **kwargs) if not isinstance(finalized, bool): raise ValueError( @@ -153,10 +153,10 @@ def should_end(self, event: Event): def on_initialize_structure(self, state: "State", **kwargs): raise NotImplementedError() - def on_initialize(self, state: "State", event: Event, **kwargs) -> bool: + def on_initialize(self, state: "State", **kwargs) -> bool: raise NotImplementedError() - def on_finalize(self, state: "State", event: Event, **kwargs) -> bool: + def on_finalize(self, state: "State", **kwargs) -> bool: raise NotImplementedError() def on_start(self, state: "State", event: Event, **kwargs): diff --git a/src/sparseml/core/recipe/args.py b/src/sparseml/core/recipe/args.py index 1fc7edfdbca..559c169d027 100644 --- a/src/sparseml/core/recipe/args.py +++ b/src/sparseml/core/recipe/args.py @@ -49,7 +49,7 @@ def evaluate_ext(self, target: Dict[str, Any]) -> Dict[str, Any]: return resolved @staticmethod - def eval_str(target: str, args: Dict[str, Any] = None) -> str: + def eval_str(target: str, args: Dict[str, Any] = None) -> Union[str,float]: if "eval(" not in target: return target @@ -62,8 +62,10 @@ def eval_str(target: str, args: Dict[str, Any] = None) -> str: inner_expr = match.group(1) result = eval(inner_expr, {"math": math}, args if args else {}) new_target = target.replace(match.group(0), str(result)) - - return RecipeArgs.eval_str(new_target, args) + try: + return float(new_target) + except ValueError: + return RecipeArgs.eval_str(new_target, args) @staticmethod def eval_args(args: Dict[str, Any]) -> "RecipeArgs": diff --git a/src/sparseml/core/recipe/recipe.py b/src/sparseml/core/recipe/recipe.py index 49bcbc17c31..a5bb14a82c5 100644 --- a/src/sparseml/core/recipe/recipe.py +++ b/src/sparseml/core/recipe/recipe.py @@ -73,7 +73,7 @@ def simplify_recipe( simplified = Recipe() simplified.version = version - simplified.args = args + simplified.args = RecipeArgs(args) simplified.stages = stages simplified.evaluate(args=args, shift=shift) @@ -93,14 +93,15 @@ def simplify_combine_recipes( ) combined.version = simplified.version combined.stages.extend(simplified.stages) + combined.args.combine(simplified.args) return combined version: str = None - args: RecipeArgs = None + args: RecipeArgs = Field(default_factory=RecipeArgs) stages: List[RecipeStage] = Field(default_factory=list) metadata: RecipeMetaData = None - args_evaluated: RecipeArgs = None + args_evaluated: RecipeArgs = Field(default_factory=RecipeArgs) def calculate_start(self) -> int: return min( diff --git a/src/sparseml/core/state.py b/src/sparseml/core/state.py index a8bf937bf0a..df79279aaf1 100644 --- a/src/sparseml/core/state.py +++ b/src/sparseml/core/state.py @@ -117,7 +117,7 @@ def update( or batches_per_step is not None ): if self.start_event is None: - self.start_event = Event(type_=EventType.BATCH_START) + self.start_event = Event() if start is not None: self.start_event.current_index = start diff --git a/src/sparseml/modifiers/distillation/output/pytorch.py b/src/sparseml/modifiers/distillation/output/pytorch.py index a75079a78b7..da226566542 100644 --- a/src/sparseml/modifiers/distillation/output/pytorch.py +++ b/src/sparseml/modifiers/distillation/output/pytorch.py @@ -28,7 +28,7 @@ class OutputDistillationModifierPyTorch(OutputDistillationModifier): _wrappers: Dict[str, KDModuleWrapper] = None - def on_initialize(self, state: State, event: Event, **kwargs) -> bool: + def on_initialize(self, state: State, **kwargs) -> bool: if ( state.framework is None or state.model is None @@ -66,7 +66,7 @@ def on_initialize(self, state: State, event: Event, **kwargs) -> bool: return True - def on_finalize(self, state: State, event: Event, **kwargs) -> bool: + def on_finalize(self, state: State, **kwargs) -> bool: for key, wrapper in self._wrappers.items(): state.model.set_layer(key, wrapper.student_layer) del wrapper diff --git a/src/sparseml/modifiers/pruning/constant/pytorch.py b/src/sparseml/modifiers/pruning/constant/pytorch.py index b6e2639c96c..cb5dccb18c5 100644 --- a/src/sparseml/modifiers/pruning/constant/pytorch.py +++ b/src/sparseml/modifiers/pruning/constant/pytorch.py @@ -24,7 +24,7 @@ class ConstantPruningModifierPyTorch(ConstantPruningModifier, LayerParamMasking) _save_masks: bool = False _use_hooks: bool = False - def on_initialize(self, state: State, event: Event, **kwargs) -> bool: + def on_initialize(self, state: State, **kwargs) -> bool: if "save_masks" in kwargs: self._save_masks = kwargs["save_masks"] if "use_hooks" in kwargs: @@ -45,7 +45,7 @@ def on_initialize(self, state: State, event: Event, **kwargs) -> bool: return True - def on_finalize(self, state: State, event: Event, **kwargs) -> bool: + def on_finalize(self, state: State, **kwargs) -> bool: for layer_param_name, _ in self.parameterized_layers_.items(): self.remove_mask(layer_param_name) diff --git a/src/sparseml/modifiers/pruning/magnitude/pytorch.py b/src/sparseml/modifiers/pruning/magnitude/pytorch.py index be8d11a2a35..10d80a312e8 100644 --- a/src/sparseml/modifiers/pruning/magnitude/pytorch.py +++ b/src/sparseml/modifiers/pruning/magnitude/pytorch.py @@ -37,7 +37,7 @@ class MagnitudePruningModifierPyTorch(MagnitudePruningModifier, LayerParamMaskin mask_creator_function_: MaskCreatorType = None current_sparsity_: float = None - def on_initialize(self, state: State, event: Event, **kwargs) -> bool: + def on_initialize(self, state: State, **kwargs) -> bool: if self.apply_globally: raise NotImplementedError("global pruning not implemented yet for PyTorch") @@ -76,7 +76,7 @@ def on_initialize(self, state: State, event: Event, **kwargs) -> bool: return True - def on_finalize(self, state: State, event: Event, **kwargs) -> bool: + def on_finalize(self, state: State, **kwargs) -> bool: if not self.leave_enabled: for layer_param_name, _ in self.parameterized_layers_.items(): self.remove_mask(layer_param_name) From 6c2255f1fa2290981f14b75d0094467533f62ded Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Thu, 21 Sep 2023 19:41:20 -0400 Subject: [PATCH 22/27] Add test --- tests/sparseml/core/recipe/__init__.py | 13 +++++++ tests/sparseml/core/recipe/test_recipe.py | 43 +++++++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 tests/sparseml/core/recipe/__init__.py create mode 100644 tests/sparseml/core/recipe/test_recipe.py diff --git a/tests/sparseml/core/recipe/__init__.py b/tests/sparseml/core/recipe/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/sparseml/core/recipe/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparseml/core/recipe/test_recipe.py b/tests/sparseml/core/recipe/test_recipe.py new file mode 100644 index 00000000000..2e3a9f5c8de --- /dev/null +++ b/tests/sparseml/core/recipe/test_recipe.py @@ -0,0 +1,43 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile + +import yaml + +from sparseml.core.recipe import Recipe + + +def _valid_recipe(): + return """ + test_stage: + pruning_modifiers: + ConstantPruningModifier: + start: 0 + end: 5 + """ + + +def test_recipe_create_instance_accepts_valid_recipe_string(): + test_recipe = _valid_recipe() + recipe = Recipe.create_instance(test_recipe) + assert recipe is not None, "Recipe could not be created from string" + + +def test_recipe_create_instance_accepts_valid_recipe_file(): + content = yaml.safe_load(_valid_recipe()) + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f: + yaml.dump(content, f) + recipe = Recipe.create_instance(f.name) + assert recipe is not None, "Recipe could not be created from file" From abeedb73065b2001ff25b0adb713ef30beafdff4 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Thu, 21 Sep 2023 19:43:46 -0400 Subject: [PATCH 23/27] Add changes to allow accepting strings --- src/sparseml/core/recipe/recipe.py | 41 ++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/src/sparseml/core/recipe/recipe.py b/src/sparseml/core/recipe/recipe.py index a5bb14a82c5..f9777044c23 100644 --- a/src/sparseml/core/recipe/recipe.py +++ b/src/sparseml/core/recipe/recipe.py @@ -33,12 +33,25 @@ class Recipe(RecipeBase): @staticmethod def create_instance(path: str) -> "Recipe": + """ + Create a recipe instance from a file, or string + + :param path: The path to the recipe file or + SparseZoo stub or the recipe string, must be a valid + json/yaml file or a valid json/yaml string + """ if not os.path.isfile(path): - # not a local file, load from SparseZoo - raise NotImplementedError() + # not a local file + if path.startswith("zoo:"): + # download from SparseZoo + raise NotImplementedError("Using SparseZoo stubs is not yet supported") + else: + # assume it's a string + obj = _load_json_or_yaml_string(path) + return Recipe.parse_obj(obj) with open(path, "r") as file: - content = file.read() + content = file.read().strip() if path.lower().endswith(".json"): obj = json.loads(content) @@ -46,13 +59,9 @@ def create_instance(path: str) -> "Recipe": obj = yaml.safe_load(content) else: try: - obj = json.loads(content) - except json.JSONDecodeError: - try: - obj = yaml.safe_load(content) - except yaml.YAMLError: - raise ValueError(f"Could not parse recipe from path {path}") - + obj = _load_json_or_yaml_string(content) + except ValueError: + raise ValueError(f"Could not parse recipe from path {path}") return Recipe.parse_obj(obj) @staticmethod @@ -233,3 +242,15 @@ class RecipeTuple: recipe: Recipe target_stages: List[str] override_args: Dict[str, Any] + + +def _load_json_or_yaml_string(content: str) -> Dict[str, Any]: + # try loading as json first, then yaml + # if both fail, raise a ValueError + try: + return json.loads(content) + except json.JSONDecodeError: + try: + return yaml.safe_load(content) + except yaml.YAMLError as err: + raise ValueError(f"Could not parse recipe from string {content}") from err From 571d21d289c463a1a7383406992764d7cc4b0cef Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Fri, 22 Sep 2023 11:45:51 -0400 Subject: [PATCH 24/27] fix recipe staging issue --- src/sparseml/core/recipe/container.py | 9 ++++++--- src/sparseml/core/session.py | 6 +++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/sparseml/core/recipe/container.py b/src/sparseml/core/recipe/container.py index fc828796bc6..aede9a663d2 100644 --- a/src/sparseml/core/recipe/container.py +++ b/src/sparseml/core/recipe/container.py @@ -44,7 +44,7 @@ class RecipeContainer: def update( self, recipe: Union[str, List[str], Recipe, List[Recipe]] = None, - recipe_stage: Union[str, List[str]] = None, + recipe_stage: Union[str, List[str], List[List[str]]] = None, recipe_args: Union[Dict[str, Any], List[Dict[str, Any]]] = None, **kwargs, ) -> Dict: @@ -55,8 +55,11 @@ def update( recipe = [recipe] if recipe_stage is None: recipe_stage = [None] * len(recipe) - elif not isinstance(recipe_stage, list): - recipe_stage = [recipe_stage] * len(recipe) + else: + if not isinstance(recipe_stage, list): + recipe_stage = [[recipe_stage]] * len(recipe) + if not isinstance(recipe_stage[0], list): + recipe_stage = [recipe_stage] * len(recipe) if recipe_args is None: recipe_args = [{}] * len(recipe) diff --git a/src/sparseml/core/session.py b/src/sparseml/core/session.py index c862ac3f45e..f6b1670e028 100644 --- a/src/sparseml/core/session.py +++ b/src/sparseml/core/session.py @@ -86,7 +86,7 @@ def initialize( self, framework: Framework = None, recipe: Union[str, List[str], "Recipe", List["Recipe"]] = None, - recipe_stage: str = None, + recipe_stage: Union[str, List[str]] = None, recipe_args: Dict[str, Any] = None, model: Any = None, teacher_model: Any = None, @@ -191,7 +191,7 @@ def pre_initialize_structure(**kwargs): def initialize( framework: Framework = None, recipe: Union[str, List[str], "Recipe", List["Recipe"]] = None, - recipe_stage: str = None, + recipe_stage: Union[str, List[str]] = None, recipe_args: Dict[str, Any] = None, model: Any = None, teacher_model: Any = None, @@ -235,7 +235,7 @@ def finalize(**kwargs) -> ModifiedState: def apply( framework: Framework = None, recipe: Union[str, List[str], "Recipe", List["Recipe"]] = None, - recipe_stage: str = None, + recipe_stage: Union[str, List[str]] = None, recipe_args: Dict[str, Any] = None, model: Any = None, teacher_model: Any = None, From 952e4ee5f331c9cd09d89f5946934b36cac71ba0 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Fri, 22 Sep 2023 12:09:40 -0400 Subject: [PATCH 25/27] style --- src/sparseml/core/framework.py | 1 + src/sparseml/core/model/pytorch.py | 6 ++++-- src/sparseml/core/modifier/base.py | 1 + src/sparseml/core/modifier/modifier.py | 1 + src/sparseml/core/optimizer/base.py | 2 +- src/sparseml/core/recipe/args.py | 2 +- src/sparseml/core/state.py | 11 ++++++----- src/sparseml/integrations/torchvision/__init__.py | 13 +++++++++++++ .../integrations/torchvision/data/__init__.py | 13 +++++++++++++ src/sparseml/integrations/torchvision/evaluator.py | 13 +++++++++++++ .../integrations/torchvision/metrics/__init__.py | 13 +++++++++++++ .../integrations/torchvision/model/__init__.py | 13 +++++++++++++ .../integrations/torchvision/optim/__init__.py | 13 +++++++++++++ src/sparseml/integrations/torchvision/trainer.py | 13 +++++++++++++ .../modifiers/pruning/magnitude/pytorch.py | 1 - .../modifiers/pruning/utils/pytorch/layer_mask.py | 2 +- src/sparseml/utils/pytorch/module.py | 14 +++++++------- 17 files changed, 114 insertions(+), 18 deletions(-) diff --git a/src/sparseml/core/framework.py b/src/sparseml/core/framework.py index 5e9aaf93d7d..9dce24613d3 100644 --- a/src/sparseml/core/framework.py +++ b/src/sparseml/core/framework.py @@ -15,6 +15,7 @@ import importlib from enum import Enum + from pydantic import ValidationError diff --git a/src/sparseml/core/model/pytorch.py b/src/sparseml/core/model/pytorch.py index ca17612b8b2..2044454e797 100644 --- a/src/sparseml/core/model/pytorch.py +++ b/src/sparseml/core/model/pytorch.py @@ -20,11 +20,11 @@ from sparseml.utils.pytorch import ( get_layer, get_layers, + get_layers_params, get_param, get_params, set_layer, set_param, - get_layers_params ) @@ -35,7 +35,9 @@ class ModifiableModelPyTorch(ModifiableModel[Module, Module, Parameter]): def __init__(self, framework=None, model=None): super().__init__(framework=framework, model=model) - def get_layers_params(self, targets: Union[str, List[str]]) -> Dict[str, ModelParameterizedLayer[Module, Parameter]]: + def get_layers_params( + self, targets: Union[str, List[str]] + ) -> Dict[str, ModelParameterizedLayer[Module, Parameter]]: return get_layers_params(targets, self.model) def get_layers(self, targets: Union[str, List[str]]) -> Dict[str, Module]: diff --git a/src/sparseml/core/modifier/base.py b/src/sparseml/core/modifier/base.py index ad6197fc303..da07485d0e2 100644 --- a/src/sparseml/core/modifier/base.py +++ b/src/sparseml/core/modifier/base.py @@ -15,6 +15,7 @@ from abc import ABC, abstractmethod + __all__ = ["ModifierInterface"] diff --git a/src/sparseml/core/modifier/modifier.py b/src/sparseml/core/modifier/modifier.py index 82ee423d87a..5e3a7979b61 100644 --- a/src/sparseml/core/modifier/modifier.py +++ b/src/sparseml/core/modifier/modifier.py @@ -21,6 +21,7 @@ from sparseml.core.framework_object import MultiFrameworkObject from sparseml.core.modifier.base import ModifierInterface + __all__ = ["Modifier"] diff --git a/src/sparseml/core/optimizer/base.py b/src/sparseml/core/optimizer/base.py index cebef996259..59cd78b963b 100644 --- a/src/sparseml/core/optimizer/base.py +++ b/src/sparseml/core/optimizer/base.py @@ -28,7 +28,7 @@ @dataclass class ModifiableOptimizer(Generic[OT, PGT], MultiFrameworkObject): optimizer: OT = None - + def __init__(self, optimizer=None, attach_optim_callbacks=False, framework=None): self.optimizer = optimizer diff --git a/src/sparseml/core/recipe/args.py b/src/sparseml/core/recipe/args.py index 559c169d027..c5aa7d2c300 100644 --- a/src/sparseml/core/recipe/args.py +++ b/src/sparseml/core/recipe/args.py @@ -49,7 +49,7 @@ def evaluate_ext(self, target: Dict[str, Any]) -> Dict[str, Any]: return resolved @staticmethod - def eval_str(target: str, args: Dict[str, Any] = None) -> Union[str,float]: + def eval_str(target: str, args: Dict[str, Any] = None) -> Union[str, float]: if "eval(" not in target: return target diff --git a/src/sparseml/core/state.py b/src/sparseml/core/state.py index df79279aaf1..b700e9c422d 100644 --- a/src/sparseml/core/state.py +++ b/src/sparseml/core/state.py @@ -19,11 +19,10 @@ from pydantic import Field from sparseml.core.data import ModifiableData -from sparseml.core.event import Event +from sparseml.core.event import Event, EventType from sparseml.core.framework import Framework from sparseml.core.model import ModifiableModel from sparseml.core.optimizer import ModifiableOptimizer -from sparseml.core.event import EventType __all__ = ["State", "Data", "Hardware", "ModifiedState"] @@ -69,8 +68,8 @@ def sparsification_ready(self) -> bool: return ( self.model is not None and self.optimizer is not None - #and self.loss is not None - #and self.batch_data is not None + # and self.loss is not None + # and self.batch_data is not None ) def update( @@ -109,7 +108,9 @@ def update( self.data.test = test_data if not copy_data else deepcopy(test_data) if calib_data is not None: calib_loader = calib_data if not copy_data else deepcopy(calib_data) - self.calib_data = ModifiableData(framework=self.framework, data_loader=calib_loader) + self.calib_data = ModifiableData( + framework=self.framework, data_loader=calib_loader + ) if ( start is not None diff --git a/src/sparseml/integrations/torchvision/__init__.py b/src/sparseml/integrations/torchvision/__init__.py index e69de29bb2d..0c44f887a47 100644 --- a/src/sparseml/integrations/torchvision/__init__.py +++ b/src/sparseml/integrations/torchvision/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/integrations/torchvision/data/__init__.py b/src/sparseml/integrations/torchvision/data/__init__.py index e69de29bb2d..0c44f887a47 100644 --- a/src/sparseml/integrations/torchvision/data/__init__.py +++ b/src/sparseml/integrations/torchvision/data/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/integrations/torchvision/evaluator.py b/src/sparseml/integrations/torchvision/evaluator.py index e69de29bb2d..0c44f887a47 100644 --- a/src/sparseml/integrations/torchvision/evaluator.py +++ b/src/sparseml/integrations/torchvision/evaluator.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/integrations/torchvision/metrics/__init__.py b/src/sparseml/integrations/torchvision/metrics/__init__.py index e69de29bb2d..0c44f887a47 100644 --- a/src/sparseml/integrations/torchvision/metrics/__init__.py +++ b/src/sparseml/integrations/torchvision/metrics/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/integrations/torchvision/model/__init__.py b/src/sparseml/integrations/torchvision/model/__init__.py index e69de29bb2d..0c44f887a47 100644 --- a/src/sparseml/integrations/torchvision/model/__init__.py +++ b/src/sparseml/integrations/torchvision/model/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/integrations/torchvision/optim/__init__.py b/src/sparseml/integrations/torchvision/optim/__init__.py index e69de29bb2d..0c44f887a47 100644 --- a/src/sparseml/integrations/torchvision/optim/__init__.py +++ b/src/sparseml/integrations/torchvision/optim/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/integrations/torchvision/trainer.py b/src/sparseml/integrations/torchvision/trainer.py index e69de29bb2d..0c44f887a47 100644 --- a/src/sparseml/integrations/torchvision/trainer.py +++ b/src/sparseml/integrations/torchvision/trainer.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/modifiers/pruning/magnitude/pytorch.py b/src/sparseml/modifiers/pruning/magnitude/pytorch.py index 10d80a312e8..c0df01ddc70 100644 --- a/src/sparseml/modifiers/pruning/magnitude/pytorch.py +++ b/src/sparseml/modifiers/pruning/magnitude/pytorch.py @@ -135,4 +135,3 @@ def _update_masks(self, event: Event): elif event.type_ == EventType.OPTIM_POST_STEP and not self._use_hooks: for layer_param_name, _ in self.parameterized_layers_.items(): self.apply_mask_weight(layer_param_name) - diff --git a/src/sparseml/modifiers/pruning/utils/pytorch/layer_mask.py b/src/sparseml/modifiers/pruning/utils/pytorch/layer_mask.py index 9c1a8eefe39..33756a07596 100644 --- a/src/sparseml/modifiers/pruning/utils/pytorch/layer_mask.py +++ b/src/sparseml/modifiers/pruning/utils/pytorch/layer_mask.py @@ -16,11 +16,11 @@ from typing import Dict import torch +from pydantic import BaseModel from torch.nn import Module, Parameter from torch.utils.hooks import RemovableHandle from sparseml.core import ModelParameterizedLayer -from pydantic import BaseModel __all__ = ["LayerParamMasking"] diff --git a/src/sparseml/utils/pytorch/module.py b/src/sparseml/utils/pytorch/module.py index fa18a5a8736..05fae4a174a 100644 --- a/src/sparseml/utils/pytorch/module.py +++ b/src/sparseml/utils/pytorch/module.py @@ -23,6 +23,7 @@ from packaging import version from torch.nn import Linear, Module, Parameter from torch.nn.modules.conv import _ConvNd + from sparseml.core.model.base import ModelParameterizedLayer @@ -64,7 +65,7 @@ "get_terminal_layers", "get_prunable_layers", "get_quantizable_layers", - "get_layers_params" + "get_layers_params", ] @@ -239,19 +240,18 @@ def get_quantizable_layers(module: Module) -> Dict[str, Module]: return quantizable -def get_layers_params(targets: Union[str, List[str]], module: Module) -> Dict[str, ModelParameterizedLayer[Parameter, Module]]: + +def get_layers_params( + targets: Union[str, List[str]], module: Module +) -> Dict[str, ModelParameterizedLayer[Parameter, Module]]: params = get_params(targets, module) layers = get_layers(targets, module) parameterized_layers = {} for name, param in params.items(): param_layer = ModelParameterizedLayer( - layer_name=name, - layer=layers[name], - param_name=name, - param=param + layer_name=name, layer=layers[name], param_name=name, param=param ) parameterized_layers[name] = param_layer return parameterized_layers - \ No newline at end of file From ed8e0ba22d077287fafa190420f1f5cd0973f559 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Fri, 22 Sep 2023 14:27:29 -0400 Subject: [PATCH 26/27] style fixes --- src/sparseml/core/__init__.py | 2 ++ src/sparseml/core/data/__init__.py | 2 ++ src/sparseml/core/data/pytorch.py | 3 ++- src/sparseml/core/factory.py | 4 ++-- src/sparseml/core/framework.py | 3 --- src/sparseml/core/lifecycle/__init__.py | 2 ++ src/sparseml/core/lifecycle/session.py | 6 +---- src/sparseml/core/model/__init__.py | 2 ++ src/sparseml/core/modifier/__init__.py | 2 ++ src/sparseml/core/modifier/base.py | 11 +++++---- src/sparseml/core/modifier/modifier.py | 23 ++++++++++--------- src/sparseml/core/optimizer/__init__.py | 2 ++ src/sparseml/core/optimizer/base.py | 3 --- src/sparseml/core/recipe/__init__.py | 2 ++ src/sparseml/core/recipe/recipe.py | 3 ++- src/sparseml/core/state.py | 2 +- src/sparseml/modifiers/__init__.py | 2 ++ .../modifiers/distillation/__init__.py | 2 ++ .../modifiers/distillation/output/__init__.py | 2 ++ .../modifiers/distillation/output/pytorch.py | 2 +- .../distillation/utils/pytorch/__init__.py | 2 ++ .../distillation/utils/pytorch/kd_factory.py | 2 +- src/sparseml/modifiers/pruning/__init__.py | 2 ++ .../modifiers/pruning/constant/__init__.py | 2 ++ src/sparseml/modifiers/pruning/helpers.py | 2 -- .../modifiers/pruning/magnitude/__init__.py | 2 ++ .../pruning/utils/pytorch/__init__.py | 2 ++ .../pruning/utils/pytorch/layer_mask.py | 2 +- src/sparseml/utils/pytorch/__init__.py | 2 ++ .../utils/pytorch/pruning/__init__.py | 2 ++ 30 files changed, 64 insertions(+), 36 deletions(-) diff --git a/src/sparseml/core/__init__.py b/src/sparseml/core/__init__.py index d206c4f8cc3..3a6fc72cb57 100644 --- a/src/sparseml/core/__init__.py +++ b/src/sparseml/core/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa + from .data import * from .event import * from .factory import * diff --git a/src/sparseml/core/data/__init__.py b/src/sparseml/core/data/__init__.py index 1101a7fa8ea..01f2aba2015 100644 --- a/src/sparseml/core/data/__init__.py +++ b/src/sparseml/core/data/__init__.py @@ -12,4 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa + from .base import ModifiableData diff --git a/src/sparseml/core/data/pytorch.py b/src/sparseml/core/data/pytorch.py index bbc890f1d7f..c40a5f0cbd8 100644 --- a/src/sparseml/core/data/pytorch.py +++ b/src/sparseml/core/data/pytorch.py @@ -81,7 +81,8 @@ def _data_merge_iter(self): @staticmethod def split_batch(batch, start_idx, end_idx): """ - Splits a batch based on its type (Tensor, Mapping, Sequence) and the provided indices. + Splits a batch based on its type (Tensor, Mapping, Sequence) and the provided + indices. """ if isinstance(batch, torch.Tensor): return batch[start_idx:end_idx] diff --git a/src/sparseml/core/factory.py b/src/sparseml/core/factory.py index ab71cb8f85d..acb5d042759 100644 --- a/src/sparseml/core/factory.py +++ b/src/sparseml/core/factory.py @@ -123,9 +123,9 @@ def create( def register(type_: str, modifier_class: Type[Modifier]): if not issubclass(modifier_class, Modifier): raise ValueError( - f"The provided class does not subclass the Modifier base class." + "The provided class does not subclass the Modifier base class." ) if not isinstance(modifier_class, type): - raise ValueError(f"The provided class is not a type.") + raise ValueError("The provided class is not a type.") ModifierFactory._registered_registry[type_] = modifier_class diff --git a/src/sparseml/core/framework.py b/src/sparseml/core/framework.py index 9dce24613d3..fe3119a108b 100644 --- a/src/sparseml/core/framework.py +++ b/src/sparseml/core/framework.py @@ -13,11 +13,8 @@ # limitations under the License. -import importlib from enum import Enum -from pydantic import ValidationError - __all__ = ["Framework"] diff --git a/src/sparseml/core/lifecycle/__init__.py b/src/sparseml/core/lifecycle/__init__.py index 581cb06e687..908c4ccfb3b 100644 --- a/src/sparseml/core/lifecycle/__init__.py +++ b/src/sparseml/core/lifecycle/__init__.py @@ -12,5 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa + from .event import * from .session import * diff --git a/src/sparseml/core/lifecycle/session.py b/src/sparseml/core/lifecycle/session.py index c7229d23ca6..80f535b3c16 100644 --- a/src/sparseml/core/lifecycle/session.py +++ b/src/sparseml/core/lifecycle/session.py @@ -17,11 +17,7 @@ from sparseml.core.event import EventType from sparseml.core.framework import Framework -from sparseml.core.lifecycle.event import ( - CallbacksEventLifecycle, - EventLifecycle, - WrappedOptimEventLifecycle, -) +from sparseml.core.lifecycle.event import CallbacksEventLifecycle, EventLifecycle from sparseml.core.modifier import ModifierInterface from sparseml.core.recipe import RecipeContainer from sparseml.core.state import State diff --git a/src/sparseml/core/model/__init__.py b/src/sparseml/core/model/__init__.py index 81ade568d8d..1e9d47367ad 100644 --- a/src/sparseml/core/model/__init__.py +++ b/src/sparseml/core/model/__init__.py @@ -12,4 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa + from .base import ModelParameterizedLayer, ModifiableModel diff --git a/src/sparseml/core/modifier/__init__.py b/src/sparseml/core/modifier/__init__.py index b205e585dbf..1b1a9f03d25 100644 --- a/src/sparseml/core/modifier/__init__.py +++ b/src/sparseml/core/modifier/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa + from .base import * from .modifier import * from .stage import * diff --git a/src/sparseml/core/modifier/base.py b/src/sparseml/core/modifier/base.py index da07485d0e2..109b3b933f1 100644 --- a/src/sparseml/core/modifier/base.py +++ b/src/sparseml/core/modifier/base.py @@ -15,6 +15,9 @@ from abc import ABC, abstractmethod +from sparseml.core.event import Event +from sparseml.core.state import State + __all__ = ["ModifierInterface"] @@ -48,17 +51,17 @@ def calculate_end(self) -> float: raise NotImplementedError() @abstractmethod - def pre_initialize_structure(self, state: "State", **kwargs): + def pre_initialize_structure(self, state: State, **kwargs): raise NotImplementedError() @abstractmethod - def initialize(self, state: "State", **kwargs): + def initialize(self, state: State, **kwargs): raise NotImplementedError() @abstractmethod - def finalize(self, state: "State", **kwargs): + def finalize(self, state: State, **kwargs): raise NotImplementedError() @abstractmethod - def update_event(self, state: "State", event: "Event", **kwargs): + def update_event(self, state: State, event: Event, **kwargs): raise NotImplementedError() diff --git a/src/sparseml/core/modifier/modifier.py b/src/sparseml/core/modifier/modifier.py index 5e3a7979b61..3df60db6247 100644 --- a/src/sparseml/core/modifier/modifier.py +++ b/src/sparseml/core/modifier/modifier.py @@ -20,6 +20,7 @@ from sparseml.core.event import Event, EventType from sparseml.core.framework_object import MultiFrameworkObject from sparseml.core.modifier.base import ModifierInterface +from sparseml.core.state import State __all__ = ["Modifier"] @@ -60,11 +61,11 @@ def calculate_start(self) -> float: def calculate_end(self) -> float: return self.end if self.end is not None else -1 - def pre_initialize_structure(self, state: "State", **kwargs): + def pre_initialize_structure(self, state: State, **kwargs): self.on_initialize_structure(state, **kwargs) self.initialized_structure_ = True - def initialize(self, state: "State", **kwargs): + def initialize(self, state: State, **kwargs): if self.initialized_: return @@ -88,7 +89,7 @@ def initialize(self, state: "State", **kwargs): self.on_start(state, state.start_event, **kwargs) self.started_ = True - def finalize(self, state: "State", **kwargs): + def finalize(self, state: State, **kwargs): if self.finalized_: return @@ -105,7 +106,7 @@ def finalize(self, state: "State", **kwargs): self.finalized_ = finalized - def update_event(self, state: "State", event: Event, **kwargs): + def update_event(self, state: State, event: Event, **kwargs): if not self.initialized_: raise RuntimeError("cannot update an uninitialized modifier") @@ -151,23 +152,23 @@ def should_end(self, event: Event): return self.end is not None and current >= self.end - def on_initialize_structure(self, state: "State", **kwargs): + def on_initialize_structure(self, state: State, **kwargs): raise NotImplementedError() - def on_initialize(self, state: "State", **kwargs) -> bool: + def on_initialize(self, state: State, **kwargs) -> bool: raise NotImplementedError() - def on_finalize(self, state: "State", **kwargs) -> bool: + def on_finalize(self, state: State, **kwargs) -> bool: raise NotImplementedError() - def on_start(self, state: "State", event: Event, **kwargs): + def on_start(self, state: State, event: Event, **kwargs): raise NotImplementedError() - def on_update(self, state: "State", event: Event, **kwargs): + def on_update(self, state: State, event: Event, **kwargs): raise NotImplementedError() - def on_end(self, state: "State", event: Event, **kwargs): + def on_end(self, state: State, event: Event, **kwargs): raise NotImplementedError() - def on_event(self, state: "State", event: Event, **kwargs): + def on_event(self, state: State, event: Event, **kwargs): pass diff --git a/src/sparseml/core/optimizer/__init__.py b/src/sparseml/core/optimizer/__init__.py index 6ded41b5440..a1ddc131b96 100644 --- a/src/sparseml/core/optimizer/__init__.py +++ b/src/sparseml/core/optimizer/__init__.py @@ -12,4 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa + from .base import ModifiableOptimizer diff --git a/src/sparseml/core/optimizer/base.py b/src/sparseml/core/optimizer/base.py index 59cd78b963b..41a4238ec1e 100644 --- a/src/sparseml/core/optimizer/base.py +++ b/src/sparseml/core/optimizer/base.py @@ -29,9 +29,6 @@ class ModifiableOptimizer(Generic[OT, PGT], MultiFrameworkObject): optimizer: OT = None - def __init__(self, optimizer=None, attach_optim_callbacks=False, framework=None): - self.optimizer = optimizer - def __init__(self, optimizer=None, attach_optim_callbacks=False, framework=None): self.optimizer = optimizer diff --git a/src/sparseml/core/recipe/__init__.py b/src/sparseml/core/recipe/__init__.py index 979d6f9d28a..4da363104e9 100644 --- a/src/sparseml/core/recipe/__init__.py +++ b/src/sparseml/core/recipe/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa + from .args import * from .base import * from .container import * diff --git a/src/sparseml/core/recipe/recipe.py b/src/sparseml/core/recipe/recipe.py index f9777044c23..43416f07096 100644 --- a/src/sparseml/core/recipe/recipe.py +++ b/src/sparseml/core/recipe/recipe.py @@ -15,12 +15,13 @@ import json import os from dataclasses import dataclass -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Union import yaml from pydantic import Field, root_validator from sparseml.core.framework import Framework +from sparseml.core.modifier import StageModifiers from sparseml.core.recipe.args import RecipeArgs from sparseml.core.recipe.base import RecipeBase from sparseml.core.recipe.metadata import RecipeMetaData diff --git a/src/sparseml/core/state.py b/src/sparseml/core/state.py index b700e9c422d..1376ae50b2c 100644 --- a/src/sparseml/core/state.py +++ b/src/sparseml/core/state.py @@ -19,7 +19,7 @@ from pydantic import Field from sparseml.core.data import ModifiableData -from sparseml.core.event import Event, EventType +from sparseml.core.event import Event from sparseml.core.framework import Framework from sparseml.core.model import ModifiableModel from sparseml.core.optimizer import ModifiableOptimizer diff --git a/src/sparseml/modifiers/__init__.py b/src/sparseml/modifiers/__init__.py index 737bb4ed07e..de33872de9b 100644 --- a/src/sparseml/modifiers/__init__.py +++ b/src/sparseml/modifiers/__init__.py @@ -12,5 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa + from .distillation import * from .pruning import * diff --git a/src/sparseml/modifiers/distillation/__init__.py b/src/sparseml/modifiers/distillation/__init__.py index 6a7699b2f53..694a5120f3c 100644 --- a/src/sparseml/modifiers/distillation/__init__.py +++ b/src/sparseml/modifiers/distillation/__init__.py @@ -12,4 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa + from .output import * diff --git a/src/sparseml/modifiers/distillation/output/__init__.py b/src/sparseml/modifiers/distillation/output/__init__.py index 87930811c41..9cdf715c135 100644 --- a/src/sparseml/modifiers/distillation/output/__init__.py +++ b/src/sparseml/modifiers/distillation/output/__init__.py @@ -12,4 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa + from .base import * diff --git a/src/sparseml/modifiers/distillation/output/pytorch.py b/src/sparseml/modifiers/distillation/output/pytorch.py index da226566542..8d4c6325efa 100644 --- a/src/sparseml/modifiers/distillation/output/pytorch.py +++ b/src/sparseml/modifiers/distillation/output/pytorch.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Tuple, Union +from typing import Dict import torch from torch.nn import Module diff --git a/src/sparseml/modifiers/distillation/utils/pytorch/__init__.py b/src/sparseml/modifiers/distillation/utils/pytorch/__init__.py index 4fb62d86716..8337aa66702 100644 --- a/src/sparseml/modifiers/distillation/utils/pytorch/__init__.py +++ b/src/sparseml/modifiers/distillation/utils/pytorch/__init__.py @@ -12,5 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa + from .kd_factory import * from .kd_wrapper import * diff --git a/src/sparseml/modifiers/distillation/utils/pytorch/kd_factory.py b/src/sparseml/modifiers/distillation/utils/pytorch/kd_factory.py index 23a321b02ad..40b81aba7e0 100644 --- a/src/sparseml/modifiers/distillation/utils/pytorch/kd_factory.py +++ b/src/sparseml/modifiers/distillation/utils/pytorch/kd_factory.py @@ -13,7 +13,7 @@ # limitations under the License. import re -from typing import Callable, Dict, List, Sequence, Tuple, Union +from typing import Callable, Dict, Sequence, Tuple, Union import torch import torch.nn.functional as TF diff --git a/src/sparseml/modifiers/pruning/__init__.py b/src/sparseml/modifiers/pruning/__init__.py index 522fd000e3f..cff59b0e9be 100644 --- a/src/sparseml/modifiers/pruning/__init__.py +++ b/src/sparseml/modifiers/pruning/__init__.py @@ -12,5 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa + from .constant import * from .magnitude import * diff --git a/src/sparseml/modifiers/pruning/constant/__init__.py b/src/sparseml/modifiers/pruning/constant/__init__.py index 03ee625d7d4..d283e9eb9c1 100644 --- a/src/sparseml/modifiers/pruning/constant/__init__.py +++ b/src/sparseml/modifiers/pruning/constant/__init__.py @@ -12,4 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa + from .base import ConstantPruningModifier diff --git a/src/sparseml/modifiers/pruning/helpers.py b/src/sparseml/modifiers/pruning/helpers.py index d1703dc8d3b..7f3727c94aa 100644 --- a/src/sparseml/modifiers/pruning/helpers.py +++ b/src/sparseml/modifiers/pruning/helpers.py @@ -17,8 +17,6 @@ from dataclasses import dataclass from typing import Any, Callable, Dict -import numpy as np - from sparseml.core import Event, State diff --git a/src/sparseml/modifiers/pruning/magnitude/__init__.py b/src/sparseml/modifiers/pruning/magnitude/__init__.py index 78cd427840b..3d8279274de 100644 --- a/src/sparseml/modifiers/pruning/magnitude/__init__.py +++ b/src/sparseml/modifiers/pruning/magnitude/__init__.py @@ -12,4 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa + from .base import MagnitudePruningModifier diff --git a/src/sparseml/modifiers/pruning/utils/pytorch/__init__.py b/src/sparseml/modifiers/pruning/utils/pytorch/__init__.py index a7bb161fee9..091133f12c8 100644 --- a/src/sparseml/modifiers/pruning/utils/pytorch/__init__.py +++ b/src/sparseml/modifiers/pruning/utils/pytorch/__init__.py @@ -12,5 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa + from .layer_mask import * from .mask_factory import * diff --git a/src/sparseml/modifiers/pruning/utils/pytorch/layer_mask.py b/src/sparseml/modifiers/pruning/utils/pytorch/layer_mask.py index 33756a07596..b68a2494a4e 100644 --- a/src/sparseml/modifiers/pruning/utils/pytorch/layer_mask.py +++ b/src/sparseml/modifiers/pruning/utils/pytorch/layer_mask.py @@ -41,7 +41,7 @@ def setup_mask_for_param(param: Parameter, mask: torch.Tensor) -> torch.Tensor: ) if mask.dtype != torch.bool: - raise ValueError(f"Mask must be a boolean tensor") + raise ValueError("Mask must be a boolean tensor") return param.data.new_tensor(mask, dtype=torch.bool) diff --git a/src/sparseml/utils/pytorch/__init__.py b/src/sparseml/utils/pytorch/__init__.py index 880ecd996e3..10c86104af1 100644 --- a/src/sparseml/utils/pytorch/__init__.py +++ b/src/sparseml/utils/pytorch/__init__.py @@ -12,4 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa + from .module import * diff --git a/src/sparseml/utils/pytorch/pruning/__init__.py b/src/sparseml/utils/pytorch/pruning/__init__.py index c89c8da17de..e2da92c4d14 100644 --- a/src/sparseml/utils/pytorch/pruning/__init__.py +++ b/src/sparseml/utils/pytorch/pruning/__init__.py @@ -12,5 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa + from .layer_mask import * from .mask import * From bfd7f84d396a3ccd47082907f8c47f215c14fcf0 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 26 Sep 2023 16:03:24 -0400 Subject: [PATCH 27/27] bug fixes that came up during obcq implementation --- src/sparseml/core/modifier/modifier.py | 3 +++ src/sparseml/core/state.py | 11 ++++++----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/sparseml/core/modifier/modifier.py b/src/sparseml/core/modifier/modifier.py index 3df60db6247..262d37fda3a 100644 --- a/src/sparseml/core/modifier/modifier.py +++ b/src/sparseml/core/modifier/modifier.py @@ -143,6 +143,9 @@ def update_event(self, state: State, event: Event, **kwargs): self.on_update(state, event, **kwargs) def should_start(self, event: Event): + if not self.start: + return False + current = event.current_index return self.start <= current and (self.end is None or current < self.end) diff --git a/src/sparseml/core/state.py b/src/sparseml/core/state.py index 1376ae50b2c..351dffba540 100644 --- a/src/sparseml/core/state.py +++ b/src/sparseml/core/state.py @@ -19,7 +19,7 @@ from pydantic import Field from sparseml.core.data import ModifiableData -from sparseml.core.event import Event +from sparseml.core.event import Event, EventType from sparseml.core.framework import Framework from sparseml.core.model import ModifiableModel from sparseml.core.optimizer import ModifiableOptimizer @@ -107,10 +107,11 @@ def update( if test_data is not None: self.data.test = test_data if not copy_data else deepcopy(test_data) if calib_data is not None: - calib_loader = calib_data if not copy_data else deepcopy(calib_data) - self.calib_data = ModifiableData( - framework=self.framework, data_loader=calib_loader - ) + self.data.calib = calib_data if not copy_data else deepcopy(calib_data) + + if "device" in kwargs: + self.hardware.device = kwargs["device"] + self.model.model.to(self.hardware.device) if ( start is not None