Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TorchFX] INT8 Weights Compression Support #2891

Merged
merged 92 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from 89 commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
297fdb4
weights compression init
anzr299 Aug 14, 2024
534e294
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Aug 16, 2024
06ca5a3
compression complete
anzr299 Aug 16, 2024
b4b2603
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Aug 19, 2024
c770d2c
Modify graph builder to include support for embedding op
anzr299 Aug 19, 2024
70b00f9
modify function to set new node meta for new module insertion to fx g…
anzr299 Aug 19, 2024
c7fa7f2
Add weights compression support for torch fx
anzr299 Aug 19, 2024
667b8a5
Add test for torch fx weights compression
anzr299 Aug 19, 2024
dca2374
reorder comments
anzr299 Aug 19, 2024
6f693c9
variable names fix
anzr299 Aug 19, 2024
159a615
Fix messages, use transformation for updating weight
anzr299 Aug 19, 2024
7a896d6
Minor mypy fix
anzr299 Aug 19, 2024
0de1d9b
fix set_weight
anzr299 Aug 19, 2024
f9e5d7c
Update torch_fx_backend.py
anzr299 Aug 20, 2024
443dce7
Add embedding metatype for torch fx as a subtype
anzr299 Aug 20, 2024
03d16f8
replace embedding metatype with torch fx subtype in torch fx graph bu…
anzr299 Aug 20, 2024
5226934
1. Adjust the torch fx weights compression backend to use fx embeddin…
anzr299 Aug 20, 2024
3cdb7b3
Update test for weight compression. Include test to see if
anzr299 Aug 20, 2024
28f7053
Fix FX metatype mapping
anzr299 Aug 20, 2024
cb0bf6b
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Aug 20, 2024
8b3c6e2
Add metatypes registry for torch fx specific embedding metatype and c…
anzr299 Aug 20, 2024
79ec939
Add copyright to new torch fx operator_metatypes file
anzr299 Aug 20, 2024
7accaf2
Add weights compression graph test
anzr299 Aug 26, 2024
5b11455
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Aug 26, 2024
1cb55c2
Merge branch 'develop' of https://github.com/anzr299/nncf into fx_com…
anzr299 Aug 26, 2024
71c50ff
pre-commit fix
anzr299 Aug 26, 2024
9f68831
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Aug 28, 2024
2cb0a41
Handle Lora correction in torch fx weights compression
anzr299 Aug 28, 2024
a9c3d57
Add graph test for compressed models in test_models
anzr299 Aug 28, 2024
0172ad1
pre commit fix
anzr299 Aug 28, 2024
f590200
1. Moved Embedding FX metatype from `experimental/torch/fx` to torch …
anzr299 Aug 29, 2024
0c7be62
shared weights support in torch fx graph builder and constant update …
anzr299 Aug 29, 2024
0a1157d
Update tests for more description
anzr299 Aug 29, 2024
0eff5cb
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Aug 30, 2024
c7b9093
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Aug 30, 2024
93ecc4e
add torch fx in supported backends
anzr299 Aug 30, 2024
b6ad458
Remove Compressed reference graphs
anzr299 Aug 30, 2024
e7097bd
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Aug 30, 2024
64b9ba7
add test for shared weights
anzr299 Sep 2, 2024
2665666
Merge branch 'openvinotoolkit:develop' into develop
anzr299 Sep 2, 2024
fb74267
Merge branch 'develop' of https://github.com/anzr299/nncf into fx_com…
anzr299 Sep 2, 2024
287cb2c
pre-commit fix
anzr299 Sep 2, 2024
449f767
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 3, 2024
c79dfc2
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 4, 2024
a10cb68
Add test for shared node decompressor call
anzr299 Sep 4, 2024
1c144a5
update backend supported in docs
anzr299 Sep 4, 2024
c5291b7
pre-commit fix
anzr299 Sep 4, 2024
174fb32
remove todo
anzr299 Sep 4, 2024
45a5274
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 10, 2024
b46d00e
add get_dtype and get_shape methods to torch fx weights compression b…
anzr299 Sep 10, 2024
32f5098
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 12, 2024
1819241
get the updated constant name from graph
anzr299 Sep 16, 2024
8a6b6d5
updated constant name from graph
anzr299 Sep 16, 2024
502c6c3
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 16, 2024
3503674
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 17, 2024
71901c5
update shared constants transformation
anzr299 Sep 20, 2024
bd5ff1f
pre commit fix
anzr299 Sep 20, 2024
b6a29ab
update docs
anzr299 Sep 20, 2024
7dd9782
refactor get weight name and port ids
anzr299 Sep 20, 2024
bbfeff0
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 20, 2024
48848be
update docs from X to Torch FX
anzr299 Sep 20, 2024
20544fd
fix shared weights attribute
anzr299 Sep 20, 2024
60ef615
Merge branch 'fx_compress_weights' of https://github.com/anzr299/nncf…
anzr299 Sep 20, 2024
fb89a4d
Fix Suggestions
anzr299 Sep 20, 2024
002758b
pre commit fix
anzr299 Sep 20, 2024
fe4d390
update is_shared attribute
anzr299 Sep 20, 2024
2ca11f8
Add tests for cosntant update transformation
anzr299 Sep 20, 2024
2be2487
pre commit fix
anzr299 Sep 20, 2024
fc543c9
Add test for edge shape
anzr299 Sep 20, 2024
02861e9
make decompressor name more readible
anzr299 Sep 20, 2024
33afddb
fix model_devices and precision test
anzr299 Sep 20, 2024
15bfeb0
Update is_shared attribute using a one liner
anzr299 Sep 20, 2024
04ed994
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 23, 2024
7683b5d
add test for nncf node is_shared attribute before applying transforma…
anzr299 Sep 23, 2024
fa56e7e
Change code to include _capture_model function for torch FX graph cap…
anzr299 Sep 23, 2024
fd9498a
pre-commit fix
anzr299 Sep 23, 2024
782b509
Fix is_shared attribute test
anzr299 Sep 23, 2024
48d050b
pre- commit fix
anzr299 Sep 23, 2024
3477d7c
add reference for checking shared constant unification transformation
anzr299 Sep 23, 2024
cbc2106
Add synthetic model with embedding to test models and include create …
anzr299 Sep 23, 2024
229517c
add reference graphs
anzr299 Sep 23, 2024
fde56b7
Include assert in shared attribute test
anzr299 Sep 24, 2024
30ff3d2
Fix reference graphs structure
anzr299 Sep 24, 2024
f26a7a0
pre-commit fix
anzr299 Sep 24, 2024
1d0a866
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 24, 2024
49d3dec
Change FXEmbedding metatype to PTAtenEmbeddingMetatype
anzr299 Sep 24, 2024
2e7e639
Move shared constants unification transformation to `apply_quantizati…
anzr299 Sep 24, 2024
817c233
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 25, 2024
26a4ff4
Corrections, comments and refactoring
anzr299 Sep 26, 2024
065bacb
Add seperate error message for dataset attribute
anzr299 Sep 26, 2024
3942d45
fix comments
anzr299 Sep 26, 2024
14096b7
Merge branch 'openvinotoolkit:develop' into fx_compress_weights
anzr299 Sep 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Weights Compression

[OpenVINO](https://github.com/openvinotoolkit/openvino) is the preferred backend to run Weights Compression with, and PyTorch is also supported.
[OpenVINO](https://github.com/openvinotoolkit/openvino) is the preferred backend to run Weights Compression with. PyTorch and Torch FX are also supported.

### The algorithm description

Expand Down Expand Up @@ -800,7 +800,7 @@ Accuracy/footprint trade-off for `microsoft/Phi-3-mini-4k-instruct`:

### Limitations

- The algorithm is supported for OpenVINO and PyTorch models.
- The algorithm is supported for OpenVINO, PyTorch and Torch FX models.
- The compression applies in-place.
- The compressed model is not trainable.
- INT4_SYM, INT4_ASYM, NF4 and E2M1 modes, grouped quantization and mixed precision selection is available for OpenVINO backend only.
Expand Down
30 changes: 24 additions & 6 deletions nncf/experimental/torch/fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import Counter
from typing import Tuple

import torch.fx
Expand Down Expand Up @@ -64,6 +65,22 @@ def _get_layer_attributes(
)
return None

def _map_fx_unique_metatypes(node: torch.fx.Node, metatype: om.OperatorMetatype) -> om.OperatorMetatype:
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
"""
Attempts to retrieve correct subtype for the given node.

:param node: Given node.
:param metatype: Given node metatype.
:param model: Target GraphModule instance.
:return: Correct FX metatype of the given node if it is exist or the original node metatype otherwise.
"""
if metatype in [om.PTEmbeddingMetatype]:
weight_node = node.args[0]
if weight_node.op == "get_attr":
return om.PTAtenEmbeddingMetatype

return metatype

@staticmethod
def _get_node_type_and_metatype(
node: torch.fx.Node, model: torch.fx.GraphModule
Expand Down Expand Up @@ -115,16 +132,18 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
:param model: torch fx GraphModule.
:return: NNCFGraph.
"""

nncf_graph = PTNNCFGraph()

const_targets_counter = Counter([node.target for node in model.graph.nodes if node.op == "get_attr"])
for source_node in model.graph.nodes:
node_type, node_metatype = GraphConverter._get_node_type_and_metatype(source_node, model)
node_metatype = GraphConverter._map_fx_unique_metatypes(source_node, node_metatype)
is_shared_node = source_node.op in ("get_attr",) and (
const_targets_counter[source_node.target] > 1 or len(source_node.users) > 1
)

nncf_graph.add_nncf_node(
node_name=source_node.name,
node_type=node_type,
node_metatype=node_metatype,
node_name=source_node.name, node_type=node_type, node_metatype=node_metatype, is_shared=is_shared_node
)

for source_node in model.graph.nodes:
Expand All @@ -134,7 +153,6 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
input_port_id, output_port_id, tensor_shape = GraphConverter.get_edge_params(
model, source_node, source_nncf_node, dist_node, idx
)

nncf_graph.add_edge_between_nncf_nodes(
source_nncf_node.node_id,
dist_node_id,
Expand All @@ -160,7 +178,7 @@ def get_edge_params(
:param source_node: Source node in format of torch.fx.Node.
:param source_nncf_node: Source node in format of NNCFNode.
:param dist_node: Distance node in format of torch.fx.Node.
:param output_idx: Output indes of the source_node.
:param output_idx: Output index of the source_node.
:return: Tuple of edge parameters: edge input port id, edge output port id and
edge tensor shape.
"""
Expand Down
50 changes: 49 additions & 1 deletion nncf/experimental/torch/fx/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,16 @@
from nncf.data import Dataset
from nncf.experimental.torch.fx.transformations import apply_quantization_transformations
from nncf.experimental.torch.fx.transformations import revert_quantization_transformations
from nncf.experimental.torch.fx.transformations import shared_constants_unification_transformation
from nncf.parameters import CompressWeightsMode
from nncf.parameters import ModelType
from nncf.parameters import QuantizationMode
from nncf.parameters import SensitivityMetric
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization
from nncf.quantization.algorithms.weight_compression.algorithm import WeightCompression
from nncf.scopes import IgnoredScope

DEFAULT_RANGE_TYPE = "mean_min_max"
Expand All @@ -49,7 +54,7 @@ def quantize_impl(
model_type: Optional[ModelType] = None,
ignored_scope: Optional[IgnoredScope] = None,
advanced_parameters: Optional[AdvancedQuantizationParameters] = None,
) -> torch.nn.Module:
) -> torch.fx.GraphModule:
"""
Implementation of the `quantize()` method for the Torch FX backend.
"""
Expand Down Expand Up @@ -103,3 +108,46 @@ def quantize_impl(
quantized_model = _disallow_eval_train(quantized_model)

return quantized_model


def compress_weights_impl(
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
model: torch.fx.GraphModule,
dataset: Dataset,
mode: CompressWeightsMode,
ratio: float,
group_size: int,
ignored_scope: IgnoredScope,
all_layers: bool,
sensitivity_metric: SensitivityMetric,
awq: bool,
subset_size: int,
scale_estimation: bool,
gptq: bool,
lora_correction: bool,
advanced_parameters: Optional[AdvancedCompressionParameters] = None,
) -> torch.fx.GraphModule:
"""
Implementation of the `compress_weights()` method for the Torch Fx backend.
"""

compression_algorithm = WeightCompression(
mode,
ratio,
group_size,
ignored_scope,
all_layers,
sensitivity_metric,
awq,
subset_size,
scale_estimation,
gptq,
lora_correction,
advanced_parameters,
)
shared_constants_unification_transformation(model)
graph = NNCFGraphFactory.create(model)
compressed_model = compression_algorithm.apply(model, graph, dataset=dataset)
compressed_model = GraphModule(compressed_model, compressed_model.graph)
compressed_model = _disallow_eval_train(compressed_model)

return compressed_model
61 changes: 47 additions & 14 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
TransformationFNType = Callable[[torch.fx.GraphModule], None]


def _set_new_node_meta(new_node: torch.fx.Node, prev_node: torch.fx.Node, target_module: torch.nn.Module):
def _set_new_node_meta(
new_node: torch.fx.Node, prev_node: torch.fx.Node, target_module: torch.nn.Module, model: torch.fx.GraphModule
):
"""
Sets correct meta \"val\" value to the new node.

Expand All @@ -37,7 +39,11 @@ def _set_new_node_meta(new_node: torch.fx.Node, prev_node: torch.fx.Node, target
New node expected to have only one input node.
:param target_module: Module which is being called by the new node.
"""
val = prev_node.meta["val"]
val = (
prev_node.meta["val"]
if prev_node.op not in ["get_attr"]
else get_tensor_constant_from_node(prev_node, model).data
)
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
val = val if isinstance(val, tuple) else (val,)
retval = []
for t in val:
Expand Down Expand Up @@ -71,16 +77,16 @@ def module_insertion_transformation(model: torch.fx.GraphModule):
target_node = get_graph_node_by_name(graph, target_point.target_node_name)

if target_point.target_type == TargetType.OPERATOR_POST_HOOK:
_set_new_node_meta(new_node, target_node, module_to_insert)
_set_new_node_meta(new_node, target_node, module_to_insert, model)
with graph.inserting_after(target_node):
for user in target_node.users:
for user in list(target_node.users):
if user is new_node:
continue
user.replace_input_with(target_node, new_node)

else:
prev_node = target_node.args[target_point.input_port_id]
_set_new_node_meta(new_node, prev_node, module_to_insert)
_set_new_node_meta(new_node, prev_node, module_to_insert, model)
target_node.replace_input_with(prev_node, new_node)

return module_insertion_transformation
Expand Down Expand Up @@ -136,17 +142,43 @@ def bias_update_transformation(model: torch.fx.GraphModule):
return bias_update_transformation


def constant_update_transformation_builder(node: NNCFNode, value: torch.Tensor) -> TransformationFNType:
def shared_constants_unification_transformation(model: torch.fx.GraphModule):
"""
checks fx graph for shared constants and eliminates redundant
shared constant while keeping only the first instance of the constant node.
This unification transformation is cruicial since the current algorithms(min_max, solver, BC, etc.)
for torch fx do not utilize the is_shared attribute of nodes for shared constants.

:param model: Target Torch FX GraphModule
:return: Transformation which attaches shared constants to nodes and removes redundant constants.
"""
prev_targets = {}

for source_node in model.graph.nodes:
dist_node = list(source_node.users)
if source_node.target in prev_targets and source_node.op in ("get_attr",):
dist_node[0].replace_input_with(source_node, prev_targets[source_node.target])
else:
prev_targets[source_node.target] = source_node

model.graph.eliminate_dead_code()
model.recompile()


def constant_update_transformation_builder(
node: NNCFNode, value: torch.Tensor, input_port_id: int = 1
) -> TransformationFNType:
"""
Return transformation which updates constant of the given node to the given value.

:param node: Node which requires bias constant update.
:param value: New value to use as the node constant.
:param input_port_id: Port Id of the constant.
:return: Transformation which updates constant of the given node to the given value.
"""

def constant_update_transformation(model: torch.fx.GraphModule):
constant_update_fn(model, get_graph_node_by_name(model.graph, node.node_name), value, input_port_id=1)
constant_update_fn(model, get_graph_node_by_name(model.graph, node.node_name), value, input_port_id)

return constant_update_transformation

Expand All @@ -161,9 +193,6 @@ def constant_update_fn(model: torch.fx.GraphModule, node: torch.fx.Node, value:
:param input_port_id: Target constant input port id.
"""
graph = model.graph
with graph.inserting_before(node):
new_constant = create_getattr_from_value(model, graph, node.name + "_updated_constant", value)

args = list(node.args)
# A bias node suppose to have constant on the second input port.
if args[input_port_id].op != "get_attr":
Expand All @@ -174,11 +203,14 @@ def constant_update_fn(model: torch.fx.GraphModule, node: torch.fx.Node, value:

# Update metadata of the new constant node.
previous_const = args[input_port_id]
new_constant.meta = copy(previous_const.meta)
new_constant.meta["val"] = value
consumer_nodes = list(previous_const.users)
# This list of consumer nodes will always be topologically sorted
# To ensure the updated node has the right order,
# we insert constant node before the node placed at the highest order in topological order.
with graph.inserting_before(consumer_nodes[0]):
new_constant = create_getattr_from_value(model, graph, node.name + "_updated_constant", value)

args[input_port_id] = new_constant
node.args = tuple(args)
previous_const.replace_all_uses_with(new_constant, propagate_meta=True)
graph.eliminate_dead_code()


Expand Down Expand Up @@ -509,6 +541,7 @@ def apply_quantization_transformations(model: torch.fx.GraphModule) -> None:
fuse_conv_bn(model)
separate_conv_and_bias(model)
separate_linear_and_bias(model)
shared_constants_unification_transformation(model)


def revert_quantization_transformations(model: torch.fx.GraphModule) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(

@property
def available_backends(self) -> List[BackendType]:
return [BackendType.OPENVINO, BackendType.TORCH]
return [BackendType.OPENVINO, BackendType.TORCH, BackendType.TORCH_FX]

def _set_backend_entity(self, model: TModel) -> None:
"""
Expand All @@ -152,6 +152,10 @@ def _set_backend_entity(self, model: TModel) -> None:
from nncf.quantization.algorithms.weight_compression.torch_backend import PTWeightCompressionAlgoBackend

self._backend_entity = PTWeightCompressionAlgoBackend()
elif model_backend == BackendType.TORCH_FX:
from nncf.quantization.algorithms.weight_compression.torch_fx_backend import FXWeightCompressionAlgoBackend

self._backend_entity = FXWeightCompressionAlgoBackend()
else:
raise nncf.UnsupportedBackendError(
"Cannot return backend-specific entity because {} is not supported!".format(model_backend.value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class PTWeightCompressionAlgoBackend(WeightCompressionAlgoBackend):
TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK,
}
MATMUL_METATYPES = [om.PTLinearMetatype, om.PTMatMulMetatype, om.PTAddmmMetatype]
EMBEDDING_METATYPES = [om.PTEmbeddingMetatype]
EMBEDDING_METATYPES = [om.PTEmbeddingMetatype, om.PTAtenEmbeddingMetatype]
CONVOLUTION_METATYPES = [
om.PTConv1dMetatype,
om.PTConv2dMetatype,
Expand Down
Loading
Loading