Skip to content

Commit

Permalink
UX pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins committed Apr 26, 2024
1 parent ca91c4f commit c894305
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/sparseml/transformers/compression/sparsity_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def infer_sparsity_structure() -> str:
return sparsity_structure

@staticmethod
def infer_config_from_model(
def from_pretrained(
model: Module,
state_dict: Optional[Dict[str, Tensor]] = None,
compress: bool = False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def save_pretrained_wrapper(
"calculation of compression statistics set "
"skip_compression_stats=True"
)
sparsity_config = SparsityConfigMetadata.infer_config_from_model(
sparsity_config = SparsityConfigMetadata.from_pretrained(
model, state_dict=state_dict, compress=save_compressed
)

Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/transformers/sparsification/sparse_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)
from transformers.file_utils import WEIGHTS_NAME

from compressed_tensors.compressors import infer_compressor_from_model_config
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.quantization import (
QuantizationConfig,
apply_quantization_config,
Expand Down Expand Up @@ -108,7 +108,7 @@ def skip(*args, **kwargs):
)

# determine compression format, if any, from the model config
compressor = infer_compressor_from_model_config(pretrained_model_name_or_path)
compressor = ModelCompressor.from_pretrained(pretrained_model_name_or_path)
quantization_config = QuantizationConfig.from_model_config(
pretrained_model_name_or_path
)
Expand Down

0 comments on commit c894305

Please sign in to comment.