diff --git a/src/sparseml/experimental/__init__.py b/src/sparseml/experimental/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/src/sparseml/experimental/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/experimental/sparsegpt/__init__.py b/src/sparseml/experimental/sparsegpt/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/src/sparseml/experimental/sparsegpt/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/experimental/sparsegpt/dispatch.py b/src/sparseml/experimental/sparsegpt/dispatch.py index f9715a05752..4c1c80eeff2 100644 --- a/src/sparseml/experimental/sparsegpt/dispatch.py +++ b/src/sparseml/experimental/sparsegpt/dispatch.py @@ -12,17 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -SUPPORTED_MODELS = ["opt", "mpt", "llama-2"] +SUPPORTED_MODELS = ["opt", "mpt", "llama"] def load_model(args, model_key: str = None, *gargs, **kwargs): model_key = _get_model_key(args) if model_key is None else model_key if model_key == "opt": - from opt import load_model as _load_model + from sparseml.experimental.sparsegpt.opt import load_model as _load_model elif model_key == "mpt": - from mpt import load_model as _load_model - elif model_key == "llama-2": - from llama2 import load_model as _load_model + from sparseml.experimental.sparsegpt.mpt import load_model as _load_model + elif model_key == "llama": + from sparseml.experimental.sparsegpt.llama2 import load_model as _load_model else: raise ValueError(f"Unrecognized model key. Supported: {SUPPORTED_MODELS}") return _load_model(args, *gargs, **kwargs) @@ -31,11 +31,11 @@ def load_model(args, model_key: str = None, *gargs, **kwargs): def load_data(args, model_key: str = None, *gargs, **kwargs): model_key = _get_model_key(args) if model_key is None else model_key if model_key == "opt": - from opt import load_data as _load_data + from sparseml.experimental.sparsegpt.opt import load_data as _load_data elif model_key == "mpt": - from mpt import load_data as _load_data - elif model_key == "llama-2": - from llama2 import load_data as _load_data + from sparseml.experimental.sparsegpt.mpt import load_data as _load_data + elif model_key == "llama": + from sparseml.experimental.sparsegpt.llama2 import load_data as _load_data else: raise ValueError(f"Unrecognized model key. Supported: {SUPPORTED_MODELS}") return _load_data(args, *gargs, **kwargs) @@ -46,9 +46,9 @@ def evaluate_perplexity( ): model_key = _get_model_key(args) if model_key is None else model_key if model_key == "opt": - from opt import ppl_eval as _ppl_eval - elif model_key == "llama-2": - from llama2 import ppl_eval as _ppl_eval + from sparseml.experimental.sparsegpt.opt import ppl_eval as _ppl_eval + elif model_key == "llama": + from sparseml.experimental.sparsegpt.llama2 import ppl_eval as _ppl_eval else: raise ValueError(f"Unrecognized model key. Supported: {SUPPORTED_MODELS}") return _ppl_eval(args, model, dataloader, dev, *gargs, **kwargs) @@ -57,11 +57,17 @@ def evaluate_perplexity( def prepare_sparsegpt(model, dataloader, args, model_key: str = None, **kwargs): model_key = _get_model_key(args) if model_key is None else model_key if model_key == "opt": - from opt import prepare_sparsegpt as _prepare_sparsegpt + from sparseml.experimental.sparsegpt.opt import ( + prepare_sparsegpt as _prepare_sparsegpt, + ) elif model_key == "mpt": - from mpt import prepare_sparsegpt as _prepare_sparsegpt - elif model_key == "llama-2": - from llama2 import prepare_sparsegpt as _prepare_sparsegpt + from sparseml.experimental.sparsegpt.mpt import ( + prepare_sparsegpt as _prepare_sparsegpt, + ) + elif model_key == "llama": + from sparseml.experimental.sparsegpt.llama2 import ( + prepare_sparsegpt as _prepare_sparsegpt, + ) else: raise ValueError(f"Unrecognized model key. Supported: {SUPPORTED_MODELS}") return _prepare_sparsegpt(model, dataloader, args, **kwargs) diff --git a/src/sparseml/experimental/sparsegpt/examples/llama2/compare_obcq.py b/src/sparseml/experimental/sparsegpt/examples/llama2/compare_obcq.py new file mode 100644 index 00000000000..cb821781628 --- /dev/null +++ b/src/sparseml/experimental/sparsegpt/examples/llama2/compare_obcq.py @@ -0,0 +1,112 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from sparseml.experimental.sparsegpt.dispatch import evaluate_perplexity, load_model +from sparseml.experimental.sparsegpt.llama2 import load_data +from sparseml.experimental.sparsegpt.main import sequential +from sparseml.modifiers.obcq.utils.helpers import ppl_eval_general +from sparseml.transformers.sparsification.obcq.obcq import one_shot +from sparseml.transformers.sparsification.obcq.utils.helpers import llama_forward + + +dataset = "open_platypus" +model_name = "/home/sadkins/ml-experiments/nlg-text_generation/" +model_name += "llama_chat-llama_7b_chat-base/dense/training" +sparsity = 0.5 +nbits = 8 +smooth_quant = 0 +observer_batches = 128 +nsamples = 128 +data_sequence_length = 2048 +sequential_hessian = 0 +experimental_recipe = "src/sparseml/experimental/sparsegpt/examples/llama2/recipes/" +experimental_recipe += "llama_recipe.yaml" +prod_recipe = "src/sparseml/transformers/sparsification/obcq/example_llama.yaml" +device = "cuda:0" +seed = 0 +prunen = 0 +prunem = 0 +percdamp = 0.01 +blocksize = 128 +ptq_only = 0 + + +class ExperimentalArgs: + model = model_name + dataset = dataset + data_sequence_length = data_sequence_length + sequential_hessian_within_layer = sequential_hessian + recipe = experimental_recipe + sparsity = sparsity + wbits = nbits + observer_batches = observer_batches + nsamples = nsamples + smoothquant = smooth_quant + seed = seed + prunen = prunen + prunem = prunem + percdamp = percdamp + blocksize = blocksize + ptq_only = ptq_only + + +class ProdArgs: + model = model_name + dataset = dataset + nsamples = nsamples + device = device + recipe = prod_recipe + eval = False + save = False + + +def run_experimental_obcq(experimental_args): + model = load_model(experimental_args) + calibration_data, _, _ = load_data(experimental_args, data_sequence_length) + sequential(model, calibration_data, device, experimental_args) + + del calibration_data + return model + + +if __name__ == "__main__": + experimental_args = ExperimentalArgs() + exp_model = run_experimental_obcq(experimental_args) + _, testloader, _ = load_data(experimental_args, data_sequence_length) + exp_perplexity = evaluate_perplexity( + experimental_args, exp_model, testloader, device, max_samples_per_iteration=8 + ) + del testloader + del exp_model + torch.cuda.empty_cache() + + prod_args = ProdArgs() + prod_model = one_shot( + model_path=prod_args.model, + dataset_name=prod_args.dataset, + num_samples=prod_args.nsamples, + device=prod_args.device, + recipe_file=prod_args.recipe, + ) + torch.cuda.empty_cache() + + _, testloader, _ = load_data(experimental_args, data_sequence_length) + prod_perplexity = ppl_eval_general( + llama_forward, prod_model, testloader, device, max_samples_per_iteration=8 + ) + print( + f"Experimental Perplexity: {exp_perplexity}, " + f"Production Perplexity: {prod_perplexity}" + ) diff --git a/src/sparseml/experimental/sparsegpt/examples/llama2/recipes/llama_recipe.yaml b/src/sparseml/experimental/sparsegpt/examples/llama2/recipes/llama_recipe.yaml new file mode 100644 index 00000000000..41513e49946 --- /dev/null +++ b/src/sparseml/experimental/sparsegpt/examples/llama2/recipes/llama_recipe.yaml @@ -0,0 +1,52 @@ +# Quantization variables +observer_freeze_epoch: 1 +bn_freeze_epoch: 1 +qat_start_epoch: 0 + +quantization_modifiers: + - !QuantizationModifier + start_epoch: eval(qat_start_epoch) + disable_quantization_observer_epoch: eval(observer_freeze_epoch) + freeze_bn_stats_epoch: eval(bn_freeze_epoch) + ignore: + - LlamaRotaryEmbedding + - LlamaRMSNorm + - SiLUActivation + - model.layers.0.mlp.down_proj + - model.layers.1.mlp.down_proj + - model.layers.2.mlp.down_proj + - model.layers.3.mlp.down_proj + - model.layers.4.mlp.down_proj + - model.layers.5.mlp.down_proj + - model.layers.6.mlp.down_proj + - model.layers.7.mlp.down_proj + - model.layers.8.mlp.down_proj + - model.layers.9.mlp.down_proj + - model.layers.10.mlp.down_proj + - model.layers.11.mlp.down_proj + - model.layers.12.mlp.down_proj + - model.layers.13.mlp.down_proj + - model.layers.14.mlp.down_proj + - model.layers.15.mlp.down_proj + - model.layers.16.mlp.down_proj + - model.layers.17.mlp.down_proj + - model.layers.18.mlp.down_proj + - model.layers.19.mlp.down_proj + - model.layers.20.mlp.down_proj + - model.layers.21.mlp.down_proj + - model.layers.22.mlp.down_proj + - model.layers.23.mlp.down_proj + - model.layers.24.mlp.down_proj + - model.layers.25.mlp.down_proj + - model.layers.26.mlp.down_proj + - model.layers.27.mlp.down_proj + - model.layers.28.mlp.down_proj + - model.layers.29.mlp.down_proj + - model.layers.30.mlp.down_proj + - model.layers.31.mlp.down_proj + scheme_overrides: + Embedding: + input_activations: null + weights: + num_bits: 8 + symmetric: False \ No newline at end of file diff --git a/src/sparseml/experimental/sparsegpt/examples/opt/compare_obcq.py b/src/sparseml/experimental/sparsegpt/examples/opt/compare_obcq.py new file mode 100644 index 00000000000..5f773435856 --- /dev/null +++ b/src/sparseml/experimental/sparsegpt/examples/opt/compare_obcq.py @@ -0,0 +1,112 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from sparseml.experimental.sparsegpt.dispatch import evaluate_perplexity, load_model +from sparseml.experimental.sparsegpt.main import sequential +from sparseml.experimental.sparsegpt.opt import load_data +from sparseml.modifiers.obcq.utils.helpers import ppl_eval_general +from sparseml.transformers.sparsification.obcq.obcq import one_shot +from sparseml.transformers.sparsification.obcq.utils.helpers import opt_forward + + +dataset = "c4" +model_name = "facebook/opt-1.3b" +sparsity = 0.5 +nbits = 8 +smooth_quant = 0 +observer_batches = 128 +nsamples = 128 +data_sequence_length = 2048 +sequential_hessian = 0 +experimental_recipe = "src/sparseml/experimental/sparsegpt/examples/opt/recipes/" +experimental_recipe += "opt-1.3b-opt_pretrain-pruned50_quantW8A8.md" +prod_recipe = "src/sparseml/transformers/sparsification/obcq/example.yaml" +device = "cuda:0" +seed = 0 +prunen = 0 +prunem = 0 +percdamp = 0.01 +blocksize = 128 +ptq_only = 0 + + +class ExperimentalArgs: + model = model_name + dataset = dataset + data_sequence_length = data_sequence_length + sequential_hessian_within_layer = sequential_hessian + recipe = experimental_recipe + sparsity = sparsity + wbits = nbits + observer_batches = observer_batches + nsamples = nsamples + smoothquant = smooth_quant + seed = seed + prunen = prunen + prunem = prunem + percdamp = percdamp + blocksize = blocksize + ptq_only = ptq_only + + +class ProdArgs: + model = model_name + dataset = dataset + nsamples = nsamples + device = device + recipe = prod_recipe + save = False + + +def run_experimental_obcq(experimental_args): + model = load_model(experimental_args) + calibration_data, _, _ = load_data(experimental_args, data_sequence_length) + sequential(model, calibration_data, device, experimental_args) + + del calibration_data + return model + + +if __name__ == "__main__": + experimental_args = ExperimentalArgs() + exp_model = run_experimental_obcq(experimental_args) + experimental_args.dataset = "wikitext2" + _, testloader, _ = load_data(experimental_args, data_sequence_length) + exp_perplexity = evaluate_perplexity( + experimental_args, exp_model, testloader, device, max_samples_per_iteration=8 + ) + + del testloader + del exp_model + torch.cuda.empty_cache() + + prod_args = ProdArgs() + prod_model = one_shot( + model_path=prod_args.model, + dataset_name=prod_args.dataset, + num_samples=prod_args.nsamples, + device=prod_args.device, + recipe_file=prod_args.recipe, + ) + experimental_args.dataset = "wikitext2" + _, testloader, _ = load_data(experimental_args, data_sequence_length) + prod_perplexity = ppl_eval_general( + opt_forward, prod_model, testloader, device, max_samples_per_iteration=8 + ) + print( + f"Experimental Perplexity: {exp_perplexity}, " + f"Production Perplexity: {prod_perplexity}" + ) diff --git a/src/sparseml/experimental/sparsegpt/examples/opt/scripts/prune_quantize.opt.0.sh b/src/sparseml/experimental/sparsegpt/examples/opt/scripts/prune_quantize.opt.0.sh index 30c67e8c40e..e476e055d27 100755 --- a/src/sparseml/experimental/sparsegpt/examples/opt/scripts/prune_quantize.opt.0.sh +++ b/src/sparseml/experimental/sparsegpt/examples/opt/scripts/prune_quantize.opt.0.sh @@ -2,11 +2,11 @@ export CUDA_VISIBLE_DEVICES=0 -ROOT=$HOME/src/neuralmagic/sparseml/src/sparseml/experimental/sparsegpt +ROOT=$HOME/sparseml/src/sparseml/experimental/sparsegpt DATASET=c4 -RECIPE_DIR=$ROOT/recipes +RECIPE_DIR=$ROOT/examples/opt/recipes RECIPE_NAME=opt-1.3b-opt_pretrain-pruned50_quantW8A8 SRC_MODEL_ORG=facebook diff --git a/src/sparseml/experimental/sparsegpt/layer_compressor.py b/src/sparseml/experimental/sparsegpt/layer_compressor.py index 0ec8dee0235..df70ed75b00 100644 --- a/src/sparseml/experimental/sparsegpt/layer_compressor.py +++ b/src/sparseml/experimental/sparsegpt/layer_compressor.py @@ -18,10 +18,13 @@ import torch import torch.nn as nn -from quant import WeightFakeQuantizer +from sparseml.experimental.sparsegpt.quant import WeightFakeQuantizer from sparseml.experimental.sparsegpt.sparsegpt import SparseGPT +DEFAULT_WBITS = 16 + + class BaseCompressor: def __init__(self, model): self.model = model diff --git a/src/sparseml/experimental/sparsegpt/llama.py b/src/sparseml/experimental/sparsegpt/llama.py index 942bd2fe49c..bcc8f18140d 100644 --- a/src/sparseml/experimental/sparsegpt/llama.py +++ b/src/sparseml/experimental/sparsegpt/llama.py @@ -26,7 +26,6 @@ repeat_kv, ) -from layer_compressor import BaseCompressor, LayerCompressor from llmfoundry import ( COMPOSER_MODEL_REGISTRY, build_finetuning_dataloader, @@ -36,7 +35,11 @@ from llmfoundry.utils.builders import build_tokenizer from model_preprocessor import QuantizationModelPreprocessor from omegaconf import OmegaConf as om -from quant import ( +from sparseml.experimental.sparsegpt.layer_compressor import ( + BaseCompressor, + LayerCompressor, +) +from sparseml.experimental.sparsegpt.quant import ( MatMulLeftInput_PV, MatMulLeftInput_QK, MatMulOutput_PV, @@ -45,7 +48,7 @@ MatMulRightInput_QK, QuantizableMatMul, ) -from sequential import SequentialSparseGPT +from sparseml.experimental.sparsegpt.sequential import SequentialSparseGPT class SequentialSparseGPT_LLAMA(SequentialSparseGPT): diff --git a/src/sparseml/experimental/sparsegpt/llama2.py b/src/sparseml/experimental/sparsegpt/llama2.py index a26231fb482..b6ced4e7e9a 100644 --- a/src/sparseml/experimental/sparsegpt/llama2.py +++ b/src/sparseml/experimental/sparsegpt/llama2.py @@ -85,7 +85,8 @@ def load_model(args): model = LlamaForCausalLM.from_pretrained(model, torch_dtype="auto") model.eval() seqlen = model.config.max_position_embeddings - return model, seqlen + model.seqlen = seqlen + return model def load_data(args, seqlen, split=0.1): diff --git a/src/sparseml/experimental/sparsegpt/main.py b/src/sparseml/experimental/sparsegpt/main.py index d64c80c5805..1fe27a8ddcd 100644 --- a/src/sparseml/experimental/sparsegpt/main.py +++ b/src/sparseml/experimental/sparsegpt/main.py @@ -158,7 +158,8 @@ def _save(model, tokenizer, save_path): wandb.init(config=args) print("Load model", flush=True) - model, seqlen = load_model(args) + model = load_model(args) + seqlen = model.seqlen print("Load data", flush=True) dataloader, testloader, tokenizer = load_data(args, None, seqlen) diff --git a/src/sparseml/experimental/sparsegpt/mpt.py b/src/sparseml/experimental/sparsegpt/mpt.py index ee836cb801c..a109b9c9573 100644 --- a/src/sparseml/experimental/sparsegpt/mpt.py +++ b/src/sparseml/experimental/sparsegpt/mpt.py @@ -21,7 +21,6 @@ import torch.nn as nn from einops import rearrange -from layer_compressor import BaseCompressor, LayerCompressor from llmfoundry import ( COMPOSER_MODEL_REGISTRY, build_finetuning_dataloader, @@ -29,9 +28,17 @@ ) from llmfoundry.data.text_data import build_text_dataloader from llmfoundry.utils.builders import build_tokenizer -from model_preprocessor import ModelPreprocessor, QuantizationModelPreprocessor from omegaconf import OmegaConf as om -from quant import ( +from sequential import SequentialSparseGPT +from sparseml.experimental.sparsegpt.layer_compressor import ( + BaseCompressor, + LayerCompressor, +) +from sparseml.experimental.sparsegpt.model_preprocessor import ( + ModelPreprocessor, + QuantizationModelPreprocessor, +) +from sparseml.experimental.sparsegpt.quant import ( MatMulLeftInput_PV, MatMulLeftInput_QK, MatMulOutput_PV, @@ -40,7 +47,6 @@ MatMulRightInput_QK, QuantizableMatMul, ) -from sequential import SequentialSparseGPT class SequentialSparseGPT_MPT(SequentialSparseGPT): diff --git a/src/sparseml/experimental/sparsegpt/opt.py b/src/sparseml/experimental/sparsegpt/opt.py index 9b1f0e83aeb..a78429f14a0 100644 --- a/src/sparseml/experimental/sparsegpt/opt.py +++ b/src/sparseml/experimental/sparsegpt/opt.py @@ -163,7 +163,8 @@ def skip(*args, **kwargs): model = OPTForCausalLM.from_pretrained(model, torch_dtype="auto") seqlen = model.config.max_position_embeddings - return model, seqlen + model.seqlen = seqlen + return model def load_data(args, seqlen, split=0.1): diff --git a/src/sparseml/experimental/sparsegpt/sequential.py b/src/sparseml/experimental/sparsegpt/sequential.py index ab0aa69b8f3..685ddfbcdcd 100644 --- a/src/sparseml/experimental/sparsegpt/sequential.py +++ b/src/sparseml/experimental/sparsegpt/sequential.py @@ -28,13 +28,11 @@ def __init__( recipe: Optional[str] = None, model_preprocessors: Optional[List[ModelPreprocessor]] = None, bottom_compressor: Optional[LayerCompressor] = None, - head_compressor: Optional[LayerCompressor] = None, args=None, ): self.model = model self.model_preprocessors = model_preprocessors self.bottom_compressor = bottom_compressor - self.head_compressor = head_compressor self.recipe = recipe self.manager = None self.compressible_layers = self.compressible_layers() @@ -105,10 +103,6 @@ def compress(self, dev: str = "cuda:0", **kwargs): ) accum_kwargs.update(layer_kwargs) - # Step 2: Prune/quantize head - if self.head_compressor is not None: - self.model, extras = self.head_compressor.compress(dev=dev, **accum_kwargs) - return self.model, {} def post_compress(self, dev: str = "cuda:0", **kwargs): diff --git a/src/sparseml/experimental/sparsegpt/utils.py b/src/sparseml/experimental/sparsegpt/utils.py index 191c97d3369..6b5d9f78a93 100644 --- a/src/sparseml/experimental/sparsegpt/utils.py +++ b/src/sparseml/experimental/sparsegpt/utils.py @@ -232,6 +232,7 @@ def ppl_eval_general( ppl = torch.exp(neg_log_likelihood / number_tokens) print(f"Perplexity: {ppl.item():3f}") + return ppl.item() def get_wikitext2(nsamples, seed, seqlen, model): diff --git a/src/sparseml/modifiers/__init__.py b/src/sparseml/modifiers/__init__.py index de33872de9b..adf250cf344 100644 --- a/src/sparseml/modifiers/__init__.py +++ b/src/sparseml/modifiers/__init__.py @@ -15,4 +15,6 @@ # flake8: noqa from .distillation import * +from .obcq import * from .pruning import * +from .quantization import * diff --git a/src/sparseml/modifiers/obcq/__init__.py b/src/sparseml/modifiers/obcq/__init__.py new file mode 100644 index 00000000000..22bafe8ce14 --- /dev/null +++ b/src/sparseml/modifiers/obcq/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa + +from .base import * +from .pytorch import * diff --git a/src/sparseml/modifiers/obcq/base.py b/src/sparseml/modifiers/obcq/base.py new file mode 100644 index 00000000000..fe66f61e505 --- /dev/null +++ b/src/sparseml/modifiers/obcq/base.py @@ -0,0 +1,63 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Union + +from sparseml.core import Modifier +from sparseml.core.state import State +from sparseml.utils import ALL_TOKEN + + +__all__ = ["SparseGPTModifier"] + + +class SparseGPTModifier(Modifier): + """ + Modifier for applying the one-shot OBCQ algorithm to a model + + Life-cycle: + - initialze + - compress + - finalize + + :param sparsity: Sparsity to compress model to + :param block_size: Used to determine number of columns to compress in one pass + :param quantize: Whether or not model is quantized (affects layer names) + :param dampening_frac: Amount of dampening to apply to H, as a fraction of the + diagonal norm + :param sequential_update: Whether or not to update weights sequentially by layer, + True saves on GPU memory + :param prunen: N for N:M pruning + :param prunem: M for N:M pruning + :param compress_layers: list of layer names to compress during OBCQ, or '__ALL__' + to compress every layer in the model + :param target_ids: list of keys in model output to cache + :param layer_prefix: name of model attribute that contains the list of layers, i.e. + model.decoder for OPT or just model for Llama + """ + + sparsity: float + block_size: int + quantize: bool + dampening_frac: Optional[float] = 0.01 + sequential_update: Optional[bool] = True + prunen: Optional[int] = 0 + prunem: Optional[int] = 0 + compress_layers: Union[str, List[str], None] = ALL_TOKEN + target_ids: Optional[List[str]] = None + layer_prefix: Optional[str] = None + + def on_initialize_structure(self, state: "State", **kwargs): + pass # nothing needed for this modifier diff --git a/src/sparseml/modifiers/obcq/pytorch.py b/src/sparseml/modifiers/obcq/pytorch.py new file mode 100644 index 00000000000..3a043df9707 --- /dev/null +++ b/src/sparseml/modifiers/obcq/pytorch.py @@ -0,0 +1,189 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import torch +from torch.nn import Module + +from sparseml.core.model import ModifiableModel +from sparseml.core.state import State +from sparseml.modifiers.obcq.base import SparseGPTModifier +from sparseml.modifiers.obcq.utils.helpers import cache_attention_inputs +from sparseml.modifiers.obcq.utils.layer_compressor import LayerCompressor + + +_LOGGER = logging.getLogger(__name__) + + +class SparseGPTModifierPyTorch(SparseGPTModifier): + """ + Pytorch implementation of SparseGPT + + Lifecycle: + - on_initialize + - initialize_obcq + - compressible_layers + - apply_obcq + - compress_bottom + - LayerCompressor.compress + - on_finalize + + :param model: Pytorch model to perform OBCQ on, in-place + """ + + model: Any = None + compressible_layers_: List = None + device_: str = "cuda:0" + finalization_kwargs_: Dict = None + + def compressible_layers(self) -> List[Module]: + """ + Retrieves the modules corresponding to a list of compressible layer names + + :return: list of Pytorch modules to compress + """ + compressible_dict = self.model.get_layers(self.compress_layers) + return [v for _, v in compressible_dict.items()] + + def on_initialize(self, state: "State", **kwargs) -> bool: + """ + Initialize and run the OBCQ algorithm on the current state + + :param state: session state storing input model and calibration data + """ + self.finalization_kwargs_ = {} + modifiable_model = state.model + calibration_dataloader = state.data.calib + device = state.hardware.device + + self.initialize_obcq(modifiable_model, device) + extras = self.apply_obcq(calibration_dataloader) + self.finalization_kwargs_.update(extras) + + return True + + def initialize_obcq( + self, + model: "ModifiableModel", + device: Optional[str] = "cuda:0", + ): + """ + Setup for SparseGPT, initialize the the compressible layers of model, and set + the device + + :param model: PyTorch model to sparsify + :param device: device to run sparsification on, preferably a GPU + """ + self.model = model + self.compressible_layers_ = self.compressible_layers() + self.model = self.model.model + self._set_device(device) + + @torch.no_grad() + def apply_obcq( + self, dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None + ) -> Dict: + """ + Run OBCQ on the loaded model, using dataloader as calibration data + + :param dataloader: calibration data for OBCQ + :return: compression outputs used for finalization + """ + accum_kwargs = {"dataloader": dataloader} + + # Step 0: Pass the calibration data through the (compressed) bottom part of the + # network, capturing the outputs which will become the inputs to the first + # decoder layer. Also return attention_mask as part of kwargs + extras = self.compress_bottom( + dev=self.device_, + target_ids=self.target_ids, + layer_prefix=self.layer_prefix, + **accum_kwargs, + ) + accum_kwargs.update(extras) + + # Step 1: Sequentially prune/quantize decoder layers + inputs = None + num_layers = len(self.compressible_layers_) + for idx, layer in enumerate(self.compressible_layers_): + if "outputs" not in accum_kwargs: + raise RuntimeError( + "The 'outputs' key is expected but not found from the " + "return of the bottom compressor" + ) + inputs = accum_kwargs["outputs"] + _LOGGER.info(f"\n===== Compressing layer {idx}/{num_layers-1} =====") + args = { + "sparsity": self.sparsity, + "prunen": self.prunen, + "prunem": self.prunem, + "blocksize": self.block_size, + "percdamp": self.dampening_frac, + "sequential_update": self.sequential_update, + "quantize": self.quantize, + } + layer_compressor = LayerCompressor(self.model, layer, idx, inputs, args) + + # Prune/quantize using SparseGPT + layer_kwargs = layer_compressor.compress(dev=self.device_, **accum_kwargs) + accum_kwargs.update(layer_kwargs) + + return extras + + def on_finalize(self, state: "State", **kwargs) -> bool: + """ + disable the observers used by the OBCQ algorithm and set kv-cache configuration + + :param state: un-used, for matching spec of Modifier base class + """ + use_cache = self.finalization_kwargs_.get("use_cache", False) + self.model.apply(torch.quantization.disable_observer) + self.model.config.use_cache = use_cache + + return True + + def compress_bottom( + self, + dataloader: List = None, + nsamples: int = None, + dev: str = "cuda:0", + target_ids: List[str] = None, + layer_prefix: str = None, + ) -> Dict: + """ + Runs calibration data through the bottom part of the network (everything up + to the first decoder layer) and return the captured outputs + + :param dataloader: calibration data to pass through the model + :nsamples: number of samples to use for calibration, or None to use it all + :dev: device to use + :return: outputs from bottom part of network, attention mask, and kv-cache state + """ + cached_inputs = cache_attention_inputs( + self.model, dataloader, dev, nsamples, target_ids, layer_prefix + ) + + outputs = cached_inputs.pop("inputs") + outputs = [o[0] for o in outputs] + cached_inputs.update({"outputs": outputs}) + return cached_inputs + + def _set_device(self, device: str): + if "cuda" in device and not torch.cuda.is_available(): + self.device_ = "cpu" + else: + self.device_ = device diff --git a/src/sparseml/modifiers/obcq/utils/__init__.py b/src/sparseml/modifiers/obcq/utils/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/src/sparseml/modifiers/obcq/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/modifiers/obcq/utils/helpers.py b/src/sparseml/modifiers/obcq/utils/helpers.py new file mode 100644 index 00000000000..455e37986d8 --- /dev/null +++ b/src/sparseml/modifiers/obcq/utils/helpers.py @@ -0,0 +1,177 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from math import ceil + +import torch + + +_LOGGER = logging.getLogger(__name__) + + +class Catcher(torch.nn.Module): + def __init__(self, module, target_keys): + super().__init__() + self.module = module + self.cache = {key: [] for key in target_keys} + self.target_keys = target_keys + self.cache["inputs"] = [] + + def forward(self, *args, **kwargs): + self.cache["inputs"].append(args) + for key in self.target_keys: + self.cache[key].append(kwargs[key]) + raise ValueError + + def get_cache(self): + return self.cache + + +def replace_module(model, old_module, new_module): + for module_name, module in model.named_modules(): + if old_module == module: + break + + current_module = model + module_name = module_name.split(".") + for child_module in module_name[:-1]: + current_module = getattr(current_module, child_module) + setattr(current_module, module_name[-1], new_module) + + +def catch(model, attention_layer, target_keys, data_loader, nsamples): + catcher_module = Catcher(attention_layer, target_keys) + replace_module(model, attention_layer, catcher_module) + device = next(attention_layer.parameters()).device + for input_id, inp in enumerate(data_loader): + if nsamples is not None and input_id == nsamples: + break + try: + if isinstance(inp, tuple): + inp = inp[0] + model(inp.to(device), use_cache=False) + except ValueError: + pass + replace_module(model, catcher_module, attention_layer) + return catcher_module.get_cache() + + +def execute_offloaded_module( + module, + buffer, + dev, + nsamples=None, + overwrite_buffer=True, + cached_inputs=None, + **kwargs, +): + module.to(dev) + if not overwrite_buffer: + new_buffer = [] + for input_index, inp in enumerate(buffer): + if nsamples is not None and input_index == nsamples: + break + if cached_inputs is None: + module_kwargs = kwargs + else: + module_kwargs = { + key: cached_inputs[key][input_index] for key in cached_inputs + } + module_kwargs.update(kwargs) + if isinstance(inp, tuple): + inp = inp[0] + output = module(inp.to(dev), **module_kwargs) + if overwrite_buffer: + buffer[input_index] = output + else: + new_buffer.append(output) + + module.cpu() + torch.cuda.empty_cache() + if overwrite_buffer: + return buffer + else: + del buffer + torch.cuda.empty_cache() + return new_buffer + + +def cache_attention_inputs( + model, dataloader, device, nsamples, target_ids, layer_prefix +): + if layer_prefix: + embed_tokens = getattr(model.model, layer_prefix).embed_tokens + first_layer = getattr(model.model, layer_prefix).layers[0] + else: + embed_tokens = model.model.embed_tokens + first_layer = model.model.layers[0] + embed_tokens.to(device) + first_layer.to(device) + cached_inputs = catch( + model, + first_layer, + target_ids, # ["attention_mask"], + dataloader, + nsamples, + ) + embed_tokens.cpu() + first_layer.cpu() + torch.cuda.empty_cache() + return cached_inputs + + +@torch.no_grad() +def ppl_eval_general( + eval_logits, model, dataloader, dev, nsamples=None, max_samples_per_iteration=128 +): + _LOGGER.info("Evaluating perplexity...") + + if nsamples is None: + nsamples = len(dataloader) + + number_iterations = int(ceil(nsamples / max_samples_per_iteration)) + neg_log_likelihood = 0.0 + number_tokens = 0 + for iteration in range(number_iterations): + if iteration < number_iterations - 1: + samples = dataloader[ + iteration + * max_samples_per_iteration : (iteration + 1) + * max_samples_per_iteration + ] + else: + samples = dataloader[iteration * max_samples_per_iteration :] + + logits = eval_logits(model, samples, dev) + + vocabulary_size = logits[0].shape[-1] + logits = [logit[:, :-1, :].view(-1, vocabulary_size) for logit in logits] + logits = torch.concatenate(logits, dim=0).contiguous().to(torch.float32) + + labels = [sample[:, 1:].view(-1) for sample in samples] + labels = torch.concatenate(labels, dim=0).to(dev) + neg_log_likelihood += torch.nn.functional.cross_entropy( + logits, + labels, + reduction="sum", + ) + + number_tokens += labels.numel() + _LOGGER.info(torch.exp(neg_log_likelihood / number_tokens)) + + ppl = torch.exp(neg_log_likelihood / number_tokens) + _LOGGER.info(f"Perplexity: {ppl.item():3f}") + + return ppl.item() diff --git a/src/sparseml/modifiers/obcq/utils/layer_compressor.py b/src/sparseml/modifiers/obcq/utils/layer_compressor.py new file mode 100644 index 00000000000..7a0197d07d2 --- /dev/null +++ b/src/sparseml/modifiers/obcq/utils/layer_compressor.py @@ -0,0 +1,246 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import logging +from typing import Dict, List + +import torch +import torch.nn as nn +from torch.nn import Module + +from sparseml.modifiers.obcq.utils.sparsegpt import SparseGPT +from sparseml.pytorch.utils.helpers import get_dependency_order + + +_LOGGER = logging.getLogger(__name__) + + +class LayerCompressor: + """ + Runs the SparseGPT algorithm on a single layer using calibration data inputs + + Lifecycle: + - compress + - pre_compress_parallel (optional) + - add_batch + - fasterprune + - post_compress + + :param model: model containing the layer we are running compression on + :param layer: layer to run compression on + :param layer_index: index of layer in the model + :param inputs: calibration data to pass through the layer + :param args: additional keyword arguments + """ + + def __init__( + self, model: Module, layer: Module, layer_index: int, inputs: List, args: Dict + ): + self.model = model + self.layer = layer + self.layer_index = layer_index + self.inputs = inputs + self.args = args + + def compressible_modules(self) -> Dict: + """ + Get the list of modules in the layer that can be compressed + + :return: dictionary of compressible modules + """ + quantize = self.args.get("quantize", False) + if quantize: + # The layer names are changed due to quantization modifiers, therefore + # we need a slightly different func to retrieve layers + modules = _find_quant_layers(self.layer) + else: + modules = _find_layers(self.layer) + return modules + + def pre_compress_parallel(self, **kwargs) -> Dict: + """ + Sets up the SparseGPT objects for each compressible module, computes the Hessian + for each using the calibration data. + + :return: SparseGPT objects for each module + """ + subset = self.compressible_modules() + + gpts = {} + for name in subset: + gpts[name] = SparseGPT(subset[name]) + + def add_batch(name): + def tmp(_, inp, out): + gpts[name].add_batch(inp[0].data, out.data) + + return tmp + + handles = [] + for name in gpts: + handles.append(subset[name].register_forward_hook(add_batch(name))) + + # Run through the samples in order to compute Hessian matrix for each module + nsamples = len(self.inputs) + forward_args_spec = inspect.getfullargspec(self.layer.__class__.forward) + passed_in_args = [arg for arg in forward_args_spec.args if arg in kwargs] + for sample_idx in range(nsamples): + passed_in_kwargs = {} + for arg in passed_in_args: + if isinstance(kwargs[arg], List): + passed_in_kwargs[arg] = kwargs[arg][sample_idx] + else: + passed_in_kwargs[arg] = kwargs[arg] + self.layer(self.inputs[sample_idx], **passed_in_kwargs) + for h in handles: + h.remove() + + return {"gpts": gpts} + + def compress(self, dev: str = "cuda:0", **kwargs) -> Dict: + """ + Run SparseGPT compression on all compressible modules in the layer + + :param dev: device to run computation on + """ + self.layer.to(dev) + if not self.args["sequential_update"]: + # compute Hessians ahead of time + extras = self.pre_compress_parallel(**kwargs) + gpts = extras["gpts"] + for name in gpts: + _LOGGER.info(f"Compressing {name}...") + sparsity = self.args["sparsity"] + gpts[name].fasterprune( + sparsity, + prunen=self.args["prunen"], + prunem=self.args["prunem"], + percdamp=self.args["percdamp"], + blocksize=self.args["blocksize"], + ) + gpts[name].free() + else: + # Hessians computed layer by layer + self.sequentially_compress(**kwargs) + + extras = self.post_compress(**kwargs) + + return {"outputs": extras["outputs"]} + + def post_compress(self, **kwargs) -> Dict: + """ + Clean up after compression + + :return: outputs of the layer + """ + nsamples = len(self.inputs) + outputs = [] + forward_args_spec = inspect.getfullargspec(self.layer.__class__.forward) + passed_in_args = [arg for arg in forward_args_spec.args if arg in kwargs] + for j in range(nsamples): + passed_in_kwargs = {} + for arg in passed_in_args: + if isinstance(kwargs[arg], List): + passed_in_kwargs[arg] = kwargs[arg][j] + else: + passed_in_kwargs[arg] = kwargs[arg] + outputs.append(self.layer(self.inputs[j], **passed_in_kwargs)[0]) + + self.inputs = None + torch.cuda.empty_cache() + return {"outputs": outputs} + + def sequentially_compress(self, **kwargs): + """ + Run compression module by module, in dependency order. Unlike in parallel + compression, we compute the Hessians layer by layer instead of computing them + all up front. This saves on memory and means compression in earlier layers + affects the inputs to later layers + """ + subset = self.compressible_modules() + + # filter kwargs that are expected as layer inputs + forward_args_spec = inspect.getfullargspec(self.layer.__class__.forward) + passed_in_args = [arg for arg in forward_args_spec.args if arg in kwargs] + + passed_in_kwargs = {} + for arg in passed_in_args: + if isinstance(kwargs[arg], List): # take the first batch + passed_in_kwargs[arg] = kwargs[arg][0] + else: + passed_in_kwargs[arg] = kwargs[arg] + order = get_dependency_order( + self.layer, subset, self.inputs[0], **passed_in_kwargs + ) + + nsamples = len(self.inputs) + for name in order: # create SparseGPT object for each compressible module + gpts = SparseGPT(subset[name]) + + def add_batch(name): + def tmp(_, inp, out): + gpts.add_batch(inp[0].data, out.data) + + return tmp + + # add SparseGPT hook for current module + handle = subset[name].register_forward_hook(add_batch(name)) + for sample_idx in range(nsamples): + passed_in_kwargs = {} + for arg in passed_in_args: + if isinstance(kwargs[arg], List): + passed_in_kwargs[arg] = kwargs[arg][sample_idx] + else: + passed_in_kwargs[arg] = kwargs[arg] + # run layer, triggering SparseGPT add_batch for current module + self.layer(self.inputs[sample_idx], **passed_in_kwargs) + handle.remove() + + _LOGGER.info(f"Compressing module {name} of layer {self.layer_index}") + gpts.fasterprune( # run SparseGPT algorithm on current module + self.args["sparsity"], + prunen=self.args["prunen"], + prunem=self.args["prunem"], + percdamp=self.args["percdamp"], + blocksize=self.args["blocksize"], + ) + gpts.free() + + +def _find_quant_layers( + module, layers=[torch.nn.qat.Conv2d, torch.nn.qat.Linear], name="" +): + res = {} + # search for QAT versions of layers + for name1, child in module.named_children(): + res.update( + _find_layers( + child, layers=layers, name=name + "." + name1 if name != "" else name1 + ) + ) + return res + + +def _find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""): + if type(module) in layers: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update( + _find_layers( + child, layers=layers, name=name + "." + name1 if name != "" else name1 + ) + ) + return res diff --git a/src/sparseml/modifiers/obcq/utils/sparsegpt.py b/src/sparseml/modifiers/obcq/utils/sparsegpt.py new file mode 100644 index 00000000000..179e833d1f2 --- /dev/null +++ b/src/sparseml/modifiers/obcq/utils/sparsegpt.py @@ -0,0 +1,213 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import math +import time + +import torch +import torch.nn as nn +import transformers + + +DEBUG = False +_LOGGER = logging.getLogger(__name__) + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + + +class SparseGPT: + """ + Runs SparseGPT on a single module that contains no sub-modules + + Lifecycle: + - add_batch + - fasterprune + - free + + + :param layer: module to run SparseGPT on + """ + + def __init__(self, layer): + self.layer = layer + self.dev = self.layer.weight.device + W = layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + self.rows = W.shape[0] + self.columns = W.shape[1] + self.H = torch.zeros((self.columns, self.columns), device=self.dev) + self.nsamples = 0 + + def add_batch(self, inp: torch.Tensor, out: torch.Tensor): + """ + Add a batch of layer input and output data to the Hessian calculation + + :param inp: tensor containing layer input + :param out: tensor containing layer our + """ + if DEBUG: + self._inp1 = inp + self.out1 = out + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(self.layer, nn.Linear) or isinstance( + self.layer, transformers.Conv1D + ): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + inp = math.sqrt(2 / self.nsamples) * inp.float() + self.H += inp.matmul(inp.t()) + + def fasterprune( + self, + sparsity: float, + prunen: int = 0, + prunem: int = 0, + blocksize: int = 128, + percdamp: float = 0.01, + ): + """ + Run pruning and quantization(if applicable) on the layer up to the target + sparsity value. + + :param sparsity: target sparsity to reach for layer + :param prunen: N for N:M pruning + :param prunem: M for N:M pruning + :param blocksize: Number of columns to compress in one pass + :param percdamp: Amount of dampening to apply to H, as a fraction of the + diagonal norm + """ + W = self.layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + W = W.float() + + tick = time.time() + + H = self.H + del self.H + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + Losses = torch.zeros(self.rows, device=self.dev) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.dev) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + mask = None + + # See section 3.4 of https://arxiv.org/abs/2203.07259 + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + if prunen == 0: + if mask is not None: + mask1 = mask[:, i1:i2] + else: + tmp = W1**2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2 + thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] + mask1 = tmp <= thresh + else: + mask1 = torch.zeros_like(W1) == 1 + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if prunen != 0 and i % prunem == 0: + tmp = ( + W1[:, i : (i + prunem)] ** 2 + / (torch.diag(Hinv1)[i : (i + prunem)].reshape((1, -1))) ** 2 + ) + mask1.scatter_( + 1, i + torch.topk(tmp, prunen, dim=1, largest=False)[1], True + ) + + q = w.clone() + q[mask1[:, i]] = 0 + + if hasattr(self.layer, "weight_fake_quant"): + scale = self.layer.weight_fake_quant.scale + zero_point = self.layer.weight_fake_quant.zero_point + dtype = self.layer.weight_fake_quant.dtype + qscheme = self.layer.weight_fake_quant.qscheme + if qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]: + q = torch.quantize_per_tensor(q, scale, zero_point, dtype) + else: + q = torch.quantize_per_channel(q, scale, zero_point, 0, dtype) + q = torch.dequantize(q) + + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d**2 + + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + + W[:, i1:i2] = Q1 + Losses += torch.sum(Losses1, 1) / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + if DEBUG: + self.layer.weight.data[:, :i2] = W[:, :i2] + self.layer.weight.data[:, i2:] = W[:, i2:] + _LOGGER.debug(torch.sum((self.layer(self._inp1) - self.out1) ** 2)) + _LOGGER.debug(torch.sum(Losses)) + + torch.cuda.synchronize() + _LOGGER.info("time %.2f" % (time.time() - tick)) + _LOGGER.info("error %.2f" % torch.sum(Losses).item()) + + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + self.layer.weight.data = W.reshape(self.layer.weight.shape).to( + self.layer.weight.data.dtype + ) + if DEBUG: + _LOGGER.debug(torch.sum((self.layer(self._inp1) - self.out1) ** 2)) + + def free(self): + """ + Free the Hessian memory after the layer is complete + """ + if DEBUG: + self._inp1 = None + self.out1 = None + self.H = None + torch.cuda.empty_cache() diff --git a/src/sparseml/modifiers/quantization/__init__.py b/src/sparseml/modifiers/quantization/__init__.py new file mode 100644 index 00000000000..22bafe8ce14 --- /dev/null +++ b/src/sparseml/modifiers/quantization/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa + +from .base import * +from .pytorch import * diff --git a/src/sparseml/modifiers/quantization/base.py b/src/sparseml/modifiers/quantization/base.py new file mode 100644 index 00000000000..7a229c43a65 --- /dev/null +++ b/src/sparseml/modifiers/quantization/base.py @@ -0,0 +1,130 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional + +from sparseml.core import Modifier, State +from sparseml.modifiers.quantization.utils.quantization_scheme import ( + QuantizationScheme, + QuantizationSchemeLoadable, +) + + +__all__ = ["QuantizationModifier"] + + +class QuantizationModifier(Modifier): + """ + Enables quantization aware training (QAT) for a given module or its submodules + After the start epoch, the specified module(s) forward pass will emulate + quantized execution and the modifier will be enabled until training is completed. + + | Sample yaml: + | QuantizationModifier: + | start: 0.0 + | scheme: + | input_activations: + | num_bits: 8 + | symmetric: False + | weights: + | num_bits: 8 + | symmetric: True + | scheme_overrides: + | feature_extractor: "default" + | classifier: + | input_activations: + | num_bits: 8 + | symmetric: False + | weights: null + | Conv2d: + | input_activations: + | num_bits: 8 + | symmetric: True + | ignore: ["ReLU", "input"] + | disable_quantization_observer_epoch: 2.0 + | freeze_bn_stats_epoch: 3.0 + | model_fuse_fn_name: 'fuse_module' + | strict: True + + :param scheme: Default QuantizationScheme to use when enabling quantization + in a module. May also be a dictionary to be loaded into the QuantizationScheme + class. A string alias may also be used, supported aliases: + ['default', 'deepsparse', 'tensorrt']. + If None, the default scheme (`QuantizationScheme()`) will be used. + Default is None + :param scheme_overrides: optional mapping of module type names or submodule type + names to quantization schemes to override them with. If a scheme is mapped to + 'default', then it will use the scheme set in the modifier scheme property + :param ignore: optional list of module class names or submodule names + to not quantize. Default is None + :param disable_quantization_observer_epoch: Epoch to disable updates to the module + quantization observers. At this point, quantized weights and zero points will + not be updated. Leave None to not disable observers during QAT. Default is None + :param freeze_bn_stats_epoch: Epoch to stop the tracking of batch norm stats. Leave + None to not stop tracking batch norm stats during QAT. Default is None + :param model_fuse_fn_name: Name of model function to fuse the model in place prior + to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as + 'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + Default is None + :param model_fuse_fn_kwargs: dictionary of keyword argument values to be passed + to the model fusing function + :param num_calibration_steps: Number of steps to run post training calibration for. + When None, the entire calibration_dataloader is used + :param strict: if True, will raise an error if any module types or submodules in + scheme_overrides or ignore are not found in a given module. Default True + """ + + scheme: Optional[QuantizationSchemeLoadable] = None + scheme_overrides: Optional[Dict[str, QuantizationSchemeLoadable]] = None + ignore: Optional[List[str]] = None + disable_quantization_observer_epoch: Optional[float] = None + freeze_bn_stats_epoch: Optional[float] = None + model_fuse_fn_name: Optional[str] = None + model_fuse_fn_kwargs: Optional[Dict[str, Any]] = None + num_calibration_steps: Optional[int] = None + post_oneshot_calibration: Optional[bool] = False + strict: bool = True + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.scheme = QuantizationScheme.load(self.scheme) + self.scheme_overrides = _load_quantization_schemes_dict( + self.scheme_overrides, self.scheme + ) + if self.model_fuse_fn_kwargs is None: + self.model_fuse_fn_kwargs = {} + + def on_initialize_structure(self, state: State, **kwargs): + pass # nothing needed for this modifier + + +class _QuantizationSchemesDict(dict): + # wrapper class for dict to override the __str__ method for yaml serialization + + def __str__(self): + return str({submodule: scheme.dict() for submodule, scheme in self.items()}) + + +def _load_quantization_schemes_dict( + schemes_dict: Optional[Dict[str, QuantizationSchemeLoadable]], + default_scheme: QuantizationScheme, +) -> Dict[str, QuantizationScheme]: + if schemes_dict is None: + return {} + return _QuantizationSchemesDict( + { + submodule: QuantizationScheme.load(scheme, default=default_scheme) + for submodule, scheme in schemes_dict.items() + } + ) diff --git a/src/sparseml/modifiers/quantization/pytorch.py b/src/sparseml/modifiers/quantization/pytorch.py new file mode 100644 index 00000000000..30e731c4415 --- /dev/null +++ b/src/sparseml/modifiers/quantization/pytorch.py @@ -0,0 +1,167 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from itertools import cycle +from typing import Any, Callable + +import torch +from torch.nn import Module + +from sparseml.core import Event, State +from sparseml.modifiers.quantization.base import QuantizationModifier +from sparseml.modifiers.quantization.utils.helpers import ( + configure_module_bn_wrappers, + fuse_module_conv_bn_relus, +) +from sparseml.modifiers.quantization.utils.quantize import ( + convert_module_qat_from_schemes, + raise_if_torch_quantization_not_available, + set_quantization_schemes, +) +from sparseml.pytorch.utils import tensors_module_forward, tensors_to_device + + +_LOGGER = logging.getLogger(__name__) + + +class QuantizationModifierPyTorch(QuantizationModifier): + calibration_dataloader_: Any = None + calibration_function_: Any = None + qat_enabled_: bool = False + + def on_initialize(self, state: State, **kwargs) -> bool: + raise_if_torch_quantization_not_available() + if self.end and self.end != -1: + raise ValueError( + "end_epoch is disabled for QuantizationModifier and can only be set to" + " -1 or None. Given {}".format(self.end) + ) + + self.calibration_dataloader_ = state.data.calib + module = state.model.model + device = state.hardware.device + state.model.model.to(device) + module = state.model.model + self._enable_module_qat(module) + + return True + + def on_finalize(self, state: State, **kwargs) -> bool: + if self.post_oneshot_calibration: + state.model.model.to(state.hardware.device) + state.model.model.apply(torch.quantization.enable_observer) + self._calibrate_if_possible(state.model.model) + state.model.model.apply(torch.quantization.disable_observer) + return True + + def on_start(self, state: State, event: Event, **kwargs): + pass + + def on_update(self, state: State, event: Event, **kwargs): + pass + + def on_end(self, state: State, event: Event, **kwargs): + pass + + def on_event(self, state: State, event: Event, **kwargs): + pass + + def _enable_module_qat(self, module: Module): + # fuse conv-bn-relu blocks prior to quantization emulation + self._fuse(module) + + # add quantization_schemes to target submodules + set_quantization_schemes( + module, + scheme=self.scheme, + scheme_overrides=self.scheme_overrides, + ignore=self.ignore, + strict=self.strict, + ) + + # fix for freezing batchnorm statistics when not fusing BN with convs. + # pytorch only supports freezing batchnorm statistics for fused modules. + # this fix wraps BN modules adding with a new module class that supports + # methods related to freezing/unfreezing BN statistics. + configure_module_bn_wrappers(module) + + # convert target qconfig layers to QAT modules with FakeQuantize + convert_module_qat_from_schemes(module) + + self.qat_enabled_ = True + + self._calibrate_if_possible(module) + + def _fuse(self, module: Module): + if self.model_fuse_fn_name in [None, "conv_bn_relus"]: + self.model_fuse_fn_kwargs["inplace"] = True + fuse_module_conv_bn_relus(module, **self.model_fuse_fn_kwargs) + elif self.model_fuse_fn_name != "no_fuse": + module_fuse_fn = getattr(module, self.model_fuse_fn_name, None) + if module_fuse_fn is None or not callable(module_fuse_fn): + raise ValueError( + "Invalid model_fuse_fn_name. " + "Module has no callable function {}".format(self.model_fuse_fn_name) + ) + module_fuse_fn(**self.model_fuse_fn_kwargs) + + def _calibrate_if_possible(self, module: Module): + if self.num_calibration_steps == 0 and self.calibration_dataloader_: + _LOGGER.warning( + f"num_calibration_steps is {self.num_calibration_steps}." + f"Calibration data loader will not be used." + ) + elif self.num_calibration_steps and not self.calibration_dataloader_: + raise ValueError( + f"num_calibration_steps is {self.num_calibration_steps}. " + "Calibration data loader is not set. Pass a " + "calibration_data_loader with initialize(...) method." + ) + + elif not self.calibration_dataloader_: + return + + self._calibrate(module) + + def _calibrate(self, module: Module): + _LOGGER.info("Running quantization calibration using calibration_dataloader") + + module_training = module.training + module.eval() + + forward_fn: Callable = ( + self.calibration_function_ + if self.calibration_function_ + else tensors_module_forward + ) + + model_device = next(module.parameters()).device + _dataloader = ( + self.calibration_dataloader_ + if self.num_calibration_steps is None + else cycle(self.calibration_dataloader_) + ) + + for batch_idx, batch in enumerate(_dataloader): + if self.num_calibration_steps and batch_idx >= self.num_calibration_steps: + break + batch = tensors_to_device(batch, model_device) + with torch.no_grad(): + forward_fn(batch, module=module) + + if module_training: + module.train() + else: + module.apply(torch.quantization.disable_observer) diff --git a/src/sparseml/modifiers/quantization/utils/constants.py b/src/sparseml/modifiers/quantization/utils/constants.py new file mode 100644 index 00000000000..08673fbfec8 --- /dev/null +++ b/src/sparseml/modifiers/quantization/utils/constants.py @@ -0,0 +1,89 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Constants related to sparseml pytorch quantization flows +""" + + +__all__ = [ + "FUSED_MODULE_NAMES", + "NON_QUANTIZABLE_MODULE_NAMES", +] + + +""" +Quantization Modifier quantizes all 'leaf' level modules by default +this list contains modules that are very unlikely to be desired for quantization +and will not have a QuantizationScheme ever attached to them by the modifier +QuantizationSchemes may be manually attached with the .quantization_scheme +property and the modifier will then pick up the module for quantization +""" +NON_QUANTIZABLE_MODULE_NAMES = { + # no-ops + "Module", + "Identity", + "Flatten", + "Unflatten", + "DataParallel", + "ModuleList", + "Sequential", + # losses + "L1Loss", + "NLLLoss", + "KLDivLoss", + "MSELoss", + "BCELoss", + "BCEWithLogitsLoss", + "NLLLoss2d", + "PoissonNLLLoss", + "CosineEmbeddingLoss", + "CTCLoss", + "HingeEmbeddingLoss", + "MarginRankingLoss", + "MultiLabelMarginLoss", + "MultiLabelSoftMarginLoss", + "MultiMarginLoss", + "SmoothL1Loss", + "GaussianNLLLoss", + "HuberLoss", + "SoftMarginLoss", + "CrossEntropyLoss", + "TripletMarginLoss", + "AdaptiveLogSoftmaxWithLoss", + "TripletMarginWithDistanceLoss", + # dropouts + "Dropout", + "Dropout1d", + "Dropout2d", + "Dropout3d", + "AlphaDropout", + "FeatureAlphaDropout", +} + + +FUSED_MODULE_NAMES = { + # Conv based layers + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", + # Linear Layers + "LinearReLU", +} diff --git a/src/sparseml/modifiers/quantization/utils/helpers.py b/src/sparseml/modifiers/quantization/utils/helpers.py new file mode 100644 index 00000000000..318769e22ad --- /dev/null +++ b/src/sparseml/modifiers/quantization/utils/helpers.py @@ -0,0 +1,907 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Helper functions for performing quantization aware training with PyTorch +""" + +from copy import deepcopy +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.intrinsic as nni +from packaging import version +from torch import quantization as torch_quantization +from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU + +from sparseml.modifiers.quantization.utils.quantization_scheme import ( + QuantizationArgs, + QuantizationScheme, + get_observer, +) +from sparseml.pytorch.nn import ReLU as ReLU_nm +from sparseml.pytorch.utils import get_layer + + +_PARSED_TORCH_VERSION = version.parse(torch.__version__) + +__all__ = [ + "QATWrapper", + "configure_module_bn_wrappers", + "configure_module_default_qconfigs", + "configure_module_qat_wrappers", + "add_quant_dequant", + "remove_activation_qat_by_layer_name", + "get_qat_qconfig", + "freeze_bn_stats", + "fuse_module_conv_bn_relus", + "prepare_embeddings_qat", + "QConfigProperties", + "LINEAR_ACTIVATION_NAMES", + "CONV_ACTIVATION_NAMES", +] + +LINEAR_ACTIVATION_NAMES = ["Linear", "LinearReLU"] +CONV_ACTIVATION_NAMES = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvBn1d", + "ConvBn2d", + "ConvBn3d", + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", + "ConvBnReLU1d", + "ConvBnReLU2d", + "ConvBnReLU3d", +] + +_QUANTIZABLE_MODULE_TYPES = ( + { + # Conv based layers + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d, + nni.ConvBn1d, + nni.ConvBn2d, + nni.ConvBn3d, + nni.ConvReLU1d, + nni.ConvReLU2d, + nni.ConvReLU3d, + nni.ConvBnReLU1d, + nni.ConvBnReLU2d, + nni.ConvBnReLU3d, + # Linear Layers + torch.nn.Linear, + nni.LinearReLU, + } + if nni # nni will always import if torch.quantization is available + else None +) + +_FUSED_MODULE_TYPES = ( + ( + # Conv based layers + nni.ConvBn1d, + nni.ConvBn2d, + nni.ConvBn3d, + nni.ConvReLU1d, + nni.ConvReLU2d, + nni.ConvReLU3d, + nni.ConvBnReLU1d, + nni.ConvBnReLU2d, + nni.ConvBnReLU3d, + # Linear Layers + nni.LinearReLU, + ) + if nni # nni will always import if torch.quantization is available + else tuple() +) + + +@dataclass +class QConfigProperties: + """ + Dataclass that stores properties needed to define qconfig objects. + Default values set here. + + :param symmetric_activations: if True, activations will have a symmetric + quantization range with a pre-specified zero point + (0 if activation_dtype=torch.qint8, 128 if activation_dtype=torch.quint8). + Default is False. + :param symmetric_weights: if True, weights will have a symmetric + quantization range with a pre-specified zero point + (0 if weight_dtype=torch.qint8, 128 if weight_dtype=torch.quint8). + Default is True. + :param reduce_range: if True, the quantization range will be reduced by one bit. + This may prevent overflow issues with model execution on certain hardware. + Default is False. + :param activation_qconfig_kwargs: Additional kwargs for quantization of + activations. + :param weight_qconfig_kwargs: Additional kwargs for quantization of + weights. + :param activation_dtype: quantized activation data type. + Default is torch.quint8. + :param weight_dtype: quantized weights data type. + Default is torch.qint8. + :param activation_bits: number of bits for activations. Default is 8. + :param weight_bits: number of bits for weights. Default is 8. + :param activation_strategy: "tensor" to quantize over the whole activation tensor, + or "channel" to quantize per channel. Default is "tensor" + :param weight_strategy: "tensor" to quantize over the whole weight tensor, or + "channel" to quantize per channel. Default is "tensor" + :param tensorrt: if True sets quantization configuration for compatibility with + explict quantization as supported by TensorRT 8.2. + """ + + _symmetric_activations: bool = False + _symmetric_weights: bool = True + reduce_range: bool = False + activation_dtype: torch.dtype = torch.quint8 + weight_dtype: torch.dtype = torch.qint8 + activation_bits: int = 8 + weight_bits: int = 8 + activation_strategy: str = "tensor" + weight_strategy: str = "tensor" + activation_qconfig_kwargs: Dict[str, Any] = field(default_factory=dict) + weight_qconfig_kwargs: Dict[str, Any] = field(default_factory=dict) + tensorrt: bool = False + + @property + def symmetric_activations(self) -> bool: + # always use symmetric activations in tensorrt mode + return self.tensorrt or self._symmetric_activations + + @symmetric_activations.setter + def symmetric_activations(self, value: bool): + self._symmetric_activations = value + + @property + def symmetric_weights(self) -> bool: + return self.tensorrt or self._symmetric_weights + + @symmetric_weights.setter + def symmetric_weights(self, value: bool): + self._symmetric_weights = value + + +class QATWrapper(Module): + """ + Wraps inputs and outputs of a Module or function with QuantStubs for + Quantization-Aware-Training (QAT) + + :param forward_fn: function to be wrapped, should generally accept and return + torch Tensor(s) + :param num_inputs: number of inputs of the forward function to add a QuantStub + to. Will wrap the first num_inputs ordered inputs of the function. Default + is 1 + :param kwarg_input_names: list of names of key word arguments to the forward pass + that should be wrapped with a fake quantize operation. Defaults to empty + :param num_outputs: number of outputs of the forward function to add a QuantStub + to. Will wrap the first num_inputs ordered outputs of the function. Default + is 1. Will also add a DeQuantStub for FP32 conversion if + torch.quantization.convert is invoked + :param input_qconfigs: QConfig to use for calibrating the input QuantStubs. Can + be a single QConfig that will be copied to each QuantStub or a list of one + QConfig for each input. Instead of a QConfig objects, the string 'asymmetric' + or 'symmetric' may be used to use default UINT8 asymmetric and symmetric + quantization respectively + :param output_qconfigs: QConfig to use for calibrating the output QuantStubs. Can + be a single QConfig that will be copied to each QuantStub or a list of one + QConfig for each output. Instead of a QConfig objects, the string 'asymmetric' + or 'symmetric' may be used to use default UINT8 asymmetric and symmetric + quantization respectively + :param qproperties: properties used to define QConfig. may also be a quantization + scheme + """ + + @staticmethod + def from_module( + module: Module, + qproperties: Union[QConfigProperties, QuantizationScheme], + ) -> "QATWrapper": + """ + :param module: torch Module to create a QATWrapper for + :return: QATWrapper object created using the given Module as the forward + function. Will attempt to find any other named parameter of the QATWrapper + constructor from the attributes of the given Module + """ + qat_wrapper_kwargs = ( + module.qat_wrapper_kwargs or {} + if hasattr(module, "qat_wrapper_kwargs") + else {} + ) + + # Remove qconfig from wrapped layer to avoid duplicate quantization + module.qconfig = None + return QATWrapper( + forward_fn=module, qproperties=qproperties, **qat_wrapper_kwargs + ) + + def __init__( + self, + forward_fn: Callable[[Any], Any], + qproperties: Union[QConfigProperties, QuantizationScheme], + num_inputs: int = 1, + kwarg_input_names: List[str] = None, + num_outputs: int = 1, + input_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + output_qconfigs: Union[ + "torch.quantization.QConfig", str, List["torch.quantization.QConfig"] + ] = "asymmetric", + ): + super().__init__() + + if torch_quantization is None: + raise RuntimeError( + "Unable to import package torch.quantization. " + "Try upgrading your PyTorch version to >= 1.7.0." + ) + + if not callable(forward_fn): + raise ValueError( + "forward_fn of QATWrapper must be callable. " + f"Received {type(forward_fn)}" + ) + + self.kwarg_input_names = kwarg_input_names or [] + num_input_quant_stubs = num_inputs + len(self.kwarg_input_names) + + self.forward_fn = forward_fn + # Add weight qconfig to forward_fn (in case it has weights) + qconfig_ = ( + get_qat_qconfig(qproperties) + if isinstance(qproperties, QConfigProperties) + else qproperties.get_qconfig() # QuantizationScheme + ) + qconfig = torch_quantization.QConfig( + activation=torch.nn.Identity, + weight=qconfig_.weight, + ) + self.forward_fn.qconfig = qconfig + + self.input_qconfigs = self._load_qconfigs( + name="input_qconfigs", + expected_len=num_input_quant_stubs, + qconfigs=input_qconfigs, + qproperties=qproperties, + ) + self.output_qconfigs = self._load_qconfigs( + name="output_qconfigs", + expected_len=num_outputs, + qconfigs=output_qconfigs, + qproperties=qproperties, + ) + + self.input_quant_stubs = torch.nn.ModuleList( + [torch_quantization.QuantStub() for _ in range(num_input_quant_stubs)] + ) + self.output_quant_stubs = torch.nn.ModuleList( + [torch_quantization.QuantStub() for _ in range(num_outputs)] + ) + self.output_dequant_stubs = torch.nn.ModuleList( + [torch_quantization.DeQuantStub() for _ in range(num_outputs)] + ) + + def forward(self, *args, **kwargs) -> Any: + """ + :param args: arguments to forward function; the first num_inputs of these args + will be wrapped by a QuantStub + :param kwargs: key word arguments to pass to the wrapped forward function + :return: outputs of the forward function with a QuantStub applied to the first + num_outputs outputs + """ + + if any(kwarg not in kwargs for kwarg in self.kwarg_input_names): + raise ValueError( + f"QATWrapper expected kwargs {self.kwarg_input_names} to be included " + f"in forward function kwargs. Found {list(kwargs.keys())}. missing " + f"{[kwarg for kwarg in self.kwarg_input_names if kwarg not in kwargs]}" + ) + + qat_args = [] + + # fake quantize positional arguments + num_args_stubs = len(self.input_quant_stubs) - len(self.kwarg_input_names) + for idx, arg in enumerate(args): + if idx < num_args_stubs: + arg = self.input_quant_stubs[idx](arg) + qat_args.append(arg) + + # fake quantize key word arguments + for idx, kwarg in enumerate(self.kwarg_input_names): + kwargs[kwarg] = self.input_quant_stubs[num_args_stubs + idx](kwargs[kwarg]) + + # wrapped forward pass + outputs = self.forward_fn(*qat_args, **kwargs) + + if len(self.output_quant_stubs) == 0: + # no output wrapping + return outputs + + if isinstance(outputs, torch.Tensor): + if len(self.output_quant_stubs) > 1: + raise ValueError( + f"QATWrapper expected {len(self.output_quant_stubs)} outputs in " + "forward pass. Found one output" + ) + # output is a single Tensor + qat_output = self.output_quant_stubs[0](outputs) + return self.output_dequant_stubs[0](qat_output) + + qat_outputs = [] + + for idx, output in enumerate(outputs): + if idx < len(self.output_quant_stubs): + output = self.output_quant_stubs[idx](output) + output = self._output_deuant_stubs[idx](output) + qat_outputs.append(output) + + return qat_outputs + + def configure_qconfig(self): + """ + Sets the qconfigs of the quant stubs to the pre-initialized QConfigs + """ + for quant_stub, qconfig in zip(self.input_quant_stubs, self.input_qconfigs): + quant_stub.qconfig = qconfig + if hasattr(qconfig, "quantization_stub"): + quant_stub.quantization_stub = qconfig.quantization_stub + + for quant_stub, qconfig in zip(self.output_quant_stubs, self.output_qconfigs): + quant_stub.qconfig = qconfig + if hasattr(qconfig, "quantization_stub"): + quant_stub.quantization_stub = qconfig.quantization_stub + + @staticmethod + def _load_qconfigs( + name: str, + expected_len: int, + qconfigs: Union["QConfig", str, List["QConfig"]], # noqa: F821 + qproperties: QConfigProperties, + ): + if not isinstance(qconfigs, (str, torch_quantization.QConfig, List)): + raise ValueError( + f"QATWrapper {name} must be a string, torch.quantization.QConfig, " + f"or a List of them. Received a {type(qconfigs)}" + ) + + if isinstance(qconfigs, (str, torch_quantization.QConfig)): + qconfigs = [deepcopy(qconfigs) for _ in range(expected_len)] + + if len(qconfigs) != expected_len: + raise ValueError( + f"QATWrapper {name} should have exactly one qconfig or one for every " + f"argument ({expected_len}). Given {len(qconfigs)}" + ) + + valid_qconfig_strs = ["asymmetric", "symmetric"] + for idx, qconfig in enumerate(qconfigs): + if not isinstance(qconfig, str): + continue + + if qconfig not in valid_qconfig_strs: + raise ValueError( + "QATWrapper qconfig names can either be " + "torch.quantization.QConfig objects or a string " + f"in {valid_qconfig_strs} that will be converted to a QConfig. " + f"Found string with value {qconfig} in {name}" + ) + + qconfig_idx = None + if isinstance(qproperties, QConfigProperties): + qproperties_idx = deepcopy(qproperties) + qproperties_idx.symmetric_activations = qconfig == "symmetric" + qconfig_idx = get_qat_qconfig(qproperties_idx) + else: + scheme_idx = deepcopy(qproperties) + symmetric = qconfig == "symmetric" + # always use output_activations of scheme because the activations + # of the QuantStub() are the ones tracked + if scheme_idx.output_activations is not None: + scheme_idx.input_activations.symmetric = symmetric + else: + scheme_idx.output_activations = QuantizationArgs( + symmetric=symmetric + ) + qconfig_idx = scheme_idx.get_qconfig() + qconfig_idx.quantization_scheme = scheme_idx + + qconfigs[idx] = qconfig_idx + + return qconfigs + + +def configure_module_bn_wrappers(module: Module): + """ + Wrap any BatchNormalization modules that are not fused with convolutions + with BNWrapper to enable freezing/unfreezing of BN statistics + + :param module: module to potentially wrap the submodules of + """ + # wrap any children of the given module as a QATWrapper if required + if not hasattr(module, "freeze_bn_stats"): + for child_name, child_module in module.named_children(): + if type(child_module) in [ + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + ]: + setattr(module, child_name, _BNWrapper(child_module)) + # recurse on child module + configure_module_bn_wrappers(child_module) + + +def configure_module_qat_wrappers( + module: Module, + qproperties: QConfigProperties, +): + """ + if any submodule of the given module has the attribute wrap_qat == True, + then it will be replaced by a QATWrapper of it created by QATWrapper.from_module. + Other named kwargs to the QATWrapper constructor must be contained in a dictionary + under an attributed named `qat_wrapper_kwargs` + + :param module: module to potentially wrap the submodules of + :param qproperties: properties used to define QConfig. + """ + # wrap any children of the given module as a QATWrapper if required + for child_name, child_module in module.named_children(): + if hasattr(child_module, "wrap_qat") and child_module.wrap_qat: + setattr( + module, + child_name, + QATWrapper.from_module( + module=child_module, + qproperties=qproperties, + ), + ) + # recurse on child module + configure_module_qat_wrappers( + module=child_module, + qproperties=qproperties, + ) + + +def configure_module_default_qconfigs(module: Module): + """ + if any submodule of the given module has a configure_qconfig function, + configure_qconfig will be called on that submodule to set the qconfig(s) of that + module to its default + + :param module: module to set qconfigs for + """ + for submodule in module.modules(): + if hasattr(submodule, "configure_qconfig") and callable( + getattr(submodule, "configure_qconfig") + ): + submodule.configure_qconfig() + + +def add_quant_dequant( + module: torch.nn.Module, name=None, parent_module=None, layer_class_names=None +): + """ + Wraps all Conv and Linear submodule with a qconfig with a QuantWrapper + :param module: the module to modify + :param name: name of the module to modify; default to None + :param parent_module: parent module containing the module to modify; default to None + :param layer_class_names: list of module class names to be added to the + list of quantizable modules + :return: the modified module + """ + named_children = module.named_children() + is_quantizable = type(module) in _QUANTIZABLE_MODULE_TYPES + if layer_class_names: + is_quantizable = ( + is_quantizable or module.__class__.__name__ in layer_class_names + ) + if is_quantizable and hasattr(module, "qconfig") and module.qconfig: + module = torch_quantization.QuantWrapper(module) + if parent_module is not None and len(list(named_children)) <= 0: + if "." in name: + # unwrap name under parent module, nested through multiple submodules + name_parts = name.split(".") + for name_part in name_parts[:-1]: + parent_module = getattr(parent_module, name_part) + name = name_parts[-1] + + # set parent module child to the newly wrapped module + setattr(parent_module, name, module) + else: + for name, child in named_children: + setattr( + module, + name, + add_quant_dequant(child, layer_class_names=layer_class_names), + ) + return module + + +def remove_activation_qat_by_layer_name(module: Module, layer_class_names: List[str]): + """ + Disables fake quantization of activations for all submodules of the given module + with class name layer_class_names + + :param module: module to remove activation fake quantization for certain layers + :param layer_class_names: list of layer class names that should be affected. + e.x. ["Linear"] + """ + for submodule in module.modules(): + if submodule.__class__.__name__ in layer_class_names and hasattr( + submodule, "qconfig" + ): + submodule.qconfig = torch_quantization.QConfig( + activation=torch.nn.Identity, + weight=submodule.qconfig.weight, + ) + + +def get_qat_qconfig(qproperties: QConfigProperties) -> "torch.quantization.QConfig": + """ + :param qproperties: properties used to define QConfig. + """ + activation_observer = get_observer( + qproperties.symmetric_activations, + qproperties.activation_strategy, + qproperties.activation_dtype, + qproperties.activation_bits, + qproperties.reduce_range, + qproperties.activation_qconfig_kwargs, + ) + + weight_observer = get_observer( + qproperties.symmetric_weights, + qproperties.weight_strategy, + qproperties.weight_dtype, + qproperties.weight_bits, + False, + qproperties.weight_qconfig_kwargs, + ) + + return torch_quantization.QConfig( + activation=activation_observer, + weight=weight_observer, + ) + + +def freeze_bn_stats(module: Module): + if hasattr(module, "freeze_bn_stats"): + module.freeze_bn_stats() + + +def fuse_module_conv_bn_relus( + module: Module, + inplace: bool = True, + override_bn_subclasses_forward: Union[bool, str] = True, +) -> Module: + """ + Performs fusion of Conv2d, BatchNorm2d, and ReLU layers found in the + given module. To be fused, these layers must appear sequentially in + module.named_modules() and be in the same submodule. + Fuses either Conv2d -> BatchNorm2d, Conv2d -> ReLU, or + Conv2d -> BatchNorm2d -> ReLU blocks + + If this function does not fuse the model in the desired way, implement an + in place fusing function for the model. + + :param module: the module to fuse + :param inplace: set True to perform fusions in-place. default is True + :param override_bn_subclasses_forward: if True, modules that are subclasses of + BatchNorm2d will be modified to be BatchNorm2d but with the forward + pass and state variables copied from the subclass. This is so these + BN modules can pass PyTorch type checking when fusing. Can set to + "override-only" and only parameters will be overwritten, not the + forward pass. Default is True + :return: the fused module + """ + if torch_quantization is None: + raise RuntimeError( + "Unable to import package torch.quantization. " + "Try upgrading your PyTorch version." + ) + if not inplace: + module = deepcopy(module) + conv_blocks = [] + current_block = [] + current_block_submodule_name = "" + for name, layer in module.named_modules(): + submodule_name = ".".join(name.split(".")[:-1]) + if ( + len(current_block) == 1 # [Conv2d] + and isinstance(layer, BatchNorm2d) + and submodule_name == current_block_submodule_name + ) or ( + len(current_block) in [1, 2] # [Conv2d] or [Conv2d, BatchNorm2d] + and isinstance(layer, ReLU) + and not isinstance(current_block[-1], ReLU) + and submodule_name == current_block_submodule_name + ): + if isinstance(layer, ReLU_nm): + _set_submodule(module, name, ReLU(inplace=layer.inplace)) + if isinstance(layer, BatchNorm2d) and not type(layer) is BatchNorm2d: + if not override_bn_subclasses_forward: + raise RuntimeError( + "Detected a Conv-BN block that uses a subclass of BatchNorm2d. " + "This will cause a type error when fusing with PyTorch, " + "set override_bn_subclasses_forward to True or 'override-only " + "to modify this BN subclass to be a BatchNorm2d object" + ) + # swap BN subclass with overwritten BN class that will pass torch + # type checking + overwritten_bn = _wrap_bn_sub_class( + layer, + override_forward=override_bn_subclasses_forward != "override-only", + ) + _set_submodule(module, name, overwritten_bn), + current_block.append(name) + else: + if current_block: + if len(current_block) > 1: # cannot fuse single module + conv_blocks.append(current_block) + current_block = [] + current_block_submodule_name = "" + if isinstance(layer, Conv2d): + current_block.append(name) + current_block_submodule_name = submodule_name + if len(current_block) > 1: + conv_blocks.append(current_block) + if conv_blocks: + # manually save and move hooks surrounding fused blocks + # into new fused modules due to torch.quantization + # error when a module has more than one hook + block_hooks = _delete_get_block_hooks(module, conv_blocks) + + # run torch fusion + if _PARSED_TORCH_VERSION < version.parse("1.10.0"): + torch_quantization.fuse_modules(module, conv_blocks, inplace=True) + else: + if module.training: + torch.ao.quantization.fuse_modules_qat( + module, conv_blocks, inplace=True + ) + else: + torch.ao.quantization.fuse_modules(module, conv_blocks, inplace=True) + + # add hooks back + _add_fused_block_hooks(module, block_hooks) + + return module + + +def prepare_embeddings_qat( + module: Module, + qproperties: Optional[QConfigProperties] = None, + qconfig: Optional["torch.quantization.QConfig"] = None, +): + """ + adds a fake quantize call to the weights of any Embedding modules in the given + module. The used qconfig will have a heirarchy of + + submodule.qconfig -> qconfig -> qproperties + + :param module: module to run QAT for the embeddings of + :param qconfig: qconfig to generate the fake quantize ops from if qconfig + not set in moduleDefault uses INT8 asymmetric range + :param qproperties: properties used to define QConfig if qconfig not present + """ + if qconfig is None and qproperties is not None: + qproperties.symmetric_weights = False + qconfig = get_qat_qconfig(qproperties) + for submodule in module.modules(): + submodule_qconfig = getattr(submodule, "qconfig", None) + submodule_qconfig = submodule_qconfig or qconfig + if isinstance(submodule, Embedding) and submodule_qconfig is not None: + _prepare_qat_embedding(submodule, submodule_qconfig) + + +def _delete_get_block_hooks( + module: Module, + fuse_blocks: List[List[str]], +) -> List[Tuple[Any, Any]]: + block_hooks = [] + for block in fuse_blocks: + pre_hooks = [] + post_hooks = [] + + for name in block: + # get Module objects in block by their names + m = get_layer(name, module) + + # extract the hooks + pre_hooks.extend(m._forward_pre_hooks.values()) + post_hooks.extend(m._forward_hooks.values()) + + # de-register the hooks from this module + m._forward_pre_hooks.clear() + m._forward_hooks.clear() + + block_hooks.append((pre_hooks, post_hooks)) + + return block_hooks + + +def _add_fused_block_hooks(module: Module, block_hooks: List[Tuple[Any, Any]]): + fused_modules = [ + mod for mod in module.modules() if isinstance(mod, _FUSED_MODULE_TYPES) + ] + + if len(fused_modules) != len(block_hooks): + raise RuntimeError( + f"Number of fused modules ({len(fused_modules)}) after layer fusion in " + f"module {module.__class__.__name__}. does not match expected " + f"({len(block_hooks)}). Module may have already been fused or block " + "skipped during torch.quantization.fuse_modules" + ) + + for fused_module, (pre_hooks, post_hooks) in zip(fused_modules, block_hooks): + for pre_hook in pre_hooks: + fused_module.register_forward_pre_hook(pre_hook) + for post_hook in post_hooks: + fused_module.register_forward_hook(post_hook) + + +def _prepare_qat_embedding(embedding: Module, qconfig: "torch.quantization.QConfig"): + embedding.weight_fake_quant = qconfig.weight() + + def _qat_forward(self, input: torch.Tensor) -> torch.Tensor: + weight = self.weight_fake_quant(self.weight) + if weight.device != input.device: + # torch DataParallel may not pick up overwritten bound method + # send weight to correct device + weight = weight.to(input.device) + + return torch.nn.functional.embedding( + input, + weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + # bind qat forward to embedding + qat_forward_bound = _qat_forward.__get__(embedding, embedding.__class__) + embedding.to(embedding.weight.device) # set weight_fake_quant to correct device + setattr(embedding, "forward", qat_forward_bound) + + +def _set_submodule(root_module: Module, sub_module_path, sub_module: Module): + sub_module.training = root_module.training + current_module = root_module + sub_module_path = sub_module_path.split(".") + for child_module in sub_module_path[:-1]: + current_module = getattr(current_module, child_module) + setattr(current_module, sub_module_path[-1], sub_module) + + +def _wrap_bn_sub_class(bn_subclass, override_forward=True): + batch_norm = BatchNorm2d(bn_subclass.num_features) + batch_norm.__dict__ = bn_subclass.__dict__ + if override_forward: + batch_norm.forward = bn_subclass.forward + del bn_subclass + return batch_norm + + +class _BNWrapper(Module): + """ + Wraps BatchNormalization module to expose methods needed to enable + freezing/unfreezing of statistics + + :param module: BatchNormalization module to be wrapped + """ + + def __init__(self, module: Module): + super().__init__() + self.bn = module + self.freeze_bn = False + + @property + def running_mean(self): + return self.bn.running_mean + + @running_mean.setter + def running_mean(self, value): + self.bn.running_mean = value + + @property + def running_var(self): + return self.bn.running_var + + @running_var.setter + def running_var(self, value): + self.bn.running_var = value + + @property + def weight(self): + return self.bn.weight + + @weight.setter + def weight(self, value): + self.bn.weight = value + + @property + def bias(self): + return self.bn.bias + + @bias.setter + def bias(self, value): + self.bn.bias = value + + @property + def gamma(self): + return self.bn.gamma + + @gamma.setter + def gamma(self, value): + self.bn.gamma = value + + @property + def beta(self): + return self.bn.beta + + @beta.setter + def beta(self, value): + self.bn.beta = value + + @property + def num_batches_tracked(self): + return self.bn.num_batches_tracked + + @num_batches_tracked.setter + def num_batches_tracked(self, value): + self.bn.num_batches_tracked = value + + @property + def eps(self): + return self.bn.eps + + @eps.setter + def eps(self, value): + self.bn.eps = value + + @property + def momentum(self): + return self.bn.momentum + + @momentum.setter + def momentum(self, value): + self.bn.momentum = value + + def forward(self, x): + return self.bn(x) + + def freeze_bn_stats(self): + self.freeze_bn = True + self.bn.training = False + return self + + def reset_running_stats(self): + self.bn.reset_running_stats() + + def train(self, mode=True): + if not self.freeze_bn: + self.bn.train(mode) + return self + + def update_bn_stats(self): + self.freeze_bn = False + self.bn.training = True + return self diff --git a/src/sparseml/modifiers/quantization/utils/quantization_scheme.py b/src/sparseml/modifiers/quantization/utils/quantization_scheme.py new file mode 100644 index 00000000000..b3ef1807227 --- /dev/null +++ b/src/sparseml/modifiers/quantization/utils/quantization_scheme.py @@ -0,0 +1,382 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Schemas and types to support quantization +""" +from copy import deepcopy +from functools import partial +from typing import Any, Dict, Optional, Union + +import torch +from packaging import version +from pydantic import BaseModel, Field, validator +from torch.nn import Identity + + +try: + from torch import quantization as torch_quantization +except Exception: + torch_quantization = None + + +__all__ = [ + "DictQuantizationArgs", + "DictQuantizationScheme", + "QuantizationArgs", + "QuantizationScheme", + "QuantizationSchemeLoadable", + "compute_range", + "get_observer", +] + + +_PARSED_TORCH_VERSION = version.parse(torch.__version__) +_TORCH_PRE_112 = _PARSED_TORCH_VERSION < version.parse("1.12.0") + + +""" +Type definition aliases for defining QuantizationArgs and QuantizationScheme +as dictionaries for YAML serialization +""" +DictQuantizationArgs = Dict[str, Union[int, bool, Dict[str, Any]]] +DictQuantizationScheme = Dict[str, DictQuantizationArgs] + +""" +Type definition for a type that is valid for loading a QuantizationScheme +using QuantizationScheme.load +""" +QuantizationSchemeLoadable = Union[ + "QuantizationScheme", + DictQuantizationScheme, + str, + None, +] + + +class QuantizationArgs(BaseModel): + """ + Class representing user facing arguments to define quantization Observers of + activations or weights in a network + """ + + num_bits: int = Field( + default=8, description="number of bits to target for quantization" + ) + symmetric: bool = Field( + default=False, + description="set True to use symmetric quantization. Default False", + ) + strategy: str = Field( + default="tensor", + description=( + "scope of the quantization to be applied. can be 'tensor' or 'channel'" + ), + ) + kwargs: Dict[str, Any] = Field( + default_factory=dict, + description=( + "optional dict of kwargs to be passed directly to torch quantization " + "Observers constructor excluding quantization range or symmetry" + ), + ) + + @classmethod + def default_activation_args(cls): + """ + :return: default 8 bits asymmetric settings + """ + return cls(num_bits=8, symmetric=False) + + @classmethod + def default_weight_args(cls): + """ + :return: default 8 bits symmetric settings + """ + return cls(num_bits=8, symmetric=True) + + def get_observer(self) -> "torch.quantization.FakeQuantize": + """ + :return: torch quantization FakeQuantize built based on these QuantizationArgs + """ + return get_observer( + symmetric=self.symmetric, + strategy=self.strategy, + dtype=torch.qint8, + bits=self.num_bits, + reduce_range=self.kwargs.get("reduce_range", False), + qconfig_kwargs=self.kwargs, + ) + + @validator("strategy") + def validate_strategy(cls, value): + valid_scopes = ["tensor", "channel"] + if value not in valid_scopes: + raise ValueError(f"`strategy` must be one of {valid_scopes}, got {value}") + return value + + +class QuantizationScheme(BaseModel): + """ + Class composed of QuantizationArgs to build QConfig and QuantWrapper objects for + quantizing models. Provides a simple user interface for defining how inputs, + weights, and outputs should be quantized + """ + + def __init__(self, *args, **kwargs): + # support for loading from yaml str + args = [arg if arg != "null" else None for arg in args] + for key, val in kwargs.items(): + if val == "null": + kwargs[key] = None + super().__init__(*args, **kwargs) + + input_activations: Optional[QuantizationArgs] = Field( + default_factory=QuantizationArgs.default_activation_args, + description=( + "target quantization setting for input activations. Set to None to " + "not quantize input activations. Default is 8 bits asymmetric" + ), + ) + weights: Optional[QuantizationArgs] = Field( + default_factory=QuantizationArgs.default_weight_args, + description=( + "target quantization setting for model weights. Set to None to " + "not quantize weights. Default is 8 bits symmetric" + ), + ) + output_activations: Optional[QuantizationArgs] = Field( + default=None, + description=( + "target quantization setting for output activations. Set to None to " + "not quantize output activations. Default is None" + ), + ) + target_hardware: Optional[str] = Field( + default=None, + description=( + "target deployment runtime/hardware name to be set by default " + "classmethods. Default is None" + ), + ) + + @classmethod + def load( + cls, + scheme: QuantizationSchemeLoadable, + default: Optional["QuantizationScheme"] = None, + ) -> "QuantizationScheme": + """ + :param scheme: QuantizationScheme, dict representation of scheme, + or string alias of a scheme to load. Valid strings: + ['default', 'deepsparse', 'tensorrt'] + :param default: default QuantizationScheme to override 'default' scheme + with + :return: constructed QuantizationScheme object from the given scheme; + if given a dict, returns QuantizationScheme.parse_obj(scheme), string + input will return the defualt QuantizationScheme if set to 'default'. + """ + if isinstance(scheme, cls): + return scheme + elif scheme is None or scheme == "default": + # if no default override, defaults to QuantizationScheme() + return deepcopy(default) or cls() + elif isinstance(scheme, str): + if scheme == "deepsparse": + return cls.deepsparse() + elif scheme == "tensorrt": + return cls.tensorrt() + raise ValueError( + f"Unrecognized QuantizationScheme string alias {scheme}. " + "Valid strings: ['default', 'deepsparse', 'tensorrt']" + ) + elif isinstance(scheme, dict): + # default to dict + scheme = {key: _parse_quantization_arg(arg) for key, arg in scheme.items()} + return cls.parse_obj(scheme) + else: + raise ValueError( + f"Unrecognized type {type(scheme)} for QuantizationScheme.load, " + "expected one of: [QuantizationScheme, Dict, str, None]" + ) + + @classmethod + def deepsparse(cls) -> "QuantizationScheme": + """ + :return: QuantizationScheme for deepsparse targeted deployments - + int8, symmetric weights, asymmetric inputs, no output quantization + """ + return cls( + input_activations=QuantizationArgs(num_bits=8, symmetric=False), + weights=QuantizationArgs(num_bits=8, symmetric=True), + output_activations=None, + target_hardware="deepsparse", + ) + + @classmethod + def tensorrt(cls) -> "QuantizationScheme": + """ + :return: QuantizationScheme for tensorrt targeted deployments - + compatibility with explict quantization as supported by TensorRT 8.2: + int8, symmetric for both weights and inputs, no output quantization + """ + return cls( + input_activations=QuantizationArgs(num_bits=8, symmetric=True), + weights=QuantizationArgs(num_bits=8, symmetric=True), + output_activations=None, + target_hardware="tensorrt", + ) + + def get_qconfig(self) -> "torch.quantization.QConfig": + """ + :return: QConfig for Modules (output activations used, + use QuantWrapper for inputs) + """ + qconfig = _get_qconfig(self.output_activations, self.weights) + # add reference to this quantization scheme for reference + qconfig.quantization_scheme = self + return qconfig + + def get_wrapper_qconfig(self) -> "torch.quantization.QConfig": + """ + :return: QConfig for QuantWrapper objects (input activations used) + """ + qconfig = _get_qconfig(self.input_activations, None) + # add reference to this quantization scheme for reference + qconfig.quantization_scheme = self + return qconfig + + def __str__(self) -> str: + """ + :return: YAML friendly string serialization + """ + dict_repr = self.dict() + dict_repr = { + key: val if val is not None else "null" for key, val in dict_repr.items() + } + return str(dict_repr) + + +def compute_range(dtype: torch.dtype, bits: int): + """ + compute quantization limits depending on data type and number of bits + + :param dtype: data type. + :param bits: number of bits. + :return: minimum limit, maximum limit, whether the range is customized + """ + bits = bits if bits else 8 + is_custom = bits != 8 + if dtype == torch.qint8: + quant_min = -(2 ** (bits - 1)) + quant_max = (2 ** (bits - 1)) - 1 + elif dtype == torch.quint8: + quant_min = 0 + quant_max = (2**bits) - 1 + + return quant_min, quant_max, is_custom + + +def get_observer( + symmetric: bool, + strategy: str, + dtype: torch.dtype, + bits: int, + reduce_range: bool, + qconfig_kwargs: Dict[str, Any], +): + quant_min, quant_max, is_custom_qrange = compute_range(dtype, bits) + + if strategy == "channel": + qscheme = torch.per_channel_symmetric if symmetric else torch.per_channel_affine + observer_cls = torch_quantization.MovingAveragePerChannelMinMaxObserver + observer_kwargs = dict( + ch_axis=0, + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + ) + else: # default to tensor strategy + qscheme = torch.per_tensor_symmetric if symmetric else torch.per_tensor_affine + observer_cls = torch_quantization.MovingAverageMinMaxObserver + observer_kwargs = dict( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + ) + """ + in torch 1.9.1, quant_min and quant_max are not passed to observer: + https://github.com/pytorch/pytorch/blob/v1.9.1/torch/quantization/fake_quantize.py#L109 + however in 1.12.0, this is fixed so both are passed to observer: + https://github.com/pytorch/pytorch/blob/v1.12.1/torch/ao/quantization/fake_quantize.py#L132 + + Passing quant_min/quant_max to observer means the observer will have + `self.has_customized_qrange == True` in both 1.9.1 and 1.12.0. + + For whatever reason, both versions calculate zero point for + quint8 differently **if there is a customized_qrange** + 1. customized qrange has zero point of 127 + 2. non-customized has zero point of 128. + source: + https://github.com/pytorch/pytorch/blob/v1.12.1/torch/ao/quantization/observer.py#L293 + + **we want to ensure that the zero point is 128** + see https://github.com/neuralmagic/sparseml/pull/604 + """ + if is_custom_qrange: + # for both versions we need to include the custom min/max values in kwargs + observer_kwargs["quant_min"] = quant_min + observer_kwargs["quant_max"] = quant_max + if _TORCH_PRE_112: + # pre 1.12, the observer doesn't get passed the quant_min/quant_max values, + # so we patch them in to the constructor of the observer + observer_cls = partial( + observer_cls, quant_min=quant_min, quant_max=quant_max + ) + else: + # if using a non custom qrange, we can rely on default values used by + # the observers + if _TORCH_PRE_112: + # pre 1.12, the observer doesn't get passed the quant_min/quant_max values, + # so we are safe to pass these to FakeQuantize + observer_kwargs["quant_min"] = quant_min + observer_kwargs["quant_max"] = quant_max + else: + # post 1.12 we cannot pass them to the observer since that will set + # has_customized_qrange. instead we rely on the default values + # being equal to the `quant_min` and `quant_max` here. + pass + + observer_kwargs["observer"] = observer_cls + observer_kwargs.update(qconfig_kwargs or {}) + observer = torch_quantization.FakeQuantize.with_args( + **observer_kwargs, + ) + + return observer + + +def _get_qconfig( + activation_args: Optional[QuantizationArgs], weight_args: Optional[QuantizationArgs] +) -> "torch.quantization.QConfig": + return torch_quantization.QConfig( + activation=activation_args.get_observer() if activation_args else Identity, + weight=weight_args.get_observer() if weight_args else Identity, + ) + + +def _parse_quantization_arg(arg: Any): + if arg == "None": + return None + return arg diff --git a/src/sparseml/modifiers/quantization/utils/quantize.py b/src/sparseml/modifiers/quantization/utils/quantize.py new file mode 100644 index 00000000000..1130287479d --- /dev/null +++ b/src/sparseml/modifiers/quantization/utils/quantize.py @@ -0,0 +1,453 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tooling for applying quantization to pytorch modules via +structured configurations +""" +from typing import Dict, List, Optional + +import torch +from packaging import version +from torch.nn import Identity, Module + +from sparseml.modifiers.quantization.utils.constants import ( + FUSED_MODULE_NAMES, + NON_QUANTIZABLE_MODULE_NAMES, +) +from sparseml.modifiers.quantization.utils.helpers import ( + QATWrapper, + configure_module_default_qconfigs, + prepare_embeddings_qat, +) +from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationScheme +from sparseml.pytorch.utils import get_layer + + +try: + from torch import quantization as torch_quantization + from torch.nn import intrinsic as torch_intrinsic +except Exception: + torch_quantization = None + torch_intrinsic = None + + +__all__ = [ + "convert_module_qat_from_schemes", + "is_qat_helper_module", + "is_quantizable_module", + "set_quantization_schemes", + "set_qconfigs_from_quantization_schemes", + "add_input_activation_quant_wrappers", + "add_output_activation_observers", + "raise_if_torch_quantization_not_available", +] + + +def is_qat_helper_module(module: Module) -> bool: + """ + :param module: module to check + :return: True if module is an instance of a torch QAT helper class + """ + # prefer FakeQuantizeBase which was introduced around torch 1.9 + fake_quantize_class = getattr( + torch_quantization, "FakeQuantizeBase", torch_quantization.FakeQuantize + ) + return isinstance( + module, + ( + fake_quantize_class, + torch_quantization.ObserverBase, + torch_quantization.DeQuantStub, + torch_quantization.QuantStub, + Identity, + ), + ) + + +def is_quantizable_module( + module: Module, + exclude_module_types: Optional[List[str]] = None, +) -> bool: + """ + :param module: module to check + :param exclude_module_types: string names of modules to not include for + quantization. Default None + :return: boolean value if the module is quantizable. Module is considered + quantizable if its type is not included in exclude_module_types or + NON_QUANTIZABLE_MODULE_NAMES and + it either has no module children outside of QAT or is a torch qat fused module + """ + # considers any non-excluded "leaf level" (no children) submodule + # to be quantizable as well as torch fused modules + + # add all default excluded module type names + exclude_module_types = set(exclude_module_types or []) + exclude_module_types.update(NON_QUANTIZABLE_MODULE_NAMES) + + module_type_name = module.__class__.__name__ + if module_type_name in exclude_module_types: + return False + + return ( + module_type_name in FUSED_MODULE_NAMES + or all( + # no children (leaf modules) evaluate to all([]) - (True) + is_qat_helper_module(child) + for child in module.children() + ) + or isinstance(module, torch_quantization.QuantWrapper) + ) + + +def set_quantization_schemes( + model: Module, + scheme: QuantizationScheme, + scheme_overrides: Optional[Dict[str, QuantizationScheme]] = None, + ignore: Optional[List[str]] = None, + strict: bool = True, +): + """ + Sets an appropriate `quantization_scheme` to targeted quantizable submodules + + :param model: module to attach QuantizationSchemes to + :param scheme: default scheme to add to a target module unless overwritten + by another scheme + :param scheme_overrides: dictionary of module type names or submodule names + mapped to a quantization scheme to override with. If a submodule matches + to multiple submodule overrides and/or a module type, module type will + take the highest priority followed by the longest matched submodule name + :param ignore: string names of modules type names or submodule names to not include + for quantization. Default None + :param strict: if True, will raise an error if any module types or submodules in + scheme_overrides or ignore are not found in the given module. Default True + """ + # default to empty dict + scheme_overrides = scheme_overrides or {} + + if strict: + _validate_set_module_schemes(model, scheme_overrides, ignore) + + # keep mapping of targets for QATWrapper to inject later so module is not modified + # during iteration + wrap_qat_targets = {} # type: Dict[str, QuantizationScheme] + + for submodule_name, submodule in model.named_modules(): + if ignore and _match_submodule_name_or_type(submodule, submodule_name, ignore): + # submodule type or graph section set to ignore, skip + continue + + # override default scheme if necessary + override_key = _match_submodule_name_or_type( + submodule, submodule_name, scheme_overrides + ) + submodule_scheme = ( + scheme if override_key is None else scheme_overrides[override_key] + ) + is_module_type_override = override_key == submodule.__class__.__name__ + + if getattr(submodule, "wrap_qat", False): + # wrap_qat overrides default scheme behavior + wrap_qat_targets[submodule_name] = submodule_scheme + elif is_module_type_override or is_quantizable_module(submodule): + # is base quantizable module or user specifically targeted module type + submodule.quantization_scheme = submodule_scheme + + # inject any targeted QATWrappers + for wraped_module_name, scheme in wrap_qat_targets.items(): + _inject_qat_wrapper(model, wraped_module_name, scheme) + + +def set_qconfigs_from_quantization_schemes(module: Module): + """ + Sets `qconfig` properties to the given module and its submodule + based on any potentially assigned quantization schemes + + :param module: module to set qconfig properties for + """ + for submodule in module.modules(): + if not hasattr(submodule, "quantization_scheme"): + continue + # potentially re-load if scheme is set as dict or str + quantization_scheme = QuantizationScheme.load(submodule.quantization_scheme) + if isinstance(submodule, torch_quantization.QuantWrapper): + submodule.qconfig = quantization_scheme.get_wrapper_qconfig() + submodule.quant.qconfig = submodule.qconfig + else: + submodule.qconfig = quantization_scheme.get_qconfig() + + +def add_input_activation_quant_wrappers(module: Module) -> Module: + """ + Adds QuantWrapper objects to wrap submodules that include quantization + schemes targeting input activations + + :param module: module to add input activation QuantWrappers for + :return: the updated module - necessary in case top level module is wrapped + as in-place modification will not support it + """ + # check if module targets input activation quantization + quantize_activations = ( + hasattr(module, "quantization_scheme") + and (module.quantization_scheme is not None) + and module.quantization_scheme.input_activations is not None + and not isinstance(module, torch.nn.quantized.FloatFunctional) + ) + + if quantize_activations: + # wrap module with a QuantWrapper and assign it the input activation qconfig + quantization_scheme = module.quantization_scheme + module = torch_quantization.QuantWrapper(module) + module.quantization_scheme = quantization_scheme + + # assumes no nested children of a wrapped block need input activation + # does not recurse further in this case + else: + # recurse to module children + for name, child in module.named_children(): + setattr(module, name, add_input_activation_quant_wrappers(child)) + return module + + +def add_output_activation_observers(module: Module): + """ + implementation of torch.quantization add_observers_ that only adds observers + according to attached quantization_scheme properties. the existing implementation + (1.9+) includes its own logic for propagating including overriding set qconfigs + for certain activations without the ability to disable this behavior + + :param module: module to add output activation observers to + """ + # adapted from torch/ao/quantization/quantize.py::_add_observer_ + # source: https://github.com/pytorch/pytorch/blob/v1.13.0/torch/ao/quantization/quantize.py#L135 # noqa: E501 + try: + device = next(module.parameters()).device + except StopIteration: + # default to CPU if module has no parameters + device = "cpu" + + def _needs_observer(target_module: Module): + # combines logic from multiple places of original implementation which + # mostly checked for existnace of a qconfig and if the target was a leaf + # module + if not hasattr(target_module, "quantization_scheme") or isinstance( + target_module, torch_quantization.QuantWrapper + ): + # submodule not targeted for quantization, already has attached + # output observer, or is QuantWrapper (quant wrapper delegates to children) + return False + + if hasattr(target_module, "activation_post_process"): + # activation post process is set, only mark for potential overriding + # if it is an identity (this comes up when the property is set for + # later overriding such as FloatFunctional + return isinstance(target_module.activation_post_process, Identity) + + for descendent_module in target_module.modules(): + if descendent_module is target_module: + continue # skip itself + descendent_scheme = getattr(descendent_module, "quantization_scheme", None) + if descendent_scheme is not None and ( + descendent_scheme.output_activations is not None + ): + # a descendent of this module targets output activations, return False + return False + # module has a quantization scheme and no descendents track output activations + return True + + def _observer_forward_hook(self, inp, output): + # reference for output activation observer hook to register + return self.activation_post_process(output) + + def _add_activation_post_process(target_module: Module): + # get output observer + output_observer = submodule.qconfig.activation() + output_observer.to(device) + + # add an activation post process module + target_module.add_module("activation_post_process", output_observer) + + # add hook to call observer after output activation has been returned + handle = target_module.register_forward_hook(_observer_forward_hook) + target_module._forward_hooks.move_to_end(handle.id, last=False) + + for submodule in module.modules(): + if not _needs_observer(submodule): + # submodule not targeted for quantization, already has attached + # output observer, or has a descendent that tracks output activations + continue + + # extract qconfig and observer from qconfig + if not hasattr(submodule, "qconfig"): + # set qconfig from scheme if not already set + set_qconfigs_from_quantization_schemes(submodule) + assert hasattr(submodule, "qconfig") + + # create observer, add as child module, and register hook to call + _add_activation_post_process(submodule) + + +def convert_module_qat_from_schemes(module: Module): + """ + Converts submodules with set quantization_schemes into quantization aware modules + with FakeQuantize modules in the model + + :param module: module to convert to QAT mode + """ + # inject necessary QuantWrappers into the module to apply QAT to + # targeted layer input activations + module = add_input_activation_quant_wrappers(module) + + # set appropriate qconfig properties in submodules + set_qconfigs_from_quantization_schemes(module) + + # override any qconfigs set in `configure_qconfigs` function + configure_module_default_qconfigs(module) + + # set modules with proper qconfigs to QAT mode + convert_kwargs = ( + dict(convert_custom_config_dict={}) # do not let torch override any qconfigs + if version.parse(torch.__version__) >= version.parse("1.8.0") + else {} + ) + torch_quantization.convert( + module, + mapping=_get_qat_module_mappings(), + inplace=True, + remove_qconfig=False, + **convert_kwargs, + ) + # re-attach any quantization schemes lost during conversion + _reattach_quantization_schemes(module) + + # add observers for output activations + add_output_activation_observers(module) + + # manual pass to convert relevant Embedding layers + prepare_embeddings_qat(module) + + +def raise_if_torch_quantization_not_available(): + """ + :raises: RuntimeError if the installed torch version does not include + support for quantization aware training + """ + if torch_quantization is None or torch_intrinsic is None: + raise RuntimeError( + "Unable to import package torch.quantization and/or " + "torch.nn.intrinsic. " + "Try upgrading your PyTorch version to use the QuantizationModifier." + ) + + +def _match_submodule_name_or_type( + submodule: Module, submodule_name: str, names_or_types: List[str] +) -> Optional[str]: + # match preferences: + # 1. match module type name + # 2. match the submodule prefix (longest first) + submodule_match = "" + for name_or_type in names_or_types: + name_to_compare = submodule_name[:] + if name_to_compare.startswith("module."): + name_to_compare = name_to_compare[7:] + if name_or_type == submodule.__class__.__name__: + # type match, return type name + return name_or_type + if name_to_compare.startswith(name_or_type) and ( + len(name_or_type) > len(submodule_match) + ): + # match to most specific submodule name + submodule_match = name_or_type + return submodule_match or None # return None if no match + + +def _inject_qat_wrapper( + root_module: Module, + target_submodule_name: str, + quantization_scheme: QuantizationScheme, +): + submodule_name_parts = target_submodule_name.split(".") + parent_name = ".".join(submodule_name_parts[:-1]) + + parent_module = get_layer(parent_name, root_module) + target_module = getattr(parent_module, submodule_name_parts[-1]) + + wrapped_target_module = QATWrapper.from_module(target_module, quantization_scheme) + setattr(parent_module, submodule_name_parts[-1], wrapped_target_module) + + +def _reattach_quantization_schemes(module: Module): + # after torch.prepare_qat is called, quantization scheme properties may be lost + # due to transfer of base module classes to their QAT implementations + # this function uses the reference to the quantization_scheme in the qconfig + # to potentially re-attach the scheme + for submodule in module.modules(): + qconfig = getattr(submodule, "qconfig", None) + if not qconfig or hasattr(submodule, "quantization_scheme"): + # no qconfig, or scheme already set + continue + quantization_scheme = getattr(qconfig, "quantization_scheme", None) + if not quantization_scheme: + continue + submodule.quantization_scheme = quantization_scheme + + +def _get_qat_module_mappings() -> Dict[Module, Module]: + mappings = torch_quantization.quantization_mappings + if not hasattr(mappings, "get_default_qat_module_mappings"): + # legacy + return mappings.get_qat_module_mappings() + # latest + return mappings.get_default_qat_module_mappings() + + +def _validate_set_module_schemes( + model: Module, + scheme_overrides: Optional[Dict[str, QuantizationScheme]] = None, + ignore: Optional[List[str]] = None, +): + def _get_unmatched_types_or_names(types_or_names): + unmatched = [] + for type_or_name in types_or_names: + matched = False + for submodule_name, submodule in model.named_modules(): + name_to_compare = submodule_name[:] + if name_to_compare.startswith("module."): + name_to_compare = name_to_compare[7:] + if name_to_compare.startswith(type_or_name) or ( + submodule.__class__.__name__ == type_or_name + ): + matched = True + break + if not matched: + unmatched.append(type_or_name) + return unmatched + + def _build_error_str(property_name, unmatched_values): + return ( + f"{property_name} contains submodule names or module types " + "that do not match to any submodules in the model. " + f"unmatched values: {unmatched_values}" + ) + + unmatched_scheme_overrides = _get_unmatched_types_or_names(scheme_overrides) + if unmatched_scheme_overrides: + raise ValueError( + _build_error_str("scheme_overrides", unmatched_scheme_overrides) + ) + + unmatched_ignore = _get_unmatched_types_or_names(ignore) + if unmatched_ignore: + raise ValueError(_build_error_str("ignore", unmatched_ignore)) diff --git a/src/sparseml/pytorch/utils/helpers.py b/src/sparseml/pytorch/utils/helpers.py index 96e85564b90..f2f5eccc7d0 100644 --- a/src/sparseml/pytorch/utils/helpers.py +++ b/src/sparseml/pytorch/utils/helpers.py @@ -103,6 +103,7 @@ "download_framework_model_by_recipe_type", "detach", "adjust_quantization_for_onnx_export", + "get_dependency_order", ] @@ -1215,3 +1216,34 @@ def adjust_quantization_for_onnx_export(module: torch.nn.Module) -> torch.nn.Mod quant.quant_max = 255 # don't update observer since ranges are artificially modified quant.observer_enabled[0] = 0 + + +def get_dependency_order( + layer: Module, subset: Dict, an_input: Tensor, **kwargs +) -> List[str]: + """ + Get a list of a subset of modules in layer ordered by execution order, which honors + the dependencies in the graph + + :param layer: pytorch module to calculate dependencies for + :param subset: subset of modules in the layer to include in the ordering + :param an_input: example input to pass through the layer forward pass, used to + determine execution order + + :return: list of module names in execution order + """ + order = [] + + def exe_input(name): + def _exe_input(_, inp, out): + if name in subset: + order.append(name) + + return _exe_input + + # register a hook for each module of interest, will be triggered in exeuction order + handles = [subset[name].register_forward_hook(exe_input(name)) for name in subset] + layer(an_input, **kwargs) + for h in handles: + h.remove() + return order diff --git a/src/sparseml/transformers/data/__init__.py b/src/sparseml/transformers/data/__init__.py new file mode 100644 index 00000000000..0f81a73351b --- /dev/null +++ b/src/sparseml/transformers/data/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa +from .base_llm import TransformersDataset +from .c4 import * +from .open_platypus import * +from .ptb import * +from .wikitext import * diff --git a/src/sparseml/transformers/data/base_llm.py b/src/sparseml/transformers/data/base_llm.py new file mode 100644 index 00000000000..61c0b362087 --- /dev/null +++ b/src/sparseml/transformers/data/base_llm.py @@ -0,0 +1,116 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from typing import Optional + +import torch +from datasets import load_dataset +from torch.utils.data import Dataset +from transformers import AutoTokenizer + +from sparsezoo.utils.registry import RegistryMixin + + +class TransformersDataset(RegistryMixin, Dataset): + def __init__( + self, + model: str, + seqlen: int, + nsamples: int, + path: str, + name: Optional[str] = None, + seed: int = 0, + split: str = "train", + use_max_tokens: bool = True, + split_percent_to_use: float = 1.0, + shuffle: bool = True, + **kwargs, + ): + self.tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) + self._nsamples = nsamples + self._seqlen = seqlen + self._use_max_tokens = use_max_tokens + self._split_from_end = False + try: + dataset = load_dataset(path, name, **kwargs, split=split) + except ValueError: + dataset = load_dataset(path, name, **kwargs, split="train") + self._split_from_end = True + + random.seed(seed) + data = list(dataset) + data_to_use = int(split_percent_to_use * len(data)) + self._data = data[-data_to_use:] if self._split_from_end else data[:data_to_use] + if not self._nsamples: + self._nsamples = len(dataset) + if shuffle: + random.shuffle(self._data) + self._data = self._data[: self._nsamples] + + def create_dataloader(self, data, join_on=None): + self.loader = [] + if self._use_max_tokens: + data_idx = 0 + encoder = self.tokenizer(join_on.join(data), return_tensors="pt")[ + "input_ids" + ][0] + while self._nsamples is None or len(self.loader) < self._nsamples: + start_idx = data_idx * self._seqlen + end_idx = start_idx + self._seqlen + if start_idx >= len(encoder): + break + elif end_idx >= len(encoder): + sequence = encoder[start_idx:] + else: + sequence = encoder[start_idx:end_idx] + data_idx += 1 + + tokenized_sample = self._add_end_token(sequence) + tokenized_sample = torch.unsqueeze(tokenized_sample, dim=0) + self.loader.append(tokenized_sample) + if data_idx >= len(data): + break + else: + for sample in data: + tokenized_sample = self.tokenizer( + sample, + truncation=True, + max_length=self._seqlen, + return_tensors="pt", + padding=False, + )["input_ids"][0] + tokenized_sample = self._add_end_token(tokenized_sample) + tokenized_sample = torch.unsqueeze(tokenized_sample, dim=0) + self.loader.append(tokenized_sample) + + def _add_end_token(self, tokenized_sample): + if tokenized_sample[-1] != self.tokenizer.eos_token_id: + if len(tokenized_sample) == self._seqlen: + tokenized_sample[-1] = self.tokenizer.eos_token_id + else: + tokenized_sample = torch.concatenate( + ( + tokenized_sample, + torch.tensor((self.tokenizer.eos_token_id,)), + ), + ) + + return tokenized_sample + + def __len__(self): + return len(self.loader) + + def __item__(self, idx): + return self.loader[idx] diff --git a/src/sparseml/transformers/data/c4.py b/src/sparseml/transformers/data/c4.py new file mode 100644 index 00000000000..dbc59151331 --- /dev/null +++ b/src/sparseml/transformers/data/c4.py @@ -0,0 +1,47 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch.nn import Module + +from sparseml.transformers.data.base_llm import TransformersDataset + + +@TransformersDataset.register(name="c4") +class C4(TransformersDataset): + def __init__( + self, + model: Module, + seqlen: int, + nsamples: int, + seed: int = 0, + split: str = "train", + split_percent_to_use: float = 1.0, + ): + kwargs = {"data_files": {split: "en/c4-train.00000-of-01024.json.gz"}} + if split_percent_to_use > 0.2: + split_percent_to_use = 0.2 + super().__init__( + model=model, + seqlen=seqlen, + nsamples=nsamples, + path="allenai/c4", + name="allenai--c4", + seed=seed, + split=split, + split_percent_to_use=split_percent_to_use, + **kwargs, + ) + + processed_data = [sample["text"] for sample in self._data] + self.create_dataloader(processed_data, join_on=" ") diff --git a/src/sparseml/transformers/data/open_platypus.py b/src/sparseml/transformers/data/open_platypus.py new file mode 100644 index 00000000000..4f33a3c42a5 --- /dev/null +++ b/src/sparseml/transformers/data/open_platypus.py @@ -0,0 +1,68 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch.nn import Module + +from sparseml.transformers.data.base_llm import TransformersDataset + + +@TransformersDataset.register(name="open_platypus") +class OpenPlatypus(TransformersDataset): + ALPACA_TEMPLATE = { + "prompt_input": "Below is an instruction that describes a task, paired with an " + "input that provides further context. Write a response that appropriately " + "completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n" + "{input}\n\n### Response:\n", + "prompt_no_input": "Below is an instruction that describes a task. Write a " + "response that appropriately completes the request.\n\n### Instruction:\n{" + "instruction}\n\n### Response:\n", + } + + def __init__( + self, + model: Module, + seqlen: int, + nsamples: int, + seed: int = 0, + split: str = "train", + split_percent_to_use: float = 1.0, + ): + super().__init__( + model=model, + seqlen=seqlen, + nsamples=nsamples, + path="garage-bAInd/Open-Platypus", + name=None, + seed=seed, + split=split, + use_max_tokens=False, + split_percent_to_use=split_percent_to_use, + ) + + processed_data = [] + for sample in self._data: + if "input" in sample: + processed_sample = self.ALPACA_TEMPLATE["prompt_input"].format( + instruction=sample["instruction"], input=sample["input"] + ) + else: + processed_sample = self.ALPACA_TEMPLATE["prompt_no_input"].format( + instruction=sample["instruction"] + ) + + if "output" in sample: + processed_sample += sample["output"] + processed_data.append(processed_sample) + + self.create_dataloader(processed_data) diff --git a/src/sparseml/transformers/data/ptb.py b/src/sparseml/transformers/data/ptb.py new file mode 100644 index 00000000000..cfb58be5eba --- /dev/null +++ b/src/sparseml/transformers/data/ptb.py @@ -0,0 +1,43 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch.nn import Module + +from sparseml.transformers.data.base_llm import TransformersDataset + + +@TransformersDataset.register(name="ptb") +class Ptb(TransformersDataset): + def __init__( + self, + model: Module, + seqlen: int, + nsamples: int, + seed: int = 0, + split: str = "train", + split_percent_to_use: float = 1.0, + ): + super().__init__( + model=model, + seqlen=seqlen, + nsamples=nsamples, + path="ptb_text_only", + name="penn_treebank", + seed=seed, + split=split, + split_percent_to_use=split_percent_to_use, + ) + + processed_data = [sample["sentence"] for sample in self._data] + self.create_dataloader(processed_data, join_on=" ") diff --git a/src/sparseml/transformers/data/wikitext.py b/src/sparseml/transformers/data/wikitext.py new file mode 100644 index 00000000000..82108a78555 --- /dev/null +++ b/src/sparseml/transformers/data/wikitext.py @@ -0,0 +1,45 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch.nn import Module + +from sparseml.transformers.data.base_llm import TransformersDataset + + +@TransformersDataset.register(name="wikitext2") +class WikiText(TransformersDataset): + def __init__( + self, + model: Module, + seqlen: int, + nsamples: int, + seed: int = 0, + split: str = "train", + split_percent_to_use: float = 1.0, + ): + super().__init__( + model=model, + seqlen=seqlen, + nsamples=nsamples, + path="wikitext", + name="wikitext-2-raw-v1", + seed=seed, + split=split, + split_percent_to_use=split_percent_to_use, + shuffle=False, + ) + + join_on = "\n\n" if split == "test" else " " + processed_data = [str(sample["text"]) for sample in self._data] + self.create_dataloader(processed_data, join_on=join_on) diff --git a/src/sparseml/transformers/sparsification/obcq/example.yaml b/src/sparseml/transformers/sparsification/obcq/example.yaml new file mode 100644 index 00000000000..5987e220902 --- /dev/null +++ b/src/sparseml/transformers/sparsification/obcq/example.yaml @@ -0,0 +1,48 @@ +test_stage: + obcq_modifiers: + QuantizationModifier: + ignore: ["lm_head", "Embedding", "OPTLearnedPositionalEmbedding", "QuantizableBatchMatMul", "BMMLeftInput_QK", "BMMRightInput_QK", "BMMOutput_QK", "BMMLeftInput_PV", "BMMRightInput_PV", "BMMOutput_PV"] + post_oneshot_calibration: True + scheme_overrides: + ReLU: + input_activations: null + output_activations: null + LayerNorm: + input_activations: null + output_activations: null + SparseGPTModifier: + sparsity: 0.5 + block_size: 128 + sequential_update: False + quantize: True + percdamp: 0.01 + prunen: 0 + prunem: 0 + compress_layers: [ + "model.decoder.layers.0", + "model.decoder.layers.1", + "model.decoder.layers.2", + "model.decoder.layers.3", + "model.decoder.layers.4", + "model.decoder.layers.5", + "model.decoder.layers.6", + "model.decoder.layers.7", + "model.decoder.layers.8", + "model.decoder.layers.9", + "model.decoder.layers.10", + "model.decoder.layers.11", + "model.decoder.layers.12", + "model.decoder.layers.13", + "model.decoder.layers.14", + "model.decoder.layers.15", + "model.decoder.layers.16", + "model.decoder.layers.17", + "model.decoder.layers.18", + "model.decoder.layers.19", + "model.decoder.layers.20", + "model.decoder.layers.21", + "model.decoder.layers.22", + "model.decoder.layers.23" + ] + target_ids: ["attention_mask"] + layer_prefix: "decoder" \ No newline at end of file diff --git a/src/sparseml/transformers/sparsification/obcq/example_llama.yaml b/src/sparseml/transformers/sparsification/obcq/example_llama.yaml new file mode 100644 index 00000000000..c15b040b598 --- /dev/null +++ b/src/sparseml/transformers/sparsification/obcq/example_llama.yaml @@ -0,0 +1,89 @@ +test_stage: + obcq_modifiers: + QuantizationModifier: + ignore: + - LlamaRotaryEmbedding + - LlamaRMSNorm + - SiLUActivation + - model.layers.0.mlp.down_proj + - model.layers.1.mlp.down_proj + - model.layers.2.mlp.down_proj + - model.layers.3.mlp.down_proj + - model.layers.4.mlp.down_proj + - model.layers.5.mlp.down_proj + - model.layers.6.mlp.down_proj + - model.layers.7.mlp.down_proj + - model.layers.8.mlp.down_proj + - model.layers.9.mlp.down_proj + - model.layers.10.mlp.down_proj + - model.layers.11.mlp.down_proj + - model.layers.12.mlp.down_proj + - model.layers.13.mlp.down_proj + - model.layers.14.mlp.down_proj + - model.layers.15.mlp.down_proj + - model.layers.16.mlp.down_proj + - model.layers.17.mlp.down_proj + - model.layers.18.mlp.down_proj + - model.layers.19.mlp.down_proj + - model.layers.20.mlp.down_proj + - model.layers.21.mlp.down_proj + - model.layers.22.mlp.down_proj + - model.layers.23.mlp.down_proj + - model.layers.24.mlp.down_proj + - model.layers.25.mlp.down_proj + - model.layers.26.mlp.down_proj + - model.layers.27.mlp.down_proj + - model.layers.28.mlp.down_proj + - model.layers.29.mlp.down_proj + - model.layers.30.mlp.down_proj + - model.layers.31.mlp.down_proj + post_oneshot_calibration: True + scheme_overrides: + Embedding: + input_activations: null + weights: + num_bits: 8 + symmetric: False + SparseGPTModifier: + sparsity: 0.5 + block_size: 128 + sequential_update: False + quantize: True + percdamp: 0.01 + prunen: 0 + prunem: 0 + compress_layers: [ + "model.layers.0", + "model.layers.1", + "model.layers.2", + "model.layers.3", + "model.layers.4", + "model.layers.5", + "model.layers.6", + "model.layers.7", + "model.layers.8", + "model.layers.9", + "model.layers.10", + "model.layers.11", + "model.layers.12", + "model.layers.13", + "model.layers.14", + "model.layers.15", + "model.layers.16", + "model.layers.17", + "model.layers.18", + "model.layers.19", + "model.layers.20", + "model.layers.21", + "model.layers.22", + "model.layers.23", + "model.layers.24", + "model.layers.25", + "model.layers.26", + "model.layers.27", + "model.layers.28", + "model.layers.29", + "model.layers.30", + "model.layers.31", + ] + target_ids: ["attention_mask", "position_ids"] \ No newline at end of file diff --git a/src/sparseml/transformers/sparsification/obcq/obcq.py b/src/sparseml/transformers/sparsification/obcq/obcq.py new file mode 100644 index 00000000000..f7ddf595602 --- /dev/null +++ b/src/sparseml/transformers/sparsification/obcq/obcq.py @@ -0,0 +1,176 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +from torch.nn import Module + +import sparseml.core.session as sml +from sparseml.core.framework import Framework +from sparseml.modifiers.obcq.utils.helpers import ppl_eval_general +from sparseml.optim.helpers import load_recipe_yaml_str +from sparseml.transformers.data import TransformersDataset +from sparseml.transformers.sparsification.obcq.utils.helpers import ( + llama_forward, + opt_forward, +) +from sparseml.transformers.utils.model import SparseCasualLM + + +__all__ = ["one_shot"] + +_LOGGER = logging.getLogger(__name__) +SUPPORTED_DATASETS = ["wikitext2", "ptb", "c4", "open_platypus"] +SUPPORTED_MODELS = ["opt", "llama"] + + +def one_shot( + model_path: str, + dataset_name: str, + num_samples: int = 128, + device: str = "cuda:0", + deploy_dir: Optional[str] = ".", + recipe_file: Optional[str] = None, + eval_data: Optional[str] = None, + do_save: Optional[bool] = False, +) -> Module: + """ + Performs in place one shot sparsification/quantization of a model based on: + + :param model_path: path to Hugging Face stub + :param dataset_name: Dataset to extract calibration data from + :param num_samples: Number of samples to extract from the dataset + :param device: Device (cuda:index or cpu) to use for computation + :param deploy_dir: The output directory to save the model to + :param recipe_file: recipe containing SparseGPT configuration + :param eval_data: dataset to use for perplexity evalaution, or none to skip + :param do_save: whether to save the output model to disk + + :return: Pytorch module with OBCQ applied + """ + + if do_save: + deploy_dir = Path(os.path.join(deploy_dir, "obcq_deployment")) + + if deploy_dir.exists(): + raise RuntimeError(f"deploy_dir={deploy_dir} already exists") + + model_loader_fn = None + forward_fn = None + if "opt" in model_path.lower(): + model_loader_fn = SparseCasualLM.opt_model_from_pretrained + forward_fn = opt_forward + elif "llama" in model_path.lower(): + model_loader_fn = SparseCasualLM.llama_model_from_pretrained + forward_fn = llama_forward + else: + raise ValueError(f"model_path={model_path} should be one of {SUPPORTED_MODELS}") + model = model_loader_fn(model_path) + + if dataset_name not in SUPPORTED_DATASETS: + raise ValueError( + f"dataset_name={dataset_name} should be one of {SUPPORTED_DATASETS}" + ) + dataset = TransformersDataset.load_from_registry( + dataset_name, + model=model_path, + seqlen=model.seqlen, + nsamples=num_samples, + seed=0, + split="train", + ) + calibration_data = dataset.loader + tokenizer = dataset.tokenizer + + sml.create_session() + session = sml.active_session() + session.apply( + framework=Framework.pytorch, + recipe=recipe_file, + model=model, + calib_data=calibration_data, + start=0.0, + device=device, + copy_data=False, + ) + + if do_save: + _save(model, tokenizer, deploy_dir, recipe_file) + if eval_data: + dataset = TransformersDataset.load_from_registry( + eval_data, + model=model_path, + seqlen=model.seqlen, + nsamples=None, + seed=0, + split="test", + split_percent_to_use=0.1 if eval_data == "open_platypus" else 1.0, + ) + test_data = dataset.loader + ppl_eval_general( + forward_fn, model, test_data, device, max_samples_per_iteration=8 + ) + + return model + + +def _save(model, tokenizer, save_path, recipe_path): + model.save_pretrained(save_path) + tokenizer.save_pretrained(save_path) + + _LOGGER.info("Saving output to {}".format(os.path.abspath(save_path))) + recipe_output_path = os.path.join(save_path, "recipe.yaml") + with open(recipe_output_path, "w") as fp: + fp.write(load_recipe_yaml_str(recipe_path)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("model", type=str, help="Hugging Face stub of model to load") + parser.add_argument( + "dataset", + type=str, + choices=["wikitext2", "ptb", "c4", "open_platypus"], + help="Name of dataset to extract calibration data from", + ) + parser.add_argument( + "--nsamples", type=int, default=128, help="Number of calibration data samples" + ) + parser.add_argument("--device", type=str, default="cuda:0") + parser.add_argument("--deploy-dir", type=str, default=".") + parser.add_argument("--recipe", type=str, default=None) + parser.add_argument( + "--eval", type=str, default=None, help="Optional dataset for perplexity eval" + ) + parser.add_argument( + "--save", type=bool, default=False, help="Save output model to disk" + ) + + args = parser.parse_args() + + one_shot( + model_path=args.model, + dataset_name=args.dataset, + deploy_dir=args.deploy_dir, + num_samples=args.nsamples, + device=args.device, + recipe_file=args.recipe, + eval_data=args.eval, + do_save=args.save, + ) diff --git a/src/sparseml/transformers/sparsification/obcq/utils/__init__.py b/src/sparseml/transformers/sparsification/obcq/utils/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/src/sparseml/transformers/sparsification/obcq/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/transformers/sparsification/obcq/utils/helpers.py b/src/sparseml/transformers/sparsification/obcq/utils/helpers.py new file mode 100644 index 00000000000..5cce4d21bf2 --- /dev/null +++ b/src/sparseml/transformers/sparsification/obcq/utils/helpers.py @@ -0,0 +1,117 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch +from torch.nn import Module + +from sparseml.modifiers.obcq.utils.helpers import ( + cache_attention_inputs, + execute_offloaded_module, +) + + +__all__ = ["opt_forward", "llama_forward"] + + +def opt_forward(model: Module, data_loader: List, device: str, nsamples: int = None): + """ + Run a forward pass of OPT, used for perplexity evaluation + + :param model: Pytorch module to run + :param data_loader: data to run through model + :param device: device name to perform computation on + :param nsamples: number of samples of data_loader to run, None to run them all + + :return: logits output of the model + """ + cached_inputs = cache_attention_inputs( + model, data_loader, device, nsamples, ["attention_mask"], "decoder" + ) + buffer = [b[0] for b in cached_inputs.pop("inputs")] + for layer in model.model.decoder.layers: + buffer = execute_offloaded_module( + layer, + buffer, + device, + cached_inputs=cached_inputs, + use_cache=False, + ) + buffer = [b[0] for b in buffer] + + del cached_inputs + torch.cuda.empty_cache() + + if model.model.decoder.final_layer_norm is not None: + buffer = execute_offloaded_module( + model.model.decoder.final_layer_norm, + buffer, + device, + ) + if model.model.decoder.project_out is not None: + buffer = execute_offloaded_module( + model.model.decoder.project_out, + buffer, + device, + ) + logits = execute_offloaded_module( + model.lm_head, + buffer, + device, + ) + + return logits + + +def llama_forward(model: Module, data_loader: List, device: str, nsamples: int = None): + """ + Run a forward pass of Llama, used for perplexity evaluation + + :param model: Pytorch module to run + :param data_loader: data to run through model + :param device: device name to perform computation on + :param nsamples: number of samples of data_loader to run, None to run them all + + :return: logits output of the model + """ + cached_inputs = cache_attention_inputs( + model, data_loader, device, nsamples, ["attention_mask", "position_ids"], None + ) + buffer = [b[0] for b in cached_inputs.pop("inputs")] + for layer in model.model.layers: + buffer = execute_offloaded_module( + layer, + buffer, + device, + cached_inputs=cached_inputs, + use_cache=False, + ) + buffer = [b[0] for b in buffer] + + del cached_inputs + torch.cuda.empty_cache() + + buffer = execute_offloaded_module( + model.model.norm, + buffer, + device, + ) + logits = execute_offloaded_module( + model.lm_head, + buffer, + device, + ) + + return logits diff --git a/src/sparseml/transformers/utils/model.py b/src/sparseml/transformers/utils/model.py index 3ba18d48fe8..3f89f7b4127 100644 --- a/src/sparseml/transformers/utils/model.py +++ b/src/sparseml/transformers/utils/model.py @@ -25,6 +25,8 @@ AutoModelForQuestionAnswering, AutoModelForSequenceClassification, AutoModelForTokenClassification, + LlamaForCausalLM, + OPTForCausalLM, ) from transformers.file_utils import WEIGHTS_NAME @@ -420,6 +422,47 @@ def _check_tf(model_name_or_path: str): ) +class SparseCasualLM: + """ + Factory class for loading LLMs from the transformers library. Currently OPT and + Llama are supported + """ + + @staticmethod + def opt_model_from_pretrained(model_path: str) -> torch.nn.Module: + """ + Load a pretrained OPT model from the specified hugging face path + + :param model_path: hugging face or local path to model + :return: loaded pretrained model + """ + + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + + model = OPTForCausalLM.from_pretrained(model_path, torch_dtype="auto") + model.eval() + model.seqlen = model.config.max_position_embeddings + return model + + @staticmethod + def llama_model_from_pretrained(model_path: str) -> torch.nn.Module: + """ + Load a pretrained Llama model from the specified hugging face path + + :param model_path: hugging face path to model + :return: loaded pretrained model + """ + model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype="auto") + model.eval() + model.seqlen = model.config.max_position_embeddings + return model + + def get_shared_tokenizer_src(student: Module, teacher: Optional[Module]) -> str: """ Get a tokenizer source used for both student and teacher, assuming