From c89430530c256b0ae0e10486964f5cc96dec7b4a Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Fri, 26 Apr 2024 15:16:18 +0000 Subject: [PATCH] UX pr comments --- src/sparseml/transformers/compression/sparsity_config.py | 2 +- .../transformers/sparsification/compressed_tensors_utils.py | 2 +- src/sparseml/transformers/sparsification/sparse_model.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sparseml/transformers/compression/sparsity_config.py b/src/sparseml/transformers/compression/sparsity_config.py index b04edf333c3..958ddc2b738 100644 --- a/src/sparseml/transformers/compression/sparsity_config.py +++ b/src/sparseml/transformers/compression/sparsity_config.py @@ -68,7 +68,7 @@ def infer_sparsity_structure() -> str: return sparsity_structure @staticmethod - def infer_config_from_model( + def from_pretrained( model: Module, state_dict: Optional[Dict[str, Tensor]] = None, compress: bool = False, diff --git a/src/sparseml/transformers/sparsification/compressed_tensors_utils.py b/src/sparseml/transformers/sparsification/compressed_tensors_utils.py index ee314bb92d3..b6852535a2c 100644 --- a/src/sparseml/transformers/sparsification/compressed_tensors_utils.py +++ b/src/sparseml/transformers/sparsification/compressed_tensors_utils.py @@ -140,7 +140,7 @@ def save_pretrained_wrapper( "calculation of compression statistics set " "skip_compression_stats=True" ) - sparsity_config = SparsityConfigMetadata.infer_config_from_model( + sparsity_config = SparsityConfigMetadata.from_pretrained( model, state_dict=state_dict, compress=save_compressed ) diff --git a/src/sparseml/transformers/sparsification/sparse_model.py b/src/sparseml/transformers/sparsification/sparse_model.py index 2182ee4c33f..d14ec60e9e8 100644 --- a/src/sparseml/transformers/sparsification/sparse_model.py +++ b/src/sparseml/transformers/sparsification/sparse_model.py @@ -30,7 +30,7 @@ ) from transformers.file_utils import WEIGHTS_NAME -from compressed_tensors.compressors import infer_compressor_from_model_config +from compressed_tensors.compressors import ModelCompressor from compressed_tensors.quantization import ( QuantizationConfig, apply_quantization_config, @@ -108,7 +108,7 @@ def skip(*args, **kwargs): ) # determine compression format, if any, from the model config - compressor = infer_compressor_from_model_config(pretrained_model_name_or_path) + compressor = ModelCompressor.from_pretrained(pretrained_model_name_or_path) quantization_config = QuantizationConfig.from_model_config( pretrained_model_name_or_path )