diff --git a/tests/sparseml/modifiers/pruning/wanda/test_base.py b/tests/sparseml/modifiers/pruning/wanda/test_base.py index 8dcb682020d..f3244ce0f25 100644 --- a/tests/sparseml/modifiers/pruning/wanda/test_base.py +++ b/tests/sparseml/modifiers/pruning/wanda/test_base.py @@ -13,27 +13,38 @@ # limitations under the License. +import unittest + +import pytest + from sparseml.core.factory import ModifierFactory from sparseml.core.framework import Framework from sparseml.modifiers.pruning.wanda.base import WandaPruningModifier from tests.sparseml.modifiers.conf import setup_modifier_factory +from tests.testing_utils import requires_torch + + +@pytest.mark.unit +@requires_torch +class TestWandaIsRegistered(unittest.TestCase): + def setUp(self): + self.kwargs = dict( + sparsity=0.5, + targets="__ALL_PRUNABLE__", + ) + setup_modifier_factory() + def test_wanda_is_registered(self): + type_ = ModifierFactory.create( + type_="WandaPruningModifier", + framework=Framework.general, + allow_experimental=False, + allow_registered=True, + **self.kwargs, + ) -def test_wanda_is_registered(): - - kwargs = dict( - sparsity=0.5, - targets="__ALL_PRUNABLE__", - ) - setup_modifier_factory() - type_ = ModifierFactory.create( - type_="WandaPruningModifier", - framework=Framework.general, - allow_experimental=False, - allow_registered=True, - **kwargs, - ) - - assert isinstance( - type_, WandaPruningModifier - ), "PyTorch ConstantPruningModifier not registered" + self.assertIsInstance( + type_, + WandaPruningModifier, + "PyTorch ConstantPruningModifier not registered", + ) diff --git a/tests/sparseml/modifiers/pruning/wanda/test_base_new.py b/tests/sparseml/modifiers/pruning/wanda/test_base_new.py deleted file mode 100644 index f3244ce0f25..00000000000 --- a/tests/sparseml/modifiers/pruning/wanda/test_base_new.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import unittest - -import pytest - -from sparseml.core.factory import ModifierFactory -from sparseml.core.framework import Framework -from sparseml.modifiers.pruning.wanda.base import WandaPruningModifier -from tests.sparseml.modifiers.conf import setup_modifier_factory -from tests.testing_utils import requires_torch - - -@pytest.mark.unit -@requires_torch -class TestWandaIsRegistered(unittest.TestCase): - def setUp(self): - self.kwargs = dict( - sparsity=0.5, - targets="__ALL_PRUNABLE__", - ) - setup_modifier_factory() - - def test_wanda_is_registered(self): - type_ = ModifierFactory.create( - type_="WandaPruningModifier", - framework=Framework.general, - allow_experimental=False, - allow_registered=True, - **self.kwargs, - ) - - self.assertIsInstance( - type_, - WandaPruningModifier, - "PyTorch ConstantPruningModifier not registered", - ) diff --git a/tests/sparseml/modifiers/quantization/test_base.py b/tests/sparseml/modifiers/quantization/test_base.py index cd5fab0e755..00400ef16c2 100644 --- a/tests/sparseml/modifiers/quantization/test_base.py +++ b/tests/sparseml/modifiers/quantization/test_base.py @@ -12,72 +12,89 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + +import pytest + from sparseml.core.event import Event from sparseml.core.factory import ModifierFactory from sparseml.core.framework import Framework from sparseml.modifiers.quantization import QuantizationModifier from tests.sparseml.modifiers.conf import setup_modifier_factory +from tests.testing_utils import requires_torch -def test_quantization_registered(): - setup_modifier_factory() +@requires_torch +@pytest.mark.unit +class TestQuantizationRegistered(unittest.TestCase): + def setUp(self): + setup_modifier_factory() + self.kwargs = dict(index=0, group="quantization", start=2.0, end=-1.0) - kwargs = dict(index=0, group="quantization", start=2.0, end=-1.0) - quant_obj = ModifierFactory.create( - type_="QuantizationModifier", - framework=Framework.general, - allow_experimental=False, - allow_registered=True, - **kwargs, - ) + def test_quantization_registered(self): + quant_obj = ModifierFactory.create( + type_="QuantizationModifier", + framework=Framework.general, + allow_experimental=False, + allow_registered=True, + **self.kwargs, + ) - assert isinstance(quant_obj, QuantizationModifier) + self.assertIsInstance(quant_obj, QuantizationModifier) -def test_end_epochs(): - start = 0.0 - scheme = dict( - input_activations=dict(num_bits=8, symmetric=True), - weights=dict(num_bits=6, symmetric=False), - ) +@requires_torch +@pytest.mark.unit +class TestEndEpochs(unittest.TestCase): + def setUp(self): + self.start = 0.0 + self.scheme = dict( + input_activations=dict(num_bits=8, symmetric=True), + weights=dict(num_bits=6, symmetric=False), + ) - disable_quant_epoch, freeze_bn_epoch = None, None - obj_modifier = QuantizationModifier( - start=start, - scheme=scheme, - disable_quantization_observer_epoch=disable_quant_epoch, - freeze_bn_stats_epoch=freeze_bn_epoch, - ) + def test_end_epochs(self): + disable_quant_epoch, freeze_bn_epoch = None, None + obj_modifier = QuantizationModifier( + start=self.start, + scheme=self.scheme, + disable_quantization_observer_epoch=disable_quant_epoch, + freeze_bn_stats_epoch=freeze_bn_epoch, + ) - assert obj_modifier.calculate_disable_observer_epoch() == -1 - assert obj_modifier.calculate_freeze_bn_stats_epoch() == -1 + assert obj_modifier.calculate_disable_observer_epoch() == -1 + assert obj_modifier.calculate_freeze_bn_stats_epoch() == -1 - for epoch in range(3): - event = Event(steps_per_epoch=1, global_step=epoch) - assert not obj_modifier.check_should_disable_observer(event) - assert not obj_modifier.check_should_freeze_bn_stats(event) + for epoch in range(3): + event = Event(steps_per_epoch=1, global_step=epoch) + assert not obj_modifier.check_should_disable_observer(event) + assert not obj_modifier.check_should_freeze_bn_stats(event) - disable_quant_epoch, freeze_bn_epoch = 3.5, 5.0 - obj_modifier = QuantizationModifier( - start=start, - scheme=scheme, - disable_quantization_observer_epoch=disable_quant_epoch, - freeze_bn_stats_epoch=freeze_bn_epoch, - ) + disable_quant_epoch, freeze_bn_epoch = 3.5, 5.0 + obj_modifier = QuantizationModifier( + start=self.start, + scheme=self.scheme, + disable_quantization_observer_epoch=disable_quant_epoch, + freeze_bn_stats_epoch=freeze_bn_epoch, + ) - assert obj_modifier.calculate_disable_observer_epoch() == disable_quant_epoch - assert obj_modifier.calculate_freeze_bn_stats_epoch() == freeze_bn_epoch + self.assertEqual( + obj_modifier.calculate_disable_observer_epoch(), disable_quant_epoch + ) + self.assertEqual( + obj_modifier.calculate_freeze_bn_stats_epoch(), freeze_bn_epoch + ) - for epoch in range(4): - event = Event(steps_per_epoch=1, global_step=epoch) - assert not obj_modifier.check_should_disable_observer(event) - assert not obj_modifier.check_should_freeze_bn_stats(event) + for epoch in range(4): + event = Event(steps_per_epoch=1, global_step=epoch) + assert not obj_modifier.check_should_disable_observer(event) + assert not obj_modifier.check_should_freeze_bn_stats(event) - event = Event(steps_per_epoch=1, global_step=4) - assert obj_modifier.check_should_disable_observer(event) - assert not obj_modifier.check_should_freeze_bn_stats(event) - - for epoch in range(5, 8): - event = Event(steps_per_epoch=1, global_step=epoch) + event = Event(steps_per_epoch=1, global_step=4) assert obj_modifier.check_should_disable_observer(event) - assert obj_modifier.check_should_freeze_bn_stats(event) + assert not obj_modifier.check_should_freeze_bn_stats(event) + + for epoch in range(5, 8): + event = Event(steps_per_epoch=1, global_step=epoch) + assert obj_modifier.check_should_disable_observer(event) + assert obj_modifier.check_should_freeze_bn_stats(event) diff --git a/tests/sparseml/modifiers/quantization/test_base_new.py b/tests/sparseml/modifiers/quantization/test_base_new.py deleted file mode 100644 index 00400ef16c2..00000000000 --- a/tests/sparseml/modifiers/quantization/test_base_new.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import pytest - -from sparseml.core.event import Event -from sparseml.core.factory import ModifierFactory -from sparseml.core.framework import Framework -from sparseml.modifiers.quantization import QuantizationModifier -from tests.sparseml.modifiers.conf import setup_modifier_factory -from tests.testing_utils import requires_torch - - -@requires_torch -@pytest.mark.unit -class TestQuantizationRegistered(unittest.TestCase): - def setUp(self): - setup_modifier_factory() - self.kwargs = dict(index=0, group="quantization", start=2.0, end=-1.0) - - def test_quantization_registered(self): - quant_obj = ModifierFactory.create( - type_="QuantizationModifier", - framework=Framework.general, - allow_experimental=False, - allow_registered=True, - **self.kwargs, - ) - - self.assertIsInstance(quant_obj, QuantizationModifier) - - -@requires_torch -@pytest.mark.unit -class TestEndEpochs(unittest.TestCase): - def setUp(self): - self.start = 0.0 - self.scheme = dict( - input_activations=dict(num_bits=8, symmetric=True), - weights=dict(num_bits=6, symmetric=False), - ) - - def test_end_epochs(self): - disable_quant_epoch, freeze_bn_epoch = None, None - obj_modifier = QuantizationModifier( - start=self.start, - scheme=self.scheme, - disable_quantization_observer_epoch=disable_quant_epoch, - freeze_bn_stats_epoch=freeze_bn_epoch, - ) - - assert obj_modifier.calculate_disable_observer_epoch() == -1 - assert obj_modifier.calculate_freeze_bn_stats_epoch() == -1 - - for epoch in range(3): - event = Event(steps_per_epoch=1, global_step=epoch) - assert not obj_modifier.check_should_disable_observer(event) - assert not obj_modifier.check_should_freeze_bn_stats(event) - - disable_quant_epoch, freeze_bn_epoch = 3.5, 5.0 - obj_modifier = QuantizationModifier( - start=self.start, - scheme=self.scheme, - disable_quantization_observer_epoch=disable_quant_epoch, - freeze_bn_stats_epoch=freeze_bn_epoch, - ) - - self.assertEqual( - obj_modifier.calculate_disable_observer_epoch(), disable_quant_epoch - ) - self.assertEqual( - obj_modifier.calculate_freeze_bn_stats_epoch(), freeze_bn_epoch - ) - - for epoch in range(4): - event = Event(steps_per_epoch=1, global_step=epoch) - assert not obj_modifier.check_should_disable_observer(event) - assert not obj_modifier.check_should_freeze_bn_stats(event) - - event = Event(steps_per_epoch=1, global_step=4) - assert obj_modifier.check_should_disable_observer(event) - assert not obj_modifier.check_should_freeze_bn_stats(event) - - for epoch in range(5, 8): - event = Event(steps_per_epoch=1, global_step=epoch) - assert obj_modifier.check_should_disable_observer(event) - assert obj_modifier.check_should_freeze_bn_stats(event) diff --git a/tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py b/tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py index 0df15ebfdd2..6fee4ba08e0 100644 --- a/tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py +++ b/tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + import pytest +from parameterized import parameterized from sparseml.core.framework import Framework from sparseml.core.model import ModifiableModel from sparseml.modifiers.obcq.pytorch import SparseGPTModifierPyTorch @@ -21,135 +24,163 @@ from sparseml.modifiers.quantization.pytorch import QuantizationModifierPyTorch from tests.sparseml.modifiers.conf import LifecyleTestingHarness, setup_modifier_factory from tests.sparseml.pytorch.helpers import LinearNet +from tests.testing_utils import requires_torch -@pytest.mark.parametrize( - "sparsity,targets", - [ - ([0.5, 0.2], "__ALL__"), # type mismatch - ([0.2, 0.1, 0.3], ["seq.fc1", "seq.fc2"]), # length mismatch - ([0.3, 0.4], ["re:.*fc1", "re:.*fc2"]), # regex not supported - ], -) -def test_invalid_layerwise_recipes_raise_exceptions(sparsity, targets): - setup_modifier_factory() - model = LinearNet() - - kwargs = dict( - sparsity=sparsity, - block_size=128, - quantize=False, - targets=targets, - ) - modifier = SparseGPTModifierPyTorch(**kwargs) - testing_harness = LifecyleTestingHarness(model=model, start=-1) - - # confirm invalid layerwise recipes fail at initialization - with pytest.raises(ValueError): - modifier.initialize(testing_harness.get_state()) - - -def test_successful_layerwise_recipe(): - setup_modifier_factory() - model = LinearNet() - - sparsities = [0.5, 0.2] - targets = ["seq.fc1", "seq.fc2"] - kwargs = dict(sparsity=sparsities, block_size=128, quantize=False, targets=targets) - modifier = SparseGPTModifierPyTorch(**kwargs) - modifier.compressible_layers_ = {"seq.fc1": None, "seq.fc2": None} - modifier.model = ModifiableModel(framework=Framework.pytorch, model=model) - found_compressible_layers = modifier.compressible_layers() - modifier.compressible_layers_ = found_compressible_layers - modifier._validate_layerwise_sparsity() - - # ensure layers names successfully match up with model - assert len(found_compressible_layers) == len(targets) - - -def test_create_default_quant_modifier(): - setup_modifier_factory() - kwargs = dict(sparsity=0.5, block_size=128, quantize=True) +# TODO: unit tests are by default sanity tests/maybe regression if multiple use +# cases/inputs +# Are we covering sufficient input cases? +# Are we ok with each test running on a per commit basis? - modifier = SparseGPTModifierPyTorch(**kwargs) - assert modifier.quantization_modifier_ is None - testing_harness = LifecyleTestingHarness(model=LinearNet()) - modifier.on_initialize_structure(testing_harness.get_state()) - assert modifier.quantize - assert isinstance(modifier.quantization_modifier_, QuantizationModifier) +@pytest.mark.unit +@requires_torch +class TestInvalidLayerwiseRecipesRaiseExceptions(unittest.TestCase): + def setUp(self): + setup_modifier_factory() - should_be_default_quant_scheme = modifier.quantization_modifier_.scheme - assert should_be_default_quant_scheme.input_activations.num_bits == 8 - assert not should_be_default_quant_scheme.input_activations.symmetric - assert should_be_default_quant_scheme.weights.num_bits == 8 - assert should_be_default_quant_scheme.weights.symmetric - - -def test_set_quant_if_modifer_already_exists(): - setup_modifier_factory() - - model = LinearNet() - kwargs = dict( - scheme=dict( - input_activations=dict(num_bits=8, symmetric=True), - weights=dict(num_bits=4, symmetric=False), - ), + @parameterized.expand( + [ + [[0.5, 0.2], "__ALL__"], + [[0.2, 0.1, 0.3], ["seq.fc1", "seq.fc2"]], + [[0.3, 0.4], ["re:.*fc1", "re:.*fc2"]], + ] ) - - modifier = QuantizationModifierPyTorch(**kwargs) - testing_harness = LifecyleTestingHarness(model=model, start=-1) - - assert not testing_harness.get_state().model.qat_active() - modifier.initialize(testing_harness.get_state()) - assert testing_harness.get_state().model.qat_active() - - kwargs = dict(sparsity=0.5, block_size=128, quantize=False) - modifier = SparseGPTModifierPyTorch(**kwargs) - assert not modifier.quantize - modifier.on_initialize_structure(testing_harness.get_state()) - - # quantization modifier not owned by SparseGPT - assert modifier.quantization_modifier_ is None - - # since quantization modifier is already applied, quantization must be set in OBCQ - assert modifier.quantize - - -def test_set_quant_in_sparsegpt(): - setup_modifier_factory() - - quant_kwargs = { - "scheme": { - "input_activations": { - "num_bits": 8, - "symmetric": False, - "strategy": "tensor", - "kwargs": {}, - }, - "weights": { - "num_bits": 4, - "symmetric": True, - "strategy": "channel", - "kwargs": {}, - }, + def test_invalid_layerwise_recipes_raise_exceptions(self, sparsity, targets): + setup_modifier_factory() + kwargs = dict( + sparsity=sparsity, + block_size=128, + quantize=False, + targets=targets, + ) + modifier = SparseGPTModifierPyTorch(**kwargs) + testing_harness = LifecyleTestingHarness(model=LinearNet(), start=-1) + + # confirm invalid layerwise recipes fail at initialization + with self.assertRaises(ValueError): + modifier.initialize(testing_harness.get_state()) + + +@pytest.mark.unit +@requires_torch +class TestSuccessfulLayerwiseRecipe(unittest.TestCase): + def setUp(self): + setup_modifier_factory() + + def test_successful_layerwise_recipe(self): + sparsities = [0.5, 0.2] + targets = ["seq.fc1", "seq.fc2"] + kwargs = dict( + sparsity=sparsities, block_size=128, quantize=False, targets=targets + ) + modifier = SparseGPTModifierPyTorch(**kwargs) + modifier.compressible_layers_ = {"seq.fc1": None, "seq.fc2": None} + modifier.model = ModifiableModel(framework=Framework.pytorch, model=LinearNet()) + found_compressible_layers = modifier.compressible_layers() + modifier.compressible_layers_ = found_compressible_layers + modifier._validate_layerwise_sparsity() + + # ensure layers names successfully match up with model + self.assertEqual(len(found_compressible_layers), len(targets)) + + +@pytest.mark.unit +@requires_torch +class TestCreateDefaultQuantModifier(unittest.TestCase): + def setUp(self): + setup_modifier_factory() + + def test_create_default_quant_modifier(self): + kwargs = dict(sparsity=0.5, block_size=128, quantize=True) + + modifier = SparseGPTModifierPyTorch(**kwargs) + assert modifier.quantization_modifier_ is None + + testing_harness = LifecyleTestingHarness(model=LinearNet()) + modifier.on_initialize_structure(testing_harness.get_state()) + assert modifier.quantize + assert isinstance(modifier.quantization_modifier_, QuantizationModifier) + + should_be_default_quant_scheme = modifier.quantization_modifier_.scheme + self.assertEqual(should_be_default_quant_scheme.input_activations.num_bits, 8) + assert not should_be_default_quant_scheme.input_activations.symmetric + self.assertEqual(should_be_default_quant_scheme.weights.num_bits, 8) + assert should_be_default_quant_scheme.weights.symmetric + + +@pytest.mark.unit +@requires_torch +class TestSetQuantIfModifierAlreadyExists(unittest.TestCase): + def setUp(self): + setup_modifier_factory() + + def test_set_quant_if_modifer_already_exists(self): + model = LinearNet() + kwargs = dict( + scheme=dict( + input_activations=dict(num_bits=8, symmetric=True), + weights=dict(num_bits=4, symmetric=False), + ), + ) + + modifier = QuantizationModifierPyTorch(**kwargs) + testing_harness = LifecyleTestingHarness(model=model, start=-1) + + assert not testing_harness.get_state().model.qat_active() + modifier.initialize(testing_harness.get_state()) + assert testing_harness.get_state().model.qat_active() + + kwargs = dict(sparsity=0.5, block_size=128, quantize=False) + modifier = SparseGPTModifierPyTorch(**kwargs) + assert not modifier.quantize + modifier.on_initialize_structure(testing_harness.get_state()) + + # quantization modifier not owned by SparseGPT + assert modifier.quantization_modifier_ is None + + # since quantization modifier is already applied, quantization must be set in + # OBCQ + assert modifier.quantize + + +class TestSetQuantInSparseGPT(unittest.TestCase): + def setUp(self): + setup_modifier_factory() + self.quant_kwargs = { + "scheme": { + "input_activations": { + "num_bits": 8, + "symmetric": False, + "strategy": "tensor", + "kwargs": {}, + }, + "weights": { + "num_bits": 4, + "symmetric": True, + "strategy": "channel", + "kwargs": {}, + }, + } } - } - quant_config = {"QuantizationModifier": quant_kwargs} - - kwargs = dict(sparsity=0.5, block_size=128, quantize=quant_config) - - modifier = SparseGPTModifierPyTorch(**kwargs) - assert modifier.quantization_modifier_ is None - - testing_harness = LifecyleTestingHarness(model=LinearNet()) - modifier.on_initialize_structure(testing_harness.get_state()) - assert modifier.quantize - assert isinstance(modifier.quantization_modifier_, QuantizationModifier) - - dict_scheme = dict(modifier.quantization_modifier_.scheme) - assert dict(dict_scheme["weights"]) == quant_kwargs["scheme"]["weights"] - assert ( - dict(dict_scheme["input_activations"]) - == quant_kwargs["scheme"]["input_activations"] - ) + self.quant_config = {"QuantizationModifier": self.quant_kwargs} + + def test_set_quant_in_sparsegpt(self): + kwargs = dict(sparsity=0.5, block_size=128, quantize=self.quant_config) + + modifier = SparseGPTModifierPyTorch(**kwargs) + assert modifier.quantization_modifier_ is None + + testing_harness = LifecyleTestingHarness(model=LinearNet()) + modifier.on_initialize_structure(testing_harness.get_state()) + assert modifier.quantize + assert isinstance(modifier.quantization_modifier_, QuantizationModifier) + + dict_scheme = dict(modifier.quantization_modifier_.scheme) + self.assertEqual( + dict(dict_scheme["weights"]), self.quant_kwargs["scheme"]["weights"] + ) + self.assertEqual( + dict(dict_scheme["input_activations"]), + self.quant_kwargs["scheme"]["input_activations"], + ) diff --git a/tests/sparseml/pytorch/modifiers/obcq/test_pytorch_new.py b/tests/sparseml/pytorch/modifiers/obcq/test_pytorch_new.py deleted file mode 100644 index 6fee4ba08e0..00000000000 --- a/tests/sparseml/pytorch/modifiers/obcq/test_pytorch_new.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import pytest - -from parameterized import parameterized -from sparseml.core.framework import Framework -from sparseml.core.model import ModifiableModel -from sparseml.modifiers.obcq.pytorch import SparseGPTModifierPyTorch -from sparseml.modifiers.quantization import QuantizationModifier -from sparseml.modifiers.quantization.pytorch import QuantizationModifierPyTorch -from tests.sparseml.modifiers.conf import LifecyleTestingHarness, setup_modifier_factory -from tests.sparseml.pytorch.helpers import LinearNet -from tests.testing_utils import requires_torch - - -# TODO: unit tests are by default sanity tests/maybe regression if multiple use -# cases/inputs -# Are we covering sufficient input cases? -# Are we ok with each test running on a per commit basis? - - -@pytest.mark.unit -@requires_torch -class TestInvalidLayerwiseRecipesRaiseExceptions(unittest.TestCase): - def setUp(self): - setup_modifier_factory() - - @parameterized.expand( - [ - [[0.5, 0.2], "__ALL__"], - [[0.2, 0.1, 0.3], ["seq.fc1", "seq.fc2"]], - [[0.3, 0.4], ["re:.*fc1", "re:.*fc2"]], - ] - ) - def test_invalid_layerwise_recipes_raise_exceptions(self, sparsity, targets): - setup_modifier_factory() - kwargs = dict( - sparsity=sparsity, - block_size=128, - quantize=False, - targets=targets, - ) - modifier = SparseGPTModifierPyTorch(**kwargs) - testing_harness = LifecyleTestingHarness(model=LinearNet(), start=-1) - - # confirm invalid layerwise recipes fail at initialization - with self.assertRaises(ValueError): - modifier.initialize(testing_harness.get_state()) - - -@pytest.mark.unit -@requires_torch -class TestSuccessfulLayerwiseRecipe(unittest.TestCase): - def setUp(self): - setup_modifier_factory() - - def test_successful_layerwise_recipe(self): - sparsities = [0.5, 0.2] - targets = ["seq.fc1", "seq.fc2"] - kwargs = dict( - sparsity=sparsities, block_size=128, quantize=False, targets=targets - ) - modifier = SparseGPTModifierPyTorch(**kwargs) - modifier.compressible_layers_ = {"seq.fc1": None, "seq.fc2": None} - modifier.model = ModifiableModel(framework=Framework.pytorch, model=LinearNet()) - found_compressible_layers = modifier.compressible_layers() - modifier.compressible_layers_ = found_compressible_layers - modifier._validate_layerwise_sparsity() - - # ensure layers names successfully match up with model - self.assertEqual(len(found_compressible_layers), len(targets)) - - -@pytest.mark.unit -@requires_torch -class TestCreateDefaultQuantModifier(unittest.TestCase): - def setUp(self): - setup_modifier_factory() - - def test_create_default_quant_modifier(self): - kwargs = dict(sparsity=0.5, block_size=128, quantize=True) - - modifier = SparseGPTModifierPyTorch(**kwargs) - assert modifier.quantization_modifier_ is None - - testing_harness = LifecyleTestingHarness(model=LinearNet()) - modifier.on_initialize_structure(testing_harness.get_state()) - assert modifier.quantize - assert isinstance(modifier.quantization_modifier_, QuantizationModifier) - - should_be_default_quant_scheme = modifier.quantization_modifier_.scheme - self.assertEqual(should_be_default_quant_scheme.input_activations.num_bits, 8) - assert not should_be_default_quant_scheme.input_activations.symmetric - self.assertEqual(should_be_default_quant_scheme.weights.num_bits, 8) - assert should_be_default_quant_scheme.weights.symmetric - - -@pytest.mark.unit -@requires_torch -class TestSetQuantIfModifierAlreadyExists(unittest.TestCase): - def setUp(self): - setup_modifier_factory() - - def test_set_quant_if_modifer_already_exists(self): - model = LinearNet() - kwargs = dict( - scheme=dict( - input_activations=dict(num_bits=8, symmetric=True), - weights=dict(num_bits=4, symmetric=False), - ), - ) - - modifier = QuantizationModifierPyTorch(**kwargs) - testing_harness = LifecyleTestingHarness(model=model, start=-1) - - assert not testing_harness.get_state().model.qat_active() - modifier.initialize(testing_harness.get_state()) - assert testing_harness.get_state().model.qat_active() - - kwargs = dict(sparsity=0.5, block_size=128, quantize=False) - modifier = SparseGPTModifierPyTorch(**kwargs) - assert not modifier.quantize - modifier.on_initialize_structure(testing_harness.get_state()) - - # quantization modifier not owned by SparseGPT - assert modifier.quantization_modifier_ is None - - # since quantization modifier is already applied, quantization must be set in - # OBCQ - assert modifier.quantize - - -class TestSetQuantInSparseGPT(unittest.TestCase): - def setUp(self): - setup_modifier_factory() - self.quant_kwargs = { - "scheme": { - "input_activations": { - "num_bits": 8, - "symmetric": False, - "strategy": "tensor", - "kwargs": {}, - }, - "weights": { - "num_bits": 4, - "symmetric": True, - "strategy": "channel", - "kwargs": {}, - }, - } - } - self.quant_config = {"QuantizationModifier": self.quant_kwargs} - - def test_set_quant_in_sparsegpt(self): - kwargs = dict(sparsity=0.5, block_size=128, quantize=self.quant_config) - - modifier = SparseGPTModifierPyTorch(**kwargs) - assert modifier.quantization_modifier_ is None - - testing_harness = LifecyleTestingHarness(model=LinearNet()) - modifier.on_initialize_structure(testing_harness.get_state()) - assert modifier.quantize - assert isinstance(modifier.quantization_modifier_, QuantizationModifier) - - dict_scheme = dict(modifier.quantization_modifier_.scheme) - self.assertEqual( - dict(dict_scheme["weights"]), self.quant_kwargs["scheme"]["weights"] - ) - self.assertEqual( - dict(dict_scheme["input_activations"]), - self.quant_kwargs["scheme"]["input_activations"], - ) diff --git a/tests/sparseml/pytorch/modifiers/pruning/wanda/test_pytorch.py b/tests/sparseml/pytorch/modifiers/pruning/wanda/test_pytorch.py index 2bdca703951..a65959a564a 100644 --- a/tests/sparseml/pytorch/modifiers/pruning/wanda/test_pytorch.py +++ b/tests/sparseml/pytorch/modifiers/pruning/wanda/test_pytorch.py @@ -12,28 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + +import pytest from sparseml.core.factory import ModifierFactory from sparseml.core.framework import Framework from tests.sparseml.modifiers.conf import setup_modifier_factory +from tests.testing_utils import requires_torch + + +@pytest.mark.unit +@requires_torch +class TestWandaPytorchIsRegistered(unittest.TestCase): + def setUp(self): + self.kwargs = dict( + sparsity=0.5, + targets="__ALL_PRUNABLE__", + ) + setup_modifier_factory() + + def test_wanda_pytorch_is_registered(self): + from sparseml.modifiers.pruning.wanda.pytorch import WandaPruningModifierPyTorch + type_ = ModifierFactory.create( + type_="WandaPruningModifier", + framework=Framework.pytorch, + allow_experimental=False, + allow_registered=True, + **self.kwargs, + ) -def test_wanda_pytorch_is_registered(): - from sparseml.modifiers.pruning.wanda.pytorch import WandaPruningModifierPyTorch - - kwargs = dict( - sparsity=0.5, - targets="__ALL_PRUNABLE__", - ) - setup_modifier_factory() - type_ = ModifierFactory.create( - type_="WandaPruningModifier", - framework=Framework.pytorch, - allow_experimental=False, - allow_registered=True, - **kwargs, - ) - - assert isinstance( - type_, WandaPruningModifierPyTorch - ), "PyTorch ConstantPruningModifier not registered" + self.assertIsInstance( + type_, + WandaPruningModifierPyTorch, + "PyTorch ConstantPruningModifier not registered", + ) diff --git a/tests/sparseml/pytorch/modifiers/pruning/wanda/test_pytorch_new.py b/tests/sparseml/pytorch/modifiers/pruning/wanda/test_pytorch_new.py deleted file mode 100644 index a65959a564a..00000000000 --- a/tests/sparseml/pytorch/modifiers/pruning/wanda/test_pytorch_new.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import pytest - -from sparseml.core.factory import ModifierFactory -from sparseml.core.framework import Framework -from tests.sparseml.modifiers.conf import setup_modifier_factory -from tests.testing_utils import requires_torch - - -@pytest.mark.unit -@requires_torch -class TestWandaPytorchIsRegistered(unittest.TestCase): - def setUp(self): - self.kwargs = dict( - sparsity=0.5, - targets="__ALL_PRUNABLE__", - ) - setup_modifier_factory() - - def test_wanda_pytorch_is_registered(self): - from sparseml.modifiers.pruning.wanda.pytorch import WandaPruningModifierPyTorch - - type_ = ModifierFactory.create( - type_="WandaPruningModifier", - framework=Framework.pytorch, - allow_experimental=False, - allow_registered=True, - **self.kwargs, - ) - - self.assertIsInstance( - type_, - WandaPruningModifierPyTorch, - "PyTorch ConstantPruningModifier not registered", - ) diff --git a/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py b/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py index 58075c2cc2c..6b258b884cb 100644 --- a/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py +++ b/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + import pytest +from parameterized import parameterized from sparseml.core import State from sparseml.core.event import Event, EventType from sparseml.core.factory import ModifierFactory @@ -30,6 +33,94 @@ _test_qat_wrapped_module, _test_quantized_module, ) +from tests.testing_utils import requires_torch + + +@pytest.mark.unit +@requires_torch +class TestQuantizationRegistered(unittest.TestCase): + def setUp(self): + setup_modifier_factory() + self.kwargs = dict(index=0, group="quantization", start=2.0, end=-1.0) + + def test_quantization_registered(self): + quant_obj = ModifierFactory.create( + type_="QuantizationModifier", + framework=Framework.pytorch, + allow_experimental=False, + allow_registered=True, + **self.kwargs, + ) + + self.assertIsInstance(quant_obj, QuantizationModifierPyTorch) + + +@pytest.mark.unit +@requires_torch +class TestQuantizationOneShot(unittest.TestCase): + def setUp(self): + scheme = dict( + input_activations=dict(num_bits=8, symmetric=True), + weights=dict(num_bits=4, symmetric=False, strategy="channel"), + ) + self.kwargs = dict(scheme=scheme) + + @parameterized.expand([[ConvNet], [LinearNet]]) + def test_quantization_oneshot(self, model_class): + model = model_class() + state = State(framework=Framework.pytorch, start_event=Event()) + state.update(model=model, start=-1) + + modifier = QuantizationModifierPyTorch(**self.kwargs) + + modifier.initialize(state) + + # for one-shot, we set up quantization on initialization + _test_qat_applied(modifier, model) + + # we shouldn't keep updating stats after one-shot + assert modifier.quantization_observer_disabled_ + + test_start_event = Event(type_=EventType.BATCH_START) + test_end_event = Event(type_=EventType.BATCH_END) + assert not modifier.should_start(test_start_event) + assert not modifier.should_end(test_end_event) + + modifier.finalize(state) + assert modifier.finalized + + +@pytest.mark.unit +@requires_torch +class TestQuantizationTraining(unittest.TestCase): + def setUp(self): + self.start_epoch = 2 + + self.kwargs = dict( + start=self.start_epoch, + scheme=dict( + input_activations=dict(num_bits=8, symmetric=True), + weights=dict(num_bits=4, symmetric=False), + ), + ) + + @parameterized.expand([[ConvNet], [LinearNet]]) + def test_quantization_training(self, model_class): + model = model_class() + + modifier = QuantizationModifierPyTorch(**self.kwargs) + + testing_harness = LifecyleTestingHarness(model=model) + modifier.initialize(testing_harness.get_state()) + assert not modifier.qat_enabled_ + + testing_harness.trigger_modifier_for_epochs(modifier, self.start_epoch) + assert not modifier.qat_enabled_ + testing_harness.trigger_modifier_for_epochs(modifier, self.start_epoch + 1) + _test_qat_applied(modifier, model) + + modifier.finalize(testing_harness.get_state()) + assert modifier.quantization_observer_disabled_ def _test_qat_applied(modifier, model): @@ -67,77 +158,3 @@ def _test_qat_applied(modifier, model): # check all non-target modules are not quantized assert not hasattr(module, "quantization_scheme") assert not hasattr(module, "qconfig") - - -def test_quantization_registered(): - setup_modifier_factory() - - kwargs = dict(index=0, group="quantization", start=2.0, end=-1.0) - quant_obj = ModifierFactory.create( - type_="QuantizationModifier", - framework=Framework.pytorch, - allow_experimental=False, - allow_registered=True, - **kwargs, - ) - - assert isinstance(quant_obj, QuantizationModifierPyTorch) - - -@pytest.mark.parametrize("model_class", [ConvNet, LinearNet]) -def test_quantization_oneshot(model_class): - model = model_class() - state = State(framework=Framework.pytorch, start_event=Event()) - state.update(model=model, start=-1) - - scheme = dict( - input_activations=dict(num_bits=8, symmetric=True), - weights=dict(num_bits=4, symmetric=False, strategy="channel"), - ) - kwargs = dict(scheme=scheme) - - modifier = QuantizationModifierPyTorch(**kwargs) - - modifier.initialize(state) - - # for one-shot, we set up quantization on initialization - _test_qat_applied(modifier, model) - - # we shouldn't keep updating stats after one-shot - assert modifier.quantization_observer_disabled_ - - test_start_event = Event(type_=EventType.BATCH_START) - test_end_event = Event(type_=EventType.BATCH_END) - assert not modifier.should_start(test_start_event) - assert not modifier.should_end(test_end_event) - - modifier.finalize(state) - assert modifier.finalized - - -@pytest.mark.parametrize("model_class", [ConvNet, LinearNet]) -def test_quantization_training(model_class): - start_epoch = 2 - - model = model_class() - kwargs = dict( - start=start_epoch, - scheme=dict( - input_activations=dict(num_bits=8, symmetric=True), - weights=dict(num_bits=4, symmetric=False), - ), - ) - - modifier = QuantizationModifierPyTorch(**kwargs) - - testing_harness = LifecyleTestingHarness(model=model) - modifier.initialize(testing_harness.get_state()) - assert not modifier.qat_enabled_ - - testing_harness.trigger_modifier_for_epochs(modifier, start_epoch) - assert not modifier.qat_enabled_ - testing_harness.trigger_modifier_for_epochs(modifier, start_epoch + 1) - _test_qat_applied(modifier, model) - - modifier.finalize(testing_harness.get_state()) - assert modifier.quantization_observer_disabled_ diff --git a/tests/sparseml/pytorch/modifiers/quantization/test_pytorch_new.py b/tests/sparseml/pytorch/modifiers/quantization/test_pytorch_new.py deleted file mode 100644 index 6b258b884cb..00000000000 --- a/tests/sparseml/pytorch/modifiers/quantization/test_pytorch_new.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import pytest - -from parameterized import parameterized -from sparseml.core import State -from sparseml.core.event import Event, EventType -from sparseml.core.factory import ModifierFactory -from sparseml.core.framework import Framework -from sparseml.modifiers.quantization.pytorch import QuantizationModifierPyTorch -from sparseml.pytorch.sparsification.quantization.quantize import ( - is_qat_helper_module, - is_quantizable_module, -) -from tests.sparseml.modifiers.conf import LifecyleTestingHarness, setup_modifier_factory -from tests.sparseml.pytorch.helpers import ConvNet, LinearNet -from tests.sparseml.pytorch.sparsification.quantization.test_modifier_quantization import ( # noqa E501 - _match_submodule_name_or_type, - _test_qat_wrapped_module, - _test_quantized_module, -) -from tests.testing_utils import requires_torch - - -@pytest.mark.unit -@requires_torch -class TestQuantizationRegistered(unittest.TestCase): - def setUp(self): - setup_modifier_factory() - self.kwargs = dict(index=0, group="quantization", start=2.0, end=-1.0) - - def test_quantization_registered(self): - quant_obj = ModifierFactory.create( - type_="QuantizationModifier", - framework=Framework.pytorch, - allow_experimental=False, - allow_registered=True, - **self.kwargs, - ) - - self.assertIsInstance(quant_obj, QuantizationModifierPyTorch) - - -@pytest.mark.unit -@requires_torch -class TestQuantizationOneShot(unittest.TestCase): - def setUp(self): - scheme = dict( - input_activations=dict(num_bits=8, symmetric=True), - weights=dict(num_bits=4, symmetric=False, strategy="channel"), - ) - self.kwargs = dict(scheme=scheme) - - @parameterized.expand([[ConvNet], [LinearNet]]) - def test_quantization_oneshot(self, model_class): - model = model_class() - state = State(framework=Framework.pytorch, start_event=Event()) - state.update(model=model, start=-1) - - modifier = QuantizationModifierPyTorch(**self.kwargs) - - modifier.initialize(state) - - # for one-shot, we set up quantization on initialization - _test_qat_applied(modifier, model) - - # we shouldn't keep updating stats after one-shot - assert modifier.quantization_observer_disabled_ - - test_start_event = Event(type_=EventType.BATCH_START) - test_end_event = Event(type_=EventType.BATCH_END) - assert not modifier.should_start(test_start_event) - assert not modifier.should_end(test_end_event) - - modifier.finalize(state) - assert modifier.finalized - - -@pytest.mark.unit -@requires_torch -class TestQuantizationTraining(unittest.TestCase): - def setUp(self): - self.start_epoch = 2 - - self.kwargs = dict( - start=self.start_epoch, - scheme=dict( - input_activations=dict(num_bits=8, symmetric=True), - weights=dict(num_bits=4, symmetric=False), - ), - ) - - @parameterized.expand([[ConvNet], [LinearNet]]) - def test_quantization_training(self, model_class): - model = model_class() - - modifier = QuantizationModifierPyTorch(**self.kwargs) - - testing_harness = LifecyleTestingHarness(model=model) - modifier.initialize(testing_harness.get_state()) - assert not modifier.qat_enabled_ - - testing_harness.trigger_modifier_for_epochs(modifier, self.start_epoch) - assert not modifier.qat_enabled_ - testing_harness.trigger_modifier_for_epochs(modifier, self.start_epoch + 1) - _test_qat_applied(modifier, model) - - modifier.finalize(testing_harness.get_state()) - assert modifier.quantization_observer_disabled_ - - -def _test_qat_applied(modifier, model): - assert modifier.qat_enabled_ - - for name, module in model.named_modules(): - if is_qat_helper_module(module): - # skip helper modules - continue - - is_target_submodule = not any( - name.startswith(submodule_name) for submodule_name in modifier.ignore - ) - is_included_module_type = any( - module_type_name == module.__class__.__name__ - for module_type_name in modifier.scheme_overrides - ) - is_quantizable = is_included_module_type or is_quantizable_module( - module, - exclude_module_types=modifier.ignore, - ) - - if is_target_submodule and is_quantizable: - if getattr(module, "wrap_qat", False): - _test_qat_wrapped_module(model, name) - elif is_quantizable: - # check each target module is quantized - override_key = _match_submodule_name_or_type( - module, - name, - list(modifier.scheme_overrides.keys()), - ) - _test_quantized_module(model, modifier, module, name, override_key) - else: - # check all non-target modules are not quantized - assert not hasattr(module, "quantization_scheme") - assert not hasattr(module, "qconfig") diff --git a/tests/sparseml/transformers/finetune/data/test_dataset_loading.py b/tests/sparseml/transformers/finetune/data/test_dataset_loading.py index 6493689416f..b43d0ee9cb8 100644 --- a/tests/sparseml/transformers/finetune/data/test_dataset_loading.py +++ b/tests/sparseml/transformers/finetune/data/test_dataset_loading.py @@ -13,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + import pytest from datasets import IterableDataset +from parameterized import parameterized from sparseml.transformers.finetune.data import TextGenerationDataset from sparseml.transformers.finetune.data.data_args import DataTrainingArguments from sparseml.transformers.finetune.model_args import ModelArguments @@ -23,219 +26,296 @@ from sparseml.transformers.finetune.training_args import TrainingArguments -@pytest.mark.usefixtures("tiny_llama_tokenizer") -def test_concatenation_tokenization(tiny_llama_tokenizer): - data_args = DataTrainingArguments( - dataset="wikitext", - dataset_config_name="wikitext-2-raw-v1", - concatenate_data=True, - ) - wiki_manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, - split="train[:5%]", - tokenizer=tiny_llama_tokenizer, - ) - raw_dataset = wiki_manager.get_raw_dataset() - assert len(raw_dataset) > 0 - assert raw_dataset.split == "train[:5%]" - assert raw_dataset.info.config_name == "wikitext-2-raw-v1" - tokenized_dataset = wiki_manager.tokenize_and_process(raw_dataset) - assert "input_ids" in tokenized_dataset.features - assert "labels" in tokenized_dataset.features - for i in range(len(tokenized_dataset)): - assert len(tokenized_dataset[i]["input_ids"]) == wiki_manager.max_seq_length - - -@pytest.mark.usefixtures("tiny_llama_tokenizer") -def test_no_padding_tokenization(tiny_llama_tokenizer): - data_args = DataTrainingArguments(dataset="open_platypus", pad_to_max_length=False) - op_manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, - split="train[5%:10%]", - tokenizer=tiny_llama_tokenizer, - ) - raw_dataset = op_manager.get_raw_dataset() - assert len(raw_dataset) > 0 - ex_item = raw_dataset[0]["text"] - assert "Below is an instruction that describes a task" in ex_item - - assert raw_dataset.split == "train[5%:10%]" - tokenized_dataset = op_manager.tokenize_and_process(raw_dataset) - assert "input_ids" in tokenized_dataset.features - assert "labels" in tokenized_dataset.features - print(tokenized_dataset[0]["input_ids"]) - - for i in range(len(tokenized_dataset)): - assert len(tokenized_dataset[i]["input_ids"]) <= op_manager.max_seq_length - - -@pytest.mark.usefixtures("tiny_llama_tokenizer") -def test_max_seq_len_clipped(tiny_llama_tokenizer): - data_args = DataTrainingArguments(dataset="open_platypus", max_seq_length=4096) - op_manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, - split="train[80%:]", - tokenizer=tiny_llama_tokenizer, - ) +@pytest.mark.unit +class TestConcentrationTokenization(unittest.TestCase): + def setUp(self): + self.data_args = DataTrainingArguments( + dataset="wikitext", + dataset_config_name="wikitext-2-raw-v1", + concatenate_data=True, + ) - assert op_manager.max_seq_length == tiny_llama_tokenizer.model_max_length + @pytest.fixture(autouse=True) + def prepare_fixture(self, tiny_llama_tokenizer): + self.tiny_llama_tokenizer = tiny_llama_tokenizer + def test_concatenation_tokenization(self): + wiki_manager = TextGenerationDataset.load_from_registry( + self.data_args.dataset, + data_args=self.data_args, + split="train[:5%]", + tokenizer=self.tiny_llama_tokenizer, + ) + raw_dataset = wiki_manager.get_raw_dataset() + self.assertGreater(len(raw_dataset), 0) + self.assertEqual(raw_dataset.split, "train[:5%]") + self.assertEqual(raw_dataset.info.config_name, "wikitext-2-raw-v1") + tokenized_dataset = wiki_manager.tokenize_and_process(raw_dataset) + assert "input_ids" in tokenized_dataset.features + assert "labels" in tokenized_dataset.features + for i in range(len(tokenized_dataset)): + self.assertEqual( + len(tokenized_dataset[i]["input_ids"]), wiki_manager.max_seq_length + ) -# test loading percentages works as expected size-wise -@pytest.mark.usefixtures("tiny_llama_tokenizer") -def test_dataset_kwargs_and_percentages(tiny_llama_tokenizer): - data_args = DataTrainingArguments( - dataset="wikitext", - raw_kwargs={ - "data_files": {"train": "wikitext-2-raw-v1/train-00000-of-00001.parquet"} - }, - ) - c4_manager_a = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, - split="train[5%:10%]", - tokenizer=tiny_llama_tokenizer, - ) - raw_dataset_a = c4_manager_a.get_raw_dataset() - c4_manager_b = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, - split="train[5%:15%]", - tokenizer=tiny_llama_tokenizer, - ) - raw_dataset_b = c4_manager_b.get_raw_dataset() - - assert len(raw_dataset_b) == 2 * len(raw_dataset_a) - - -@pytest.mark.usefixtures("tiny_llama_tokenizer") -@pytest.mark.parametrize( - "dataset_key,dataset_config,split,do_concat", - [ - ("ptb", "penn_treebank", "train[:5%]", False), - ("gsm8k", "main", "train[:5%]", True), - ("ultrachat_200k", "default", "train_sft[:2%]", False), - ], -) -def test_datasets(tiny_llama_tokenizer, dataset_key, dataset_config, split, do_concat): - data_args = DataTrainingArguments( - dataset=dataset_key, - dataset_config_name=dataset_config, - concatenate_data=do_concat, - ) - manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, - split=split, - tokenizer=tiny_llama_tokenizer, +@pytest.mark.unit +class TestNoPaddingTokenization(unittest.TestCase): + def setUp(self): + self.data_args = DataTrainingArguments( + dataset="open_platypus", pad_to_max_length=False + ) + + @pytest.fixture(autouse=True) + def prepare_fixture(self, tiny_llama_tokenizer): + self.tiny_llama_tokenizer = tiny_llama_tokenizer + + @pytest.mark.usefixtures("tiny_llama_tokenizer") + def test_no_padding_tokenization(self): + op_manager = TextGenerationDataset.load_from_registry( + self.data_args.dataset, + data_args=self.data_args, + split="train[5%:10%]", + tokenizer=self.tiny_llama_tokenizer, + ) + raw_dataset = op_manager.get_raw_dataset() + self.assertGreater(len(raw_dataset), 0) + ex_item = raw_dataset[0]["text"] + self.assertIn("Below is an instruction that describes a task", ex_item) + + self.assertEqual(raw_dataset.split, "train[5%:10%]") + tokenized_dataset = op_manager.tokenize_and_process(raw_dataset) + self.assertIn("input_ids", tokenized_dataset.features) + self.assertIn("labels", tokenized_dataset.features) + print(tokenized_dataset[0]["input_ids"]) + + for i in range(len(tokenized_dataset)): + self.assertLessEqual( + len(tokenized_dataset[i]["input_ids"]), op_manager.max_seq_length + ) + + +@pytest.mark.unit +class TestMaxSeqLenClipped(unittest.TestCase): + def setUp(self): + self.data_args = DataTrainingArguments( + dataset="open_platypus", max_seq_length=4096 + ) + + @pytest.fixture(autouse=True) + def prepare_fixture(self, tiny_llama_tokenizer): + self.tiny_llama_tokenizer = tiny_llama_tokenizer + + def test_max_seq_len_clipped(self): + op_manager = TextGenerationDataset.load_from_registry( + self.data_args.dataset, + data_args=self.data_args, + split="train[80%:]", + tokenizer=self.tiny_llama_tokenizer, + ) + + self.assertEqual( + op_manager.max_seq_length, self.tiny_llama_tokenizer.model_max_length + ) + + +@pytest.mark.unit +class TestDatasetKwargsAndPercent(unittest.TestCase): + def setUp(self): + self.data_args = DataTrainingArguments( + dataset="wikitext", + raw_kwargs={ + "data_files": { + "train": "wikitext-2-raw-v1/train-00000-of-00001.parquet" + } + }, + ) + + @pytest.fixture(autouse=True) + def prepare_fixture(self, tiny_llama_tokenizer): + self.tiny_llama_tokenizer = tiny_llama_tokenizer + + def test_dataset_kwargs_and_percentages(self): + + c4_manager_a = TextGenerationDataset.load_from_registry( + self.data_args.dataset, + data_args=self.data_args, + split="train[5%:10%]", + tokenizer=self.tiny_llama_tokenizer, + ) + raw_dataset_a = c4_manager_a.get_raw_dataset() + + c4_manager_b = TextGenerationDataset.load_from_registry( + self.data_args.dataset, + data_args=self.data_args, + split="train[5%:15%]", + tokenizer=self.tiny_llama_tokenizer, + ) + raw_dataset_b = c4_manager_b.get_raw_dataset() + + self.assertEqual(len(raw_dataset_b), 2 * len(raw_dataset_a)) + + +@pytest.mark.unit +class TestDatasets(unittest.TestCase): + @pytest.fixture(autouse=True) + def prepare_fixture(self, tiny_llama_tokenizer): + self.tiny_llama_tokenizer = tiny_llama_tokenizer + + @parameterized.expand( + [ + ["ptb", "penn_treebank", "train[:5%]", False], + ["gsm8k", "main", "train[:5%]", True], + ["ultrachat_200k", "default", "train_sft[:2%]", False], + ] ) - raw_dataset = manager.get_raw_dataset() - assert len(raw_dataset) > 0 - assert raw_dataset.split == split - assert raw_dataset.info.config_name == dataset_config + def test_datasets(self, dataset_key, dataset_config, split, do_concat): + data_args = DataTrainingArguments( + dataset=dataset_key, + dataset_config_name=dataset_config, + concatenate_data=do_concat, + ) + manager = TextGenerationDataset.load_from_registry( + data_args.dataset, + data_args=data_args, + split=split, + tokenizer=self.tiny_llama_tokenizer, + ) + raw_dataset = manager.get_raw_dataset() + self.assertGreater(len(raw_dataset), 0) + self.assertEqual(raw_dataset.split, split) + self.assertEqual(raw_dataset.info.config_name, dataset_config) - tokenized_dataset = manager.tokenize_and_process(raw_dataset) - assert "input_ids" in tokenized_dataset.features - assert "labels" in tokenized_dataset.features - for i in range(len(tokenized_dataset)): - if do_concat: - assert len(tokenized_dataset[i]["input_ids"]) == manager.max_seq_length - else: - assert len(tokenized_dataset[i]["input_ids"]) <= manager.max_seq_length + tokenized_dataset = manager.tokenize_and_process(raw_dataset) + assert "input_ids" in tokenized_dataset.features + assert "labels" in tokenized_dataset.features + for i in range(len(tokenized_dataset)): + if do_concat: + self.assertEqual( + len(tokenized_dataset[i]["input_ids"]), manager.max_seq_length + ) + else: + self.assertLessEqual( + len(tokenized_dataset[i]["input_ids"]), manager.max_seq_length + ) @pytest.mark.skip("Dataset load broken on Hugging Face") -@pytest.mark.usefixtures("tiny_llama_tokenizer") -def test_evol(tiny_llama_tokenizer): - data_args = DataTrainingArguments( - dataset="evolcodealpaca", - dataset_config_name=None, - concatenate_data=False, - ) - evol_manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, - split="train[:2%]", - tokenizer=tiny_llama_tokenizer, - ) - raw_dataset = evol_manager.get_raw_dataset() - assert len(raw_dataset) > 0 - assert raw_dataset.split == "train[:2%]" - - tokenized_dataset = evol_manager.tokenize_and_process(raw_dataset) - assert "input_ids" in tokenized_dataset.features - assert "labels" in tokenized_dataset.features - for i in range(len(tokenized_dataset)): - assert len(tokenized_dataset[i]["input_ids"]) <= evol_manager.max_seq_length - - -@pytest.mark.usefixtures("tiny_llama_tokenizer") -def test_dvc_dataloading(tiny_llama_tokenizer): - data_args = DataTrainingArguments( - dataset="csv", - dataset_path="dvc://workshop/satellite-data/jan_train.csv", - dvc_data_repository="https://github.com/iterative/dataset-registry.git", - ) - manager = TextGenerationDataset( - text_column="", - data_args=data_args, - split="train", - tokenizer=tiny_llama_tokenizer, - ) +@pytest.mark.unit +class TestEvol(unittest.TestCase): + @pytest.fixture(autouse=True) + def prepare_fixture(self, tiny_llama_tokenizer): + self.tiny_llama_tokenizer = tiny_llama_tokenizer - raw_dataset = manager.get_raw_dataset() - assert len(raw_dataset) > 0 - assert isinstance(raw_dataset[0], dict) + def setUp(self): + self.data_args = DataTrainingArguments( + dataset="evolcodealpaca", + dataset_config_name=None, + concatenate_data=False, + ) + def test_evol(self): + evol_manager = TextGenerationDataset.load_from_registry( + self.data_args.dataset, + data_args=self.data_args, + split="train[:2%]", + tokenizer=self.tiny_llama_tokenizer, + ) + raw_dataset = evol_manager.get_raw_dataset() + self.assertGreater(len(raw_dataset), 0) + self.assertEqual(raw_dataset.split, "train[:2%]") -@pytest.mark.usefixtures("tiny_llama_tokenizer") -def test_stream_loading(tiny_llama_tokenizer): - data_args = DataTrainingArguments( - dataset="wikitext", - dataset_config_name="wikitext-2-raw-v1", - concatenate_data=True, - streaming=True, - ) - manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, - split="train", - tokenizer=tiny_llama_tokenizer, - ) + tokenized_dataset = evol_manager.tokenize_and_process(raw_dataset) + self.assertIn("input_ids", tokenized_dataset.features) + self.assertIn("labels", tokenized_dataset.features) + for i in range(len(tokenized_dataset)): + self.assertLessEqual( + len(tokenized_dataset[i]["input_ids"]), evol_manager.max_seq_length + ) + + +@pytest.mark.unit +class TestDVCLoading(unittest.TestCase): + def setUp(self): + self.data_args = DataTrainingArguments( + dataset="csv", + dataset_path="dvc://workshop/satellite-data/jan_train.csv", + dvc_data_repository="https://github.com/iterative/dataset-registry.git", + ) + + @pytest.fixture(autouse=True) + def prepare_fixture(self, tiny_llama_tokenizer): + self.tiny_llama_tokenizer = tiny_llama_tokenizer + + def test_dvc_dataloading(self): + manager = TextGenerationDataset( + text_column="", + data_args=self.data_args, + split="train", + tokenizer=self.tiny_llama_tokenizer, + ) + + raw_dataset = manager.get_raw_dataset() + self.assertGreater(len(raw_dataset), 0) + self.assertIsInstance(raw_dataset[0], dict) + + +@pytest.mark.unit +class TestStreamLoading(unittest.TestCase): + def setUp(self): + self.data_args = DataTrainingArguments( + dataset="wikitext", + dataset_config_name="wikitext-2-raw-v1", + concatenate_data=True, + streaming=True, + ) + + @pytest.fixture(autouse=True) + def prepare_fixture(self, tiny_llama_tokenizer): + self.tiny_llama_tokenizer = tiny_llama_tokenizer + + def test_stream_loading(self): + manager = TextGenerationDataset.load_from_registry( + self.data_args.dataset, + data_args=self.data_args, + split="train", + tokenizer=self.tiny_llama_tokenizer, + ) + + raw_dataset = manager.get_raw_dataset() + processed = manager.tokenize_and_process(raw_dataset) + self.assertIsInstance(processed, IterableDataset) + with pytest.raises(TypeError): + # in streaming mode we don't know the length of the dataset + _ = len(processed) + + # confirm tokenization of streamed item works correctly + item = next(iter(processed)) + self.assertIn("labels", item) + self.assertEqual(len(item["input_ids"]), manager.max_seq_length) + + +@pytest.mark.unit +class TestSplitLoading(unittest.TestCase): + @pytest.fixture(autouse=True) + def prepare_fixture(self, tiny_llama_tokenizer): + self.tiny_llama_tokenizer = tiny_llama_tokenizer - raw_dataset = manager.get_raw_dataset() - processed = manager.tokenize_and_process(raw_dataset) - assert isinstance(processed, IterableDataset) - with pytest.raises(TypeError): - # in streaming mode we don't know the length of the dataset - _ = len(processed) - - # confirm tokenization of streamed item works correctly - item = next(iter(processed)) - assert "labels" in item - assert len(item["input_ids"]) == manager.max_seq_length - - -@pytest.mark.usefixtures("tiny_llama_tokenizer") -@pytest.mark.parametrize( - "split_def", [("train"), ("train[60%:]"), ({"train": "train[:20%]"}), (None)] -) -def test_split_loading(split_def, tiny_llama_tokenizer): - data_args = DataTrainingArguments(dataset="open_platypus", splits=split_def) - training_args = TrainingArguments(do_train=True, output_dir="dummy") - model_args = ModelArguments(model=None) - stage_runner = StageRunner( - model_args=model_args, - data_args=data_args, - training_args=training_args, - model=None, + @parameterized.expand( + [["train"], ["train[60%:]"], [{"train": "train[:20%]"}], [None]] ) - stage_runner.populate_datasets(tokenizer=tiny_llama_tokenizer) + def test_split_loading(self, split_def): + data_args = DataTrainingArguments(dataset="open_platypus", splits=split_def) + training_args = TrainingArguments(do_train=True, output_dir="dummy") + model_args = ModelArguments(model=None) + stage_runner = StageRunner( + model_args=model_args, + data_args=data_args, + training_args=training_args, + model=None, + ) + stage_runner.populate_datasets(tokenizer=self.tiny_llama_tokenizer) - train_dataset = stage_runner.get_dataset_split("train") - assert train_dataset is not None - assert isinstance(train_dataset[0], dict) + train_dataset = stage_runner.get_dataset_split("train") + assert train_dataset is not None + self.assertIsInstance(train_dataset[0], dict) diff --git a/tests/sparseml/transformers/finetune/data/test_dataset_loading_new.py b/tests/sparseml/transformers/finetune/data/test_dataset_loading_new.py deleted file mode 100644 index 9f3d6df32f2..00000000000 --- a/tests/sparseml/transformers/finetune/data/test_dataset_loading_new.py +++ /dev/null @@ -1,321 +0,0 @@ -# test both cases of make dataset splits -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import pytest -from datasets import IterableDataset - -from parameterized import parameterized -from sparseml.transformers.finetune.data import TextGenerationDataset -from sparseml.transformers.finetune.data.data_args import DataTrainingArguments -from sparseml.transformers.finetune.model_args import ModelArguments -from sparseml.transformers.finetune.runner import StageRunner -from sparseml.transformers.finetune.training_args import TrainingArguments - - -@pytest.mark.unit -class TestConcentrationTokenization(unittest.TestCase): - def setUp(self): - self.data_args = DataTrainingArguments( - dataset="wikitext", - dataset_config_name="wikitext-2-raw-v1", - concatenate_data=True, - ) - - @pytest.fixture(autouse=True) - def prepare_fixture(self, tiny_llama_tokenizer): - self.tiny_llama_tokenizer = tiny_llama_tokenizer - - def test_concatenation_tokenization(self): - wiki_manager = TextGenerationDataset.load_from_registry( - self.data_args.dataset, - data_args=self.data_args, - split="train[:5%]", - tokenizer=self.tiny_llama_tokenizer, - ) - raw_dataset = wiki_manager.get_raw_dataset() - self.assertGreater(len(raw_dataset), 0) - self.assertEqual(raw_dataset.split, "train[:5%]") - self.assertEqual(raw_dataset.info.config_name, "wikitext-2-raw-v1") - tokenized_dataset = wiki_manager.tokenize_and_process(raw_dataset) - assert "input_ids" in tokenized_dataset.features - assert "labels" in tokenized_dataset.features - for i in range(len(tokenized_dataset)): - self.assertEqual( - len(tokenized_dataset[i]["input_ids"]), wiki_manager.max_seq_length - ) - - -@pytest.mark.unit -class TestNoPaddingTokenization(unittest.TestCase): - def setUp(self): - self.data_args = DataTrainingArguments( - dataset="open_platypus", pad_to_max_length=False - ) - - @pytest.fixture(autouse=True) - def prepare_fixture(self, tiny_llama_tokenizer): - self.tiny_llama_tokenizer = tiny_llama_tokenizer - - @pytest.mark.usefixtures("tiny_llama_tokenizer") - def test_no_padding_tokenization(self): - op_manager = TextGenerationDataset.load_from_registry( - self.data_args.dataset, - data_args=self.data_args, - split="train[5%:10%]", - tokenizer=self.tiny_llama_tokenizer, - ) - raw_dataset = op_manager.get_raw_dataset() - self.assertGreater(len(raw_dataset), 0) - ex_item = raw_dataset[0]["text"] - assert "Below is an instruction that describes a task" in ex_item - - self.assertEqual(raw_dataset.split, "train[5%:10%]") - tokenized_dataset = op_manager.tokenize_and_process(raw_dataset) - assert "input_ids" in tokenized_dataset.features - assert "labels" in tokenized_dataset.features - print(tokenized_dataset[0]["input_ids"]) - - for i in range(len(tokenized_dataset)): - self.assertLessEqual( - len(tokenized_dataset[i]["input_ids"]), op_manager.max_seq_length - ) - - -@pytest.mark.unit -class TestMaxSeqLenClipped(unittest.TestCase): - def setUp(self): - self.data_args = DataTrainingArguments( - dataset="open_platypus", max_seq_length=4096 - ) - - @pytest.fixture(autouse=True) - def prepare_fixture(self, tiny_llama_tokenizer): - self.tiny_llama_tokenizer = tiny_llama_tokenizer - - def test_max_seq_len_clipped(self): - op_manager = TextGenerationDataset.load_from_registry( - self.data_args.dataset, - data_args=self.data_args, - split="train[80%:]", - tokenizer=self.tiny_llama_tokenizer, - ) - - self.assertEqual( - op_manager.max_seq_length, self.tiny_llama_tokenizer.model_max_length - ) - - -@pytest.mark.unit -class TestDatasetKwargsAndPercent(unittest.TestCase): - def setUp(self): - self.data_args = DataTrainingArguments( - dataset="wikitext", - raw_kwargs={ - "data_files": { - "train": "wikitext-2-raw-v1/train-00000-of-00001.parquet" - } - }, - ) - - @pytest.fixture(autouse=True) - def prepare_fixture(self, tiny_llama_tokenizer): - self.tiny_llama_tokenizer = tiny_llama_tokenizer - - def test_dataset_kwargs_and_percentages(self): - - c4_manager_a = TextGenerationDataset.load_from_registry( - self.data_args.dataset, - data_args=self.data_args, - split="train[5%:10%]", - tokenizer=self.tiny_llama_tokenizer, - ) - raw_dataset_a = c4_manager_a.get_raw_dataset() - - c4_manager_b = TextGenerationDataset.load_from_registry( - self.data_args.dataset, - data_args=self.data_args, - split="train[5%:15%]", - tokenizer=self.tiny_llama_tokenizer, - ) - raw_dataset_b = c4_manager_b.get_raw_dataset() - - self.assertEqual(len(raw_dataset_b), 2 * len(raw_dataset_a)) - - -@pytest.mark.unit -class TestDatasets(unittest.TestCase): - @pytest.fixture(autouse=True) - def prepare_fixture(self, tiny_llama_tokenizer): - self.tiny_llama_tokenizer = tiny_llama_tokenizer - - @parameterized.expand( - [ - ["ptb", "penn_treebank", "train[:5%]", False], - ["gsm8k", "main", "train[:5%]", True], - ["ultrachat_200k", "default", "train_sft[:2%]", False], - ] - ) - def test_datasets(self, dataset_key, dataset_config, split, do_concat): - data_args = DataTrainingArguments( - dataset=dataset_key, - dataset_config_name=dataset_config, - concatenate_data=do_concat, - ) - manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, - split=split, - tokenizer=self.tiny_llama_tokenizer, - ) - raw_dataset = manager.get_raw_dataset() - self.assertGreater(len(raw_dataset), 0) - self.assertEqual(raw_dataset.split, split) - self.assertEqual(raw_dataset.info.config_name, dataset_config) - - tokenized_dataset = manager.tokenize_and_process(raw_dataset) - assert "input_ids" in tokenized_dataset.features - assert "labels" in tokenized_dataset.features - for i in range(len(tokenized_dataset)): - if do_concat: - self.assertEqual( - len(tokenized_dataset[i]["input_ids"]), manager.max_seq_length - ) - else: - self.assertLessEqual( - len(tokenized_dataset[i]["input_ids"]), manager.max_seq_length - ) - - -@pytest.mark.skip("Dataset load broken on Hugging Face") -@pytest.mark.unit -class TestEvol(unittest.TestCase): - @pytest.fixture(autouse=True) - def prepare_fixture(self, tiny_llama_tokenizer): - self.tiny_llama_tokenizer = tiny_llama_tokenizer - - def setUp(self): - self.data_args = DataTrainingArguments( - dataset="evolcodealpaca", - dataset_config_name=None, - concatenate_data=False, - ) - - def test_evol(self): - evol_manager = TextGenerationDataset.load_from_registry( - self.data_args.dataset, - data_args=self.data_args, - split="train[:2%]", - tokenizer=self.tiny_llama_tokenizer, - ) - raw_dataset = evol_manager.get_raw_dataset() - self.assertGreater(len(raw_dataset), 0) - self.assertEqual(raw_dataset.split, "train[:2%]") - - tokenized_dataset = evol_manager.tokenize_and_process(raw_dataset) - assert "input_ids" in tokenized_dataset.features - assert "labels" in tokenized_dataset.features - for i in range(len(tokenized_dataset)): - self.assertLessEqual( - len(tokenized_dataset[i]["input_ids"]), evol_manager.max_seq_length - ) - - -@pytest.mark.unit -class TestDVCLoading(unittest.TestCase): - def setUp(self): - self.data_args = DataTrainingArguments( - dataset="csv", - dataset_path="dvc://workshop/satellite-data/jan_train.csv", - dvc_data_repository="https://github.com/iterative/dataset-registry.git", - ) - - @pytest.fixture(autouse=True) - def prepare_fixture(self, tiny_llama_tokenizer): - self.tiny_llama_tokenizer = tiny_llama_tokenizer - - def test_dvc_dataloading(self): - manager = TextGenerationDataset( - text_column="", - data_args=self.data_args, - split="train", - tokenizer=self.tiny_llama_tokenizer, - ) - - raw_dataset = manager.get_raw_dataset() - self.assertGreater(len(raw_dataset), 0) - self.assertIsInstance(raw_dataset[0], dict) - - -@pytest.mark.unit -class TestStreamLoading(unittest.TestCase): - def setUp(self): - self.data_args = DataTrainingArguments( - dataset="wikitext", - dataset_config_name="wikitext-2-raw-v1", - concatenate_data=True, - streaming=True, - ) - - @pytest.fixture(autouse=True) - def prepare_fixture(self, tiny_llama_tokenizer): - self.tiny_llama_tokenizer = tiny_llama_tokenizer - - def test_stream_loading(self): - manager = TextGenerationDataset.load_from_registry( - self.data_args.dataset, - data_args=self.data_args, - split="train", - tokenizer=self.tiny_llama_tokenizer, - ) - - raw_dataset = manager.get_raw_dataset() - processed = manager.tokenize_and_process(raw_dataset) - self.assertIsInstance(processed, IterableDataset) - with pytest.raises(TypeError): - # in streaming mode we don't know the length of the dataset - _ = len(processed) - - # confirm tokenization of streamed item works correctly - item = next(iter(processed)) - assert "labels" in item - self.assertEqual(len(item["input_ids"]), manager.max_seq_length) - - -@pytest.mark.unit -class TestSplitLoading(unittest.TestCase): - @pytest.fixture(autouse=True) - def prepare_fixture(self, tiny_llama_tokenizer): - self.tiny_llama_tokenizer = tiny_llama_tokenizer - - @parameterized.expand( - [["train"], ["train[60%:]"], [{"train": "train[:20%]"}], [None]] - ) - def test_split_loading(self, split_def): - data_args = DataTrainingArguments(dataset="open_platypus", splits=split_def) - training_args = TrainingArguments(do_train=True, output_dir="dummy") - model_args = ModelArguments(model=None) - stage_runner = StageRunner( - model_args=model_args, - data_args=data_args, - training_args=training_args, - model=None, - ) - stage_runner.populate_datasets(tokenizer=self.tiny_llama_tokenizer) - - train_dataset = stage_runner.get_dataset_split("train") - assert train_dataset is not None - self.assertIsInstance(train_dataset[0], dict)