Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental][TorchFX] quantize_pt2e + X86Quantizer introduction #3121

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
from typing import Callable, List, Optional, TypeVar

from nncf import Dataset
from nncf.common.graph.graph import NNCFGraph
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
from nncf.experimental.common.quantization.algorithms.post_training.pipeline import experimental_create_ptq_pipeline
from nncf.experimental.common.quantization.algorithms.quantizer.base_quantizer import NNCFQuantizer
from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
from nncf.quantization.advanced_parameters import RangeEstimatorParameters
from nncf.quantization.algorithms.algorithm import Algorithm

TModel = TypeVar("TModel")
TPass = Callable[[TModel], TModel]


class ExperimentalPostTrainingQuantization(Algorithm):
"""
Implements Experimental Post-Training Quantization algorithm, which basically includes:
1) ChannelAlignment
2) MinMaxRangeInit
3) FastBiasCorrection or BiasCorrection
"""

def __init__(
self,
quantizer: NNCFQuantizer,
subset_size: int = 300,
fast_bias_correction: Optional[bool] = True,
smooth_quant: bool = False,
bias_correction_params: Optional[AdvancedBiasCorrectionParameters] = None,
smooth_quant_params: Optional[AdvancedSmoothQuantParameters] = None,
activations_range_estimator_params: Optional[RangeEstimatorParameters] = None,
weights_range_estimator_params: Optional[RangeEstimatorParameters] = None,
):
"""
:param quantizer: NNCFQuantizer to use in MiMaxRageInit algorithm.
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
:param subset_size: Size of a subset to calculate activations
statistics used for quantization.
:param fast_bias_correction: Setting this option to `False` enables a different
bias correction method which is more accurate, in general, and takes
more time but requires less memory. None disables the bias correction algorithm.
:param smooth_quant: Setting this option to `True` enables the SmoothQuant algorithm.
:param bias_correction_params: Contains advanced parameters for fine-tuning bias correction algorithm.
:param smooth_quant_params: Contains advanced alpha parameters for SmoothQuant algorithm.
:param activations_range_estimator_params: Contains parameters for estimating the range
of activations of the model.
:param weights_range_estimator_params: Contains parameters for estimating the range
of weights of the model.
"""
self._pipeline = experimental_create_ptq_pipeline(
quantizer=quantizer,
subset_size=subset_size,
fast_bias_correction=fast_bias_correction,
smooth_quant=smooth_quant,
bias_correction_params=bias_correction_params,
smooth_quant_params=smooth_quant_params,
activations_range_estimator_params=activations_range_estimator_params,
weights_range_estimator_params=weights_range_estimator_params,
)

@property
def available_backends(self) -> List[BackendType]:
backends = set(BackendType)
for algorithm in itertools.chain.from_iterable(self._pipeline.pipeline_steps):
backends = backends.intersection(algorithm.available_backends)
return list(backends)

def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer:
return self._pipeline.get_statistic_points_for_step(0, model, graph)

def apply(
self,
model: TModel,
graph: NNCFGraph,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
) -> TModel:
if dataset is None and len(self._pipeline.pipeline_steps) > 1:
raise ValueError(
"A dataset is required for the post-training quantization "
"algorithm to collect statistics for intermediate models."
)

step_index_to_statistics = None
if statistic_points:
step_index_to_statistics = {0: statistic_points}

return self._pipeline.run_from_step(model, dataset, graph, 0, step_index_to_statistics)
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, TypeVar

from nncf.experimental.common.quantization.algorithms.quantizer.base_quantizer import NNCFQuantizer
from nncf.experimental.common.quantization.algorithms.range_estimator.range_estimator import MinMaxRangeEstimator
from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
from nncf.quantization.advanced_parameters import RangeEstimatorParameters
from nncf.quantization.algorithms.bias_correction.algorithm import BIAS_CORRECTION_THRESHOLD
from nncf.quantization.algorithms.bias_correction.algorithm import BiasCorrection
from nncf.quantization.algorithms.fast_bias_correction.algorithm import FAST_BIAS_CORRECTION_THRESHOLD
from nncf.quantization.algorithms.fast_bias_correction.algorithm import FastBiasCorrection
from nncf.quantization.algorithms.pipeline import Pipeline
from nncf.quantization.algorithms.smooth_quant.algorithm import SmoothQuant

