Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TorchFX] INT8 Weights Compression Support #2891

Merged
merged 92 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
297fdb4
weights compression init
anzr299 Aug 14, 2024
534e294
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Aug 16, 2024
06ca5a3
compression complete
anzr299 Aug 16, 2024
b4b2603
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Aug 19, 2024
c770d2c
Modify graph builder to include support for embedding op
anzr299 Aug 19, 2024
70b00f9
modify function to set new node meta for new module insertion to fx g…
anzr299 Aug 19, 2024
c7fa7f2
Add weights compression support for torch fx
anzr299 Aug 19, 2024
667b8a5
Add test for torch fx weights compression
anzr299 Aug 19, 2024
dca2374
reorder comments
anzr299 Aug 19, 2024
6f693c9
variable names fix
anzr299 Aug 19, 2024
159a615
Fix messages, use transformation for updating weight
anzr299 Aug 19, 2024
7a896d6
Minor mypy fix
anzr299 Aug 19, 2024
0de1d9b
fix set_weight
anzr299 Aug 19, 2024
f9e5d7c
Update torch_fx_backend.py
anzr299 Aug 20, 2024
443dce7
Add embedding metatype for torch fx as a subtype
anzr299 Aug 20, 2024
03d16f8
replace embedding metatype with torch fx subtype in torch fx graph bu…
anzr299 Aug 20, 2024
5226934
1. Adjust the torch fx weights compression backend to use fx embeddin…
anzr299 Aug 20, 2024
3cdb7b3
Update test for weight compression. Include test to see if
anzr299 Aug 20, 2024
28f7053
Fix FX metatype mapping
anzr299 Aug 20, 2024
cb0bf6b
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Aug 20, 2024
8b3c6e2
Add metatypes registry for torch fx specific embedding metatype and c…
anzr299 Aug 20, 2024
79ec939
Add copyright to new torch fx operator_metatypes file
anzr299 Aug 20, 2024
7accaf2
Add weights compression graph test
anzr299 Aug 26, 2024
5b11455
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Aug 26, 2024
1cb55c2
Merge branch 'develop' of https://github.com/anzr299/nncf into fx_com…
anzr299 Aug 26, 2024
71c50ff
pre-commit fix
anzr299 Aug 26, 2024
9f68831
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Aug 28, 2024
2cb0a41
Handle Lora correction in torch fx weights compression
anzr299 Aug 28, 2024
a9c3d57
Add graph test for compressed models in test_models
anzr299 Aug 28, 2024
0172ad1
pre commit fix
anzr299 Aug 28, 2024
f590200
1. Moved Embedding FX metatype from `experimental/torch/fx` to torch …
anzr299 Aug 29, 2024
0c7be62
shared weights support in torch fx graph builder and constant update …
anzr299 Aug 29, 2024
0a1157d
Update tests for more description
anzr299 Aug 29, 2024
0eff5cb
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Aug 30, 2024
c7b9093
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Aug 30, 2024
93ecc4e
add torch fx in supported backends
anzr299 Aug 30, 2024
b6ad458
Remove Compressed reference graphs
anzr299 Aug 30, 2024
e7097bd
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Aug 30, 2024
64b9ba7
add test for shared weights
anzr299 Sep 2, 2024
2665666
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Sep 2, 2024
fb74267
Merge branch 'develop' of https://github.com/anzr299/nncf into fx_com…
anzr299 Sep 2, 2024
287cb2c
pre-commit fix
anzr299 Sep 2, 2024
449f767
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 3, 2024
c79dfc2
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 4, 2024
a10cb68
Add test for shared node decompressor call
anzr299 Sep 4, 2024
1c144a5
update backend supported in docs
anzr299 Sep 4, 2024
c5291b7
pre-commit fix
anzr299 Sep 4, 2024
174fb32
remove todo
anzr299 Sep 4, 2024
45a5274
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 10, 2024
b46d00e
add get_dtype and get_shape methods to torch fx weights compression b…
anzr299 Sep 10, 2024
32f5098
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 12, 2024
1819241
get the updated constant name from graph
anzr299 Sep 16, 2024
8a6b6d5
updated constant name from graph
anzr299 Sep 16, 2024
502c6c3
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 16, 2024
3503674
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 17, 2024
71901c5
update shared constants transformation
anzr299 Sep 20, 2024
bd5ff1f
pre commit fix
anzr299 Sep 20, 2024
b6a29ab
update docs
anzr299 Sep 20, 2024
7dd9782
refactor get weight name and port ids
anzr299 Sep 20, 2024
bbfeff0
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 20, 2024
48848be
update docs from X to Torch FX
anzr299 Sep 20, 2024
20544fd
fix shared weights attribute
anzr299 Sep 20, 2024
60ef615
Merge branch 'fx_compress_weights' of https://github.com/anzr299/nncf…
anzr299 Sep 20, 2024
fb89a4d
Fix Suggestions
anzr299 Sep 20, 2024
002758b
pre commit fix
anzr299 Sep 20, 2024
fe4d390
update is_shared attribute
anzr299 Sep 20, 2024
2ca11f8
Add tests for cosntant update transformation
anzr299 Sep 20, 2024
2be2487
pre commit fix
anzr299 Sep 20, 2024
fc543c9
Add test for edge shape
anzr299 Sep 20, 2024
02861e9
make decompressor name more readible
anzr299 Sep 20, 2024
33afddb
fix model_devices and precision test
anzr299 Sep 20, 2024
15bfeb0
Update is_shared attribute using a one liner
anzr299 Sep 20, 2024
04ed994
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 23, 2024
7683b5d
add test for nncf node is_shared attribute before applying transforma…
anzr299 Sep 23, 2024
fa56e7e
Change code to include _capture_model function for torch FX graph cap…
anzr299 Sep 23, 2024
fd9498a
pre-commit fix
anzr299 Sep 23, 2024
782b509
Fix is_shared attribute test
anzr299 Sep 23, 2024
48d050b
pre- commit fix
anzr299 Sep 23, 2024
3477d7c
add reference for checking shared constant unification transformation
anzr299 Sep 23, 2024
cbc2106
Add synthetic model with embedding to test models and include create …
anzr299 Sep 23, 2024
229517c
add reference graphs
anzr299 Sep 23, 2024
fde56b7
Include assert in shared attribute test
anzr299 Sep 24, 2024
30ff3d2
Fix reference graphs structure
anzr299 Sep 24, 2024
f26a7a0
pre-commit fix
anzr299 Sep 24, 2024
1d0a866
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 24, 2024
49d3dec
Change FXEmbedding metatype to PTAtenEmbeddingMetatype
anzr299 Sep 24, 2024
2e7e639
Move shared constants unification transformation to `apply_quantizati…
anzr299 Sep 24, 2024
817c233
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 25, 2024
26a4ff4
Corrections, comments and refactoring
anzr299 Sep 26, 2024
065bacb
Add seperate error message for dataset attribute
anzr299 Sep 26, 2024
3942d45
fix comments
anzr299 Sep 26, 2024
14096b7
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions nncf/experimental/torch/fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,52 @@
from nncf.common.graph.layer_attributes import Dtype
from nncf.common.graph.operator_metatypes import UnknownMetatype
from nncf.common.logging import nncf_logger
from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.operator_metatypes import PT_OPERATOR_METATYPES
from nncf.torch.graph.operator_metatypes import PTOperatorMetatype


class GraphConverter:
"""
Builds the NNCFGraph from an torch.fx.GraphModule instance.
"""

def _get_node_subtype(
node: torch.fx.Node, metatype: om.OperatorMetatype, model: torch.fx.GraphModule
) -> om.OperatorMetatype:
"""
Attempts to retrieve correct subtype for the given node.

:param node: Given node.
:param metatype: Given node metatype.
:param model: Target GraphModule instance.
:return: Correct subtype of the given node if it is exist or the original node metatype otherwise.
"""
if metatype in [om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype]:
if len(node.args) < 7:
return metatype
constant_node = node.args[1]
if constant_node.op != "get_attr":
return metatype
weight = get_tensor_constant_from_node(constant_node, model)
out_channels = weight.shape[0]
groups = node.args[6]
if out_channels > 1 and out_channels == groups:
return {
om.PTConv1dMetatype: om.PTDepthwiseConv1dSubtype,
om.PTConv2dMetatype: om.PTDepthwiseConv2dSubtype,
om.PTConv3dMetatype: om.PTDepthwiseConv3dSubtype,
}[metatype]
elif metatype in [om.PTEmbeddingMetatype]:
weight_node = node.args[0]
if weight_node.op == "get_attr":
return om.FXEmbeddingMetatype

return metatype

@staticmethod
def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMetatype]:
def _get_node_type_and_metatype(node: torch.fx.Node, model: torch.fx.GraphModule) -> Tuple[str, om.OperatorMetatype]:
"""
Retrieves node's type and metatype.

Expand All @@ -53,6 +88,7 @@ def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMe
# TODO(dlyakhov): get correct nodes types from this nodes as well
node_type = str(node.target)
node_metatype = PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type)
node_metatype = GraphConverter._get_node_subtype(node, node_metatype, model)
else:
node_type = node.op
node_metatype = UnknownMetatype
Expand All @@ -74,8 +110,7 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
nncf_graph = PTNNCFGraph()

for source_node in model.graph.nodes:
node_type, node_metatype = GraphConverter._get_node_type_and_metatype(source_node)

node_type, node_metatype = GraphConverter._get_node_type_and_metatype(source_node, model)
nncf_graph.add_nncf_node(
node_name=source_node.name,
node_type=node_type,
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -89,7 +124,6 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
input_port_id, output_port_id, tensor_shape = GraphConverter.get_edge_params(
model, source_node, source_nncf_node, dist_node, idx
)

nncf_graph.add_edge_between_nncf_nodes(
source_nncf_node.node_id,
dist_node_id,
Expand All @@ -115,7 +149,7 @@ def get_edge_params(
:param source_node: Source node in format of torch.fx.Node.
:param source_nncf_node: Source node in format of NNCFNode.
:param dist_node: Distance node in format of torch.fx.Node.
:param output_idx: Output indes of the source_node.
:param output_idx: Output index of the source_node.
:return: Tuple of edge parameters: edge input port id, edge output port id and
edge tensor shape.
"""
Expand Down
44 changes: 44 additions & 0 deletions nncf/experimental/torch/fx/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@
from nncf.data import Dataset
from nncf.experimental.torch.fx.transformations import apply_quantization_transformations
from nncf.experimental.torch.fx.transformations import revert_quantization_transformations
from nncf.parameters import CompressWeightsMode
from nncf.parameters import ModelType
from nncf.parameters import QuantizationMode
from nncf.parameters import SensitivityMetric
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization
from nncf.quantization.algorithms.weight_compression.algorithm import WeightCompression
from nncf.scopes import IgnoredScope

DEFAULT_RANGE_TYPE = "mean_min_max"
Expand Down Expand Up @@ -105,3 +109,43 @@ def quantize_impl(
quantized_model = _disallow_eval_train(quantized_model)

return quantized_model


def compress_weights_impl(
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
model: torch.fx.GraphModule,
dataset: Dataset,
mode: CompressWeightsMode,
ratio: float,
group_size: int,
ignored_scope: IgnoredScope,
all_layers: bool,
sensitivity_metric: SensitivityMetric,
awq: bool,
subset_size: int,
scale_estimation: bool,
gptq: bool,
advanced_parameters: Optional[AdvancedCompressionParameters] = None,
) -> torch.fx.GraphModule:
"""
Implementation of the `compress_weights()` method for the Torch Fx backend.
"""

compression_algorithm = WeightCompression(
mode,
ratio,
group_size,
ignored_scope,
all_layers,
sensitivity_metric,
awq,
subset_size,
scale_estimation,
gptq,
advanced_parameters,
)
graph = NNCFGraphFactory.create(model)
compressed_model = compression_algorithm.apply(model, graph, dataset=dataset)
compressed_model = GraphModule(compressed_model, compressed_model.graph)
compressed_model = _disallow_eval_train(compressed_model)

return compressed_model
21 changes: 14 additions & 7 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
TransformationFNType = Callable[[torch.fx.GraphModule], None]


def _set_new_node_meta(new_node: torch.fx.Node, prev_node: torch.fx.Node, target_module: torch.nn.Module):
def _set_new_node_meta(
new_node: torch.fx.Node, prev_node: torch.fx.Node, target_module: torch.nn.Module, model: torch.fx.GraphModule
):
"""
Sets correct meta \"val\" value to the new node.

Expand All @@ -36,7 +38,11 @@ def _set_new_node_meta(new_node: torch.fx.Node, prev_node: torch.fx.Node, target
New node expected to have only one input node.
:param target_module: Module which is being called by the new node.
"""
val = prev_node.meta["val"]
val = (
prev_node.meta["val"]
if prev_node.op not in ["get_attr"]
else get_tensor_constant_from_node(prev_node, model).data
)
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
val = val if isinstance(val, tuple) else (val,)
retval = []
for t in val:
Expand Down Expand Up @@ -70,16 +76,16 @@ def module_insertion_transformation(model: torch.fx.GraphModule):
target_node = get_graph_node_by_name(graph, target_point.target_node_name)

if target_point.target_type == TargetType.OPERATOR_POST_HOOK:
_set_new_node_meta(new_node, target_node, module_to_insert)
_set_new_node_meta(new_node, target_node, module_to_insert, model)
with graph.inserting_after(target_node):
for user in target_node.users:
for user in list(target_node.users):
if user is new_node:
continue
user.replace_input_with(target_node, new_node)

else:
prev_node = target_node.args[target_point.input_port_id]
_set_new_node_meta(new_node, prev_node, module_to_insert)
_set_new_node_meta(new_node, prev_node, module_to_insert, model)
target_node.replace_input_with(prev_node, new_node)

return module_insertion_transformation
Expand Down Expand Up @@ -131,17 +137,18 @@ def bias_update_transformation(model: torch.fx.GraphModule):
return bias_update_transformation


def constant_update_transformation_builder(node: NNCFNode, value: torch.Tensor) -> TransformationFNType:
def constant_update_transformation_builder(node: NNCFNode, value: torch.Tensor, input_port_id: int = 1) -> TransformationFNType:
"""
Return transformation which updates constant of the given node to the given value.

:param node: Node which requires bias constant update.
:param value: New value to use as the node constant.
:param input_port_id: Port Id of the constant.
:return: Transformation which updates constant of the given node to the given value.
"""

def constant_update_transformation(model: torch.fx.GraphModule):
constant_update_fn(model, get_graph_node_by_name(model.graph, node.node_name), value, input_port_id=1)
constant_update_fn(model, get_graph_node_by_name(model.graph, node.node_name), value, input_port_id)

return constant_update_transformation

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def _set_backend_entity(self, model: TModel) -> None:
from nncf.quantization.algorithms.weight_compression.torch_backend import PTWeightCompressionAlgoBackend

self._backend_entity = PTWeightCompressionAlgoBackend()
elif model_backend == BackendType.TORCH_FX:
from nncf.quantization.algorithms.weight_compression.torch_fx_backend import FXWeightCompressionAlgoBackend

self._backend_entity = FXWeightCompressionAlgoBackend()
else:
raise nncf.UnsupportedBackendError(
"Cannot return backend-specific entity because {} is not supported!".format(model_backend.value)
Expand Down
Loading
Loading