Skip to content

Commit

Permalink
fix import errors and multiframework inits
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins authored and markurtz committed Sep 14, 2023
1 parent f04ca6f commit bc73e15
Show file tree
Hide file tree
Showing 19 changed files with 56 additions and 65 deletions.
10 changes: 1 addition & 9 deletions src/sparseml/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .data import *
from .event import *
from .framework import *
from .model import *
from .modifier import *
from .optimizer import *
from .recipe import *
from .session import *
from .state import *
from .session import *
2 changes: 1 addition & 1 deletion src/sparseml/core/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import *
from .base import ModifiableData
2 changes: 1 addition & 1 deletion src/sparseml/core/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __new__(
if cls is MultiFrameworkObject:
raise TypeError("MultiFrameworkObject cannot be instantiated directly")

instance = super(MultiFrameworkObject, cls).__new__(cls, **kwargs)
instance = super(MultiFrameworkObject, cls).__new__(cls)

package = instance.__class__.__module__.rsplit(".", 1)[0]
class_name = instance.__class__.__name__
Expand Down
2 changes: 1 addition & 1 deletion src/sparseml/core/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import *
from .base import ModifiableModel, ModelParameterizedLayer
3 changes: 3 additions & 0 deletions src/sparseml/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class ModelParameterizedLayer(Generic[LT, PT]):
class ModifiableModel(Generic[MT, LT, PT], MultiFrameworkObject):
model: MT = None

def __init__(self, framework=None, model=None):
self.model = model

def get_layers_params(
self, targets: Union[str, List[str]]
) -> Dict[str, ModelParameterizedLayer[LT, PT]]:
Expand Down
4 changes: 4 additions & 0 deletions src/sparseml/core/model/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@


class ModifiableModelPyTorch(ModifiableModel[Module, Module, Parameter]):

def __init__(self, framework=None, model=None):
super().__init__(framework=framework, model=model)

def get_layers(self, targets: Union[str, List[str]]) -> Dict[str, Module]:
return get_layers(targets, self.model)

Expand Down
7 changes: 2 additions & 5 deletions src/sparseml/core/modifier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import *
from .factory import *
from .modifier import *
from .recipe import *
from .stage import *
from .stage import StageModifiers
from .factory import ModifierFactory
12 changes: 4 additions & 8 deletions src/sparseml/core/modifier/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@

from abc import ABC, abstractmethod

from sparseml.core.event import Event
from sparseml.core.state import State


__all__ = ["ModifierInterface"]


Expand All @@ -39,17 +35,17 @@ def calculate_end(self) -> float:
raise NotImplementedError()

@abstractmethod
def pre_initialize_structure(self, state: State, **kwargs):
def pre_initialize_structure(self, state: "State", **kwargs):
raise NotImplementedError()

@abstractmethod
def initialize(self, state: State, **kwargs):
def initialize(self, state: "State", **kwargs):
raise NotImplementedError()

@abstractmethod
def finalize(self, state: State, **kwargs):
def finalize(self, state: "State", **kwargs):
raise NotImplementedError()

@abstractmethod
def update_event(self, state: State, event: Event, **kwargs):
def update_event(self, state: "State", event: "Event", **kwargs):
raise NotImplementedError()
5 changes: 1 addition & 4 deletions src/sparseml/core/modifier/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from sparseml.core.framework import Framework
from sparseml.core.modifier.modifier import Modifier


__all__ = ["ModifierFactory"]

Expand All @@ -25,5 +22,5 @@ def refresh():
raise NotImplementedError()

@staticmethod
def create(type_: str, framework: Framework, **kwargs) -> Modifier:
def create(type_: str, framework: "Framework", **kwargs) -> "Modifier":
raise NotImplementedError()
22 changes: 10 additions & 12 deletions src/sparseml/core/modifier/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
from sparseml.core.event import Event, EventType
from sparseml.core.framework import MultiFrameworkObject
from sparseml.core.modifier.base import ModifierInterface
from sparseml.core.state import State


__all__ = ["Modifier"]

Expand Down Expand Up @@ -49,11 +47,11 @@ def calculate_start(self) -> float:
def calculate_end(self) -> float:
return self.end if self.end is not None else -1

def pre_initialize_structure(self, state: State, **kwargs):
def pre_initialize_structure(self, state: "State", **kwargs):
self.on_initialize_structure(state, **kwargs)
self._initialized_structure = True

def initialize(self, state: State, **kwargs):
def initialize(self, state: "State", **kwargs):
if self._initialized:
return

Expand All @@ -77,7 +75,7 @@ def initialize(self, state: State, **kwargs):
self.on_start(state, state.start_event, **kwargs)
self._started = True

def finalize(self, state: State, **kwargs):
def finalize(self, state: "State", **kwargs):
if self._finalized:
return

Expand All @@ -94,7 +92,7 @@ def finalize(self, state: State, **kwargs):

self._finalized = finalized

def update_event(self, state: State, event: Event, **kwargs):
def update_event(self, state: "State", event: Event, **kwargs):
if not self._initialized:
raise RuntimeError("cannot update an uninitialized modifier")

Expand Down Expand Up @@ -138,20 +136,20 @@ def should_end(self, event: Event):

return self.end is not None and current >= self.end

def on_initialize_structure(self, state: State, **kwargs):
def on_initialize_structure(self, state: "State", **kwargs):
raise NotImplementedError()

def on_initialize(self, state: State, event: Event, **kwargs) -> bool:
def on_initialize(self, state: "State", event: Event, **kwargs) -> bool:
raise NotImplementedError()

def on_finalize(self, state: State, event: Event, **kwargs) -> bool:
def on_finalize(self, state: "State", event: Event, **kwargs) -> bool:
raise NotImplementedError()

def on_start(self, state: State, event: Event, **kwargs):
def on_start(self, state: "State", event: Event, **kwargs):
raise NotImplementedError()

def on_update(self, state: State, event: Event, **kwargs):
def on_update(self, state: "State", event: Event, **kwargs):
raise NotImplementedError()

def on_end(self, state: State, event: Event, **kwargs):
def on_end(self, state: "State", event: Event, **kwargs):
raise NotImplementedError()
16 changes: 9 additions & 7 deletions src/sparseml/core/modifier/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from pydantic import BaseModel, Field

from sparseml.core.modifier.base import ModifierInterface
from sparseml.core.modifier.modifier import Modifier
from sparseml.core.state import Event, State

__all__ = [
"StageModifier"
]


class StageModifiers(ModifierInterface, BaseModel):
modifiers: List[Modifier] = Field(default_factory=list)
modifiers: List["Modifier"] = Field(default_factory=list)
index: int = None
group: str = None

Expand All @@ -47,21 +49,21 @@ def calculate_end(self) -> float:
mod.calculate_end() for mod in self.modifiers if mod.calculate_end() >= 0
)

