Skip to content

Commit

Permalink
Merge branch 'main' into add-python3.11-support
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli authored Oct 13, 2023
2 parents 64aa6ad + a2a7095 commit 85a90e3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
15 changes: 10 additions & 5 deletions src/sparseml/pytorch/sparsification/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,21 +357,26 @@ def _match_submodule_name_or_type(
submodule: Module, submodule_name: str, names_or_types: List[str]
) -> Optional[str]:
# match preferences:
# 1. match module type name
# 2. match the submodule prefix (longest first)
# 1. match the submodule prefix (longest first)
# 2. match module type name
submodule_match = ""
for name_or_type in names_or_types:
name_to_compare = submodule_name[:]
if name_to_compare.startswith("module."):
name_to_compare = name_to_compare[7:]
if name_or_type == submodule.__class__.__name__:
# type match, return type name
return name_or_type
if name_to_compare.startswith(name_or_type) and (
len(name_or_type) > len(submodule_match)
):
# match to most specific submodule name
submodule_match = name_or_type

# If didn't find prefix, try to match to match type
if not submodule_match:
for name_or_type in names_or_types:
if name_or_type == submodule.__class__.__name__:
# type match, return type name
return name_or_type

return submodule_match or None # return None if no match


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
QuantizationScheme,
)
from sparseml.pytorch.sparsification.quantization.quantize import (
_match_submodule_name_or_type,
is_qat_helper_module,
is_quantizable_module,
)
Expand Down Expand Up @@ -66,17 +67,16 @@ def _assert_observers_eq(observer_1, observer_2):
_assert_observers_eq(qconfig_1.weight, qconfig_2.weight)


def _test_quantized_module(base_model, modifier, module, name):
def _test_quantized_module(base_model, modifier, module, name, override_key):
# check quant scheme and configs are set
quantization_scheme = getattr(module, "quantization_scheme", None)
qconfig = getattr(module, "qconfig", None)
assert quantization_scheme is not None
assert qconfig is not None

# if module type is overwritten in by scheme_overrides, check scheme set correctly
module_type_name = module.__class__.__name__
if module_type_name in modifier.scheme_overrides:
expected_scheme = modifier.scheme_overrides[module_type_name]
if override_key is not None:
expected_scheme = modifier.scheme_overrides[override_key]
assert quantization_scheme == expected_scheme

is_quant_wrapper = isinstance(module, torch_quantization.QuantWrapper)
Expand Down Expand Up @@ -148,7 +148,12 @@ def _test_qat_applied(modifier, model):
_test_qat_wrapped_module(model, name)
elif is_quantizable:
# check each target module is quantized
_test_quantized_module(model, modifier, module, name)
override_key = _match_submodule_name_or_type(
module,
name,
list(modifier.scheme_overrides.keys()),
)
_test_quantized_module(model, modifier, module, name, override_key)
else:
# check all non-target modules are not quantized
assert not hasattr(module, "quantization_scheme")
Expand Down

0 comments on commit 85a90e3

Please sign in to comment.