diff --git a/integrations/huggingface-transformers/tutorials/text-generation/example_alternating_recipe.yaml b/integrations/huggingface-transformers/tutorials/text-generation/example_alternating_recipe.yaml index 26d9d414359..ca186150c4f 100644 --- a/integrations/huggingface-transformers/tutorials/text-generation/example_alternating_recipe.yaml +++ b/integrations/huggingface-transformers/tutorials/text-generation/example_alternating_recipe.yaml @@ -5,7 +5,6 @@ initial_sparsity_stage: sparsity: 0.5 block_size: 128 sequential_update: False - quantize: False percdamp: 0.01 mask_structure: "0:0" targets: [ @@ -24,7 +23,6 @@ next_sparsity_stage: sparsity: 0.7 block_size: 128 sequential_update: False - quantize: False percdamp: 0.01 mask_structure: "0:0" targets: [ diff --git a/src/sparseml/modifiers/obcq/base.py b/src/sparseml/modifiers/obcq/base.py index f6e504e7b05..4960f71bae7 100644 --- a/src/sparseml/modifiers/obcq/base.py +++ b/src/sparseml/modifiers/obcq/base.py @@ -12,20 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union -from sparseml.core.factory import ModifierFactory +from sparseml.core import Modifier +from sparseml.core.model.base import ModifiableModel from sparseml.core.state import State -from sparseml.modifiers.pruning.wanda.base import WandaPruningModifier __all__ = ["SparseGPTModifier"] -_LOGGER = logging.getLogger(__name__) - -class SparseGPTModifier(WandaPruningModifier): +class SparseGPTModifier(Modifier): """ Modifier for applying the one-shot OBCQ algorithm to a model @@ -41,84 +38,91 @@ class SparseGPTModifier(WandaPruningModifier): - on_finalize - LayerCompressor.revert_layer_wrappers() - :param block_size: Used to determine number of columns to compress in one pass - :param quantize: Whether or not to quantize weights during SparseGPT. Set to - True to quantize using an existing quantization modifier, or pass in the - configuration for a quantization modifier if one does not already exist - in the recipe :param sparsity: Sparsity to compress model to + :param sparsity_profile: Can be set to 'owl' to use Outlier Weighed + Layerwise Sparsity (OWL), more information can be found + in the paper https://arxiv.org/pdf/2310.05175 + :param owl_m: Number of outliers to use for OWL + :param owl_lmbda: Lambda value to use for OWL + :param mask_structure: String to define the structure of the mask to apply. + Must be of the form N:M where N, M are integers that define a custom block + shape. Defaults to 0:0 which represents an unstructured mask. + :param sequential_update: Whether or not to update weights sequentially by layer, + True saves on GPU memory + :param targets: list of layer names to compress during OBCQ, or '__ALL__' + to compress every layer in the model + :param block_size: Used to determine number of columns to compress in one pass :param dampening_frac: Amount of dampening to apply to H, as a fraction of the diagonal norm + :param preserve_sparsity_mask: Whether or not to preserve the sparsity mask + during when applying sparsegpt, this becomes useful when starting from a + previously pruned model, defaults to False. """ - block_size: int = 128 - quantize: Union[bool, Dict] = False sparsity: Union[float, List[float]] = 0.0 + sparsity_profile: Optional[str] = None + owl_m: Optional[int] = None + owl_lmbda: Optional[float] = None + mask_structure: str = "0:0" + sequential_update: Optional[bool] = False + targets: Union[str, List[str], None] = None + block_size: int = 128 dampening_frac: Optional[float] = 0.01 - quantization_modifier_: Any = None + preserve_sparsity_mask: bool = False + prunen_: Optional[int] = None + prunem_: Optional[int] = None + compressible_layers_: Optional[List] = None def on_initialize_structure(self, state: State, **kwargs): """ - Check the model's quantization state matches that expected by this modifier, - adding a default quantization scheme if needed + Initialize the structure of the model for compression. + This modifier does not modifiy the model structure, so this method + is a no-op. + + :param state: session state storing input model and calibration data + """ + return True + + def compressible_layers(self) -> Dict: + """ + Retrieves the modules corresponding to a list of + compressible layer names + + :precondition: self.model is set and is a `ModifiableModel` + :precondition: The `ModifiableModel` implements a `get_layers` + method + :return: dictionary of modules to compress + """ + if not isinstance(self.model, ModifiableModel): + raise ValueError( + "`self.model` must be a ModifiableModel to use " + f"the {self.__class__.__qualname__} modifier but got " + f"{type(self.model)} instead" + ) + + return self.model.get_layers(self.targets) + + def _validate_layerwise_sparsity(self): + if isinstance(self.sparsity, float): + # single sparsity will be applied to all layers + return + + target_layers = list(self.compressible_layers_.keys()) + + if len(target_layers) != len(self.sparsity): + raise ValueError( + "Number of layer targets must match the number of " + f"sparsities. Got {len(target_layers)} layers and " + f"{len(self.sparsity)} sparsities" + ) + + def on_finalize(self, state: State, **kwargs): + """ + Nothing to do on finalize, on this level. + Quantization Modifier if any will be finalized in the subclass :param state: session state storing input model and calibration data + :param kwargs: additional arguments + :return: True """ - quantization_already_active = state.model.qat_active() - if isinstance(self.quantize, bool): - if not self.quantize and quantization_already_active: - _LOGGER.warning( - "SparseGPT quantization is set to False, but a " - "quantization modifier is already active on the model " - "resetting quantize to True" - ) - self.quantize = True - elif self.quantize and not quantization_already_active: - _LOGGER.warning( - "SparseGPT quantization is set to True without an " - "active quantization modifier. Creating a default " - "8-bit quantization modifier" - ) - default_quant_config = {"QuantizationModifier": {}} - self._build_quant_modifier_from_dict( - default_quant_config, state.framework - ) - return # use existing quantization modifier if there is one - else: - if not isinstance(self.quantize, Dict): - raise ValueError( - "SparseGPTModifier.quantize accepts only a single " - "quantization modifier or a boolean. Found " - f"type {type(self.quantize)}" - ) - if len(self.quantize) != 1: - raise ValueError( - "SparseGPTModifier.quantize accepts only a single " - "quantization modifier or a boolean. Found " - f"{len(self.quantize)} modifiers" - ) - if quantization_already_active: - _LOGGER.warning( - "Attempting to initialize quantization for SparseGPT " - "but a quantization modifier has already been applied. " - "The quantization configuration defined under the " - "SparseGPT modifier will be ignored." - ) - self.quantize = True - return - self._build_quant_modifier_from_dict(self.quantize, state.framework) - self.quantize = True - - if self.quantization_modifier_: - self.quantization_modifier_.on_initialize_structure(state, **kwargs) - - def _build_quant_modifier_from_dict(self, quant_config, framework): - modifier_type = list(quant_config.keys())[0] - modifier_args = quant_config[modifier_type] - self.quantization_modifier_ = ModifierFactory.create( - modifier_type, - framework=framework, - allow_registered=True, - allow_experimental=True, - **modifier_args, - ) + return True diff --git a/src/sparseml/modifiers/obcq/pytorch.py b/src/sparseml/modifiers/obcq/pytorch.py index de1eef74189..b2a15e67cd2 100644 --- a/src/sparseml/modifiers/obcq/pytorch.py +++ b/src/sparseml/modifiers/obcq/pytorch.py @@ -13,13 +13,19 @@ # limitations under the License. import logging -from typing import List, Optional +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import numpy as np +import torch +from tqdm import tqdm from sparseml.core.model import ModifiableModel from sparseml.core.state import State from sparseml.modifiers.obcq.base import SparseGPTModifier from sparseml.modifiers.obcq.utils.sgpt_wrapper import SparseGptWrapper -from sparseml.modifiers.pruning.wanda.pytorch import WandaPruningModifierPyTorch +from sparseml.modifiers.utils.layer_compressor import LayerCompressor +from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward +from sparseml.utils.pytorch.module import get_prunable_layers __all__ = ["SparseGPTModifierPyTorch"] @@ -27,7 +33,7 @@ _LOGGER = logging.getLogger(__name__) -class SparseGPTModifierPyTorch(WandaPruningModifierPyTorch, SparseGPTModifier): +class SparseGPTModifierPyTorch(SparseGPTModifier): """ Pytorch implementation of SparseGPT @@ -40,14 +46,23 @@ class SparseGPTModifierPyTorch(WandaPruningModifierPyTorch, SparseGPTModifier): - run_calibration_forward() - LayerCompressor.compress() - LayerCompressor.post_compress() - - on_finalize - - LayerCompressor.revert_layer_wrappers() + - LayerCompressor.revert_layer_wrappers() + + | Sample yaml: + | test_stage: + | obcq_modifiers: + | SparseGPTModifier: + | sparsity: 0.5 + | mask_structure: "2:4" + | sequential_update: True + | dampening_frac: 0.001 + | block_size: 128 :param model: Pytorch model to perform OBCQ on, in-place """ model: Optional[ModifiableModel] = None - layer_compressors: List = None + layer_compressors_: Optional[List[Any]] = None def on_initialize(self, state: "State", **kwargs) -> bool: """ @@ -57,25 +72,111 @@ def on_initialize(self, state: "State", **kwargs) -> bool: """ if not self.initialized_structure_: self.on_initialize_structure(state, **kwargs) - if self.quantization_modifier_: - self.quantization_modifier_.initialize(state, **kwargs) - if not self.quantize and self.sparsity == 0.0: + + if self.sparsity == 0.0: raise ValueError( - "To use the SparseGPTModifier, target sparsity must be > 0.0 or " - "quantization must be enabled." + "To use the SparseGPTModifier, target sparsity must be > 0.0" ) - return super(SparseGPTModifierPyTorch, self).on_initialize(state, **kwargs) + modifiable_model = state.model + calibration_dataloader = state.data.calib - def on_finalize(self, state: "State", **kwargs) -> bool: + if self.targets is None: + # if no targets are provided, default to the modules that shouldn't be + # split by FSDP. For Transformers models this is equivalent to the + # decoder layers (ie LlamaDecoderLayer) + self.targets = modifiable_model.get_no_split_params() + + self.initialize_compression(modifiable_model, calibration_dataloader) + self.apply_compression(calibration_dataloader) + + return True + + def initialize_compression( + self, + model: ModifiableModel, + dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None, + ): """ - disable the quantization observers used by the OBCQ algorithm + Setup for WANDA, initializes the model, device, + and other parameters, also initilializes the + compressible layers of model, and sets the device - :param state: session state storing input model and calibration data + :param model: model to initialize for compression """ - if self.quantization_modifier_: - self.quantization_modifier_.finalize(state, **kwargs) + self.model = model + self.compressible_layers_ = self.compressible_layers() + self.model = self.model.model + self.layer_compressors_ = [] + self._infer_mask_block_size() + + if self.sparsity_profile is not None and self.sparsity_profile.lower() == "owl": + _LOGGER.info( + "Inferring layer-wise sparsities from " + f"{len(dataloader)} calibration samples..." + ) + self.sparsity = self._infer_layer_sparsity(dataloader) + self._validate_layerwise_sparsity() + + for idx, (name, layer) in enumerate(self.compressible_layers_.items()): + _LOGGER.info(f"Preparing {name} for compression") + if isinstance(self.sparsity, Dict): + layer_sparsity = self.sparsity[name] + elif isinstance(self.sparsity, List): + layer_sparsity = self.sparsity[idx] + else: # float + layer_sparsity = self.sparsity + args = self._pruning_arguments(layer_sparsity) + comp_cls = self._compression_class() + compressor = LayerCompressor(comp_cls, self.model, layer, idx, name, args) + if not self.sequential_update: + # add all batch processing hooks before the forward pass + compressor.pre_compress() + self.layer_compressors_.append(compressor) + @torch.no_grad() + def apply_compression( + self, dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None + ) -> Dict: + """ + Run Wanda on the loaded model, using dataloader as calibration data + + :param dataloader: calibration data for WANDA + """ + class_name = self.__class__.__name__.replace("PyTorch", "") + _LOGGER.info( + f"Running {class_name} calibration with " f"{len(dataloader)} samples..." + ) + if not self.sequential_update: + # in non-sequential mode we run one forward batch for all modules + run_calibration_forward(self.model, dataloader, mask_padding=True) + + num_layers = len(self.compressible_layers_) + for idx, layer_compressor in enumerate(self.layer_compressors_): + layer_sparsity = layer_compressor.args["sparsity"] + _LOGGER.info( + f"\n===== Compressing layer {idx+1}/{num_layers} " + f"to sparsity {layer_sparsity} =====" + ) + + # Prune/quantize using SparseGPT + if self.sequential_update: + # in sequential mode we run one forward pass for each module we + # want to compress, this will be really slow but allows compression in + # earlier layers to affect later layers + layer_compressor.pre_compress() + _LOGGER.info(f"Calibrating {layer_compressor.name}...") + run_calibration_forward(self.model, dataloader, mask_padding=True) + layer_compressor.compress() + layer_compressor.post_compress() + layer_compressor.revert_layer_wrappers() + torch.cuda.empty_cache() + + def on_finalize(self, state: "State", **kwargs) -> bool: + """ + :param state: session state storing input model and calibration data + :return: True if the finalization was successful + """ return super(SparseGPTModifierPyTorch, self).on_finalize(state, **kwargs) def _pruning_arguments(self, sparsity): @@ -91,6 +192,7 @@ def _pruning_arguments(self, sparsity): "prunem": self.prunem_, "blocksize": self.block_size, "percdamp": self.dampening_frac, + "preserve_sparsity_mask": self.preserve_sparsity_mask, } def _compression_class(self): @@ -98,3 +200,96 @@ def _compression_class(self): :return: wrapper class used for root modules of this compression class """ return SparseGptWrapper + + def _infer_mask_block_size(self): + """ + Infer the mask block size from the mask structure. + Parses mask_structure of the form N:M where N, M are integers that + define a custom block shape; and sets prunen_ and prunem_ accordingly. + + :post-condition: prunen_ and prunem_ are set + """ + if self.mask_structure is None: + raise ValueError("mask_structure must be defined") + + self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":"))) + + def _infer_layer_sparsity(self, calibration_dataloader): + acts = _get_activations(self.model, calibration_dataloader) + sparsegpt_groups = {} + for name, layer in self.compressible_layers_.items(): + prunable_layers = get_prunable_layers(layer) + z = [ + m.weight.abs() * acts[f"{name}.{n}"].unsqueeze(0) + for n, m in prunable_layers.items() + ] + sparsegpt_groups[name] = torch.cat([item.flatten().cpu() for item in z]) + + acts = None + del acts + torch.cuda.empty_cache() + + outlier_ratios = {} + for group in sparsegpt_groups: + threshold = torch.mean(sparsegpt_groups[group]) * self.owl_m + outlier_ratios[group] = ( + 100 + * (sparsegpt_groups[group] > threshold).sum().item() + / sparsegpt_groups[group].numel() + ) + outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios]) + for k in outlier_ratios: + outlier_ratios[k] = (outlier_ratios[k] - outlier_ratios_arr.min()) * ( + 1 + / (outlier_ratios_arr.max() - outlier_ratios_arr.min()) + * self.owl_lmbda + * 2 + ) + outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios]) + sparsities = { + k: 1 + - ( + outlier_ratios[k] + - np.mean(outlier_ratios_arr) + + (1 - float(self.sparsity)) + ) + for k in outlier_ratios + } + _LOGGER.info(f"OWL sparsities for sp={self.sparsity} are:") + for k in sparsities: + _LOGGER.info(f"Sparsity for {k}: {sparsities[k]}") + return sparsities + + +@torch.no_grad() +def _get_activations(model, data_loader, nsamples=128): + import functools + + model.eval() + acts = {} + + def save_acts(module, input, name): + if isinstance(input, tuple): + input = input[0] + if name not in acts: + acts[name] = 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() + else: + acts[name] += 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() + + hooks = [] + for name, mod in model.named_modules(): + if isinstance(mod, torch.nn.Linear) and "lm_head" not in name: + hooks.append( + mod.register_forward_pre_hook(functools.partial(save_acts, name=name)) + ) + device = next(model.parameters()).device + for batch in tqdm(data_loader): + batch = {k: v.to(device) for k, v in batch.items()} + model(**batch) + batch = None + torch.cuda.empty_cache() + + for h in hooks: + h.remove() + + return acts diff --git a/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py b/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py index 68171ac33a1..634dbfac805 100644 --- a/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py +++ b/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py @@ -25,7 +25,6 @@ import logging import math -from copy import copy import torch import torch.nn as nn @@ -85,6 +84,7 @@ def fasterprune( prunem: int = 0, blocksize: int = 128, percdamp: float = 0.01, + preserve_sparsity_mask: bool = False, ): """ Run pruning and quantization(if applicable) on the layer up to the target @@ -95,7 +95,8 @@ def fasterprune( :param prunem: M for N:M pruning :param blocksize: Number of columns to compress in one pass :param percdamp: Amount of dampening to apply to H, as a fraction of the - diagonal norm + diagonal norm + :param preserve_sparsity_mask: Extend or ignore the base sparsity mask """ final_shape = self.layer.weight.shape final_dtype = self.layer.weight.dtype @@ -124,6 +125,13 @@ def fasterprune( Hinv = self.H mask = None + if preserve_sparsity_mask: + # compute existing sparsity mask + mask = torch.where( + W == 0, + torch.tensor(1, dtype=torch.bool), + torch.tensor(0, dtype=torch.bool), + ) # See section 3.4 of https://arxiv.org/abs/2203.07259 for i1 in range(0, self.columns, blocksize): @@ -139,12 +147,32 @@ def fasterprune( if prunen == 0: if mask is not None: mask1 = mask[:, i1:i2] + if int(W1.numel() * sparsity) > mask1.sum(): + # target sparsity is higher than base sparsity, extend mask1 + tmp = ( + (~mask[:, i1:i2]) + * W1**2 + / (torch.diag(Hinv1).reshape((1, -1))) ** 2 + ) + thresh = torch.sort(tmp.flatten())[0][ + int(tmp.numel() * sparsity) + ] + mask1 = tmp <= thresh + else: + raise ValueError( + "The target sparsity is lower than the sparsity " + "of the base model. Please retry " + "after turning preserve_sparsity_mask=False" + ) else: tmp = W1**2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2 thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] mask1 = tmp <= thresh else: - mask1 = torch.zeros_like(W1) == 1 + if mask is not None: + mask1 = mask[:, i1:i2] + else: + mask1 = torch.zeros_like(W1) == 1 for i in range(count): w = W1[:, i] @@ -155,6 +183,9 @@ def fasterprune( W1[:, i : (i + prunem)] ** 2 / (torch.diag(Hinv1)[i : (i + prunem)].reshape((1, -1))) ** 2 ) + if mask is not None: + tmp = tmp * (~mask[:, i : (i + prunem)]) + mask1.scatter_( 1, i + torch.topk(tmp, prunen, dim=1, largest=False)[1], True ) @@ -162,66 +193,6 @@ def fasterprune( q = w.clone() q[mask1[:, i]] = 0 - if hasattr(self.layer, "weight_fake_quant"): - scale = self.layer.weight_fake_quant.scale - zero_point = self.layer.weight_fake_quant.zero_point - dtype = self.layer.weight_fake_quant.dtype - qscheme = self.layer.weight_fake_quant.qscheme - if qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]: - q = torch.quantize_per_tensor(q, scale, zero_point, dtype) - else: - q = torch.quantize_per_channel(q, scale, zero_point, 0, dtype) - q = torch.dequantize(q) - elif hasattr(self.layer, "quantization_scheme"): - quant_scheme = self.layer.quantization_scheme - if quant_scheme.weights is not None: - scale = self.layer.weight_scale - zero_point = self.layer.weight_zero_point - from compressed_tensors.quantization import QuantizationStrategy - from compressed_tensors.quantization.lifecycle.forward import ( - fake_quantize, - ) - - strategy = quant_scheme.weights.strategy - - if strategy == QuantizationStrategy.TENSOR: - q = fake_quantize( - q, - scale, - zero_point, - self.layer.quantization_scheme.weights, - ) - elif strategy == QuantizationStrategy.CHANNEL: - # TODO: for channelwise why isn't this just a 1d tensor? - q = fake_quantize( - q, - scale[:, 0], - zero_point[:, 0], - quant_scheme.weights, - ) - else: # strategy == QuantizationStrategy.GROUP - # TODO: for grouped quantization its always 3d but the last - # dim is always 1. Can we just make it 2d instead and avoid? - scale = scale[:, :, 0] - zero_point = zero_point[:, :, 0] - - # get the group index for the current column - column_idx = i1 + i - input_dim_group = ( - column_idx // quant_scheme.weights.group_size - ) - - # Since we're only applying quantization to a slice, this - # ends up being a channelwise application - altered_qargs = copy(quant_scheme.weights) - altered_qargs.strategy = QuantizationStrategy.CHANNEL - q = fake_quantize( - q, - scale[:, input_dim_group], - zero_point[:, input_dim_group], - altered_qargs, - ) - Q1[:, i] = q Losses1[:, i] = (w - q) ** 2 / d**2 @@ -232,7 +203,12 @@ def fasterprune( W[:, i1:i2] = Q1 Losses += torch.sum(Losses1, 1) / 2 - W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + if preserve_sparsity_mask: + # respect the sparsity of other groups + # really not needed, but kept for explicitness + W[:, i2:] -= (~mask[:, i2:]) * Err1.matmul(Hinv[i1:i2, i2:]) + else: + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) _LOGGER.info("time %.2f" % (time.time() - tick)) _LOGGER.info("error %.2f" % torch.sum(Losses).item()) diff --git a/src/sparseml/modifiers/pruning/wanda/base.py b/src/sparseml/modifiers/pruning/wanda/base.py index 26cb6db5bf7..d59621cc09d 100644 --- a/src/sparseml/modifiers/pruning/wanda/base.py +++ b/src/sparseml/modifiers/pruning/wanda/base.py @@ -37,8 +37,8 @@ class WandaPruningModifier(Modifier): - run_calibration_forward() - LayerCompressor.compress() - LayerCompressor.post_compress() + - LayerCompressor.revert_layer_wrappers() - on_finalize - - LayerCompressor.revert_layer_wrappers() :param sparsity: Sparsity to compress model to :param mask_structure: String to define the structure of the mask to apply. diff --git a/src/sparseml/modifiers/pruning/wanda/pytorch.py b/src/sparseml/modifiers/pruning/wanda/pytorch.py index 8d7e8ff3b76..6203e73f600 100644 --- a/src/sparseml/modifiers/pruning/wanda/pytorch.py +++ b/src/sparseml/modifiers/pruning/wanda/pytorch.py @@ -44,8 +44,18 @@ class WandaPruningModifierPyTorch(WandaPruningModifier): - run_calibration_forward() - LayerCompressor.compress() - LayerCompressor.post_compress() + - LayerCompressor.revert_layer_wrappers() - on_finalize - - LayerCompressor.revert_layer_wrappers() + + | Sample yaml: + | test_stage: + | wanda_modifiers: + | WandaPruningModifier: + | sparsity: 0.05 + | mask_structure: "2:4" + | sequential_update: True + | targets: __ALL__ + :param model: `ModifiableModel` to perform WANDA on, in-place """ @@ -141,7 +151,7 @@ def apply_compression( f"to sparsity {layer_sparsity} =====" ) - # Prune/quantize using SparseGPT + # Prune/quantize using the layer compressor if self.sequential_update: # in sequential mode we run one forward pass for each module we # want to compress, this will be really slow but allows compression in diff --git a/src/sparseml/modifiers/quantization/gptq/__init__.py b/src/sparseml/modifiers/quantization/gptq/__init__.py new file mode 100644 index 00000000000..9cdf715c135 --- /dev/null +++ b/src/sparseml/modifiers/quantization/gptq/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa + +from .base import * diff --git a/src/sparseml/modifiers/quantization/gptq/base.py b/src/sparseml/modifiers/quantization/gptq/base.py new file mode 100644 index 00000000000..b91fa2dad60 --- /dev/null +++ b/src/sparseml/modifiers/quantization/gptq/base.py @@ -0,0 +1,207 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Dict, List, Optional, Union + +from pydantic import Field + +from compressed_tensors.quantization import QuantizationScheme +from sparseml.core import Modifier +from sparseml.core.factory import ModifierFactory +from sparseml.core.model.base import ModifiableModel +from sparseml.core.state import State + + +__all__ = ["GPTQModifier"] + +_LOGGER = logging.getLogger(__name__) + + +class GPTQModifier(Modifier): + """ + Modifier for applying the one-shot OBCQ algorithm to a model + + Lifecycle: + - on_initialize + - initialize_compression() + - compressible_layers() + - LayerCompressor.pre_compress() + - apply_compression() + - run_calibration_forward() + - LayerCompressor.compress() + - LayerCompressor.post_compress() + - on_finalize + - LayerCompressor.revert_layer_wrappers() + + + :param sequential_update: Whether or not to update weights sequentially by layer, + True saves on GPU memory + :param targets: list of layer names to compress during GPTQ, or '__ALL__' + to compress every layer in the model + :param block_size: Used to determine number of columns to compress in one pass + :param quantize: Set to True to quantize using an existing quantization modifier, + or pass in the configuration for a quantization modifier if one does not + already exist in the recipe + :param dampening_frac: Amount of dampening to apply to H, as a fraction of the + diagonal norm + :param config_groups: [Used, if a quantization modifier is not specified], + dictionary specifying quantization schemes to apply to target + modules. Modules not matching a scheme target will NOT be quantized. + :param ignore: [Used, if a quantization modifier is not specified] + optional list of module class names or submodule names to not + quantize even if they match a target in config_groups. Defaults to empty list. + :param disable_quantization_observer_epoch: [Used, if a quantization modifier is + not specified] Epoch to disable updates to the module + quantization observers. At this point, quantized weights and zero points will + not be updated. Leave None to not disable observers during QAT. Default is None + :param num_calibration_steps: Number of steps to run post training calibration for. + When None, the entire calibration_dataloader is used + """ + + sequential_update: Optional[bool] = False + targets: Union[str, List[str], None] = None + block_size: int = 128 + quantize: Union[bool, Dict] = True + dampening_frac: Optional[float] = 0.01 + config_groups: Optional[Dict[str, QuantizationScheme]] = None + ignore: List[str] = Field(default_factory=list) + disable_quantization_observer_epoch: Optional[float] = None + num_calibration_steps: Optional[int] = None + compressible_layers_: Optional[List] = None + quantization_modifier_: Any = None + + def on_initialize_structure(self, state: State, **kwargs): + """ + Check the model's quantization state matches that expected by this modifier, + adding a default quantization scheme if needed + + :param state: session state storing input model and calibration data + """ + quantization_already_active = state.model.qat_active() + if isinstance(self.quantize, bool): + if not self.quantize and quantization_already_active: + _LOGGER.warning( + "GPTQ quantization is set to False, but a " + "quantization modifier is already active on the model " + "resetting quantize to True" + ) + self.quantize = True + elif self.quantize and not quantization_already_active: + _LOGGER.warning( + "GPTQ quantization is set to True without an " + "active quantization modifier." + ) + self._build_quant_modifier(state.framework) + return # use existing quantization modifier if there is one + else: + if not isinstance(self.quantize, Dict): + raise ValueError( + "GPTQModifier.quantize accepts only a single " + "quantization modifier or a boolean. Found " + f"type {type(self.quantize)}" + ) + if len(self.quantize) != 1: + raise ValueError( + "GPTQModifier.quantize accepts only a single " + "quantization modifier or a boolean. Found " + f"{len(self.quantize)} modifiers" + ) + if quantization_already_active: + _LOGGER.warning( + "Attempting to initialize quantization for GPTQ " + "but a quantization modifier has already been applied. " + "The quantization configuration defined under the " + "GPTQ modifier will be ignored." + ) + self.quantize = True + return + self._build_quant_modifier_from_dict(self.quantize, state.framework) + self.quantize = True + + if self.quantization_modifier_: + self.quantization_modifier_.on_initialize_structure(state, **kwargs) + + def _build_quant_modifier(self, framework): + """ + Build a quantization modifier based on the specified config_groups, + ignore list, and num_calibration_steps. + + :postcondition: self.quantization_modifier_ is set to the built + quantization modifier + :param framework: the framework to build the quantization modifier for + """ + + quantization_args_names = [ + "config_groups", + "num_calibration_steps", + "ignore", + "disable_quantization_observer_epoch", + ] + + quant_args = { + key: getattr(self, key) + for key in quantization_args_names + if getattr(self, key, False) + } + + if "config_groups" not in quant_args: + default_quant_scheme = QuantizationScheme.default_scheme( + targets=self.targets + ) + quant_args["config_groups"] = {"config_group_0": default_quant_scheme} + _LOGGER.info(f"Building quantization modifier with args: {quant_args}") + vllm_quant_config = {"vLLMQuantizationModifier": quant_args} + self._build_quant_modifier_from_dict(vllm_quant_config, framework) + + def compressible_layers(self) -> Dict: + """ + Retrieves the modules corresponding to a list of + compressible layer names + + :precondition: self.model is set and is a `ModifiableModel` + :precondition: The `ModifiableModel` implements a `get_layers` + method + :return: dictionary of modules to compress + """ + if not isinstance(self.model, ModifiableModel): + raise ValueError( + "`self.model` must be a ModifiableModel to use " + f"the {self.__class__.__qualname__} modifier but got " + f"{type(self.model)} instead" + ) + + return self.model.get_layers(self.targets) + + def _build_quant_modifier_from_dict(self, quant_config, framework): + modifier_type = list(quant_config.keys())[0] + modifier_args = quant_config[modifier_type] + self.quantization_modifier_ = ModifierFactory.create( + modifier_type, + framework=framework, + allow_registered=True, + allow_experimental=True, + **modifier_args, + ) + + def on_finalize(self, state: State, **kwargs): + """ + Nothing to do on finalize, on this level. + Quantization Modifier if any will be finalized in the subclass + + :param state: session state storing input model and calibration data + :param kwargs: additional arguments + :return: True + """ + return True diff --git a/src/sparseml/modifiers/quantization/gptq/pytorch.py b/src/sparseml/modifiers/quantization/gptq/pytorch.py new file mode 100644 index 00000000000..4bc3a8ff953 --- /dev/null +++ b/src/sparseml/modifiers/quantization/gptq/pytorch.py @@ -0,0 +1,195 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import torch + +from sparseml.core.model import ModifiableModel +from sparseml.core.state import State +from sparseml.modifiers.quantization.gptq.base import GPTQModifier +from sparseml.modifiers.utils.layer_compressor import LayerCompressor +from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward +from src.sparseml.modifiers.quantization.gptq.utils.gptq_wrapper import GPTQWrapper + + +__all__ = ["GPTQModifierPyTorch"] + +_LOGGER = logging.getLogger(__name__) + + +class GPTQModifierPyTorch(GPTQModifier): + """ + Pytorch implementation of GPTQ + Lifecycle: + - on_initialize + - initialize_compression() + - compressible_layers() + - LayerCompressor.pre_compress() + - apply_compression() + - run_calibration_forward() + - LayerCompressor.compress() + - LayerCompressor.post_compress() + - LayerCompressor.revert_layer_wrappers() + | Sample yaml: + | test_stage: + | obcq_modifiers: + | GPTQModifier: + | sequential_update: True + | dampening_frac: 0.001 + | block_size: 128 + | config_groups: + | group_0: + | targets: + | - "Linear" + | input_activations: null + | output_activations: null + | weights: + | num_bits: 8 + | type: "int" + | symmetric: true + | strategy: "tensor" + | group_size: 128 + + + :param model: Pytorch model to perform GPTQ on, in place. + """ + + model: Optional[ModifiableModel] = None + layer_compressors_: Optional[List[Any]] = None + + def on_initialize(self, state: "State", **kwargs) -> bool: + """ + Initialize and run the GPTQ algorithm on the current state + + :param state: session state storing input model and calibration data + """ + if not self.initialized_structure_: + self.on_initialize_structure(state, **kwargs) + if self.quantization_modifier_: + self.quantization_modifier_.initialize(state, **kwargs) + if not self.quantize: + raise ValueError("To use the GPTQModifier, quantization must be enabled.") + + modifiable_model = state.model + calibration_dataloader = state.data.calib + + if self.targets is None: + # if no targets are provided, default to the modules that shouldn't be + # split by FSDP. For Transformers models this is equivalent to the + # decoder layers (ie LlamaDecoderLayer) + self.targets = modifiable_model.get_no_split_params() + + self.initialize_compression(modifiable_model, calibration_dataloader) + self.apply_compression(calibration_dataloader) + + return True + + def initialize_compression( + self, + model: ModifiableModel, + dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None, + ): + """ + Setup for GPTQ, initializes the model + and other parameters, also initilializes the + compressible layers of model, and sets the device + + :param model: model to initialize for compression + :param dataloader: calibration data for GPTQ + """ + self.model = model + self.compressible_layers_ = self.compressible_layers() + self.model = self.model.model + self.layer_compressors_ = [] + + for idx, (name, layer) in enumerate(self.compressible_layers_.items()): + _LOGGER.info(f"Preparing {name} for compression") + if isinstance(self.sparsity, Dict): + layer_sparsity = self.sparsity[name] + elif isinstance(self.sparsity, List): + layer_sparsity = self.sparsity[idx] + else: # float + layer_sparsity = self.sparsity + args = self._pruning_arguments(layer_sparsity) + comp_cls = self._compression_class() + compressor = LayerCompressor(comp_cls, self.model, layer, idx, name, args) + if not self.sequential_update: + # add all batch processing hooks before the forward pass + compressor.pre_compress() + self.layer_compressors_.append(compressor) + + @torch.no_grad() + def apply_compression( + self, dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None + ) -> Dict: + """ + Run GPTQ on the loaded model, using dataloader as calibration data + + :param dataloader: calibration data for GPTQ + """ + class_name = self.__class__.__name__.replace("PyTorch", "") + _LOGGER.info( + f"Running {class_name} calibration with " f"{len(dataloader)} samples..." + ) + if not self.sequential_update: + # in non-sequential mode we run one forward batch for all modules + run_calibration_forward(self.model, dataloader, mask_padding=True) + + num_layers = len(self.compressible_layers_) + for idx, layer_compressor in enumerate(self.layer_compressors_): + _LOGGER.info(f"\n===== Compressing layer {idx+1}/{num_layers} " " =====") + + # Prune/quantize using GPTQ + if self.sequential_update: + # in sequential mode we run one forward pass for each module we + # want to compress, this will be really slow but allows compression in + # earlier layers to affect later layers + layer_compressor.pre_compress() + _LOGGER.info(f"Calibrating {layer_compressor.name}...") + run_calibration_forward(self.model, dataloader, mask_padding=True) + layer_compressor.compress() + layer_compressor.post_compress() + layer_compressor.revert_layer_wrappers() + torch.cuda.empty_cache() + + def on_finalize(self, state: "State", **kwargs) -> bool: + """ + disable the quantization observers used by the OBCQ algorithm + + :param state: session state storing input model and calibration data + """ + if self.quantization_modifier_: + self.quantization_modifier_.finalize(state, **kwargs) + + return super(GPTQModifierPyTorch, self).on_finalize(state, **kwargs) + + def _pruning_arguments(self): + """ + Gather the parameters needed for root module compression in a dict + + :param sparsity: target sparsity + :return: dict of params for pruning + """ + return { + "blocksize": self.block_size, + "percdamp": self.dampening_frac, + } + + def _compression_class(self): + """ + :return: wrapper class used for root modules of this compression class + """ + return GPTQWrapper diff --git a/src/sparseml/modifiers/quantization/gptq/utils/__init__.py b/src/sparseml/modifiers/quantization/gptq/utils/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/src/sparseml/modifiers/quantization/gptq/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py new file mode 100644 index 00000000000..12d68596ee4 --- /dev/null +++ b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py @@ -0,0 +1,249 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +from sparseml.modifiers.utils import SPARSITY_THRESHOLD +from sparseml.modifiers.utils.compression_wrapper import ModuleCompressionWrapper + + +try: + import transformers +except ImportError as err: + transformers = None + transformers_err = err + +import logging +import math +from copy import copy + +import torch +import torch.nn as nn + + +__all__ = ["GPTQWrapper"] + +_LOGGER = logging.getLogger(__name__) + + +class GPTQWrapper(ModuleCompressionWrapper): + """ + Runs GPTQ on a single module that contains no sub-modules + + Lifecycle: + - add_batch + - fasterprune + - free + + :param name: name of module to run compression on + :param layer: module to run compression on + """ + + def __init__(self, name, layer): + super().__init__(name=name, layer=layer) + + # for Hessian calculation + self.register_buffer( + "H", torch.zeros((self.columns, self.columns), device=self.dev) + ) + + def add_batch(self, inp: torch.Tensor, out: torch.Tensor): + """ + Add a batch of layer input and output data to the Hessian calculation + + :param inp: tensor containing layer input + :param out: tensor containing layer output + """ + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(self.layer, nn.Linear) or isinstance( + self.layer, transformers.Conv1D + ): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + inp = math.sqrt(2 / self.nsamples) * inp.float() + self.H += inp.matmul(inp.t()).to(self.dev) + + def fasterprune( + self, + blocksize: int = 128, + percdamp: float = 0.01, + ): + """ + Run pruning and quantization(if applicable) on the layer up to the target + sparsity value. + + :param blocksize: Number of columns to compress in one pass + :param percdamp: Amount of dampening to apply to H, as a fraction of the + diagonal norm + """ + final_shape = self.layer.weight.shape + final_dtype = self.layer.weight.dtype + W = self.layer.weight.data.clone() + from sparseml.pytorch.utils.helpers import tensor_sparsity + + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + W = W.float() + + tick = time.time() + + dead = torch.diag(self.H) == 0 + self.H[dead, dead] = 1 + W[:, dead] = 0 + + Losses = torch.zeros(self.rows, device=self.dev) + + damp = percdamp * torch.mean(torch.diag(self.H)) + diag = torch.arange(self.columns, device=self.dev) + self.H[diag, diag] += damp + self.H = torch.linalg.cholesky(self.H) + self.H = torch.cholesky_inverse(self.H) + self.H = torch.linalg.cholesky(self.H, upper=True) + Hinv = self.H + + sparsity = tensor_sparsity(W) + mask = ( + torch.where( + W == 0, + torch.tensor(1, dtype=torch.bool), + torch.tensor(0, dtype=torch.bool), + ) + if sparsity >= SPARSITY_THRESHOLD + else None + ) + + # See section 3.4 of https://arxiv.org/abs/2203.07259 + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + if sparsity >= SPARSITY_THRESHOLD: + tmp = ( + (~mask[:, i1:i2]) + * W1**2 + / (torch.diag(Hinv1).reshape((1, -1))) ** 2 + ) + thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] + mask1 = tmp <= thresh + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + q = w.clone() + if sparsity >= SPARSITY_THRESHOLD: + q[mask1[:, i]] = 0 + + if hasattr(self.layer, "weight_fake_quant"): + scale = self.layer.weight_fake_quant.scale + zero_point = self.layer.weight_fake_quant.zero_point + dtype = self.layer.weight_fake_quant.dtype + qscheme = self.layer.weight_fake_quant.qscheme + if qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]: + q = torch.quantize_per_tensor(q, scale, zero_point, dtype) + else: + q = torch.quantize_per_channel(q, scale, zero_point, 0, dtype) + q = torch.dequantize(q) + elif hasattr(self.layer, "quantization_scheme"): + quant_scheme = self.layer.quantization_scheme + if quant_scheme.weights is not None: + scale = self.layer.weight_scale + zero_point = self.layer.weight_zero_point + from compressed_tensors.quantization import QuantizationStrategy + from compressed_tensors.quantization.lifecycle.forward import ( + fake_quantize, + ) + + strategy = quant_scheme.weights.strategy + + if strategy == QuantizationStrategy.TENSOR: + q = fake_quantize( + q, + scale, + zero_point, + self.layer.quantization_scheme.weights, + ) + elif strategy == QuantizationStrategy.CHANNEL: + # TODO: for channelwise why isn't this just a 1d tensor? + q = fake_quantize( + q, + scale[:, 0], + zero_point[:, 0], + quant_scheme.weights, + ) + else: # strategy == QuantizationStrategy.GROUP + # TODO: for grouped quantization its always 3d but the last + # dim is always 1. Can we just make it 2d instead and avoid? + scale = scale[:, :, 0] + zero_point = zero_point[:, :, 0] + + # get the group index for the current column + column_idx = i1 + i + input_dim_group = ( + column_idx // quant_scheme.weights.group_size + ) + + # Since we're only applying quantization to a slice, this + # ends up being a channelwise application + altered_qargs = copy(quant_scheme.weights) + altered_qargs.strategy = QuantizationStrategy.CHANNEL + q = fake_quantize( + q, + scale[:, input_dim_group], + zero_point[:, input_dim_group], + altered_qargs, + ) + + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d**2 + + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + + W[:, i1:i2] = Q1 + Losses += torch.sum(Losses1, 1) / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + _LOGGER.info("time %.2f" % (time.time() - tick)) + _LOGGER.info("error %.2f" % torch.sum(Losses).item()) + + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + W = W.reshape(final_shape).to(final_dtype) + + # This is a bit hacky, but FSDP updates only work if we change the weight in + # place, clone() or direct assignment won't work + self.layer.weight -= self.layer.weight + self.layer.weight += W + + def free(self): + """ + Free the Hessian memory after the layer is complete + """ + delattr(self, "H") + super().free() diff --git a/src/sparseml/modifiers/utils/__init__.py b/src/sparseml/modifiers/utils/__init__.py index 0c44f887a47..39d1132f697 100644 --- a/src/sparseml/modifiers/utils/__init__.py +++ b/src/sparseml/modifiers/utils/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# flake8: noqa + +from .constants import * diff --git a/src/sparseml/modifiers/utils/constants.py b/src/sparseml/modifiers/utils/constants.py new file mode 100644 index 00000000000..3801c2e9ea9 --- /dev/null +++ b/src/sparseml/modifiers/utils/constants.py @@ -0,0 +1,18 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +__all__ = ["SPARSITY_THRESHOLD"] + +SPARSITY_THRESHOLD: float = 0.05 diff --git a/src/sparseml/transformers/sparsification/obcq/README.md b/src/sparseml/transformers/sparsification/obcq/README.md index 28f686f5afd..50ef351c34c 100644 --- a/src/sparseml/transformers/sparsification/obcq/README.md +++ b/src/sparseml/transformers/sparsification/obcq/README.md @@ -214,10 +214,15 @@ test_stage: sparsity: 0.5 block_size: 128 sequential_update: true - quantize: true percdamp: 0.01 mask_structure: "0:0" targets: ["re:model.layers.\\d*$"] + GPTQModifier: + block_size: 128 + sequential_update: False + percdamp: 0.01 + targets: ["re:model.layers.\\d+$"] + ``` ## How to Adapt a Recipe for a New Model You can modify the above recipe to perform one-shot quantization on other models, for example [Mistral](https://huggingface.co/docs/transformers/main/model_doc/mistral). @@ -260,10 +265,14 @@ test_stage: sparsity: 0.5 block_size: 128 sequential_update: true - quantize: true percdamp: 0.01 mask_structure: "0:0" targets: ["re:model.layers.\\d*$"] + GPTQModifier: + block_size: 128 + sequential_update: False + percdamp: 0.01 + targets: ["re:model.layers.\\d+$"] ``` Save the recipe to a file named `recipe.yaml`. diff --git a/src/sparseml/transformers/sparsification/obcq/example.yaml b/src/sparseml/transformers/sparsification/obcq/example.yaml index e6adf24de62..f0a8c501a21 100644 --- a/src/sparseml/transformers/sparsification/obcq/example.yaml +++ b/src/sparseml/transformers/sparsification/obcq/example.yaml @@ -26,7 +26,6 @@ test_stage: sparsity: 0.5 block_size: 128 sequential_update: False - quantize: True percdamp: 0.01 mask_structure: "0:0" targets: [ @@ -55,3 +54,8 @@ test_stage: "model.decoder.layers.22", "model.decoder.layers.23" ] + GPTQModifier: + block_size: 128 + sequential_update: False + percdamp: 0.01 + targets: ["re:model.layers.\\d+$"] \ No newline at end of file diff --git a/src/sparseml/transformers/sparsification/obcq/example_llama.yaml b/src/sparseml/transformers/sparsification/obcq/example_llama.yaml index da265bf7d27..a6cd783df68 100644 --- a/src/sparseml/transformers/sparsification/obcq/example_llama.yaml +++ b/src/sparseml/transformers/sparsification/obcq/example_llama.yaml @@ -54,7 +54,6 @@ test_stage: sparsity: 0.5 block_size: 128 sequential_update: False - quantize: True percdamp: 0.01 mask_structure: "0:0" targets: [ @@ -90,4 +89,9 @@ test_stage: "model.layers.29", "model.layers.30", "model.layers.31", - ] \ No newline at end of file + ] + GPTQModifier: + block_size: 128 + sequential_update: False + percdamp: 0.01 + targets: ["re:model.layers.\\d+$"] diff --git a/src/sparseml/transformers/sparsification/obcq/example_mistral.yaml b/src/sparseml/transformers/sparsification/obcq/example_mistral.yaml index 7800c9b9b09..85c8037e566 100644 --- a/src/sparseml/transformers/sparsification/obcq/example_mistral.yaml +++ b/src/sparseml/transformers/sparsification/obcq/example_mistral.yaml @@ -21,6 +21,9 @@ test_stage: sparsity: 0.5 block_size: 128 sequential_update: true - quantize: true percdamp: 0.01 - mask_structure: "0:0" \ No newline at end of file + mask_structure: "0:0" + GPTQModifier: + block_size: 128 + sequential_update: true + percdamp: 0.01 \ No newline at end of file diff --git a/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index 87558f5a625..b673c887c60 100644 --- a/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -20,8 +20,9 @@ from sparseml.core.framework import Framework from sparseml.core.model import ModifiableModel from sparseml.modifiers.obcq.pytorch import SparseGPTModifierPyTorch -from sparseml.modifiers.quantization import QuantizationModifier +from sparseml.modifiers.quantization.gptq.pytorch import GPTQModifierPyTorch from sparseml.modifiers.quantization.pytorch import QuantizationModifierPyTorch +from sparseml.modifiers.quantization_vllm.base import vLLMQuantizationModifier from tests.sparseml.modifiers.conf import LifecyleTestingHarness, setup_modifier_factory from tests.sparseml.pytorch.helpers import LinearNet from tests.testing_utils import requires_torch @@ -45,7 +46,6 @@ def test_invalid_layerwise_recipes_raise_exceptions(self, sparsity, targets): kwargs = dict( sparsity=sparsity, block_size=128, - quantize=False, targets=targets, ) modifier = SparseGPTModifierPyTorch(**kwargs) @@ -65,9 +65,7 @@ def setUp(self): def test_successful_layerwise_recipe(self): sparsities = [0.5, 0.2] targets = ["seq.fc1", "seq.fc2"] - kwargs = dict( - sparsity=sparsities, block_size=128, quantize=False, targets=targets - ) + kwargs = dict(sparsity=sparsities, block_size=128, targets=targets) modifier = SparseGPTModifierPyTorch(**kwargs) modifier.compressible_layers_ = {"seq.fc1": None, "seq.fc2": None} modifier.model = ModifiableModel(framework=Framework.pytorch, model=LinearNet()) @@ -86,17 +84,19 @@ def setUp(self): setup_modifier_factory() def test_create_default_quant_modifier(self): - kwargs = dict(sparsity=0.5, block_size=128, quantize=True) + kwargs = dict(block_size=128) - modifier = SparseGPTModifierPyTorch(**kwargs) + modifier = GPTQModifierPyTorch(**kwargs) assert modifier.quantization_modifier_ is None testing_harness = LifecyleTestingHarness(model=LinearNet()) modifier.on_initialize_structure(testing_harness.get_state()) assert modifier.quantize - assert isinstance(modifier.quantization_modifier_, QuantizationModifier) - - should_be_default_quant_scheme = modifier.quantization_modifier_.scheme + assert isinstance(modifier.quantization_modifier_, vLLMQuantizationModifier) + default_config_group_name = "config_group_0" + should_be_default_quant_scheme = modifier.quantization_modifier_.config_groups[ + default_config_group_name + ] self.assertEqual(should_be_default_quant_scheme.input_activations.num_bits, 8) assert not should_be_default_quant_scheme.input_activations.symmetric self.assertEqual(should_be_default_quant_scheme.weights.num_bits, 8) @@ -125,56 +125,62 @@ def test_set_quant_if_modifer_already_exists(self): modifier.initialize(testing_harness.get_state()) assert testing_harness.get_state().model.qat_active() - kwargs = dict(sparsity=0.5, block_size=128, quantize=False) - modifier = SparseGPTModifierPyTorch(**kwargs) - assert not modifier.quantize - modifier.on_initialize_structure(testing_harness.get_state()) - - # quantization modifier not owned by SparseGPT - assert modifier.quantization_modifier_ is None + kwargs = dict(block_size=128) + modifier = GPTQModifierPyTorch(**kwargs) + assert not modifier.quantization_modifier_ + modifier.on_initialize_structure(testing_harness.get_state()) # since quantization modifier is already applied, quantization must be set in - # OBCQ + # GPTQ assert modifier.quantize -class TestSetQuantInSparseGPT(unittest.TestCase): +class TestSetQuantInGPTQ(unittest.TestCase): def setUp(self): setup_modifier_factory() self.quant_kwargs = { - "scheme": { - "input_activations": { - "num_bits": 8, - "symmetric": False, - "strategy": "tensor", - "kwargs": {}, - }, - "weights": { - "num_bits": 4, - "symmetric": True, - "strategy": "channel", - "kwargs": {}, - }, + "config_groups": { + "config_group_0": { + "targets": ["Linear"], + "input_activations": { + "num_bits": 8, + "symmetric": False, + "strategy": "tensor", + "kwargs": {}, + }, + "weights": { + "num_bits": 4, + "symmetric": True, + "strategy": "channel", + "kwargs": {}, + }, + } } } - self.quant_config = {"QuantizationModifier": self.quant_kwargs} + self.quant_config = {"vLLMQuantizationModifier": self.quant_kwargs} - def test_set_quant_in_sparsegpt(self): - kwargs = dict(sparsity=0.5, block_size=128, quantize=self.quant_config) + def test_set_quant_in_gptq(self): + kwargs = dict(block_size=128, quantize=self.quant_config) - modifier = SparseGPTModifierPyTorch(**kwargs) + modifier = GPTQModifierPyTorch(**kwargs) assert modifier.quantization_modifier_ is None testing_harness = LifecyleTestingHarness(model=LinearNet()) modifier.on_initialize_structure(testing_harness.get_state()) assert modifier.quantize - self.assertIsInstance(modifier.quantization_modifier_, QuantizationModifier) + self.assertIsInstance(modifier.quantization_modifier_, vLLMQuantizationModifier) - dict_scheme = dict(modifier.quantization_modifier_.scheme) - self.assertEqual( - dict(dict_scheme["weights"]), self.quant_kwargs["scheme"]["weights"] + dict_scheme = dict(modifier.quantization_modifier_.config_groups) + self._check_config( + dict(dict_scheme["config_group_0"].weights), + self.quant_kwargs["config_groups"]["config_group_0"]["weights"], ) - self.assertEqual( - dict(dict_scheme["input_activations"]), - self.quant_kwargs["scheme"]["input_activations"], + self._check_config( + dict(dict_scheme["config_group_0"].input_activations), + self.quant_kwargs["config_groups"]["config_group_0"]["input_activations"], ) + + def _check_config(self, actual, expected): + self.assertEqual(actual["num_bits"], expected["num_bits"]) + self.assertEqual(actual["symmetric"], expected["symmetric"]) + self.assertEqual(actual["strategy"], expected["strategy"]) diff --git a/tests/sparseml/transformers/compression/recipes/new_quant_full.yaml b/tests/sparseml/transformers/compression/recipes/new_quant_full.yaml index c5a55fa3284..409a168ecfd 100644 --- a/tests/sparseml/transformers/compression/recipes/new_quant_full.yaml +++ b/tests/sparseml/transformers/compression/recipes/new_quant_full.yaml @@ -25,9 +25,8 @@ test_stage: input_activations: null output_activations: null targets: ["Embedding"] - SparseGPTModifier: - sparsity: 0.0 + GPTQModifier: block_size: 128 sequential_update: False - quantize: True + percdamp: 0.01 targets: ["re:model.layers.\\d+$"] \ No newline at end of file diff --git a/tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml b/tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml index 64a1f87b29d..68bf42e1bc5 100644 --- a/tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml +++ b/tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml @@ -12,9 +12,8 @@ test_stage: input_activations: null output_activations: null targets: ["Linear", "Embedding"] - SparseGPTModifier: - sparsity: 0.0 + GPTQModifier: block_size: 128 sequential_update: False - quantize: True + percdamp: 0.01 targets: ["re:model.layers.\\d+$"] \ No newline at end of file diff --git a/tests/sparseml/transformers/compression/recipes/old_quant_full.yaml b/tests/sparseml/transformers/compression/recipes/old_quant_full.yaml index 8a94733242a..95edd24628e 100644 --- a/tests/sparseml/transformers/compression/recipes/old_quant_full.yaml +++ b/tests/sparseml/transformers/compression/recipes/old_quant_full.yaml @@ -31,9 +31,8 @@ test_stage: strategy: "tensor" input_activations: null output_activations: null - SparseGPTModifier: - sparsity: 0.0 + GPTQModifier: block_size: 128 sequential_update: False - quantize: True + percdamp: 0.01 targets: ["re:model.layers.\\d+$"] \ No newline at end of file diff --git a/tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml b/tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml index e095a22912b..375dcfceb6c 100644 --- a/tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml +++ b/tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml @@ -28,9 +28,8 @@ test_stage: strategy: "tensor" input_activations: null output_activations: null - SparseGPTModifier: - sparsity: 0.0 + GPTQModifier: block_size: 128 sequential_update: False - quantize: True + percdamp: 0.01 targets: ["re:model.layers.\\d+$"] \ No newline at end of file diff --git a/tests/sparseml/transformers/finetune/test_alternate_recipe.yaml b/tests/sparseml/transformers/finetune/test_alternate_recipe.yaml index 411d6a41fed..877d6eae91e 100644 --- a/tests/sparseml/transformers/finetune/test_alternate_recipe.yaml +++ b/tests/sparseml/transformers/finetune/test_alternate_recipe.yaml @@ -4,7 +4,6 @@ test_oneshot_stage: sparsity: 0.7 block_size: 128 sequential_update: False - quantize: False percdamp: 0.01 mask_structure: "0:0" targets: [ diff --git a/tests/sparseml/transformers/obcq/obcq_configs/consec_runs/mask_structure/tiny_llama_mask_structure_preservation.yaml b/tests/sparseml/transformers/obcq/obcq_configs/consec_runs/mask_structure/tiny_llama_mask_structure_preservation.yaml new file mode 100644 index 00000000000..98aadb22cf0 --- /dev/null +++ b/tests/sparseml/transformers/obcq/obcq_configs/consec_runs/mask_structure/tiny_llama_mask_structure_preservation.yaml @@ -0,0 +1,9 @@ +cadence: "commit" +test_type: "sanity" +model: "Xenova/llama2.c-stories15M" +dataset: open_platypus +initial_pruning_only_recipe: "tests/sparseml/transformers/obcq/recipes/sparse_with_mask_structure.yaml" +initial_sparsity: 0.5 +recipe_mask_structure: "2:4" +subsequent_prune_and_quant_recipe: "tests/sparseml/transformers/obcq/recipes/additional_sparsity_with_quant.yaml" +final_sparsity: 0.7 \ No newline at end of file diff --git a/tests/sparseml/transformers/obcq/recipes/additional_sparsity.yaml b/tests/sparseml/transformers/obcq/recipes/additional_sparsity.yaml index 19d479e8666..9dde06bc309 100644 --- a/tests/sparseml/transformers/obcq/recipes/additional_sparsity.yaml +++ b/tests/sparseml/transformers/obcq/recipes/additional_sparsity.yaml @@ -4,7 +4,6 @@ test_stage: sparsity: 0.7 block_size: 128 sequential_update: True - quantize: False percdamp: 0.01 mask_structure: "0:0" targets: [ diff --git a/tests/sparseml/transformers/obcq/recipes/additional_sparsity_with_quant.yaml b/tests/sparseml/transformers/obcq/recipes/additional_sparsity_with_quant.yaml new file mode 100644 index 00000000000..42538955b5e --- /dev/null +++ b/tests/sparseml/transformers/obcq/recipes/additional_sparsity_with_quant.yaml @@ -0,0 +1,43 @@ +test_stage: + obcq_modifiers: + SmoothQuantModifier: + smoothing_strength: 0.5 + mappings: [ + [["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm"], + [["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"] + ] + QuantizationModifier: + ignore: + - LlamaRotaryEmbedding + - LlamaRMSNorm + - SiLU + - model.layers.0.mlp.down_proj + - model.layers.1.mlp.down_proj + - model.layers.2.mlp.down_proj + - model.layers.3.mlp.down_proj + - model.layers.4.mlp.down_proj + - model.layers.5.mlp.down_proj + post_oneshot_calibration: True + scheme_overrides: + Embedding: + input_activations: null + weights: + num_bits: 8 + symmetric: False + SparseGPTModifier: + sparsity: 0.7 + block_size: 128 + sequential_update: False + percdamp: 0.01 + mask_structure: "0:0" + targets: [ + "model.layers.0", + ] + preserve_sparsity_mask: True + GPTQModifier: + sequential_update: False + dampening_frac: 0.01 + targets: [ + "model.layers.0", + ] + block_size: 128 \ No newline at end of file diff --git a/tests/sparseml/transformers/obcq/recipes/quant.yaml b/tests/sparseml/transformers/obcq/recipes/quant.yaml index d229cba2923..0de93074d63 100644 --- a/tests/sparseml/transformers/obcq/recipes/quant.yaml +++ b/tests/sparseml/transformers/obcq/recipes/quant.yaml @@ -27,7 +27,6 @@ test_stage: sparsity: 0.0 block_size: 128 sequential_update: False - quantize: True percdamp: 0.01 mask_structure: "0:0" targets: [ @@ -37,4 +36,16 @@ test_stage: "model.layers.3", "model.layers.4", "model.layers.5" - ] \ No newline at end of file + ] + GPTQModifier: + block_size: 128 + sequential_update: False + percdamp: 0.01 + targets: [ + "model.layers.0", + "model.layers.1", + "model.layers.2", + "model.layers.3", + "model.layers.4", + "model.layers.5" + ] \ No newline at end of file diff --git a/tests/sparseml/transformers/obcq/recipes/quant_and_sparse.yaml b/tests/sparseml/transformers/obcq/recipes/quant_and_sparse.yaml index ddaf20b854f..7af58d32815 100644 --- a/tests/sparseml/transformers/obcq/recipes/quant_and_sparse.yaml +++ b/tests/sparseml/transformers/obcq/recipes/quant_and_sparse.yaml @@ -28,9 +28,20 @@ test_stage: sparsity: 0.5 block_size: 128 sequential_update: False - quantize: True percdamp: 0.01 mask_structure: "0:0" + targets: [ + "model.layers.0", + "model.layers.1", + "model.layers.2", + "model.layers.3", + "model.layers.4", + "model.layers.5" + ] + GPTQModifier: + block_size: 128 + sequential_update: False + percdamp: 0.01 targets: [ "model.layers.0", "model.layers.1", diff --git a/tests/sparseml/transformers/obcq/recipes/sparse.yaml b/tests/sparseml/transformers/obcq/recipes/sparse.yaml index 3b03ff95f7e..70ffc7bf784 100644 --- a/tests/sparseml/transformers/obcq/recipes/sparse.yaml +++ b/tests/sparseml/transformers/obcq/recipes/sparse.yaml @@ -4,7 +4,6 @@ test_stage: sparsity: 0.3 block_size: 128 sequential_update: False - quantize: False percdamp: 0.01 mask_structure: "0:0" targets: [ diff --git a/tests/sparseml/transformers/obcq/recipes/sparse_with_mask_structure.yaml b/tests/sparseml/transformers/obcq/recipes/sparse_with_mask_structure.yaml new file mode 100644 index 00000000000..5f283b6095a --- /dev/null +++ b/tests/sparseml/transformers/obcq/recipes/sparse_with_mask_structure.yaml @@ -0,0 +1,11 @@ +test_stage: + obcq_modifiers: + SparseGPTModifier: + sparsity: 0.5 + block_size: 128 + sequential_update: False + percdamp: 0.01 + mask_structure: "2:4" + targets: [ + "model.layers.0", + ] \ No newline at end of file diff --git a/tests/sparseml/transformers/obcq/recipes/test_tiny2.yaml b/tests/sparseml/transformers/obcq/recipes/test_tiny2.yaml index f513b7e0c4f..db88979eaf8 100644 --- a/tests/sparseml/transformers/obcq/recipes/test_tiny2.yaml +++ b/tests/sparseml/transformers/obcq/recipes/test_tiny2.yaml @@ -4,7 +4,6 @@ test_stage: sparsity: 0.5 block_size: 128 sequential_update: False - quantize: False percdamp: 0.01 mask_structure: "0:0" targets: [ diff --git a/tests/sparseml/transformers/obcq/test_mask_structure_preservation.py b/tests/sparseml/transformers/obcq/test_mask_structure_preservation.py new file mode 100644 index 00000000000..a068c391431 --- /dev/null +++ b/tests/sparseml/transformers/obcq/test_mask_structure_preservation.py @@ -0,0 +1,148 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest +from pathlib import Path + +import pytest + +import sparseml +from parameterized import parameterized_class +from tests.testing_utils import parse_params, requires_torch + + +MASK_STRUCTURE_CONFIGS_DIRECTORY = ( + "tests/sparseml/transformers/obcq/obcq_configs/consec_runs/mask_structure" +) + + +def tensor_follows_mask_structure(tensor, mask: str = "2:4"): + """ + :param tensor: tensor to check + :param mask: mask structure to check for, in the format "n:m" + :return: True if the tensor follows the mask structure, False otherwise. + Note, some weights can incidentally be zero, so we check for + atleast n zeros in each chunk of size m + """ + import torch + + n, m = tuple(map(int, mask.split(":"))) + # Reshape the tensor into chunks of size m + tensor = tensor.view(-1, m) + + # Count the number of zeros in each chunk + zero_counts = (tensor == 0).sum(dim=1) + + # Check if the number of zeros in each chunk atleast n + # Greater than sign is needed as some weights can incidentally + # be zero + return torch.all(zero_counts >= n) + + +@requires_torch +@pytest.mark.integration +@parameterized_class(parse_params(MASK_STRUCTURE_CONFIGS_DIRECTORY)) +class TestMaskStructurePreserved(unittest.TestCase): + """ + Tests that the mask structure is preserved across multiple runs of oneshot + initial model is pruned using a mask_structure, and then the pruned model + is further pruned and quantized. + """ + + model = None + initial_pruning_only_recipe = None + initial_sparsity = None + recipe_mask_structure = None + dataset = None + subsequent_prune_and_quant_recipe = None + final_sparsity = None + + def setUp(self) -> None: + import torch + + self.device = "cuda:0" if torch.cuda.is_available() else "cpu" + self.output = "./oneshot_output" + self.output_first = Path(self.output) / "test_1" + self.output_second = Path(self.output) / "test_2" + + def test_mask_structure_preserved(self): + """ + Checks that the mask structure is preserved across runs of oneshot + between the initial pruning and the subsequent pruning + quantization + """ + import math + + import torch + + from sparseml.pytorch.model_load.helpers import get_session_model + from sparseml.pytorch.utils.helpers import tensor_sparsity + from sparseml.transformers import oneshot + from sparseml.utils.pytorch import qat_active + + tolerance = 1e-3 + num_calibration_samples = 16 + + oneshot( + model=self.model, + dataset=self.dataset, + num_calibration_samples=num_calibration_samples, + recipe=self.initial_pruning_only_recipe, + output_dir=self.output_first, + oneshot_device=self.device, + clear_sparse_session=False, + ) + first_tiny_model = get_session_model() + targetted_layer = first_tiny_model.model.layers[0].self_attn.k_proj + target_layer_sparsity = tensor_sparsity(targetted_layer.weight) + initial_mask = first_tiny_model.model.layers[0].self_attn.k_proj.weight == 0 + + # sparsity is as expected, i.e close to self.initial_sparsity + assert math.isclose( + target_layer_sparsity.item(), self.initial_sparsity, rel_tol=tolerance + ) + # mask structure is as expected, i.e same as self.recipe_mask_structure + assert tensor_follows_mask_structure(initial_mask, self.recipe_mask_structure) + + sparseml.reset_session() + + oneshot( + model=self.output_first, + dataset=self.dataset, + num_calibration_samples=num_calibration_samples, + recipe=self.subsequent_prune_and_quant_recipe, + output_dir=self.output_second, + oneshot_device=self.device, + clear_sparse_session=False, + ) + + second_tiny_model = get_session_model() + + # model is loaded + assert second_tiny_model is not None + + targetted_layer = second_tiny_model.model.layers[0].self_attn.k_proj.module + target_layer_sparsity = tensor_sparsity(targetted_layer.weight) + + # sparsity is as expected, i.e close to self.final_sparsity + assert math.isclose( + target_layer_sparsity.item(), self.final_sparsity, rel_tol=tolerance + ) + # qat should be active, second recipe has quantization + assert qat_active(second_tiny_model) + + # original mask structure is preserved, additional zeros are + # added on top of the initial mask + final_mask = targetted_layer.weight == 0 + assert torch.all(initial_mask <= final_mask) diff --git a/tests/sparseml/transformers/oneshot/oneshot_configs/recipes/recipe.yaml b/tests/sparseml/transformers/oneshot/oneshot_configs/recipes/recipe.yaml index 6157f2ec114..c5bf782d494 100644 --- a/tests/sparseml/transformers/oneshot/oneshot_configs/recipes/recipe.yaml +++ b/tests/sparseml/transformers/oneshot/oneshot_configs/recipes/recipe.yaml @@ -4,7 +4,6 @@ test_stage: sparsity: 0.5 block_size: 128 sequential_update: False - quantize: False targets: [ 're:model.layers.3.mlp.gate_proj.weight' ] \ No newline at end of file diff --git a/tests/sparseml/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml b/tests/sparseml/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml index d51a0ec420c..39f9d65762d 100644 --- a/tests/sparseml/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml +++ b/tests/sparseml/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml @@ -10,7 +10,6 @@ recipe: | sparsity: 0.5 block_size: 128 sequential_update: False - quantize: False targets: [ 're:model.layers.3.mlp.gate_proj.weight' ] \ No newline at end of file diff --git a/tests/sparseml/transformers/oneshot/oneshot_configs/tiny_stories_conf4.yaml b/tests/sparseml/transformers/oneshot/oneshot_configs/tiny_stories_conf4.yaml index 2dfc6553563..c6cc1376c15 100644 --- a/tests/sparseml/transformers/oneshot/oneshot_configs/tiny_stories_conf4.yaml +++ b/tests/sparseml/transformers/oneshot/oneshot_configs/tiny_stories_conf4.yaml @@ -11,7 +11,6 @@ recipe: | sparsity: 0.5 block_size: 128 sequential_update: False - quantize: False targets: [ 're:model.layers.3.mlp.gate_proj.weight' ] \ No newline at end of file