def pre_initialize_structure(self, state: State, **kwargs):
def pre_initialize_structure(self, state: "State", **kwargs):
for modifier in self.modifiers:
modifier.pre_initialize_structure(state, **kwargs)
self._initialized_structure = True

def initialize(self, state: State, **kwargs):
def initialize(self, state: "State", **kwargs):
for modifier in self.modifiers:
modifier.initialize(state, **kwargs)
self._initialized = True

def finalize(self, state: State, **kwargs):
def finalize(self, state: "State", **kwargs):
for modifier in self.modifiers:
modifier.finalize(state, **kwargs)
self._finalized = True

def update_event(self, state: State, event: Event, **kwargs):
def update_event(self, state: "State", event: "Event", **kwargs):
for modifier in self.modifiers:
modifier.update_event(state, event, **kwargs)
2 changes: 1 addition & 1 deletion src/sparseml/core/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import *
from .base import ModifiableOptimizer
3 changes: 3 additions & 0 deletions src/sparseml/core/optimizer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
@dataclass
class ModifiableOptimizer(Generic[OT, PGT], MultiFrameworkObject):
optimizer: OT = None

def __init__(self, optimizer=None, attach_optim_callbacks=False, framework=None):
self.optimizer = optimizer

def get_param_groups(self) -> List[PGT]:
raise NotImplementedError()
Expand Down
3 changes: 3 additions & 0 deletions src/sparseml/core/optimizer/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@


