Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of powerpropagation as a modifier #1685

Merged
merged 48 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
d7dc46a
update get_named_layers_and_params_by_regex in src/sparseml/pytorch/u…
ohaijen May 18, 2023
86703b9
Merge branch 'main' of github.com:neuralmagic/sparseml into main
ohaijen May 18, 2023
c74827a
first pass at top-kast pruner - decoupling forward and backward sparsity
ohaijen Jul 3, 2023
e56774a
get top-kast pruner working
ohaijen Jul 12, 2023
36f4c77
make style and quality
ohaijen Jul 12, 2023
ec4ff60
small cleanup
ohaijen Jul 17, 2023
5959b07
Merge branch 'topkast' of github.com:neuralmagic/sparseml into topkast
ohaijen Jul 17, 2023
51d7453
Merge branch 'main' of github.com:neuralmagic/sparseml into main
ohaijen Jul 17, 2023
0b006d3
first pass at powerpropagated resnet
ohaijen Jul 17, 2023
216aa60
first pass for powerpropagation as a modifier
ohaijen Jul 24, 2023
719db6a
forgot recipe_templage
ohaijen Jul 25, 2023
4a6c863
create actual powerpropagated subclasses
ohaijen Aug 2, 2023
ffc3eac
Merge branch 'main' of github.com:neuralmagic/sparseml into main
ohaijen Aug 2, 2023
516a8f6
make pytorch.torchvision.train.py work with torch.distributed.launch
ohaijen Aug 2, 2023
c0051af
undo accidental add
ohaijen Aug 2, 2023
5075bde
Merge remote-tracking branch 'origin/torchvision_ddp_fix' into topkast
ohaijen Aug 2, 2023
b999330
some fixes to the modifier
ohaijen Aug 2, 2023
8405384
setup powerpropagation tests
ohaijen Aug 7, 2023
955ce1e
remove fist attempt at implementation
ohaijen Aug 7, 2023
4aa3d22
get powerpropagation working as a wrapper
ohaijen Aug 9, 2023
56cc96f
update testing recipe
ohaijen Aug 9, 2023
7f5ea36
ensure we don't set the mask twice and convert L2 loss to weight decay
ohaijen Aug 9, 2023
4b6e7b8
revert train.py to main version
ohaijen Aug 10, 2023
24a7bf7
clean, working version of powerpropagation
ohaijen Aug 10, 2023
102c266
Merge branch 'powerprop' of github.com:neuralmagic/sparseml into powe…
ohaijen Aug 10, 2023
97bea93
make style/quality
ohaijen Aug 10, 2023
dbfafee
undo changes that shouldn't be in the PR
ohaijen Aug 10, 2023
1bfab60
Merge branch 'main' into powerprop
ohaijen Aug 10, 2023
d8671bf
undo accidental change?
ohaijen Aug 10, 2023
c70e63e
Merge branch 'main' into powerprop
ohaijen Aug 10, 2023
24f3d89
make quality and add a weight decay test
ohaijen Aug 14, 2023
e11d5aa
revert changes to train.py
ohaijen Aug 14, 2023
6b072f5
Merge branch 'main' into topkast
ohaijen Aug 14, 2023
39fac2e
Merge branch 'main' of github.com:neuralmagic/sparseml into main
ohaijen Sep 4, 2023
4ec08bc
Merge branch 'main' of github.com:neuralmagic/sparseml into main
ohaijen Sep 11, 2023
000de7d
Merge branch 'main' into topkast
ohaijen Sep 11, 2023
4065a87
Merge branch 'main' into powerprop
ohaijen Sep 11, 2023
24c9ea7
Merge branch 'topkast' into powerprop
ohaijen Sep 11, 2023
a934051
Revert "Merge branch 'topkast' into powerprop"
ohaijen Sep 11, 2023
830e01e
Merge branch 'main' into powerprop
anmarques Oct 24, 2023
b963683
fix quality errors
ohaijen Oct 25, 2023
3dadd5b
Merge branch 'powerprop' of github.com:neuralmagic/sparseml into powe…
ohaijen Oct 25, 2023
0903d4e
Merge branch 'main' into powerprop
ohaijen Nov 6, 2023
f246989
Merge branch 'main' into powerprop
abhinavnmagic Nov 16, 2023
4e96424
Merge branch 'main' into powerprop
ohaijen Nov 20, 2023
db8896e
Merge branch 'main' into powerprop
ohaijen Nov 20, 2023
6b77f62
Merge branch 'main' into powerprop
ohaijen Nov 20, 2023
547889f
Merge branch 'main' into powerprop
ohaijen Nov 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sparseml/pytorch/recipe_template/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
MFACPruningModifier,
MovementPruningModifier,
OBSPruningModifier,
PowerpropagationModifier,
)
from sparseml.pytorch.sparsification.quantization.legacy_modifier_quantization import (
QuantizationModifier,
Expand Down
1 change: 1 addition & 0 deletions src/sparseml/pytorch/sparsification/pruning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .mask_creator import *
from .mask_params import *
from .modifier_as import *
from .modifier_powerpropagation import *
from .modifier_pruning_acdc import *
from .modifier_pruning_base import *
from .modifier_pruning_constant import *
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,327 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Modifier for models through powerproagation.

"""


import logging
from typing import List, Optional, Union

import torch
from torch.nn import Conv2d, Linear, Module, Parameter
from torch.nn import functional as F
from torch.optim.optimizer import Optimizer

from sparseml.optim import ModifierProp
from sparseml.pytorch.sparsification.modifier import (
PyTorchModifierYAML,
ScheduledModifier,
)
from sparseml.pytorch.utils import (
NamedLayerParam,
get_named_layers_and_params_by_regex,
get_prunable_layers,
tensors_to_device,
)
from sparseml.pytorch.utils.logger import LoggerManager
from sparseml.utils import ALL_PRUNABLE_TOKEN, ALL_TOKEN, validate_str_iterable


__all__ = [
"PowerpropagationModifier",
"PowerpropagationWrapper",
]


_LOGGER = logging.getLogger(__name__)


class PowerpropagationWrapper(Module):
def __init__(self, layer: Module, alpha: float = 1.0):
super(PowerpropagationWrapper, self).__init__()

if not isinstance(layer, Conv2d) and not isinstance(layer, Linear):
raise ValueError("Powerpropagation only works with Linear and Conv layers")

self.layer = layer
# First set alpha to 1, then update it to the correct
# value. This avoids replicating the code that updates
# the layer weights.
self.register_buffer("alpha", torch.tensor(1.0, requires_grad=False))
self.set_alpha(alpha)

def forward(self, x):
weight = self.layer.weight * pow(abs(self.layer.weight), self.alpha - 1)

if isinstance(self.layer, Conv2d):
return F.conv2d(
x,
weight,
self.layer.bias,
self.layer.stride,
self.layer.padding,
self.layer.dilation,
self.layer.groups,
)
elif isinstance(self.layer, Linear):
return F.linear(x, weight, self.layer.bias)
else:
raise ValueError(
"Powerpropagation only works with Linear and Conv2d layers"
)

def set_alpha(self, new_alpha):
with torch.no_grad():

self.layer.weight *= pow(abs(self.layer.weight), self.alpha / new_alpha - 1)
abhinavnmagic marked this conversation as resolved.
Show resolved Hide resolved
# If there were any zeros in the weights, these may now be nan,
# depending on the old and new values of alpha.
self.layer.weight.data = torch.nan_to_num(self.layer.weight)
self.alpha = torch.tensor(float(new_alpha))


@PyTorchModifierYAML()
class PowerpropagationModifier(ScheduledModifier):
"""
Does powerpropagation. TODO: more here.

| Sample yaml:
| !PowerpropagationModifier
| start_epoch: 0.0
| end_epoch: 100
| alpha: 2.0
| params: __ALL_PRUNABLE__
| strict: True

:param start_epoch: The epoch to start the modifier at
:param alpha: The degree weights should be raised to before the standard forward
pass, preserving the original sign of the weight. Noninteger weights are OK.
:param params: A list of full parameter names or regex patterns of names to apply
pruning to. Regex patterns must be specified with the prefix 're:'. __ALL__
will match to all parameters. __ALL_PRUNABLE__ will match to all ConvNd
and Linear layers' weights. If a sparsity to param mapping is defined by
final_sparsity, then params should be set to []
:param strict: if True, will raise an error if any module types or submodules in
scheme_overrides or ignore are not found in a given module. Default True
:param end_epoch: The epoch at which the architecture changes will be reversed,
converting the network back to a normal architecture. Note that if this is not
set, or if it is set for after the network finishes training, the architecture
changes will become part of the model, making it largely incompatible with
other frameworks.
"""

def __init__(
self,
start_epoch: Union[int, float],
end_epoch: Union[int, float],
params: Union[str, List[str]],
alpha: float = 1.0,
strict: bool = True,
):
super(PowerpropagationModifier, self).__init__(
start_epoch=start_epoch, end_epoch=end_epoch, end_comparator=-1
)

self._alpha = alpha
self._strict = strict
self._params = validate_str_iterable(
params, "{} for params".format(self.__class__.__name__)
)
self._propagated_layers = {}

self._validate_params()

def initialize(
self,
module: Module,
epoch: float = 0,
loggers: Optional[LoggerManager] = None,
**kwargs,
):
"""
Grab the params and apply if epoch in range to control pruning for.

:param module: the PyTorch model/module to modify
:param epoch: The epoch to initialize the modifier and module at.
Defaults to 0 (start of the training process)
:param loggers: Optional list of loggers to log the modification process to
:param kwargs: Optional kwargs to support specific arguments
for individual modifiers.
"""
super().initialize(module, epoch, loggers, **kwargs)
self._powerpropagated_layers = self._create_named_layers_and_params(module)

@ModifierProp()
def alpha(self) -> Optional[float]:
"""
:return: alpha (the power to which weights are raised during the forward pass)
"""
return self._alpha

@alpha.setter
def alpha(self, value: float):
"""
:prams value: alpha (the power to which weights are raised during the
forward pass)
"""
self._alpha = value

@ModifierProp()
def params(self) -> Union[str, List[str], None]:
"""
:return: A list of full parameter names or regex patterns of names to apply
pruning to. Regex patterns must be specified with the prefix 're:'. __ALL__
will match to all parameters. __ALL_PRUNABLE__ will match to all ConvNd
and Linear layers' weights
"""
return self._params

def update(
self, module: Module, optimizer: Optimizer, epoch: float, steps_per_epoch: int
):
"""
If start_pending(), converts layers to powerpropagated layers
If end_pending(), undoes the conversion

:param module: module to modify
:param optimizer: optimizer to modify
:param epoch: current epoch and progress within the current epoch
:param steps_per_epoch: number of steps taken within each epoch
(calculate batch number using this and epoch)
"""
super().update(module, optimizer, epoch, steps_per_epoch)
self._check_powerpropagation_update(module, epoch, steps_per_epoch)

def _check_powerpropagation_update(
self, module: Module, epoch: float, steps_per_epoch: int
):
if self.start_pending(epoch, steps_per_epoch):
self._enable_module_powerpropagation(module)
if self.end_pending(epoch, steps_per_epoch):
self._disable_module_powerpropagation(module)

# TODO: Make this do something useful
self._log_powerpropagation(module, epoch, steps_per_epoch)

def _enable_module_powerpropagation(self, module: Module):
print(module.state_dict().keys())
for name, layer, param in self._powerpropagated_layers:
self._enable_powerprop(module, name, layer, param)
print("\n\n\n", module.state_dict().keys())
self._powerpropagation_enabled = True

def _disable_module_powerpropagation(self, module: Module):
if not self._powerpropagation_enabled:
return
for name, layer in self._propagated_layers.items():
self._undo_enable_powerprop(module, name, layer)
print("\n\n\n", module.state_dict().keys())
self._powerpropagation_enabled = False

# from https://pytorch.org/docs/stable/_modules/torch/ao/quantization/fuse_modules.html#fuse_modules # noqa: E501
# Generalization of setattr
def _set_module(self, model, submodule_key, module):
tokens = submodule_key.split(".")
sub_tokens = tokens[:-1]
cur_mod = model
for s in sub_tokens:
cur_mod = getattr(cur_mod, s)

setattr(cur_mod, tokens[-1], module)

def _enable_powerprop(
self, model: Module, name: str, layer: Module, param: Parameter
):
if isinstance(layer, Conv2d) or isinstance(layer, Linear):
powerpropagated_layer = PowerpropagationWrapper(layer, self._alpha)
if param.is_cuda:
powerpropagated_layer = powerpropagated_layer.to(
torch.get_device(param)
)
self._propagated_layers[name] = powerpropagated_layer
self._set_module(model, name, powerpropagated_layer)
else:
raise RuntimeError(f"don't know how do do powerpropagation for {layer}")
return

def _undo_enable_powerprop(self, model: Module, name: str, layer: Module):
if isinstance(layer, PowerpropagationWrapper):
# Setting alpha to 1 automatically updates the inner layer
# weights to the correct non-exponentiated values.
layer.set_alpha(1)
self._set_module(model, name, layer.layer)
else:
raise RuntimeError(f"don't know how to undo powerpropagation for {layer}")
return

def _validate_params(self):
self.validate_schedule()

def _log_powerpropagation(
self,
module: Module,
epoch: float,
steps_per_epoch: int,
):
"""
Check whether to log an update for the learning rate of the modifier.

:param module: module to modify
:param optimizer: optimizer to modify
:param epoch: current epoch and progress within the current epoch
:param steps_per_epoch: number of steps taken within each epoch
(calculate batch number using this and epoch)
"""

def _log(tag, value):
self.log_scalar(
tag=tag,
value=value,
epoch=epoch,
steps_per_epoch=steps_per_epoch,
)

_log(
tag="PowerpropagationModifier/alpha",
value=self._alpha,
)

def _create_named_layers_and_params(self, module: Module) -> List[NamedLayerParam]:
if self._check_params_match(ALL_TOKEN):
param_names = ["re:.*"]
elif self._check_params_match(ALL_PRUNABLE_TOKEN):
param_names = [
name + ".weight" for (name, _) in get_prunable_layers(module)
]
else:
param_names = self._params

chosen = get_named_layers_and_params_by_regex(
module,
param_names,
params_strict=self._strict,
)
return [(x[0], x[1], x[3]) for x in chosen]

def _check_params_match(self, token: Union[str, List[str]]):
if isinstance(token, str):
return token in self._params or token == self._params

if isinstance(self._params, str):
return self._params in token

return len(set(token).intersection(set(self._params))) > 0
Loading
Loading