TModel = TypeVar("TModel")


def experimental_create_ptq_pipeline(
quantizer: NNCFQuantizer,
subset_size: int = 300,
fast_bias_correction: Optional[bool] = True,
smooth_quant: bool = False,
bias_correction_params: Optional[AdvancedBiasCorrectionParameters] = None,
smooth_quant_params: Optional[AdvancedSmoothQuantParameters] = None,
activations_range_estimator_params: Optional[RangeEstimatorParameters] = None,
weights_range_estimator_params: Optional[RangeEstimatorParameters] = None,
) -> Pipeline:
"""
Creates an experimental post-training quantization pipeline.

The experimental post-training quantization pipeline includes the following steps:
1) SmoothQuant
2) MinMaxRangeInit
3) FastBiasCorrection or BiasCorrection

:param quantizer: NNCFQuantizer to use in MiMaxRageInit algorithm.
:param subset_size: Size of a subset to calculate activations
statistics used for quantization.
:param fast_bias_correction: Setting this option to `False` enables a different
bias correction method which is more accurate, in general, and takes
more time but requires less memory. None disables the bias correction algorithm.
:param smooth_quant: Setting this option to `True` enables the SmoothQuant algorithm.
:param bias_correction_params: Contains advanced parameters for fine-tuning bias correction algorithm.
:param smooth_quant_params: Contains advanced alpha parameters for SmoothQuant algorithm.
:param activations_range_estimator_params: Contains parameters for estimating the range
of activations of the model.
:param weights_range_estimator_params: Contains parameters for estimating the range
of weights of the model.
:return: An experimental post-training quantization pipeline.
"""

# Build the post-training quantization pipeline.
pipeline_steps = []

if smooth_quant_params is None:
smooth_quant_params = AdvancedSmoothQuantParameters()

if smooth_quant and smooth_quant_params.convolution >= 0 or smooth_quant_params.matmul >= 0:
alpha_map = {"convolution": smooth_quant_params.convolution, "matmul": smooth_quant_params.matmul}
pipeline_steps.append([SmoothQuant(subset_size, False, alpha_map=alpha_map)])

# Add the `MinMaxQuantization` algorithm as the third step of the pipeline.
pipeline_steps.append(
[
MinMaxRangeEstimator(
quantizer=quantizer,
subset_size=subset_size,
inplace_statistics=False,
activations_range_estimator_params=activations_range_estimator_params,
weights_range_estimator_params=weights_range_estimator_params,
)
]
)

if fast_bias_correction is not None:
# Add the `FastBiasCorrection` or `BiasCorrection` as additional algorithm
# inside the third step of the pipeline. It is added after `MinMaxQuantization`
# algorithm.
if fast_bias_correction:
threshold = FAST_BIAS_CORRECTION_THRESHOLD
bias_correction_subset_size = subset_size
bias_correction_cls = FastBiasCorrection
else:
threshold = BIAS_CORRECTION_THRESHOLD
bias_correction_subset_size = max(int(subset_size * 0.2), 1)
bias_correction_cls = BiasCorrection

if bias_correction_params is None:
bias_correction_params = AdvancedBiasCorrectionParameters()

if bias_correction_params.threshold is not None:
threshold = bias_correction_params.threshold

pipeline_steps[-1].append(
bias_correction_cls(
bias_correction_subset_size,
threshold,
bias_correction_params.apply_for_all_nodes,
)
)

return Pipeline(pipeline_steps)
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod
from typing import TypeVar

from nncf.common.graph.graph import NNCFGraph
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup

TModel = TypeVar("TModel")