class ModifiableOptimizerPyTorch(ModifiableOptimizer[Optimizer, Dict[str, Any]]):
def __init__(self, optimizer=None, attach_optim_callbacks=False, framework=None):
super().__init__(optimizer=optimizer, attach_optim_callbacks=attach_optim_callbacks, framework=framework)

def get_param_groups(self) -> List[Dict[str, Any]]:
return self.optimizer.param_groups

Expand Down
6 changes: 1 addition & 5 deletions src/sparseml/core/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .args import *
from .metadata import *
from .modifier import *
from .recipe import *
from .stage import *
from .recipe import Recipe
1 change: 0 additions & 1 deletion src/sparseml/core/recipe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from pydantic import BaseModel, root_validator

from sparseml.core.framework import Framework
from sparseml.core.modifier import Modifier, ModifierFactory
from sparseml.core.recipe.args import RecipeArgs


Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/core/recipe/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pydantic import root_validator

from sparseml.core.framework import Framework
from sparseml.core.modifier import Modifier, ModifierFactory
from sparseml.core.modifier import ModifierFactory
from sparseml.core.recipe.args import RecipeArgs
from sparseml.core.recipe.base import RecipeBase

Expand Down Expand Up @@ -56,7 +56,7 @@ def evaluate(self, args: RecipeArgs = None, shift: int = None):
if shift is not None and "end" in self._args_evaluated:
self._args_evaluated["end"] += shift

def create_modifier(self, framework: Framework) -> Modifier:
def create_modifier(self, framework: Framework) -> "Modifier":
return ModifierFactory.create(self.type, framework, **self._args_evaluated)

@root_validator(pre=True)
Expand Down
14 changes: 6 additions & 8 deletions src/sparseml/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
WrappedOptimEventLifecycle,
)
from sparseml.core.framework import Framework
from sparseml.core.modifier import StageModifiers
from sparseml.core.recipe import Recipe
from sparseml.core.state import ModifiedState, State


Expand Down Expand Up @@ -52,7 +50,7 @@ class _CallbackContainer:
class SparseSession:
def __init__(self):
self._state: State = State()
self._modifiers: List[StageModifiers] = []
self._modifiers: List["StageModifiers"] = []
self._initialized_structure = False
self._initialized = False
self._finalized = False
Expand All @@ -63,7 +61,7 @@ def state(self) -> State:
return self._state

@property
def modifiers(self) -> List[StageModifiers]:
def modifiers(self) -> List["StageModifiers"]:
return self._modifiers

@property
Expand All @@ -85,7 +83,7 @@ def event_called(self) -> bool:
def pre_initialize_structure(
self,
model: Any,
recipe: Union[Recipe, List[Recipe]],
recipe: Union["Recipe", List["Recipe"]],
framework: Framework = None,
**kwargs,
) -> ModifiedState:
Expand Down Expand Up @@ -113,7 +111,7 @@ def pre_initialize_structure(
def initialize(
self,
framework: Framework = None,
recipe: Union[str, List[str], Recipe, List[Recipe]] = None,
recipe: Union[str, List[str], "Recipe", List["Recipe"]] = None,
recipe_stage: str = None,
recipe_args: Dict[str, Any] = None,
model: Any = None,
Expand Down Expand Up @@ -340,7 +338,7 @@ def pre_initialize_structure(**kwargs):

def initialize(
framework: Framework = None,
recipe: Union[str, List[str], Recipe, List[Recipe]] = None,
recipe: Union[str, List[str], "Recipe", List["Recipe"]] = None,
recipe_stage: str = None,
recipe_args: Dict[str, Any] = None,
model: Any = None,
Expand Down Expand Up @@ -384,7 +382,7 @@ def finalize(**kwargs) -> ModifiedState:

def apply(
framework: Framework = None,
recipe: Union[str, List[str], Recipe, List[Recipe]] = None,
recipe: Union[str, List[str], "Recipe", List["Recipe"]] = None,
recipe_stage: str = None,
recipe_args: Dict[str, Any] = None,
model: Any = None,
Expand Down
3 changes: 3 additions & 0 deletions src/sparseml/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def update_recipe(
recipe_stage: str = None,
recipe_args: Dict[str, Any] = None,
):
if recipe is None:
return

if not isinstance(recipe, list):
recipe = [recipe]

Expand Down

0 comments on commit bc73e15

Please sign in to comment.