From f7cb678f793e3c650c12efc0fe47df63a72f445f Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Mon, 6 May 2024 16:02:29 -0400 Subject: [PATCH] Refactor Quantization Modifer and Reloading (#2246) * initial commit * update setup.py * Update setup.py * fix setup.py * move all config to sparsetensors * cleanup class name and comments * initial implementation untested * fixing issues * add test script * update perplexity test * refactor to compressed-tensors * rename sparsetensors * update setup * Sa/model reload (#2250) * working reload * sparsegpt * cleanup * refactor tests * only run oneshot once * all tests passing * remove unused config * reset models on each parameterize * style * bring back SparsityConfigMetadata * Update setup.py Co-authored-by: Rahul Tuli * add more comparisons, tighten threshold * use wikitext for perplexity * update setup * fix import problem * fix clearml test * compressed-tensors are transformers dep * address PR comments * can't repeat freeze * UX pr comments * quality * shape consistency * address PR comments --------- Co-authored-by: dbogunowicz Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Co-authored-by: Rahul Tuli Co-authored-by: George Ohashi --- .../finetuning/example_single_gpu_config.yaml | 15 ++ .../modifiers/obcq/utils/sgpt_wrapper.py | 23 ++ .../modifiers/quantization_vllm/__init__.py | 17 ++ .../modifiers/quantization_vllm/base.py | 83 +++++++ .../modifiers/quantization_vllm/pytorch.py | 141 +++++++++++ .../compression/sparsity_config.py | 2 +- .../compressed_tensors_utils.py | 50 +++- .../sparsification/sparse_model.py | 41 +-- src/sparseml/utils/pytorch/module.py | 3 + .../transformers/compression/__init__.py | 13 + .../compression/recipes/new_quant_full.yaml | 33 +++ .../compression/recipes/new_quant_weight.yaml | 20 ++ .../compression/recipes/old_quant_full.yaml | 39 +++ .../compression/recipes/old_quant_weight.yaml | 36 +++ .../compression/test_quantization.py | 233 ++++++++++++++++++ tests/testing_utils.py | 7 +- 16 files changed, 732 insertions(+), 24 deletions(-) create mode 100644 integrations/huggingface-transformers/finetuning/example_single_gpu_config.yaml create mode 100644 src/sparseml/modifiers/quantization_vllm/__init__.py create mode 100644 src/sparseml/modifiers/quantization_vllm/base.py create mode 100644 src/sparseml/modifiers/quantization_vllm/pytorch.py create mode 100644 tests/sparseml/transformers/compression/__init__.py create mode 100644 tests/sparseml/transformers/compression/recipes/new_quant_full.yaml create mode 100644 tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml create mode 100644 tests/sparseml/transformers/compression/recipes/old_quant_full.yaml create mode 100644 tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml create mode 100644 tests/sparseml/transformers/compression/test_quantization.py diff --git a/integrations/huggingface-transformers/finetuning/example_single_gpu_config.yaml b/integrations/huggingface-transformers/finetuning/example_single_gpu_config.yaml new file mode 100644 index 00000000000..d2f7ec8cdc7 --- /dev/null +++ b/integrations/huggingface-transformers/finetuning/example_single_gpu_config.yaml @@ -0,0 +1,15 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: 'NO' +enable_cpu_affinity: false +gpu_ids: 0 +machine_rank: 0 +main_training_function: main +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py b/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py index a911fa1c0c7..2b439862b4e 100644 --- a/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py +++ b/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py @@ -171,6 +171,29 @@ def fasterprune( else: q = torch.quantize_per_channel(q, scale, zero_point, 0, dtype) q = torch.dequantize(q) + elif hasattr(self.layer, "quantization_scheme"): + if self.layer.quantization_scheme.weights is not None: + scale = self.layer.weight_scale + zero_point = self.layer.weight_zero_point + from compressed_tensors.quantization.lifecycle.forward import ( + fake_quantize, + ) + + while scale.ndim < 2: + scale = scale.unsqueeze(1) + zero_point = zero_point.unsqueeze(1) + + while q.ndim < 2: + q = q.unsqueeze(1) + q = fake_quantize( + q, + scale[:, i], + zero_point[:, i], + self.layer.quantization_scheme.weights, + ) + + while q.ndim != 1: + q.squeeze() Q1[:, i] = q Losses1[:, i] = (w - q) ** 2 / d**2 diff --git a/src/sparseml/modifiers/quantization_vllm/__init__.py b/src/sparseml/modifiers/quantization_vllm/__init__.py new file mode 100644 index 00000000000..9cdf715c135 --- /dev/null +++ b/src/sparseml/modifiers/quantization_vllm/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa + +from .base import * diff --git a/src/sparseml/modifiers/quantization_vllm/base.py b/src/sparseml/modifiers/quantization_vllm/base.py new file mode 100644 index 00000000000..c8b2522ecee --- /dev/null +++ b/src/sparseml/modifiers/quantization_vllm/base.py @@ -0,0 +1,83 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional + +from pydantic import Field + +from compressed_tensors.quantization import ( + QuantizationConfig, + QuantizationScheme, + QuantizationStatus, +) +from sparseml.core import Event, Modifier + + +__all__ = ["vLLMQuantizationModifier"] + + +class vLLMQuantizationModifier(Modifier): + """ + Enables post training quantization (PTQ) and quantization aware training (QAT) for a + given module or its submodules. After calibration (PTQ) or the start epoch (QAT), + the specified module(s) forward pass will emulate quantized execution and the + modifier will be enabled until training is completed. + + :param config_groups: dictionary specifying quantization schemes to apply to target + modules. Modules not matching a scheme target will NOT be quantized. + :param ignore: optional list of module class names or submodule names to not + quantize even if they match a target in config_groups. Defaults to empty list. + :param disable_quantization_observer_epoch: Epoch to disable updates to the module + quantization observers. At this point, quantized weights and zero points will + not be updated. Leave None to not disable observers during QAT. Default is None + :param num_calibration_steps: Number of steps to run post training calibration for. + When None, the entire calibration_dataloader is used + """ + + config_groups: Dict[str, QuantizationScheme] + ignore: List[str] = Field(default_factory=list) + disable_quantization_observer_epoch: Optional[float] = None + num_calibration_steps: Optional[int] = None + + def create_init_config(self) -> QuantizationConfig: + return QuantizationConfig( + config_groups=self.config_groups, + quantization_status=QuantizationStatus.INITIALIZED, + ignore=self.ignore, + ) + + def calculate_disable_observer_epoch(self) -> float: + """ + Get the epoch at which we want to disable to quantization observer + :return epoch to disable at, or -1 if it is not set + """ + return ( + self.disable_quantization_observer_epoch + if self.disable_quantization_observer_epoch is not None + else -1 + ) + + def check_should_disable_observer(self, event: Event) -> bool: + """ + Given the current index, determine if we should disable the observer + + :param event: Event to get index from + :return: True if observer should be disabled, False otherwise + """ + disable_epoch = self.calculate_disable_observer_epoch() + if disable_epoch == -1: + return False + if event.current_index >= disable_epoch: + return True + return False diff --git a/src/sparseml/modifiers/quantization_vllm/pytorch.py b/src/sparseml/modifiers/quantization_vllm/pytorch.py new file mode 100644 index 00000000000..a6e7f179525 --- /dev/null +++ b/src/sparseml/modifiers/quantization_vllm/pytorch.py @@ -0,0 +1,141 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any + +from torch.nn import Module + +from compressed_tensors.quantization import ( + apply_quantization_config, + freeze_module_quantization, + set_module_for_calibration, +) +from sparseml.core import Event, EventType, State +from sparseml.modifiers.quantization_vllm.base import vLLMQuantizationModifier +from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward + + +_LOGGER = logging.getLogger(__name__) + + +class vLLMQuantizationModifierPyTorch(vLLMQuantizationModifier): + """ + PyTorch specific implementation of vLLMQuantizationModifier + + Enables post training quantization (PTQ) and quantization aware training (QAT) for a + given module or its submodules. After calibration (PTQ) or the start epoch (QAT), + the specified module(s) forward pass will emulate quantized execution and the + modifier will be enabled until training is completed. + + :param config_groups: dictionary specifying quantization schemes to apply to target + modules. Modules not matching a scheme target will NOT be quantized. + :param ignore: optional list of module class names or submodule names to not + quantize even if they match a target in config_groups. Defaults to empty list. + :param disable_quantization_observer_epoch: Epoch to disable updates to the module + quantization observers. At this point, quantized weights and zero points will + not be updated. Leave None to not disable observers during QAT. Default is None + :param num_calibration_steps: Number of steps to run post training calibration for. + When None, the entire calibration_dataloader is used + """ + + calibration_dataloader_: Any = None + calibration_function_: Any = None + + def on_initialize_structure(self, state: State, **kwargs): + module = state.model.model + self._apply_modifier_to_model(module) + module.apply(freeze_module_quantization) + + def on_initialize(self, state: State, **kwargs) -> bool: + if self.end and self.end != -1: + raise ValueError( + "end_epoch is disabled for QuantizationModifier and can only be set to" + " -1 or None. Given {}".format(self.end) + ) + + self.calibration_dataloader_ = state.data.calib + module = state.model.model + + # intialize quantization in appropriate modules + self._apply_modifier_to_model(module) + + if self.calculate_start() == -1: # one-shot + module.apply(set_module_for_calibration) + self._calibrate_if_possible(module) + module.apply(freeze_module_quantization) + + return True + + def on_finalize(self, state: State, **kwargs) -> bool: + return True + + def on_start(self, state: State, event: Event, **kwargs): + module = state.model.model + module.apply(set_module_for_calibration) + + def on_update(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.BATCH_START: + if self.check_should_disable_observer(event): + module = state.model.model + module.apply(freeze_module_quantization) + + def on_end(self, state: State, event: Event, **kwargs): + module = state.model.model + module.apply(freeze_module_quantization) + + def on_event(self, state: State, event: Event, **kwargs): + pass + + def _apply_modifier_to_model(self, model: Module): + modifier_as_config = self.create_init_config() + apply_quantization_config(model, modifier_as_config) + + def _calibrate_if_possible(self, module: Module): + if self.num_calibration_steps == 0 and self.calibration_dataloader_: + _LOGGER.warning( + f"num_calibration_steps is {self.num_calibration_steps}." + f"Calibration data loader will not be used." + ) + elif self.num_calibration_steps and not self.calibration_dataloader_: + raise ValueError( + f"num_calibration_steps is {self.num_calibration_steps}. " + "Calibration data loader is not set. Pass a " + "calibration_data_loader with initialize(...) method." + ) + + elif not self.calibration_dataloader_: + return + + self._calibrate(module) + + def _calibrate(self, module: Module): + class_name = self.__class__.__name__.replace("PyTorch", "") + _LOGGER.info( + f"Running {class_name} calibration with " + f"{len(self.calibration_dataloader_)} samples..." + ) + + module_training = module.training + module.eval() + + run_calibration_forward( + module, + self.calibration_dataloader_, + self.num_calibration_steps, + self.calibration_function_, + ) + + if module_training: + module.train() diff --git a/src/sparseml/transformers/compression/sparsity_config.py b/src/sparseml/transformers/compression/sparsity_config.py index b04edf333c3..958ddc2b738 100644 --- a/src/sparseml/transformers/compression/sparsity_config.py +++ b/src/sparseml/transformers/compression/sparsity_config.py @@ -68,7 +68,7 @@ def infer_sparsity_structure() -> str: return sparsity_structure @staticmethod - def infer_config_from_model( + def from_pretrained( model: Module, state_dict: Optional[Dict[str, Tensor]] = None, compress: bool = False, diff --git a/src/sparseml/transformers/sparsification/compressed_tensors_utils.py b/src/sparseml/transformers/sparsification/compressed_tensors_utils.py index ab9a7f5f5fc..b6852535a2c 100644 --- a/src/sparseml/transformers/sparsification/compressed_tensors_utils.py +++ b/src/sparseml/transformers/sparsification/compressed_tensors_utils.py @@ -22,7 +22,14 @@ from transformers import PreTrainedModel from transformers.file_utils import CONFIG_NAME -from compressed_tensors import SPARSITY_CONFIG_NAME, CompressionConfig, ModelCompressor +from compressed_tensors import ( + QUANTIZATION_CONFIG_NAME, + SPARSITY_CONFIG_NAME, + CompressionConfig, + ModelCompressor, + QuantizationConfig, +) +from compressed_tensors.quantization.utils import is_model_quantized from sparseml.transformers.compression.sparsity_config import SparsityConfigMetadata from sparseml.utils.pytorch import qat_active @@ -76,16 +83,45 @@ def save_pretrained_wrapper( # state_dict gets passed in as a kwarg for FSDP models state_dict = kwargs.get("state_dict", None) - if qat_active(model): + # check if we are in the old quantization framework + if qat_active(model) and not is_model_quantized(model): _LOGGER.info( - "Compression for quantized models is not yet supported. Save will " - "be run without compression and no sparsity statistics will be " - "calculated." + "Compression for models quantized with QuantizationModifer is not " + "supported. Save will be run without compression and no sparsity " + "statistics will be calculated. To save a quantized model in a " + "compressed state please use vLLMQuantizationModifier instead." ) - return original_save_pretrained.__get__(model, model_class)( + + original_save_pretrained.__get__(model, model_class)( + save_directory, **kwargs + ) + + return + + elif qat_active(model): # quantized in new framework + _LOGGER.info( + "Sparsity compression for quantized models is not yet supported. " + "No sparsity statistics will be calculated and no sparsity config " + "will be saved." + ) + + original_save_pretrained.__get__(model, model_class)( save_directory, **kwargs ) + quant_config = QuantizationConfig.from_pretrained(model) + quant_config_data = quant_config.model_dump(exclude_unset=True) + config_file_path = os.path.join(save_directory, CONFIG_NAME) + + # add the sparsity config to the model's config file + with open(config_file_path, "r") as config_file: + config_data = json.load(config_file) + config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data + with open(config_file_path, "w") as config_file: + json.dump(config_data, config_file, indent=2, sort_keys=True) + + return + if sparsity_config is not None: sparsity_config.global_sparsity = ( SparsityConfigMetadata.infer_global_sparsity( @@ -104,7 +140,7 @@ def save_pretrained_wrapper( "calculation of compression statistics set " "skip_compression_stats=True" ) - sparsity_config = SparsityConfigMetadata.infer_config_from_model( + sparsity_config = SparsityConfigMetadata.from_pretrained( model, state_dict=state_dict, compress=save_compressed ) diff --git a/src/sparseml/transformers/sparsification/sparse_model.py b/src/sparseml/transformers/sparsification/sparse_model.py index a22316ca179..995b349f513 100644 --- a/src/sparseml/transformers/sparsification/sparse_model.py +++ b/src/sparseml/transformers/sparsification/sparse_model.py @@ -30,7 +30,12 @@ ) from transformers.file_utils import WEIGHTS_NAME -from compressed_tensors import ModelCompressor, get_safetensors_folder +from compressed_tensors.compressors import ModelCompressor +from compressed_tensors.quantization import ( + QuantizationConfig, + apply_quantization_config, + load_pretrained_quantization, +) from sparseml.modifiers.quantization.modification import modify_model from sparseml.pytorch.model_load.helpers import ( apply_recipe_structure_to_model, @@ -102,6 +107,9 @@ def skip(*args, **kwargs): # determine compression format, if any, from the model config compressor = ModelCompressor.from_pretrained(pretrained_model_name_or_path) + quantization_config = QuantizationConfig.from_model_config( + pretrained_model_name_or_path + ) # temporarily set the log level to error, to ignore printing out long missing # and unexpected key error messages (these are EXPECTED for quantized models) @@ -117,21 +125,26 @@ def skip(*args, **kwargs): # If model is compressed on disk, decompress and load the weights if compressor is not None: - # if we loaded from a HF stub, find the cached model - model_path = get_safetensors_folder( - pretrained_model_name_or_path, cache_dir=kwargs.get("cache_dir", None) + # decompress weights + compressor.overwrite_weights( + model_path=pretrained_model_name_or_path, model=model ) - # decompress weights - compressor.overwrite_weights(model_path=model_path, model=model) - - recipe = resolve_recipe(recipe=recipe, model_path=pretrained_model_name_or_path) - if recipe: - apply_recipe_structure_to_model( - model=model, - model_path=pretrained_model_name_or_path, - recipe_path=recipe, + if quantization_config is not None: + # if we loaded from a HF stub, find the cached model + apply_quantization_config(model, quantization_config) + load_pretrained_quantization(model, pretrained_model_name_or_path) + else: + recipe = resolve_recipe( + recipe=recipe, model_path=pretrained_model_name_or_path ) + if recipe: + apply_recipe_structure_to_model( + model=model, + model_path=pretrained_model_name_or_path, + recipe_path=recipe, + ) + return model @@ -140,8 +153,6 @@ class SparseAutoModel: Factory class for creating sparse models using transformers AutoModel classes """ - from sparseml.modifiers.quantization.modification import modify_model - @staticmethod def masked_language_modeling_from_pretrained( model_name_or_path: str, diff --git a/src/sparseml/utils/pytorch/module.py b/src/sparseml/utils/pytorch/module.py index 2228a533b31..780f1255db1 100644 --- a/src/sparseml/utils/pytorch/module.py +++ b/src/sparseml/utils/pytorch/module.py @@ -25,6 +25,7 @@ from torch.nn import Linear, Module, Parameter from torch.nn.modules.conv import _ConvNd +from compressed_tensors.quantization.utils import is_module_quantized from sparseml.core.model.base import ModelParameterizedLayer from sparseml.utils.fsdp.context import fix_fsdp_module_name, summon_full_params_context @@ -283,6 +284,8 @@ def qat_active(module: Module) -> bool: for _, layer in module.named_modules(): if isinstance(layer, torch.quantization.FakeQuantize): return True + if is_module_quantized(layer): + return True return False diff --git a/tests/sparseml/transformers/compression/__init__.py b/tests/sparseml/transformers/compression/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/sparseml/transformers/compression/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparseml/transformers/compression/recipes/new_quant_full.yaml b/tests/sparseml/transformers/compression/recipes/new_quant_full.yaml new file mode 100644 index 00000000000..c5a55fa3284 --- /dev/null +++ b/tests/sparseml/transformers/compression/recipes/new_quant_full.yaml @@ -0,0 +1,33 @@ +test_stage: + quant_modifiers: + vLLMQuantizationModifier: + ignore: ["lm_head", "model.layers.0.mlp.down_proj"] + config_groups: + group_0: + weights: + num_bits: 8 + type: "int" + symmetric: true + strategy: "tensor" + input_activations: + num_bits: 8 + type: "int" + symmetric: false + strategy: "tensor" + output_activations: null + targets: ["Linear"] + group_1: + weights: + num_bits: 8 + type: "int" + symmetric: true + strategy: "tensor" + input_activations: null + output_activations: null + targets: ["Embedding"] + SparseGPTModifier: + sparsity: 0.0 + block_size: 128 + sequential_update: False + quantize: True + targets: ["re:model.layers.\\d+$"] \ No newline at end of file diff --git a/tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml b/tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml new file mode 100644 index 00000000000..64a1f87b29d --- /dev/null +++ b/tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml @@ -0,0 +1,20 @@ +test_stage: + quant_modifiers: + vLLMQuantizationModifier: + ignore: ["lm_head", "model.layers.0.mlp.down_proj"] + config_groups: + group_0: + weights: + num_bits: 8 + type: "int" + symmetric: true + strategy: "tensor" + input_activations: null + output_activations: null + targets: ["Linear", "Embedding"] + SparseGPTModifier: + sparsity: 0.0 + block_size: 128 + sequential_update: False + quantize: True + targets: ["re:model.layers.\\d+$"] \ No newline at end of file diff --git a/tests/sparseml/transformers/compression/recipes/old_quant_full.yaml b/tests/sparseml/transformers/compression/recipes/old_quant_full.yaml new file mode 100644 index 00000000000..8a94733242a --- /dev/null +++ b/tests/sparseml/transformers/compression/recipes/old_quant_full.yaml @@ -0,0 +1,39 @@ +test_stage: + quant_modifiers: + QuantizationModifier: + ignore: + - model.layers.0.mlp.down_proj + - lm_head + - LlamaRotaryEmbedding + - LlamaRMSNorm + - SiLU + - MatMulLeftInput_QK + - MatMulRightInput_QK + - MatMulOutput_QK + - MatMulLeftInput_PV + - MatMulRightInput_PV + - MatMulOutput_PV + scheme_overrides: + Linear: + weights: + num_bits: 8 + symmetric: true + strategy: "tensor" + input_activations: + num_bits: 8 + symmetric: false + strategy: "tensor" + output_activations: null + Embedding: + weights: + num_bits: 8 + symmetric: true + strategy: "tensor" + input_activations: null + output_activations: null + SparseGPTModifier: + sparsity: 0.0 + block_size: 128 + sequential_update: False + quantize: True + targets: ["re:model.layers.\\d+$"] \ No newline at end of file diff --git a/tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml b/tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml new file mode 100644 index 00000000000..e095a22912b --- /dev/null +++ b/tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml @@ -0,0 +1,36 @@ +test_stage: + quant_modifiers: + QuantizationModifier: + ignore: + - model.layers.0.mlp.down_proj + - lm_head + - LlamaRotaryEmbedding + - LlamaRMSNorm + - SiLU + - MatMulLeftInput_QK + - MatMulRightInput_QK + - MatMulOutput_QK + - MatMulLeftInput_PV + - MatMulRightInput_PV + - MatMulOutput_PV + scheme_overrides: + Linear: + weights: + num_bits: 8 + symmetric: true + strategy: "tensor" + input_activations: null + output_activations: null + Embedding: + weights: + num_bits: 8 + symmetric: true + strategy: "tensor" + input_activations: null + output_activations: null + SparseGPTModifier: + sparsity: 0.0 + block_size: 128 + sequential_update: False + quantize: True + targets: ["re:model.layers.\\d+$"] \ No newline at end of file diff --git a/tests/sparseml/transformers/compression/test_quantization.py b/tests/sparseml/transformers/compression/test_quantization.py new file mode 100644 index 00000000000..1fc5f1af3c7 --- /dev/null +++ b/tests/sparseml/transformers/compression/test_quantization.py @@ -0,0 +1,233 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +import shutil +import tempfile +import unittest + +import torch +from torch.utils.data import DataLoader +from transformers import DefaultDataCollator + +from compressed_tensors.quantization.utils import is_module_quantized +from parameterized import parameterized_class +from sparseml.pytorch.utils import tensors_to_device +from sparseml.transformers import ( + SparseAutoModelForCausalLM, + SparseAutoTokenizer, + oneshot, +) +from sparseml.transformers.finetune.data import TextGenerationDataset +from sparseml.transformers.finetune.data.data_args import DataTrainingArguments +from tests.testing_utils import requires_gpu, requires_torch + + +@requires_torch +@requires_gpu +@parameterized_class( + ("old_recipe", "new_recipe"), + [ + ( + "tests/sparseml/transformers/compression/recipes/old_quant_full.yaml", + "tests/sparseml/transformers/compression/recipes/new_quant_full.yaml", + ), + ( + "tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml", + "tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml", + ), + ], +) +class TestQuantizationMatches(unittest.TestCase): + old_recipe = None + new_recipe = None + model_stub = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" + dataset = "open_platypus" + old_output = "tiny_llama_old" + new_output = "tiny_llama_new" + max_seq_length = 512 + num_comparisons = 64 + + @classmethod + def setUpClass(cls): + cls.test_dir = tempfile.mkdtemp() + + cls.model_old = SparseAutoModelForCausalLM.from_pretrained( + cls.model_stub, device_map="cuda:0" + ) + cls._run_oneshot( + cls.model_old, + cls.old_recipe, + cls.dataset, + os.path.join(cls.test_dir, cls.old_output), + ) + + cls.model_new = SparseAutoModelForCausalLM.from_pretrained( + cls.model_stub, device_map="cuda:1" + ) + cls._run_oneshot( + cls.model_new, + cls.new_recipe, + cls.dataset, + os.path.join(cls.test_dir, cls.new_output), + ) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.test_dir) + del cls.model_new + del cls.model_old + torch.cuda.empty_cache() + + @staticmethod + def _run_oneshot(model, recipe, dataset, output_dir): + num_calibration_samples = 512 + max_seq_length = 512 + pad_to_max_length = False + + oneshot( + model=model, + dataset=dataset, + overwrite_output_dir=True, + output_dir=output_dir, + max_seq_length=max_seq_length, + num_calibration_samples=num_calibration_samples, + recipe=recipe, + pad_to_max_length=pad_to_max_length, + ) + + def _get_quant_info_old(self, model): + quant_info_weights = {} + quant_info_inputs = {} + for name, module in model.named_modules(): + if hasattr(module, "weight_fake_quant"): + scale = module.weight_fake_quant.scale.item() + zp = module.weight_fake_quant.zero_point.item() + quant_info_weights[name] = (scale, zp) + elif hasattr(module, "quant"): + scale = module.quant.activation_post_process.scale.item() + zp = module.quant.activation_post_process.zero_point.item() + quant_info_inputs[name] = (scale, zp) + + return quant_info_weights, quant_info_inputs + + def _get_quant_info_new(self, model): + quant_info_weights = {} + quant_info_inputs = {} + for name, module in model.named_modules(): + if is_module_quantized(module): + if module.quantization_scheme.weights is not None: + quant_info_weights[name] = ( + module.weight_scale.item(), + module.weight_zero_point.item(), + ) + if module.quantization_scheme.input_activations is not None: + quant_info_inputs[name] = ( + module.input_scale.item(), + module.input_zero_point.item(), + ) + + return quant_info_weights, quant_info_inputs + + def test_quantization_counts(self): + old_quant_weights, old_quant_inputs = self._get_quant_info_old(self.model_old) + new_quant_weights, new_quant_inputs = self._get_quant_info_new(self.model_new) + + assert len(old_quant_weights) == len(new_quant_weights) + assert len(old_quant_inputs) == len(new_quant_inputs) + + def test_quantization_scale_and_zp(self): + old_quant_weights, old_quant_inputs = self._get_quant_info_old(self.model_old) + new_quant_weights, new_quant_inputs = self._get_quant_info_new(self.model_new) + + for name, (o_scale, o_zp) in old_quant_weights.items(): + if name.endswith(".module"): + name = name[:-7] + n_scale, n_zp = new_quant_weights[name] + assert math.isclose(o_scale, n_scale, abs_tol=1e-3, rel_tol=1e-3) + assert o_zp == n_zp + + # allow for error here due to implementation differences + for name, (o_scale, o_zp) in old_quant_inputs.items(): + n_scale, n_zp = new_quant_inputs[name] + assert math.isclose(o_scale, n_scale, abs_tol=1e-2, rel_tol=1e-2) + assert abs(o_zp - n_zp) < 5 + + def test_quantization_reload(self): + model_reloaded = SparseAutoModelForCausalLM.from_pretrained( + os.path.join(self.test_dir, self.new_output) + ) + + og_weights, og_inputs = self._get_quant_info_new(self.model_new) + reloaded_weights, reloaded_inputs = self._get_quant_info_new(model_reloaded) + + for name, (o_scale, o_zp) in og_weights.items(): + n_scale, n_zp = reloaded_weights[name] + assert o_scale == n_scale + assert o_zp == n_zp + + for name, (o_scale, o_zp) in og_inputs.items(): + n_scale, n_zp = reloaded_inputs[name] + assert o_scale == n_scale + assert o_zp == n_zp + + def _get_dataloader(self, data_args, tokenizer): + dataset_manager = TextGenerationDataset.load_from_registry( + data_args.dataset, + data_args=data_args, + split="train", + tokenizer=tokenizer, + ) + calib_dataset = dataset_manager.tokenize_and_process( + dataset_manager.get_raw_dataset() + ) + data_loader = DataLoader( + calib_dataset, + batch_size=1, + collate_fn=DefaultDataCollator(), + sampler=torch.utils.data.RandomSampler(calib_dataset), + ) + + return data_loader + + @torch.no_grad() + def test_perplexity(self): + tokenizer = SparseAutoTokenizer.from_pretrained(self.model_stub) + data_args = DataTrainingArguments( + dataset="wikitext", + dataset_config_name="wikitext-2-raw-v1", + max_seq_length=self.max_seq_length, + concatenate_data=True, + ) + dataloader = self._get_dataloader(data_args, tokenizer) + + total_ppl_old = 0.0 + total_ppl_new = 0.0 + total_non_nan = 0 + for idx, sample in enumerate(dataloader): + if idx >= self.num_comparisons: + break + output_new = self.model_new(**tensors_to_device(sample, "cuda:1")) + output_old = self.model_old(**tensors_to_device(sample, "cuda:0")) + if torch.isnan(output_old.loss) and torch.isnan(output_new.loss): + continue + total_ppl_old += torch.exp(output_old.loss).item() + total_ppl_new += torch.exp(output_new.loss).item() + total_non_nan += 1 + + avg_ppl_ratio = (total_ppl_new / total_non_nan) / ( + total_ppl_old / total_non_nan + ) + assert avg_ppl_ratio <= 1.02 diff --git a/tests/testing_utils.py b/tests/testing_utils.py index c42402847af..240d8a76da6 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -37,7 +37,12 @@ def is_torch_available(): def is_gpu_available(): - return False + try: + import torch # noqa: F401 + + return torch.cuda.device_count() > 0 + except ImportError: + return False def requires_torch(test_case):