class NNCFQuantizer:
@abstractmethod
def get_quantization_setup(self, model: TModel, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup:
"""
Builds SingleConfigQuantizerSetup for the given model.

:param model: Backend-specific model, for which Quantization Target Points are being seek.
:param nncf_graph: NNCFGraph instance.
:return: SingleConfigQuantizerSetup for the given model.
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from collections import defaultdict
from copy import deepcopy
from typing import Dict, Tuple, Union

import torch
import torch.fx
from torch.ao.quantization.quantizer import Quantizer
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec
from torch.ao.quantization.quantizer.quantizer import QuantizationSpecBase
from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec

import nncf
from nncf.common.graph.graph import NNCFGraph
from nncf.common.quantization.quantizer_setup import ActivationQuantizationInsertionPoint
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizationPoint
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup
from nncf.common.quantization.quantizer_setup import WeightQuantizationInsertionPoint
from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
from nncf.common.quantization.structs import QuantizerConfig
from nncf.experimental.common.quantization.algorithms.quantizer.base_quantizer import NNCFQuantizer

EdgeOrNode = Union[Tuple[torch.fx.Node, torch.fx.Node]]


class NNCFFXQuantizer(NNCFQuantizer):
def __init__(self, quantizer: Quantizer):
self._quantizer = quantizer

def get_quantization_setup(self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup:
anotated_model = deepcopy(model)

self._quantizer.transform_for_annotation(anotated_model)
self._quantizer.annotate(anotated_model)
self._quantizer.validate(anotated_model)
return self.get_quantizer_config_from_anotated_model(anotated_model)

@staticmethod
def get_quantizer_config_from_anotated_model(anotated_model: torch.fx.GraphModule) -> SingleConfigQuantizerSetup:
edge_or_node_to_qspec = _get_edge_or_node_to_qspec(anotated_model)

q_map = defaultdict(list)
for edge, qspec in edge_or_node_to_qspec.items():
if not isinstance(edge, tuple):
continue
from_n, to_n = edge
q_map[from_n].append(to_n)

q_setup = SingleConfigQuantizerSetup()
for from_n, to_nodes in q_map.items():
to_n = to_nodes[0]
qspec = edge_or_node_to_qspec[(from_n, to_n)]
if qspec is None:
continue
if isinstance(qspec, QuantizationSpec):
if qspec.qscheme in [torch.per_channel_affine, torch.per_channel_symmetric]:
per_channel = True
elif qspec.qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
per_channel = False
else:
raise nncf.InternalError(f"Unknown qscheme: {qspec.qscheme}")
signed = qspec.dtype is torch.uint8
mode = (
QuantizationMode.SYMMETRIC
if qspec.qscheme in [torch.per_channel_symmetric, torch.per_tensor_symmetric]
else QuantizationMode.ASYMMETRIC
)
qconfig = QuantizerConfig(mode=mode, signedness_to_force=signed, per_channel=per_channel)
qps = []
# If input node is a constant and placed not at activations port (0)
if from_n.op == "get_attr" and to_n.args.index(from_n) != 0:
qip = WeightQuantizationInsertionPoint(to_n.name)
qp = SingleConfigQuantizationPoint(qip, qconfig, [x.name for x in to_nodes])
qps.append(qp)
else:
if len(from_n.users) == len(to_nodes):
qip = ActivationQuantizationInsertionPoint(from_n.name)
qp = SingleConfigQuantizationPoint(qip, qconfig, [x.name for x in to_nodes])
qps.append(qp)
else:
for to_n_ in to_nodes:
input_port_id = to_n_.args.index(from_n)
qip = ActivationQuantizationInsertionPoint(to_n_.name, input_port_id)
qp = SingleConfigQuantizationPoint(qip, qconfig, [to_n_.name])
qps.append(qp)

for qp in qps:
q_setup.add_independent_quantization_point(qp)

elif isinstance(qspec, SharedQuantizationSpec):
pass
else:
raise nncf.InternalError(f"Unknown torch.ao quantization spec: {qspec}")

return q_setup


def _get_edge_or_node_to_qspec(
model: torch.fx.GraphModule,
) -> Dict[EdgeOrNode, QuantizationSpecBase]:
"""
Get a map from EdgeOrNode to quantization spec based on annotations on the nodes.

:param model: torch.fx.GraphModule instance.
:return: A map from EdgeOrNode to quantization spec based on annotations on the nodes.
"""
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase] = {}
for n in model.graph.nodes:
if hasattr(n, "meta") and "quantization_annotation" in n.meta:
qa = n.meta["quantization_annotation"]
for input_to_n, qspec in qa.input_qspec_map.items():
input_edge = (input_to_n, n)
edge_or_node_to_qspec[input_edge] = qspec
if qa.output_qspec is not None:
output_node = n
qspec = qa.output_qspec
edge_or_node_to_qspec[output_node] = qspec
return edge_or_node_to_qspec
Loading