Skip to content

Commit

Permalink
update test file names
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Apr 9, 2024
1 parent 98e0313 commit d7c3ae0
Show file tree
Hide file tree
Showing 12 changed files with 654 additions and 1,354 deletions.
47 changes: 29 additions & 18 deletions tests/sparseml/modifiers/pruning/wanda/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,38 @@
# limitations under the License.


import unittest

import pytest

from sparseml.core.factory import ModifierFactory
from sparseml.core.framework import Framework
from sparseml.modifiers.pruning.wanda.base import WandaPruningModifier
from tests.sparseml.modifiers.conf import setup_modifier_factory
from tests.testing_utils import requires_torch


@pytest.mark.unit
@requires_torch
class TestWandaIsRegistered(unittest.TestCase):
def setUp(self):
self.kwargs = dict(
sparsity=0.5,
targets="__ALL_PRUNABLE__",
)
setup_modifier_factory()

def test_wanda_is_registered(self):
type_ = ModifierFactory.create(
type_="WandaPruningModifier",
framework=Framework.general,
allow_experimental=False,
allow_registered=True,
**self.kwargs,
)

def test_wanda_is_registered():

kwargs = dict(
sparsity=0.5,
targets="__ALL_PRUNABLE__",
)
setup_modifier_factory()
type_ = ModifierFactory.create(
type_="WandaPruningModifier",
framework=Framework.general,
allow_experimental=False,
allow_registered=True,
**kwargs,
)

assert isinstance(
type_, WandaPruningModifier
), "PyTorch ConstantPruningModifier not registered"
self.assertIsInstance(
type_,
WandaPruningModifier,
"PyTorch ConstantPruningModifier not registered",
)
50 changes: 0 additions & 50 deletions tests/sparseml/modifiers/pruning/wanda/test_base_new.py

This file was deleted.

117 changes: 67 additions & 50 deletions tests/sparseml/modifiers/quantization/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,72 +12,89 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import pytest

from sparseml.core.event import Event
from sparseml.core.factory import ModifierFactory
from sparseml.core.framework import Framework
from sparseml.modifiers.quantization import QuantizationModifier
from tests.sparseml.modifiers.conf import setup_modifier_factory
from tests.testing_utils import requires_torch


def test_quantization_registered():
setup_modifier_factory()
@requires_torch
@pytest.mark.unit
class TestQuantizationRegistered(unittest.TestCase):
def setUp(self):
setup_modifier_factory()
self.kwargs = dict(index=0, group="quantization", start=2.0, end=-1.0)

kwargs = dict(index=0, group="quantization", start=2.0, end=-1.0)
quant_obj = ModifierFactory.create(
type_="QuantizationModifier",
framework=Framework.general,
allow_experimental=False,
allow_registered=True,
**kwargs,
)
def test_quantization_registered(self):
quant_obj = ModifierFactory.create(
type_="QuantizationModifier",
framework=Framework.general,
allow_experimental=False,
allow_registered=True,
**self.kwargs,
)

assert isinstance(quant_obj, QuantizationModifier)
self.assertIsInstance(quant_obj, QuantizationModifier)


def test_end_epochs():
start = 0.0
scheme = dict(
input_activations=dict(num_bits=8, symmetric=True),
weights=dict(num_bits=6, symmetric=False),
)
@requires_torch
@pytest.mark.unit
class TestEndEpochs(unittest.TestCase):
def setUp(self):
self.start = 0.0
self.scheme = dict(
input_activations=dict(num_bits=8, symmetric=True),
weights=dict(num_bits=6, symmetric=False),
)

disable_quant_epoch, freeze_bn_epoch = None, None
obj_modifier = QuantizationModifier(
start=start,
scheme=scheme,
disable_quantization_observer_epoch=disable_quant_epoch,
freeze_bn_stats_epoch=freeze_bn_epoch,
)
def test_end_epochs(self):
disable_quant_epoch, freeze_bn_epoch = None, None
obj_modifier = QuantizationModifier(
start=self.start,
scheme=self.scheme,
disable_quantization_observer_epoch=disable_quant_epoch,
freeze_bn_stats_epoch=freeze_bn_epoch,
)

assert obj_modifier.calculate_disable_observer_epoch() == -1
assert obj_modifier.calculate_freeze_bn_stats_epoch() == -1
assert obj_modifier.calculate_disable_observer_epoch() == -1
assert obj_modifier.calculate_freeze_bn_stats_epoch() == -1

for epoch in range(3):
event = Event(steps_per_epoch=1, global_step=epoch)
assert not obj_modifier.check_should_disable_observer(event)
assert not obj_modifier.check_should_freeze_bn_stats(event)
for epoch in range(3):
event = Event(steps_per_epoch=1, global_step=epoch)
assert not obj_modifier.check_should_disable_observer(event)
assert not obj_modifier.check_should_freeze_bn_stats(event)

disable_quant_epoch, freeze_bn_epoch = 3.5, 5.0
obj_modifier = QuantizationModifier(
start=start,
scheme=scheme,
disable_quantization_observer_epoch=disable_quant_epoch,
freeze_bn_stats_epoch=freeze_bn_epoch,
)
disable_quant_epoch, freeze_bn_epoch = 3.5, 5.0
obj_modifier = QuantizationModifier(
start=self.start,
scheme=self.scheme,
disable_quantization_observer_epoch=disable_quant_epoch,
freeze_bn_stats_epoch=freeze_bn_epoch,
)

assert obj_modifier.calculate_disable_observer_epoch() == disable_quant_epoch
assert obj_modifier.calculate_freeze_bn_stats_epoch() == freeze_bn_epoch
self.assertEqual(
obj_modifier.calculate_disable_observer_epoch(), disable_quant_epoch
)
self.assertEqual(
obj_modifier.calculate_freeze_bn_stats_epoch(), freeze_bn_epoch
)

for epoch in range(4):
event = Event(steps_per_epoch=1, global_step=epoch)
assert not obj_modifier.check_should_disable_observer(event)
assert not obj_modifier.check_should_freeze_bn_stats(event)
for epoch in range(4):
event = Event(steps_per_epoch=1, global_step=epoch)
assert not obj_modifier.check_should_disable_observer(event)
assert not obj_modifier.check_should_freeze_bn_stats(event)

event = Event(steps_per_epoch=1, global_step=4)
assert obj_modifier.check_should_disable_observer(event)
assert not obj_modifier.check_should_freeze_bn_stats(event)

for epoch in range(5, 8):
event = Event(steps_per_epoch=1, global_step=epoch)
event = Event(steps_per_epoch=1, global_step=4)
assert obj_modifier.check_should_disable_observer(event)
assert obj_modifier.check_should_freeze_bn_stats(event)
assert not obj_modifier.check_should_freeze_bn_stats(event)

for epoch in range(5, 8):
event = Event(steps_per_epoch=1, global_step=epoch)
assert obj_modifier.check_should_disable_observer(event)
assert obj_modifier.check_should_freeze_bn_stats(event)
100 changes: 0 additions & 100 deletions tests/sparseml/modifiers/quantization/test_base_new.py

This file was deleted.

Loading

0 comments on commit d7c3ae0

Please sign in to comment.