Skip to content

Commit

Permalink
[Bug][ConstantPruningModifier] Fix mask de register bug (#1773)
Browse files Browse the repository at this point in the history
* Fix mask de-register logic

* forgot to remove commented out line
  • Loading branch information
rahul-tuli authored Oct 19, 2023
1 parent 54ebc6d commit e11ed3d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
7 changes: 4 additions & 3 deletions src/sparseml/modifiers/pruning/utils/pytorch/layer_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,10 @@ def remove_mask(self, layer_param_name: str):
mask_settings = self._mask_settings[layer_param_name]
parameterized_layer = self._masked_layer_params[layer_param_name]

if mask_settings.persistent:
parameterized_layer.layer.unregister_buffer(
param_mask_name(parameterized_layer.param_name)
if not mask_settings.persistent:
delattr(
parameterized_layer.layer,
param_mask_name(parameterized_layer.param_name),
)

del self._masked_layer_params[layer_param_name]
Expand Down
14 changes: 11 additions & 3 deletions tests/sparseml/modifiers/pruning/constant/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,24 @@ def test_constant_pruning_modifier_e2e(model, optimizer):
modifier.on_update(state, event=Event(type_=EventType.OPTIM_PRE_STEP))
modifier.on_update(state, event=Event(type_=EventType.OPTIM_POST_STEP))
modifier.on_end(state, None)

# copy old mask settings as finalize will remove them
# this is needed to check if a mask was persistent

old_mask_settings = modifier._mask_settings.copy()
modifier.finalize(state)

# check mask is removed
for _, parameterized_layer in modifier.parameterized_layers_.items():
for layer_param_name, parameterized_layer in modifier.parameterized_layers_.items():
mask_name = param_mask_name(parameterized_layer.param_name)

if not old_mask_settings[layer_param_name].persistent:
assert not hasattr(parameterized_layer.layer, mask_name)

# mask name should not be in _mask_settings or
# _masked_layer_params
assert mask_name not in modifier._mask_settings
assert mask_name not in modifier._masked_layer_params
assert layer_param_name not in modifier._mask_settings
assert layer_param_name not in modifier._masked_layer_params

# sparsity should restored by ConstantPruningModifierPyTorch

Expand Down

0 comments on commit e11ed3d

Please sign in to comment.