From 45c19763d513876b913711c55c25020ba77a7451 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Thu, 5 Dec 2024 14:08:02 +0100 Subject: [PATCH 1/7] Add TOSA table as custom edge op Edge operators that are lowered to TOSA TABLEs are convereted to a custom edge IR table-op. Signed-off-by: Oscar Andersson Change-Id: I147008c30b9b46c7b8ae1a1c15bc540fea614a69 --- backends/arm/_passes/arm_pass_manager.py | 8 ++ backends/arm/_passes/insert_table_ops.py | 118 +++++++++++++++++++++++ backends/arm/operators/__init__.py | 1 + backends/arm/operators/op_exp.py | 54 ++--------- backends/arm/operators/op_log.py | 53 ++-------- backends/arm/operators/op_reciprocal.py | 57 ++--------- backends/arm/operators/op_rsqrt.py | 54 ++--------- backends/arm/operators/op_sigmoid.py | 55 ++--------- backends/arm/operators/op_table.py | 41 ++++++++ backends/arm/operators/op_tanh.py | 62 ++---------- 10 files changed, 212 insertions(+), 291 deletions(-) create mode 100644 backends/arm/_passes/insert_table_ops.py create mode 100644 backends/arm/operators/op_table.py diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index b4bb809b85..3e5d4e7438 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -33,6 +33,7 @@ FoldAndAnnotateQParamsPass, QuantizeFullArgument, ) +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import ( KeepDimsFalseToSqueezePass, ) @@ -94,10 +95,17 @@ def transform_to_backend_pipeline( exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.avg_pool2d.default, exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.exp.default, exir_ops.edge.aten.full.default, + exir_ops.edge.aten.log.default, + exir_ops.edge.aten.reciprocal.default, + exir_ops.edge.aten.rsqrt.default, + exir_ops.edge.aten.sigmoid.default, + exir_ops.edge.aten.tanh.default, ] ) ) + self.add_pass(InsertTableOpsPass(exported_program)) for spec in compile_spec: if spec.key == "permute_memory_format": memory_format = spec.value.decode() diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py new file mode 100644 index 0000000000..41f1c25924 --- /dev/null +++ b/backends/arm/_passes/insert_table_ops.py @@ -0,0 +1,118 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable + +import torch +from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm.tosa_quant_utils import QuantArgs +from executorch.exir import ExportedProgram + +from executorch.exir.dialects._ops import ops as exir_ops + +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx import GraphModule +from torch.library import impl, Library + +lib = Library("tosa", "DEF") +lib.define("_table(Tensor self) -> Tensor") + + +@impl(lib, "_table") +def _table_impl(*args, **kwargs): + return args[0] + + +class InsertTableOpsPass(ExportPass): + """ + For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these + edge ops with a tosa._table(input: Tensor, target_str: str) where target_str == str(node.target). + When loweringthe _table node target_str will be used to find the corresponding torch operator + which will be used to produce the table values in operators/op_table.py. + """ + + table_ops = { + exir_ops.edge.aten.exp.default: torch.exp, + exir_ops.edge.aten.log.default: torch.log, + exir_ops.edge.aten.reciprocal.default: torch.reciprocal, + exir_ops.edge.aten.rsqrt.default: torch.rsqrt, + exir_ops.edge.aten.sigmoid.default: torch.sigmoid, + exir_ops.edge.aten.tanh.default: torch.tanh, + } + + def __init__(self, exported_program: ExportedProgram): + super().__init__() + self.exported_program = exported_program + + def register_buffer(self, buffer_name: str, buffer: torch.Tensor) -> None: + """ + Add buffer to self.exported_program.state_dict + """ + self.exported_program.state_dict[buffer_name] = buffer + + def generate_table_values( + self, + torch_op: Callable[[torch.Tensor], torch.Tensor], + in_quantargs: QuantArgs, + out_quantargs: QuantArgs, + ) -> torch.Tensor: + def f(x: torch.Tensor) -> torch.Tensor: + x = in_quantargs.dequantize_value(x) + x = torch_op(x) + return out_quantargs.quantize_value(x) + + input_dtype = in_quantargs.dtype + steps = in_quantargs.qmax - in_quantargs.qmin + 1 + return f( + torch.linspace( + start=in_quantargs.qmin, + end=in_quantargs.qmax, + steps=steps, + # use torch.int64 to avoid overflow when dequantizing (subtracting zp). + # e.g. torch.tensor(-50, dtype=torch.int8) - 100 == torch.tensor(106, dtype=torch.int8) + dtype=torch.int64, + ) + ).to(dtype=input_dtype) + + def call(self, graph_module: GraphModule) -> PassResult: + modified = False + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in self.table_ops: + continue + input_qparams = node.meta["input_qparams"] + output_qparams = node.meta["output_qparams"] + if len(input_qparams) == 0 or len(output_qparams) == 0: + # We only want to replace the node if it's quantized + continue + # Create table node + with graph_module.graph.inserting_before(node): + table_node = create_node( + graph=graph_module.graph, + op_target=torch.ops.tosa._table, + args=(node.args[0],), + ) + assert len(input_qparams) == 1 + assert len(output_qparams) == 1 + # Generate table buffer + buffer = self.generate_table_values( + torch_op=self.table_ops[node.target], + in_quantargs=input_qparams[0], + out_quantargs=output_qparams[0], + ) + # Register buffer in self.exported_program.state_dict + self.register_buffer(buffer_name=table_node.name, buffer=buffer) + node.replace_all_uses_with(table_node) + graph_module.graph.erase_node(node) + table_node.meta["input_qparams"] = input_qparams + table_node.meta["output_qparams"] = output_qparams + modified = True + + if modified: + # retrace the graph to update the fake tensor types + graph_module = super().call(graph_module).graph_module + + graph_module.recompile() + return PassResult(graph_module, modified) diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 6db9c968f0..38b73fe0f6 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -37,6 +37,7 @@ op_squeeze, op_sub, op_sum, + op_table, op_tanh, op_to_copy, op_transpose, diff --git a/backends/arm/operators/op_exp.py b/backends/arm/operators/op_exp.py index 7a0b4e104f..26433582d9 100644 --- a/backends/arm/operators/op_exp.py +++ b/backends/arm/operators/op_exp.py @@ -6,30 +6,25 @@ # pyre-unsafe from typing import List -import numpy as np - import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification -from executorch.backends.arm.tosa_quant_utils import ( - dequantize_value, - get_quant_arg_downstream, - get_quant_arg_upstream, - QuantArgs, - quantize_value, -) from serializer.tosa_serializer import TosaOp from torch.fx import Node @register_node_visitor -class ExpVisitor(NodeVisitor): +class ExpVisitor_0_80_MI(NodeVisitor): target = "aten.exp.default" + # BI case should be handled by op_table + tosa_specs = [TosaSpecification.create_from_string("TOSA-0.80+MI")] + def __init__(self, *args): super().__init__(*args) @@ -43,41 +38,6 @@ def define_node( ) -> None: assert len(node.all_input_nodes) == 1 + assert inputs[0].dtype == output.dtype == ts.DType.FP32 - if is_quant_node: - # Assume quantized input is 8 bit. - - # Create attribute for 8 bit table lookup. - input_node = node.all_input_nodes[0] - in_quantargs = get_quant_arg_upstream(input_node) - output_node = list(node.users)[0] - out_quantargs = get_quant_arg_downstream(output_node) - - table = exp_table_8bit(in_quantargs, out_quantargs) - table_attr = ts.TosaSerializerAttribute() - table_attr.TableAttribute(table) - - tosa_graph.addOperator( - TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr - ) - else: - tosa_graph.addOperator(TosaOp.Op().EXP, [inputs[0].name], [output.name]) - - -def exp_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs): - """ - Returns a table mapping 256 entries to exp([qmin,qmax]) - """ - - def exp(x): - # Convert quantized input to floating point exp input space. - v = dequantize_value(x, in_quantargs) - # Compute exp. - v = np.exp(v) - # Convert exp output back to quantized space. - return quantize_value(v, out_quantargs) - - return [ - exp(x) - for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8) - ] + tosa_graph.addOperator(TosaOp.Op().EXP, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_log.py b/backends/arm/operators/op_log.py index 76adc2325e..ffff21c6c8 100644 --- a/backends/arm/operators/op_log.py +++ b/backends/arm/operators/op_log.py @@ -6,22 +6,14 @@ # pyre-unsafe from typing import List -import numpy as np - import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification -from executorch.backends.arm.tosa_quant_utils import ( - dequantize_value, - get_quant_arg_downstream, - get_quant_arg_upstream, - QuantArgs, - quantize_value, -) from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -30,6 +22,9 @@ class LogVisitor(NodeVisitor): target = "aten.log.default" + # BI case should be handled by op_table + tosa_specs = [TosaSpecification.create_from_string("TOSA-0.80+MI")] + def __init__(self, *args): super().__init__(*args) @@ -41,44 +36,8 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - assert len(node.all_input_nodes) == 1 assert len(node.users) == 1 + assert inputs[0].dtype == output.dtype == ts.DType.FP32 - if is_quant_node: - # Assume quantized input is 8 bit. - - # Create attribute for 8 bit table lookup. - input_node = node.all_input_nodes[0] - in_quantargs = get_quant_arg_upstream(input_node) - output_node = list(node.users)[0] - out_quantargs = get_quant_arg_downstream(output_node) - - table = log_table_8bit(in_quantargs, out_quantargs) - table_attr = ts.TosaSerializerAttribute() - table_attr.TableAttribute(table) - - tosa_graph.addOperator( - TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr - ) - else: - tosa_graph.addOperator(TosaOp.Op().LOG, [inputs[0].name], [output.name]) - - -def log_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs): - """ - Returns a table mapping 256 entries to log([qmin,qmax]) - """ - - def log(x): - # Convert quantized input to floating point log input space. - v = dequantize_value(x, in_quantargs) - # Compute log. - v = np.log(v) - # Convert log output back to quantized space. - return quantize_value(v, out_quantargs) - - return [ - log(x) - for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8) - ] + tosa_graph.addOperator(TosaOp.Op().LOG, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_reciprocal.py b/backends/arm/operators/op_reciprocal.py index 2f2758b0f5..024fb63a5d 100644 --- a/backends/arm/operators/op_reciprocal.py +++ b/backends/arm/operators/op_reciprocal.py @@ -6,8 +6,6 @@ # pyre-unsafe from typing import List -import numpy as np - import serializer.tosa_serializer as ts import torch from executorch.backends.arm.operators.node_visitor import ( @@ -15,20 +13,17 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - dequantize_value, - get_quant_arg_downstream, - get_quant_arg_upstream, - QuantArgs, - quantize_value, -) +from executorch.backends.arm.tosa_specification import TosaSpecification from serializer.tosa_serializer import TosaOp @register_node_visitor -class DivVisitor(NodeVisitor): +class ReciprocalVisitor_080_MI(NodeVisitor): target = "aten.reciprocal.default" + # BI case should be handled by op_table + tosa_specs = [TosaSpecification.create_from_string("TOSA-0.80+MI")] + def __init__(self, *args): super().__init__(*args) @@ -40,43 +35,5 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - # 1/X - - if is_quant_node: - input = inputs[0] - input_qargs = get_quant_arg_upstream(node.all_input_nodes[0]) - output_qargs = get_quant_arg_downstream(list(node.users)[0]) - - div_table = div_table_8bit(input_qargs, output_qargs) - - table_attr = ts.TosaSerializerAttribute() - table_attr.TableAttribute(div_table) - tosa_graph.addOperator( - TosaOp.Op().TABLE, [input.name], [output.name], table_attr - ) - - else: - tosa_graph.addOperator( - TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name] - ) - - -def div_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs): - """ - Returns a table mapping 256 entries to div([qmin,qmax]) - """ - - def div(x): - # Convert quantized input to floating point div input space. - v1 = dequantize_value(x, in_quantargs) - # Compute div. - v2 = 1.0 / v1 - # Convert div output back to quantized space. - v3 = quantize_value(v2, out_quantargs) - - return v3 - - return [ - div(x) - for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8) - ] + assert inputs[0].dtype == output.dtype == ts.DType.FP32 + tosa_graph.addOperator(TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_rsqrt.py b/backends/arm/operators/op_rsqrt.py index b503a323b1..49218645d7 100644 --- a/backends/arm/operators/op_rsqrt.py +++ b/backends/arm/operators/op_rsqrt.py @@ -6,7 +6,6 @@ # pyre-unsafe from typing import List -import numpy as np import serializer.tosa_serializer as ts import torch from executorch.backends.arm.operators.node_visitor import ( @@ -14,20 +13,20 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - dequantize_value, - get_quant_arg_downstream, - get_quant_arg_upstream, - QuantArgs, - quantize_value, -) +from executorch.backends.arm.tosa_specification import TosaSpecification from serializer.tosa_serializer import TosaOp @register_node_visitor -class RsqrtVisitor(NodeVisitor): +class RsqrtVisitor_080_MI(NodeVisitor): target = "aten.rsqrt.default" + # BI case should be handled by op_table + tosa_specs = [TosaSpecification.create_from_string("TOSA-0.80+MI")] + + def __init__(self, *args): + super().__init__(*args) + def define_node( self, node: torch.fx.Node, @@ -36,38 +35,5 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - if is_quant_node: - # Assume quantized input is 8 bit. - # Create attribute for 8 bit table lookup. - input_node = node.all_input_nodes[0] - in_quantargs = get_quant_arg_upstream(input_node) - output_node = list(node.users)[0] - out_quantargs = get_quant_arg_downstream(output_node) - table = rsqrt_table_8bit(in_quantargs, out_quantargs) - table_attr = ts.TosaSerializerAttribute() - table_attr.TableAttribute(table) - tosa_graph.addOperator( - TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr - ) - else: - tosa_graph.addOperator(TosaOp.Op().RSQRT, [inputs[0].name], [output.name]) - - -def rsqrt_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs): - """ - Returns a table mapping 256 entries to rqsrt([qmin,qmax]) - Reference: https://www.mlplatform.org/tosa/tosa_spec.html#_rsqrt - """ - - def rqsrt(x): - # Convert quantized input to floating point rqsrt input space. - v = dequantize_value(x, in_quantargs) - # Compute rqsrt. - v = 1 / np.sqrt(v) - # Convert rqsrt output back to quantized space. - return quantize_value(v, out_quantargs) - - return [ - rqsrt(x) - for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8) - ] + assert inputs[0].dtype == output.dtype == ts.DType.FP32 + tosa_graph.addOperator(TosaOp.Op().RSQRT, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_sigmoid.py b/backends/arm/operators/op_sigmoid.py index e299e99b43..d9c93fc7ed 100644 --- a/backends/arm/operators/op_sigmoid.py +++ b/backends/arm/operators/op_sigmoid.py @@ -6,30 +6,25 @@ # pyre-unsafe from typing import List -import numpy as np - import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification -from executorch.backends.arm.tosa_quant_utils import ( - dequantize_value, - get_quant_arg_downstream, - get_quant_arg_upstream, - QuantArgs, - quantize_value, -) from serializer.tosa_serializer import TosaOp from torch.fx import Node @register_node_visitor -class SigmoidVisitor(NodeVisitor): +class SigmoidVisitor_080_MI(NodeVisitor): target = "aten.sigmoid.default" + # BI case should be handled by op_table + tosa_specs = [TosaSpecification.create_from_string("TOSA-0.80+MI")] + def __init__(self, *args): super().__init__(*args) @@ -44,42 +39,6 @@ def define_node( assert len(node.all_input_nodes) == 1 assert len(node.users) == 1 + assert inputs[0].dtype == output.dtype == ts.DType.FP32 - if is_quant_node: - # Assume quantized input is 8 bit. - - # Create attribute for 8 bit table lookup. - input_node = node.all_input_nodes[0] - in_quantargs = get_quant_arg_upstream(input_node) - output_node = list(node.users)[0] - out_quantargs = get_quant_arg_downstream(output_node) - - table = sigmoid_table_8bit(in_quantargs, out_quantargs) - table_attr = ts.TosaSerializerAttribute() - table_attr.TableAttribute(table) - - tosa_graph.addOperator( - TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr - ) - else: - tosa_graph.addOperator(TosaOp.Op().SIGMOID, [inputs[0].name], [output.name]) - - -def sigmoid_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs): - """ - Returns a table mapping 256 entries to sigmoid([qmin,qmax]) - Reference: https://www.mlplatform.org/tosa/tosa_spec.html#_sigmoid - """ - - def sigmoid(x): - # Convert quantized input to floating point sigmoid input space. - v = dequantize_value(x, in_quantargs) - # Compute sigmoid. - v = 1.0 / (1.0 + np.exp(-v)) - # Convert sigmoid output back to quantized space. - return quantize_value(v, out_quantargs) - - return [ - sigmoid(x) - for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8) - ] + tosa_graph.addOperator(TosaOp.Op().SIGMOID, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_table.py b/backends/arm/operators/op_table.py new file mode 100644 index 0000000000..0a5892f067 --- /dev/null +++ b/backends/arm/operators/op_table.py @@ -0,0 +1,41 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import List + +import numpy as np + +import serializer.tosa_serializer as ts +import torch +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from serializer.tosa_serializer import TosaOp + + +@register_node_visitor +class TableVisitor(NodeVisitor): + target = "_table" + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + assert node.name in self._exported_program.state_dict.keys() + assert inputs[0].dtype == output.dtype == ts.DType.INT8 + table = self._exported_program.state_dict[node.name] + table_attr = ts.TosaSerializerAttribute() + table_attr.TableAttribute(np.array(table)) + tosa_graph.addOperator( + TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr + ) diff --git a/backends/arm/operators/op_tanh.py b/backends/arm/operators/op_tanh.py index 2c84580edc..5fa2a52beb 100644 --- a/backends/arm/operators/op_tanh.py +++ b/backends/arm/operators/op_tanh.py @@ -6,30 +6,24 @@ # pyre-unsafe from typing import List -import numpy as np - import serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg - -from executorch.backends.arm.tosa_quant_utils import ( - dequantize_value, - get_quant_arg_downstream, - get_quant_arg_upstream, - QuantArgs, - quantize_value, -) +from executorch.backends.arm.tosa_specification import TosaSpecification from serializer.tosa_serializer import TosaOp from torch.fx import Node @register_node_visitor -class TanhVisitor(NodeVisitor): +class TanhVisitor_080_MI(NodeVisitor): target = "aten.tanh.default" + # BI case should be handled by op_table + tosa_specs = [TosaSpecification.create_from_string("TOSA-0.80+MI")] + def __init__(self, *args): super().__init__(*args) @@ -41,47 +35,5 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - - assert len(node.all_input_nodes) == 1 - - if is_quant_node: - # Assume quantized input is 8 bit. - assert len(node.users) == 1 - - # Create attribute for 8 bit table lookup. - input_node = node.all_input_nodes[0] - in_quantargs = get_quant_arg_upstream(input_node) - output_node = list(node.users)[0] - out_quantargs = get_quant_arg_downstream(output_node) - - table = tanh_table_8bit(in_quantargs, out_quantargs) - table_attr = ts.TosaSerializerAttribute() - table_attr.TableAttribute(table) - - tosa_graph.addOperator( - TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr - ) - else: - tosa_graph.addOperator(TosaOp.Op().TANH, [inputs[0].name], [output.name]) - - -def tanh_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs): - """ - Returns a table mapping 256 entries to tanh([qmin,qmax]) - Reference: https://www.mlplatform.org/tosa/tosa_spec.html#_tanh - """ - - def tanh(x): - # Convert quantized input to floating point tanh input space. - v = dequantize_value(x, in_quantargs) - # Compute tanh. - v = np.exp(-2.0 * v) - v = (1.0 - v) / (1.0 + v) - - # Convert tanh output back to quantized space. - return quantize_value(v, out_quantargs) - - return [ - tanh(x) - for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8) - ] + assert inputs[0].dtype == output.dtype == ts.DType.FP32 + tosa_graph.addOperator(TosaOp.Op().TANH, [inputs[0].name], [output.name]) From 4c802948b57ded50134f9ecf322f5481f01054b4 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Thu, 5 Dec 2024 14:20:34 +0100 Subject: [PATCH 2/7] Add support for concat q/dq folding This is a special case where node.args can be lists with many incoming dq-nodes. Signed-off-by: Oscar Andersson Change-Id: Icf511a8bdeaaffb597b18455ab7f1fbd947ce3ca --- backends/arm/_passes/arm_pass_manager.py | 5 +- .../fold_qdq_with_annotated_qparams_pass.py | 73 +++++++++++-------- 2 files changed, 47 insertions(+), 31 deletions(-) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 3e5d4e7438..446c3cfe1c 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -90,14 +90,15 @@ def transform_to_backend_pipeline( self.add_pass( FoldAndAnnotateQParamsPass( [ - exir_ops.edge.aten.minimum.default, - exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.avg_pool2d.default, + exir_ops.edge.aten.cat.default, exir_ops.edge.aten.convolution.default, exir_ops.edge.aten.exp.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.log.default, + exir_ops.edge.aten.maximum.default, + exir_ops.edge.aten.minimum.default, exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.rsqrt.default, exir_ops.edge.aten.sigmoid.default, diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index f078cf2118..aaa0c5272b 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -80,6 +80,46 @@ def __init__(self, targeted_ops: Iterable[EdgeOpOverload]) -> None: super().__init__() self.targeted_ops = targeted_ops + def fold_and_annotate_arg( + self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int + ): + input_qparams = None + nodes_to_remove = set() + for arg in arg_list: + if not isinstance(arg, Node): + return + """ + Make sure arg has requires_grad set to False + For parameters that are not quantized, sometimes (i.e. convolution) + the Parameter(FakeTensor(...)) has requires_grad set to True, which + causes the retracing of the graph to fail with: + + E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch. + E + E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + E Original traceback: + E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward + E x = conv(x) + """ + if arg.op == "placeholder": + arg.meta["val"].requires_grad = False + + arg_quant_params = None + if arg.target == dq_op: + arg_quant_params = QuantArgs.from_operator(arg.target, arg.args) + # add arg to nodes_to_remove to fold the dq-node + nodes_to_remove.add(arg) + if input_qparams is not None and input_qparams != arg_quant_params: + # Two args are quantized differently + raise RuntimeError("Input qparams does not match!") + input_qparams = arg_quant_params + if input_qparams is not None: + node.meta["input_qparams"][i] = input_qparams + for n in nodes_to_remove: + assert n.target == dq_op + n.replace_all_uses_with(n.args[0]) + graph_module.graph.erase_node(n) + def call(self, graph_module: GraphModule) -> PassResult: # Loop over the graph nodes and find any node in the 'targeted_ops' list. @@ -98,36 +138,11 @@ def call(self, graph_module: GraphModule) -> PassResult: n.meta["input_qparams"] = {} n.meta["output_qparams"] = {} for i, arg in enumerate(n.args): - if not isinstance(arg, Node): - continue - - # Make sure arg has requires_grad set to False - # For parameters that are not quantized, sometimes (i.e. convolution) - # the Parameter(FakeTensor(...)) has requires_grad set to True, which - # causes the retracing of the graph to fail with: - # - # E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch. - # E - # E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) - # E Original traceback: - # E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward - # E x = conv(x) - # - if arg.op == "placeholder": - arg.meta["val"].requires_grad = False - - if arg.target != dq_op: - continue - - # arg.target for argument i is a dequant node, extract the information - n.meta["input_qparams"][i] = QuantArgs.from_operator( - arg.target, arg.args - ) + if isinstance(arg, list): + self.fold_and_annotate_arg(graph_module, n, arg, i) - # arg.args[0] is the tensor input, replace the input usage - tensor_input = cast(Node, arg.args[0]) - n.replace_input_with(arg, tensor_input) - graph_module.graph.erase_node(arg) + elif isinstance(arg, Node): + self.fold_and_annotate_arg(graph_module, n, [arg], i) # Copy the users, since we are modifying it. users_copy = copy.copy(n.users) From 83fb36463afb92147c0412994da83e675293cf5d Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Thu, 5 Dec 2024 14:30:14 +0100 Subject: [PATCH 3/7] Increase q/dq folding coverage Add support for q/dq folding of more operators such as hardtanh, maxpool2d, mul, relu, select, sub, to_copy. Signed-off-by: Oscar Andersson Change-Id: Ifdabda4c927dade41c000859054696844c546f7b --- backends/arm/_passes/arm_pass_manager.py | 18 ++-- backends/arm/operators/op_div.py | 51 --------- backends/arm/operators/op_hardtanh.py | 13 +-- backends/arm/operators/op_max_pool2d.py | 26 +++-- backends/arm/operators/op_mul.py | 129 +++++++++++++---------- backends/arm/operators/op_relu.py | 11 +- backends/arm/operators/op_select.py | 4 +- backends/arm/operators/op_sub.py | 84 ++++++++++++--- backends/arm/operators/op_to_copy.py | 2 - backends/arm/test/ops/test_scalars.py | 2 +- backends/arm/test/ops/test_select.py | 2 - 11 files changed, 181 insertions(+), 161 deletions(-) delete mode 100644 backends/arm/operators/op_div.py diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 446c3cfe1c..26e52bfea5 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -79,12 +79,6 @@ def transform_to_backend_pipeline( self.add_pass(DecomposeVarPass()) self.add_pass(ConvertMeanDimToAveragePool()) self.add_pass(DecomposeMeanDimPass()) - self.add_pass(MatchArgRanksPass(exported_program)) - self.add_pass(DecomposeDivPass()) - self.add_pass(KeepDimsFalseToSqueezePass()) - self.add_pass(ConvertSplitToSlicePass()) - self.add_pass(Conv1dUnsqueezePass(exported_program)) - self.add_pass(DecomposeSoftmaxesPass()) self.add_pass(DecomposeLinearPass()) self.add_pass(QuantizeFullArgument()) self.add_pass( @@ -96,17 +90,29 @@ def transform_to_backend_pipeline( exir_ops.edge.aten.convolution.default, exir_ops.edge.aten.exp.default, exir_ops.edge.aten.full.default, + exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.log.default, + exir_ops.edge.aten.max_pool2d.default, exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.reciprocal.default, + exir_ops.edge.aten.relu.default, exir_ops.edge.aten.rsqrt.default, + exir_ops.edge.aten.select_copy.int, exir_ops.edge.aten.sigmoid.default, + exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.tanh.default, ] ) ) self.add_pass(InsertTableOpsPass(exported_program)) + self.add_pass(MatchArgRanksPass(exported_program)) + self.add_pass(DecomposeDivPass()) + self.add_pass(KeepDimsFalseToSqueezePass()) + self.add_pass(ConvertSplitToSlicePass()) + self.add_pass(Conv1dUnsqueezePass(exported_program)) + self.add_pass(DecomposeSoftmaxesPass()) for spec in compile_spec: if spec.key == "permute_memory_format": memory_format = spec.value.decode() diff --git a/backends/arm/operators/op_div.py b/backends/arm/operators/op_div.py deleted file mode 100644 index 2332e807c4..0000000000 --- a/backends/arm/operators/op_div.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2023-2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -from typing import List - -import serializer.tosa_serializer as ts -import torch -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_specification import TosaSpecification -from executorch.backends.arm.tosa_utils import tosa_shape -from serializer.tosa_serializer import TosaOp - - -@register_node_visitor -class DivVisitor(NodeVisitor): - target = "aten.div.Tensor" - - # Only supported for MI - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80+MI"), - ] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, - inputs: List[TosaArg], - output: TosaArg, - is_quant_node: bool, - ) -> None: - # FP32 Div is implemented as output=x/y -> output=x*1/y e.g. MUL(x,RECIPROCAL(y)) - recip = tosa_graph.addIntermediate( - tosa_shape(inputs[1].shape, inputs[1].dim_order), inputs[1].dtype - ) - tosa_graph.addOperator(TosaOp.Op().RECIPROCAL, [inputs[1].name], [recip.name]) - - attr = ts.TosaSerializerAttribute() - attr.MulAttribute(0) - tosa_graph.addOperator( - TosaOp.Op().MUL, [inputs[0].name, recip.name], [output.name], attr - ) diff --git a/backends/arm/operators/op_hardtanh.py b/backends/arm/operators/op_hardtanh.py index e726028206..544e00c5a2 100644 --- a/backends/arm/operators/op_hardtanh.py +++ b/backends/arm/operators/op_hardtanh.py @@ -8,16 +8,16 @@ import serializer.tosa_serializer as ts import torch +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, +) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - get_quant_arg_upstream, - quantize_value, -) +from executorch.backends.arm.tosa_quant_utils import quantize_value from serializer.tosa_serializer import TosaOp @@ -38,9 +38,10 @@ def define_node( ) -> None: attr = ts.TosaSerializerAttribute() - if is_quant_node: + if inputs[0].dtype == ts.DType.INT8: # Get quant parameters - qargs = get_quant_arg_upstream(node.all_input_nodes[0]) + input_qparams = get_input_qparams(node) + qargs = input_qparams[0] # Convert to quantized representation clamp_min_qs = quantize_value(inputs[1].number, qargs) clamp_max_qs = quantize_value(inputs[2].number, qargs) diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 0a4092e3a9..9cc40c47df 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -8,16 +8,15 @@ import serializer.tosa_serializer as ts import torch +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + get_output_qparams, +) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - get_quant_arg_downstream, - get_quant_arg_upstream, -) - from serializer.tosa_serializer import TosaOp @@ -46,19 +45,18 @@ def define_node( except IndexError: padding = [0, 0, 0, 0] - accumulator_type = input_tensor.dtype - - if is_quant_node: - # Accumulator type always is int8 when input tensor is an integer type. - accumulator_type = ts.DType.INT8 + accumulator_type = output.dtype # Initilize zero point to zero. input_zp = 0 - output_zp = 0 + if inputs[0].dtype == ts.DType.INT8: + input_qparams = get_input_qparams(node) + input_zp = input_qparams[0].zp - if is_quant_node: - input_zp = get_quant_arg_upstream(node.all_input_nodes[0]).zp - output_zp = get_quant_arg_downstream(list(node.users)[0]).zp + output_zp = 0 + if output.dtype == ts.DType.INT8: + output_qparams = get_output_qparams(node) + output_zp = output_qparams[0].zp attr = ts.TosaSerializerAttribute() attr.PoolAttribute( diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index ad578aa1f0..84c489790d 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -5,26 +5,34 @@ # pyre-unsafe -from typing import cast, List +from typing import List import executorch.backends.arm.tosa_quant_utils as tqutils import executorch.backends.arm.tosa_utils as tutils import serializer.tosa_serializer as ts import torch +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, +) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from serializer.tosa_serializer import TosaOp @register_node_visitor -class MulVisitor(NodeVisitor): +class MulVisitor_080_BI(NodeVisitor): target = "aten.mul.Tensor" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + ] + def define_node( self, node: torch.fx.Node, @@ -33,57 +41,68 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: + assert inputs[0].dtype == inputs[1].dtype == output.dtype == ts.DType.INT8 + input_A = inputs[0] + input_B = inputs[1] + input_qparams = get_input_qparams(node) + input_A_qargs = input_qparams[0] + input_B_qargs = input_qparams[1] + input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order) + input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order) + + # Rescale inputs to INT32 with zp=0 + input_A_rescaled = tqutils.build_rescale_to_int32( + tosa_graph, + input_A, + input_A_qargs.zp, + rescale_scale=1.0, + ) + input_B_rescaled = tqutils.build_rescale_to_int32( + tosa_graph, + input_B, + input_B_qargs.zp, + rescale_scale=1.0, + ) + + output_shape = tutils.tosa_shape(output.shape, output.dim_order) + mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32) + + # Do the INT32 Mul + attr = ts.TosaSerializerAttribute() + attr.MulAttribute(shift=0) + tosa_graph.addOperator( + TosaOp.Op().MUL, + [ + input_A_rescaled.name, + input_B_rescaled.name, + ], + [mul_output.name], + attr, + ) + output_scale = input_A_qargs.scale * input_B_qargs.scale + tqutils.insert_rescale_op_to_int8(tosa_graph, mul_output, output_scale, node) + + +@register_node_visitor +class MulVisitor_080_MI(MulVisitor_080_BI): + # inheriting 'target' from BI class + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] - if is_quant_node: - input_A = inputs[0] - input_B = inputs[1] - input_A_qargs = tqutils.get_quant_arg_upstream( - cast(torch.fx.Node, node.args[0]) - ) - input_B_qargs = tqutils.get_quant_arg_upstream( - cast(torch.fx.Node, node.args[1]) - ) - - input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order) - input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order) - output_shape = tutils.tosa_shape(output.shape, output.dim_order) - - # Rescale inputs to INT32 with zp=0 - input_A_rescaled = tqutils.build_rescale_to_int32( - tosa_graph, - input_A, - input_A_qargs.zp, - rescale_scale=1.0, - ) - input_B_rescaled = tqutils.build_rescale_to_int32( - tosa_graph, - input_B, - input_B_qargs.zp, - rescale_scale=1.0, - ) - - mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32) - - # Do the INT32 Mul - attr = ts.TosaSerializerAttribute() - attr.MulAttribute(shift=0) - tosa_graph.addOperator( - TosaOp.Op().MUL, - [ - input_A_rescaled.name, - input_B_rescaled.name, - ], - [mul_output.name], - attr, - ) - - tqutils.rescale_node_back_to_int8( - node, mul_output, input_A_qargs.scale * input_B_qargs.scale, tosa_graph - ) - - else: - attr = ts.TosaSerializerAttribute() - attr.MulAttribute(shift=0) - tosa_graph.addOperator( - TosaOp.Op().MUL, [inputs[0].name, inputs[1].name], [output.name], attr - ) + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + if inputs[0].dtype == ts.DType.INT8: + return super().define_node(node, tosa_graph, inputs, output, is_quant_node) + attr = ts.TosaSerializerAttribute() + attr.MulAttribute(shift=0) + tosa_graph.addOperator( + TosaOp.Op().MUL, [inputs[0].name, inputs[1].name], [output.name], attr + ) diff --git a/backends/arm/operators/op_relu.py b/backends/arm/operators/op_relu.py index a3a7c82ab8..0641a5d983 100644 --- a/backends/arm/operators/op_relu.py +++ b/backends/arm/operators/op_relu.py @@ -8,6 +8,9 @@ import executorch.backends.arm.tosa_quant_utils as tqutils import serializer.tosa_serializer as ts import torch.fx +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_output_qparams, +) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -37,10 +40,10 @@ def define_node( clamp_max_fp = 0.0 clamp_min_qs = 0 clamp_max_qs = 0 - if is_quant_node: - out_qargs = tqutils.get_quant_arg_downstream(list(node.users)[0]) - clamp_min_qs = tqutils.quantize_value(0, out_qargs) - clamp_max_qs = tqutils.quantize_value(float("inf"), out_qargs) + if inputs[0].dtype == ts.DType.INT8: + out_qargs = get_output_qparams(node) + clamp_min_qs = tqutils.quantize_value(0, out_qargs[0]) + clamp_max_qs = tqutils.quantize_value(float("inf"), out_qargs[0]) else: clamp_min_fp = 0 diff --git a/backends/arm/operators/op_select.py b/backends/arm/operators/op_select.py index 77875507d9..eddd5c4adf 100644 --- a/backends/arm/operators/op_select.py +++ b/backends/arm/operators/op_select.py @@ -50,9 +50,7 @@ def define_node( expanded_shape = tuple(1 if i == dim else shape[i] for i in range(rank)) expanded_shape = tosa_shape(expanded_shape, input_node.dim_order) - output_reshaped = tosa_graph.addIntermediate( - expanded_shape, ts.DType.INT8 if is_quant_node else output.dtype - ) + output_reshaped = tosa_graph.addIntermediate(expanded_shape, output.dtype) attr_slice = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index b86a5ea3ad..6125158eb9 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -16,14 +16,19 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from serializer.tosa_serializer import TosaOp from torch.fx import Node @register_node_visitor -class SubVisitor(NodeVisitor): +class SubVisitor_080_BI(NodeVisitor): target = "aten.sub.Tensor" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + ] + def __init__(self, *args): super().__init__(*args) @@ -35,32 +40,77 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - if is_quant_node: - input_nodes = tutils.get_two_inputs(node) + # Specification (0.80) states that input and output types + # should all be the same + assert inputs[0].dtype == inputs[1].dtype == output.dtype + # Handle int8 (quantized) and int32 + assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32] - # Rescale inputs to 32 bit - rescaled_inputs, scale = tqutils.rescale_nodes_to_int32( - input_nodes, tosa_graph + if inputs[0].dtype == ts.DType.INT8: + rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( + tosa_graph, inputs, node ) + else: + # input[0].dtype == ts.DType.INT32 + # Non quantized input, natively support by TOSA.SUB + rescaled_inputs = inputs - # Prepare sub output tensor + if output.dtype == ts.DType.INT8: broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) sub_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) + else: + # output.dtype == ts.DType.INT32 + sub_output = output - # Do the INT32 Sub - tosa_graph.addOperator( - TosaOp.Op().SUB, - [ - rescaled_inputs[0].name, - rescaled_inputs[1].name, - ], - [sub_output.name], - ) + # Do the INT32 Sub + tosa_graph.addOperator( + TosaOp.Op().SUB, + [ + rescaled_inputs[0].name, + rescaled_inputs[1].name, + ], + [sub_output.name], + None, + ) + if output.dtype == ts.DType.INT8: # Scale output back to 8 bit - tqutils.rescale_node_back_to_int8(node, sub_output, scale, tosa_graph) + # pyre-ignore + tqutils.insert_rescale_op_to_int8(tosa_graph, sub_output, scale_back, node) + + +@register_node_visitor +class SubVisitor_080_MI(SubVisitor_080_BI): + # inheriting 'target' from BI class + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + # Specification (0.80) states that input and output types + # should all be the same + assert inputs[0].dtype == inputs[1].dtype == output.dtype + + if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: + # Call the inherited define_node for handling integers + super().define_node(node, tosa_graph, inputs, output, is_quant_node) else: # FP32 Sub lowering + assert inputs[0].dtype == ts.DType.FP32 + assert output.dtype == ts.DType.FP32 + + # MI lowering tosa_graph.addOperator( TosaOp.Op().SUB, [inputs[0].name, inputs[1].name], diff --git a/backends/arm/operators/op_to_copy.py b/backends/arm/operators/op_to_copy.py index 15077d6df7..c0e4f0de4c 100644 --- a/backends/arm/operators/op_to_copy.py +++ b/backends/arm/operators/op_to_copy.py @@ -38,6 +38,4 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - assert not is_quant_node, "Casting of quantized values is not supported." - assert inputs tosa_graph.addOperator(TosaOp.Op().CAST, [inputs[0].name], [output.name]) diff --git a/backends/arm/test/ops/test_scalars.py b/backends/arm/test/ops/test_scalars.py index f03d8f72d1..bcf294de4a 100644 --- a/backends/arm/test/ops/test_scalars.py +++ b/backends/arm/test/ops/test_scalars.py @@ -157,7 +157,7 @@ def _test_add_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: tuple): def test_MI(self, test_name: str, op: torch.nn.Module, x, y): expected_exception = None if any(token in test_name for token in ("Sub_int", "Sub__int")): - expected_exception = ValueError + expected_exception = (AssertionError, ValueError) elif test_name.endswith("_st"): expected_exception = AttributeError diff --git a/backends/arm/test/ops/test_select.py b/backends/arm/test/ops/test_select.py index f44e61c64f..c39b20a731 100644 --- a/backends/arm/test/ops/test_select.py +++ b/backends/arm/test/ops/test_select.py @@ -117,8 +117,6 @@ def _test_select_ethos_BI_pipeline( .check(["torch.ops.quantized_decomposed"]) .to_edge() .partition() - .dump_artifact() - .dump_operator_distribution() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() ) From f8e1c3f5038735179a49cc16e2ea7dc94b099c83 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Thu, 5 Dec 2024 14:32:46 +0100 Subject: [PATCH 4/7] Add support for sum q/dq folding sum is retraced to an int64 dtype of operator after q/dq folding. This patch adds a pass to manually force the dtype to be int8. Signed-off-by: Oscar Andersson Change-Id: Ifa737a398c5a878d52cd76a2392499905da085ce --- backends/arm/_passes/arm_pass_manager.py | 5 +- .../fold_qdq_with_annotated_qparams_pass.py | 30 ++++ backends/arm/operators/op_sum.py | 140 +++++++++++------- 3 files changed, 119 insertions(+), 56 deletions(-) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 26e52bfea5..79e99e4118 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -32,6 +32,7 @@ from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( FoldAndAnnotateQParamsPass, QuantizeFullArgument, + RetraceFoldedDtypesPass, ) from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import ( @@ -102,14 +103,16 @@ def transform_to_backend_pipeline( exir_ops.edge.aten.select_copy.int, exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.sum.dim_IntList, exir_ops.edge.aten.tanh.default, ] ) ) + self.add_pass(RetraceFoldedDtypesPass()) self.add_pass(InsertTableOpsPass(exported_program)) + self.add_pass(KeepDimsFalseToSqueezePass()) self.add_pass(MatchArgRanksPass(exported_program)) self.add_pass(DecomposeDivPass()) - self.add_pass(KeepDimsFalseToSqueezePass()) self.add_pass(ConvertSplitToSlicePass()) self.add_pass(Conv1dUnsqueezePass(exported_program)) self.add_pass(DecomposeSoftmaxesPass()) diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index aaa0c5272b..aa5358fe17 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -196,3 +196,33 @@ def call(self, graph_module: GraphModule) -> PassResult: modified = True return PassResult(graph_module, modified) + + +class RetraceFoldedDtypesPass(ExportPass): + """ + FoldAndAnnotateQParamsPass folds dq and q nodes. When the graph is retraced + some operators are retraced to types that cannot be handled by TOSA. One + such example is sum.dim_IntList: + q (int8) -> dq (fp32) -> sum (fp32) -> q (int8) ... + After folding it becomes: + q (int8) -> sum (int64) -> ... + This pass changes types of ops in self.targeted_ops, such as sum, so that + the output type of that matches the type of the output_qparams. + """ + + targeted_ops = { + exir_ops.edge.aten.sum.dim_IntList, + } + + def call_operator(self, op, args, kwargs, meta): + if op not in self.targeted_ops: + return super().call_operator(op, args, kwargs, meta) + + node_kwargs = kwargs.copy() + output_qparams = meta["output_qparams"] + if len(output_qparams) == 0: + return super().call_operator(op, args, kwargs, meta) + + output_dtype = output_qparams[0].dtype + node_kwargs["dtype"] = output_dtype + return super().call_operator(op, args, node_kwargs, meta) diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py index 26d29f5179..a4d2d8f914 100644 --- a/backends/arm/operators/op_sum.py +++ b/backends/arm/operators/op_sum.py @@ -16,14 +16,19 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from serializer.tosa_serializer import TosaOp from torch.fx import Node @register_node_visitor -class AddVisitor(NodeVisitor): +class SumVisitor_080_BI(NodeVisitor): target = "aten.sum.dim_IntList" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + ] + def __init__(self, *args): super().__init__(*args) @@ -35,64 +40,89 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - input_node = inputs[0] - input_shape = list(input_node.shape) + input_shape = list(inputs[0].shape) dim_list = cast(list[int], inputs[1].special) - dim_list = [dim % len(input_node.shape) for dim in dim_list] + dim_list = [dim % len(input_shape) for dim in dim_list] keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False) assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass" - if is_quant_node: + # Rescale input to 32 bit + rescaled_inputs, scale = tqutils.insert_rescale_ops_to_int32( + tosa_graph, + [inputs[0]], + node, + ) + + prev_node = rescaled_inputs[0] + reduced_shape = input_shape + + # Reduce all dims in dim_list one-by-one. + for dim in dim_list: + # When reduced, the size of the dim becomes 1. + reduced_shape[dim] = 1 + + attr = ts.TosaSerializerAttribute() + attr.AxisAttribute(inputs[0].dim_order.index(dim)) + + next_node = tosa_graph.addIntermediate( + tutils.tosa_shape(reduced_shape, inputs[0].dim_order), + dtype=ts.DType.INT32, + ) + + tosa_graph.addOperator( + TosaOp.Op().REDUCE_SUM, [prev_node.name], [next_node.name], attr + ) + + prev_node = next_node + tqutils.insert_rescale_op_to_int8(tosa_graph, prev_node, scale, node) + + +@register_node_visitor +class SumVisitor_080_MI(SumVisitor_080_BI): + # inheriting 'target' from BI class + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + if inputs[0].dtype == ts.DType.INT8: + return super().define_node(node, tosa_graph, inputs, output, is_quant_node) + input_name = inputs[0].name + reduced_shape = list(inputs[0].shape) + dim_list = cast(list[int], inputs[1].special) + dim_list = [dim % len(reduced_shape) for dim in dim_list] + keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False) + assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass" + + # Reduce all dims in dim_list one-by-one. + for dim in dim_list: + # When reduced, the size of the dim becomes 1 + reduced_shape[dim] = 1 + + attr = ts.TosaSerializerAttribute() + attr.AxisAttribute(inputs[0].dim_order.index(dim)) + + if dim == dim_list[-1]: + output_name = output.name + else: + output_name = tosa_graph.addIntermediate( + tutils.tosa_shape(reduced_shape, inputs[0].dim_order), + dtype=ts.DType.FP32, + ).name - # Rescale input to 32 bit - rescaled_inputs, scale = tqutils.rescale_nodes_to_int32( - [node.all_input_nodes[0]], tosa_graph + tosa_graph.addOperator( + TosaOp.Op().REDUCE_SUM, [input_name], [output_name], attr ) - prev_node = rescaled_inputs[0] - reduced_shape = input_shape - - # Reduce all dims in dim_list one-by-one. - for dim in dim_list: - # When reduced, the size of the dim becomes 1. - reduced_shape[dim] = 1 - - attr = ts.TosaSerializerAttribute() - attr.AxisAttribute(input_node.dim_order.index(dim)) - - next_node = tosa_graph.addIntermediate( - tutils.tosa_shape(reduced_shape, input_node.dim_order), - dtype=ts.DType.INT32, - ) - - tosa_graph.addOperator( - TosaOp.Op().REDUCE_SUM, [prev_node.name], [next_node.name], attr - ) - - prev_node = next_node - tqutils.rescale_node_back_to_int8(node, prev_node, scale, tosa_graph) - else: - input_name = input_node.name - reduced_shape = input_shape - - # Reduce all dims in dim_list one-by-one. - for dim in dim_list: - # When reduced, the size of the dim becomes 1 - reduced_shape[dim] = 1 - - attr = ts.TosaSerializerAttribute() - attr.AxisAttribute(input_node.dim_order.index(dim)) - - if dim == dim_list[-1]: - output_name = output.name - else: - output_name = tosa_graph.addIntermediate( - tutils.tosa_shape(reduced_shape, input_node.dim_order), - dtype=ts.DType.FP32, - ).name - - tosa_graph.addOperator( - TosaOp.Op().REDUCE_SUM, [input_name], [output_name], attr - ) - - input_name = output_name + input_name = output_name From 0b381e4af185be0d6eb30a492343c488f7837a6e Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Thu, 5 Dec 2024 14:42:26 +0100 Subject: [PATCH 5/7] Complete q/dq folding coverage Add support for q/dq folding for the remaining supported ops in Arm backend. Signed-off-by: Oscar Andersson Change-Id: I9012b4a501ce018c9771c729706be3b031a5c7ae --- .../arm/_passes/annotate_decomposed_matmul.py | 91 +++++++++++ backends/arm/_passes/arm_pass_manager.py | 38 +++-- backends/arm/_passes/conv1d_unsqueeze_pass.py | 22 +-- .../arm/_passes/size_adjust_conv2d_pass.py | 10 +- backends/arm/operators/op_bmm.py | 44 +++--- backends/arm/operators/op_mm.py | 143 +++++++++++------- backends/arm/process_node.py | 5 +- backends/arm/test/misc/test_debug_feats.py | 3 +- backends/arm/test/ops/test_bmm.py | 7 +- backends/arm/test/ops/test_clone.py | 50 +----- 10 files changed, 244 insertions(+), 169 deletions(-) create mode 100644 backends/arm/_passes/annotate_decomposed_matmul.py diff --git a/backends/arm/_passes/annotate_decomposed_matmul.py b/backends/arm/_passes/annotate_decomposed_matmul.py new file mode 100644 index 0000000000..44f99fd1a7 --- /dev/null +++ b/backends/arm/_passes/annotate_decomposed_matmul.py @@ -0,0 +1,91 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +from typing import Any, Dict, List + +import torch +from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.backends.arm.tosa_quant_utils import dq_op, q_op +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx import GraphModule +from torch.fx.passes.utils.source_matcher_utils import ( + get_source_partitions, + SourcePartition, +) + + +class AnnotateDecomposedMatmulPass(ExportPass): + """ + torch.matmul can be decomposed in many ways, for instance: + dq -> matmul -> q can become + dq -> repeat -> view -> bmm -> view -> dq which makes quantization folding + difficult. This helper function find all matmul partitions and annotate its + matmul-op (can be mm or bmm). + """ + + def call(self, graph_module: GraphModule): + matmul_partitions: Dict[Any, List[SourcePartition]] = get_source_partitions( + graph_module.graph, + [ + torch.matmul, + ], + None, + ) + matmul_partitions = list( + itertools.chain.from_iterable(matmul_partitions.values()) + ) + matmul_targets = { + exir_ops.edge.aten.mm.default, + exir_ops.edge.aten.bmm.default, + } + for partition in matmul_partitions: + quantized_input = all( + input_node.target == dq_op for input_node in partition.input_nodes + ) + matmul_node = [ + node for node in partition.nodes if node.target in matmul_targets + ][0] + if quantized_input: + matmul_args = matmul_node.all_input_nodes + for i in range(len(matmul_args)): + input_node = partition.input_nodes[i] + matmul_input_node = matmul_args[i] + # Remove partition input dq-node + input_node.replace_all_uses_with(input_node.args[0]) + graph_module.graph.erase_node(input_node) + input_node_qargs = input_node.args[1:] + with graph_module.graph.inserting_before(matmul_node): + # Create new dq-node before matmul + dq_node = create_node( + graph=graph_module.graph, + op_target=dq_op, + ) + dq_node.args = (matmul_input_node, *input_node_qargs) + matmul_node.replace_input_with(matmul_input_node, dq_node) + + partition_output = list(partition.output_nodes[0].users)[0] + quantized_output = partition_output.target == q_op + if quantized_output: + output_node_qargs = partition_output.args[1:] + with graph_module.graph.inserting_after(matmul_node): + # Create q-node after matmul + q_node = create_node( + graph=graph_module.graph, + op_target=q_op, + ) + matmul_node.replace_all_uses_with(q_node) + q_node.args = (matmul_node, *output_node_qargs) + # Remove partition output q-node + partition_output.replace_all_uses_with(partition_output.args[0]) + graph_module.graph.erase_node(partition_output) + + # retrace the graph to update the fake tensor types + graph_module = super().call(graph_module).graph_module + + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 79e99e4118..6d747d8129 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -11,6 +11,9 @@ from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import ( AnnotateChannelsLastDimOrder, ) +from executorch.backends.arm._passes.annotate_decomposed_matmul import ( + AnnotateDecomposedMatmulPass, +) from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( @@ -69,51 +72,64 @@ def transform_to_backend_pipeline( self, exported_program: ExportedProgram, compile_spec: list[CompileSpec] ): """Apply passes before transforming program to backend""" - self.add_pass(CastInt64ToInt32Pass(exported_program)) + self.add_pass(DecomposeLinearPass()) self.add_pass(RemoveGetItemPass()) - self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) - self.add_pass(SizeAdjustConv2DPass()) - self.add_pass(RemoveClonePass()) - self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(DecomposeLayerNormPass()) - self.add_pass(UnsqueezeBeforeRepeatPass()) self.add_pass(DecomposeVarPass()) self.add_pass(ConvertMeanDimToAveragePool()) self.add_pass(DecomposeMeanDimPass()) - self.add_pass(DecomposeLinearPass()) + self.add_pass(ConvertSplitToSlicePass()) + # TODO MLETORCH-558 + self.add_pass(AnnotateDecomposedMatmulPass()) self.add_pass(QuantizeFullArgument()) self.add_pass( FoldAndAnnotateQParamsPass( [ + exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.avg_pool2d.default, + exir_ops.edge.aten.bmm.default, exir_ops.edge.aten.cat.default, exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.clone.default, exir_ops.edge.aten.exp.default, + exir_ops.edge.aten.expand_copy.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.log.default, exir_ops.edge.aten.max_pool2d.default, - exir_ops.edge.aten.maximum.default, - exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.mm.default, exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.permute_copy.default, exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.repeat.default, exir_ops.edge.aten.rsqrt.default, exir_ops.edge.aten.select_copy.int, exir_ops.edge.aten.sigmoid.default, + exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.squeeze_copy.dims, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.sum.dim_IntList, exir_ops.edge.aten.tanh.default, + exir_ops.edge.aten.unsqueeze_copy.default, + exir_ops.edge.aten.upsample_nearest2d.vec, + exir_ops.edge.aten.view_copy.default, ] ) ) self.add_pass(RetraceFoldedDtypesPass()) self.add_pass(InsertTableOpsPass(exported_program)) - self.add_pass(KeepDimsFalseToSqueezePass()) + self.add_pass(ConvertExpandCopyToRepeatPass()) + self.add_pass(UnsqueezeBeforeRepeatPass()) + self.add_pass(CastInt64ToInt32Pass(exported_program)) + self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) + self.add_pass(SizeAdjustConv2DPass()) + self.add_pass(RemoveClonePass()) self.add_pass(MatchArgRanksPass(exported_program)) self.add_pass(DecomposeDivPass()) - self.add_pass(ConvertSplitToSlicePass()) + self.add_pass(KeepDimsFalseToSqueezePass()) self.add_pass(Conv1dUnsqueezePass(exported_program)) self.add_pass(DecomposeSoftmaxesPass()) for spec in compile_spec: diff --git a/backends/arm/_passes/conv1d_unsqueeze_pass.py b/backends/arm/_passes/conv1d_unsqueeze_pass.py index 158e4bf452..16c6f6b209 100644 --- a/backends/arm/_passes/conv1d_unsqueeze_pass.py +++ b/backends/arm/_passes/conv1d_unsqueeze_pass.py @@ -12,10 +12,8 @@ from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_param_tensor, - insert_q_dq_pair, is_param_node, ) -from executorch.backends.arm.tosa_quant_utils import dq_op, q_op from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -27,10 +25,8 @@ class Conv1dUnsqueezePass(ExportPass): supports 2d and 3d convolution. This is done by modifying the graph to do the following: 1) unsqueeze the convolution's input from 3d to 4d - 2) if the input to unsqueeze is quantized, insert q/dq-pair after unsqueeze - 3) perform a conv2d (with a modified version of the original conv1d args) - 4) squeeze the output back down to 3d. - 5) if all users of squeeze are quantized, insert q/dq-pair before squeeze + 2) perform a conv2d (with a modified version of the original conv1d args) + 3) squeeze the output back down to 3d. """ def __init__(self, exported_program: ExportedProgram) -> None: @@ -94,8 +90,6 @@ def call(self, graph_module: torch.fx.GraphModule): continue kernel_node = node.args[1] - if kernel_node.target == dq_op: - kernel_node = kernel_node.args[0] if not is_param_node(self.exported_program, kernel_node): raise AssertionError( @@ -131,11 +125,6 @@ def call(self, graph_module: torch.fx.GraphModule): ) node.replace_input_with(input_node, unsqueeze_before) - # If Quantized we must insert unsqueeze --> q --> dq --> node - if input_node.target == dq_op: - q_params = input_node.args[1:] - insert_q_dq_pair(graph, unsqueeze_before, q_params) - with graph.inserting_after(node): squeeze_after = create_node( graph, @@ -151,13 +140,6 @@ def call(self, graph_module: torch.fx.GraphModule): for user in original_users: user.replace_input_with(node, squeeze_after) - # If quantized, insert conv2d --> q --> dq --> squeeze - if all( - original_user.target == q_op for original_user in original_users - ): - q_params = original_users[0].args[1:] - insert_q_dq_pair(graph, node, q_params) - graph_module.recompile() # Since we are overriding "call", we need to call the parent's "call" # to retrace the graph and regenerate metadata diff --git a/backends/arm/_passes/size_adjust_conv2d_pass.py b/backends/arm/_passes/size_adjust_conv2d_pass.py index c7bd27dcce..08da9a74c9 100644 --- a/backends/arm/_passes/size_adjust_conv2d_pass.py +++ b/backends/arm/_passes/size_adjust_conv2d_pass.py @@ -9,7 +9,6 @@ from typing import cast, Optional import torch.fx -from executorch.backends.arm.tosa_quant_utils import is_node_quantized from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch._ops import OpOverload @@ -113,14 +112,7 @@ def call(self, graph_module: torch.fx.GraphModule): slice_node = graph.create_node( "call_function", self.slice_op, (last_node,) + args ) - if is_node_quantized(last_node): - q_params = last_node.args[1:] - dq_node = insert_q_dq_pair( - graph_module.graph, slice_node, q_params - ) - last_node = dq_node - else: - last_node = slice_node + last_node = slice_node conv_node.replace_input_with(cast(torch.fx.Node, input_node), last_node) modified_graph = True diff --git a/backends/arm/operators/op_bmm.py b/backends/arm/operators/op_bmm.py index 8c9bd7ac2a..acd74630c2 100644 --- a/backends/arm/operators/op_bmm.py +++ b/backends/arm/operators/op_bmm.py @@ -8,18 +8,17 @@ from typing import List import serializer.tosa_serializer as ts -import torch.fx +import torch +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + get_output_qparams, +) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - build_rescale, - get_quant_arg_downstream, - get_quant_arg_upstream, -) -from executorch.backends.arm.tosa_utils import get_two_inputs +from executorch.backends.arm.tosa_quant_utils import build_rescale from serializer.tosa_serializer import TosaOp @@ -38,23 +37,27 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - input0, input1 = get_two_inputs(node) + assert inputs[0].dtype == inputs[1].dtype, "Both inputs must be of same type" + assert inputs[0].dtype in [ + ts.DType.INT8, + ts.DType.FP32, + ], "Only int8 and float32 supported" # aten.bmm maps directly to MATMUL # NOTE: For now, only INT8 & FP32 is supported # For INT8, we need to get the zero points and add an intermediate tensor # for a later rescale. - if is_quant_node: - input0_q_params = get_quant_arg_upstream(input0) - input1_q_params = get_quant_arg_upstream(input1) - input0_zp = input0_q_params.zp - input1_zp = input1_q_params.zp + + if inputs[0].dtype == ts.DType.INT8: + input_qparams = get_input_qparams(node) + input0_zp = input_qparams[0].zp + input1_zp = input_qparams[1].zp bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) bmm_output_name = bmm_result.name else: - input0_zp, input1_zp = 0, 0 bmm_output_name = output.name + input0_zp, input1_zp = 0, 0 # Add the MATMUL to the TOSA graph. attr = ts.TosaSerializerAttribute() @@ -62,18 +65,17 @@ def define_node( tosa_graph.addOperator( TosaOp.Op().MATMUL, - [input0.name, input1.name], + [inputs[0].name, inputs[1].name], [bmm_output_name], attr, ) # As INT8 accumulates into INT32, we need to rescale it back to INT8 - if is_quant_node: - output_q_params = get_quant_arg_downstream(list(node.users)[0]) - + if output.dtype == ts.DType.INT8: + output_qparams = get_output_qparams(node)[0] final_output_scale = ( - input0_q_params.scale * input1_q_params.scale - ) / output_q_params.scale + input_qparams[0].scale * input_qparams[1].scale + ) / output_qparams.scale build_rescale( tosa_fb=tosa_graph, @@ -84,6 +86,6 @@ def define_node( output_type=ts.DType.INT8, output_shape=bmm_result.shape, input_zp=0, - output_zp=output_q_params.zp, + output_zp=output_qparams.zp, is_double_round=False, ) diff --git a/backends/arm/operators/op_mm.py b/backends/arm/operators/op_mm.py index 81334de16c..efdb0d611e 100644 --- a/backends/arm/operators/op_mm.py +++ b/backends/arm/operators/op_mm.py @@ -9,28 +9,29 @@ import serializer.tosa_serializer as ts import torch +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + get_output_qparams, +) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - build_rescale, - get_quant_arg_downstream, - get_quant_arg_upstream, -) -from executorch.backends.arm.tosa_utils import ( - build_reshape, - expand_dims, - get_two_inputs, -) +from executorch.backends.arm.tosa_quant_utils import build_rescale +from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.arm.tosa_utils import build_reshape, expand_dims from serializer.tosa_serializer import TosaOp @register_node_visitor -class MMVisitor(NodeVisitor): +class MMVisitor_080_BI(NodeVisitor): target = "aten.mm.default" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + ] + def __init__(self, *args): super().__init__(*args) @@ -42,31 +43,25 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - input0, input1 = get_two_inputs(node) - # For atem.mm, the two inputs are of rank 2 # For TOSA it needs to be rank 3 # So they need to be reshaped from (H, W) to (1, H, W) - # NOTE: For now, only INT8 & FP32 is supported - reshape_dtype = ts.DType.INT8 if is_quant_node else ts.DType.FP32 + reshape_dtype = output.dtype input0_reshaped = expand_dims(tosa_graph, inputs[0], reshape_dtype, 0) input1_reshaped = expand_dims(tosa_graph, inputs[1], reshape_dtype, 0) # The output also needs to be rank 3 output_new_shape = (1, output.shape[0], output.shape[1]) - # For INT8, we need to get the zero point, otherwise it is 0 - input0_zp, input1_zp = 0, 0 - if is_quant_node: - input0_zp = get_quant_arg_upstream(input0).zp - input1_zp = get_quant_arg_upstream(input1).zp + input_qparams = get_input_qparams(node) + assert len(input_qparams) == 2 + input0_qparams = input_qparams[0] + input1_qparams = input_qparams[1] - mat_mul_result = tosa_graph.addIntermediate( - output_new_shape, ts.DType.INT32 if is_quant_node else output.dtype - ) + mat_mul_result = tosa_graph.addIntermediate(output_new_shape, ts.DType.INT32) attr = ts.TosaSerializerAttribute() - attr.MatMulAttribute(A_zp=input0_zp, B_zp=input1_zp) + attr.MatMulAttribute(A_zp=input0_qparams.zp, B_zp=input1_qparams.zp) tosa_graph.addOperator( TosaOp.Op().MATMUL, @@ -75,13 +70,8 @@ def define_node( attr, ) - if is_quant_node: - reshape_intermediate = tosa_graph.addIntermediate( - output.shape, ts.DType.INT32 - ) - reshape_output_name = reshape_intermediate.name - else: - reshape_output_name = output.name + reshape_intermediate = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) + reshape_output_name = reshape_intermediate.name # Reshape the final output back to rank 2 build_reshape( @@ -89,25 +79,72 @@ def define_node( ) # As INT8 accumulates into INT32, we need to rescale it back to INT8 - if is_quant_node: - input0_q_params = get_quant_arg_upstream(input0) - input1_q_params = get_quant_arg_upstream(input1) - output_q_params = get_quant_arg_downstream(list(node.users)[0]) - - final_output_scale = ( - input0_q_params.scale * input1_q_params.scale - ) / output_q_params.scale - - # As the input will be INT32, the input_zp must be set to 0 - build_rescale( - tosa_fb=tosa_graph, - scale=final_output_scale, - # pyre-ignore[61]: Uninitialized local [61]: Local variable `reshape_intermediate` is undefined, or not always defined. - input_node=reshape_intermediate, - output_name=output.name, - output_type=ts.DType.INT8, - output_shape=reshape_intermediate.shape, - input_zp=0, - output_zp=output_q_params.zp, - is_double_round=False, - ) + output_qparams = get_output_qparams(node) + assert len(output_qparams) == 1 + + final_output_scale = ( + input0_qparams.scale * input1_qparams.scale + ) / output_qparams[0].scale + + # As the input will be INT32, the input_zp must be set to 0 + build_rescale( + tosa_fb=tosa_graph, + scale=final_output_scale, + input_node=reshape_intermediate, + output_name=output.name, + output_type=output.dtype, + output_shape=reshape_intermediate.shape, + input_zp=0, + output_zp=output_qparams[0].zp, + is_double_round=False, + ) + + +@register_node_visitor +class MMVisitor_080_MI(MMVisitor_080_BI): + # inheriting 'target' from BI class + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + if inputs[0].dtype == ts.DType.INT8: + return super().define_node(node, tosa_graph, inputs, output, is_quant_node) + reshape_dtype = output.dtype + # For atem.mm, the two inputs are of rank 2 + # For TOSA it needs to be rank 3 + # So they need to be reshaped from (H, W) to (1, H, W) + input0_reshaped = expand_dims(tosa_graph, inputs[0], reshape_dtype, 0) + input1_reshaped = expand_dims(tosa_graph, inputs[1], reshape_dtype, 0) + + # The output also needs to be rank 3 + output_new_shape = (1, output.shape[0], output.shape[1]) + + # Set zps to 0 + input0_zp, input1_zp = 0, 0 + attr = ts.TosaSerializerAttribute() + attr.MatMulAttribute(A_zp=input0_zp, B_zp=input1_zp) + mat_mul_result = tosa_graph.addIntermediate(output_new_shape, output.dtype) + reshape_output_name = output.name + + tosa_graph.addOperator( + TosaOp.Op().MATMUL, + [input0_reshaped.name, input1_reshaped.name], + [mat_mul_result.name], + attr, + ) + # Reshape the final output back to rank 2 + build_reshape( + tosa_graph, mat_mul_result.name, output.shape, reshape_output_name + ) diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 1b9b96b5ad..6a07862682 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -19,6 +19,7 @@ from executorch.backends.arm.operators.node_visitor import NodeVisitor from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg from executorch.backends.arm.tosa_quant_utils import ( + dq_q_ops, get_quantized_node_output_dtype, is_node_quantized, ) @@ -43,9 +44,9 @@ def process_call_function( # Convert output (this node itself) output = TosaArg(node) - is_quant_node = is_node_quantized(node) + is_quant_node = output.dtype == ts.DType.INT8 or node.target in dq_q_ops if is_quant_node: - output_dtype = map_dtype(get_quantized_node_output_dtype(node)) + output_dtype = ts.DType.INT8 else: output_dtype = output.dtype tosa_graph.currRegion.currBasicBlock.addTensor( diff --git a/backends/arm/test/misc/test_debug_feats.py b/backends/arm/test/misc/test_debug_feats.py index 47d259da47..b5ff882537 100644 --- a/backends/arm/test/misc/test_debug_feats.py +++ b/backends/arm/test/misc/test_debug_feats.py @@ -52,8 +52,7 @@ def _tosa_MI_pipeline(self, module: torch.nn.Module, dump_file=None): compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), ) .export() - .to_edge() - .partition() + .to_edge_transform_and_lower() .dump_artifact(dump_file) .dump_artifact() ) diff --git a/backends/arm/test/ops/test_bmm.py b/backends/arm/test/ops/test_bmm.py index 3a3c2772bd..0b830fa46b 100644 --- a/backends/arm/test/ops/test_bmm.py +++ b/backends/arm/test/ops/test_bmm.py @@ -35,7 +35,10 @@ def forward(self, x, y): return torch.bmm(x, y) class MatMul(torch.nn.Module): - test_parameters = [(torch.rand(2, 3, 5), torch.rand(2, 5, 2))] + test_parameters = [ + (torch.rand(2, 3, 5), torch.rand(2, 5, 2)), + (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)), + ] def forward(self, x, y): return torch.matmul(x, y) @@ -89,7 +92,7 @@ def _test_bmm_tosa_BI_pipeline( .check_not(["executorch_exir_dialects_edge__ops_aten_bmm_default"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() - .run_method_and_compare_outputs(inputs=test_data) + .run_method_and_compare_outputs(inputs=test_data, qtol=1) ) def _test_bmm_ethosu_BI_pipeline( diff --git a/backends/arm/test/ops/test_clone.py b/backends/arm/test/ops/test_clone.py index 2ec2f621fa..300ebb6f37 100644 --- a/backends/arm/test/ops/test_clone.py +++ b/backends/arm/test/ops/test_clone.py @@ -11,20 +11,17 @@ import unittest from typing import Tuple -import pytest - import torch from executorch.backends.arm.quantizer.arm_quantizer import ( ArmQuantizer, get_symmetric_quantization_config, ) -from executorch.backends.arm.test import common, conftest +from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.backends.xnnpack.test.tester.tester import Quantize -from executorch.exir.backend.compile_spec_schema import CompileSpec from parameterized import parameterized @@ -80,41 +77,6 @@ def _test_clone_tosa_BI_pipeline( .run_method_and_compare_outputs(inputs=test_data, qtol=1) ) - def _test_clone_tosa_ethos_pipeline( - self, - compile_spec: list[CompileSpec], - module: torch.nn.Module, - test_data: Tuple[torch.Tensor], - ): - quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) - tester = ( - ArmTester(module, example_inputs=test_data, compile_spec=compile_spec) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) - .export() - .check_count({"torch.ops.aten.clone.default": 1}) - .to_edge() - .partition() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .serialize() - ) - if conftest.is_option_enabled("corstone_fvp"): - tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) - - def _test_clone_tosa_u55_pipeline( - self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] - ): - self._test_clone_tosa_ethos_pipeline( - common.get_u55_compile_spec(), module, test_data - ) - - def _test_clone_tosa_u85_pipeline( - self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] - ): - self._test_clone_tosa_ethos_pipeline( - common.get_u85_compile_spec(), module, test_data - ) - @parameterized.expand(Clone.test_parameters) def test_clone_tosa_MI(self, test_tensor: torch.Tensor): self._test_clone_tosa_MI_pipeline(self.Clone(), (test_tensor,)) @@ -122,13 +84,3 @@ def test_clone_tosa_MI(self, test_tensor: torch.Tensor): @parameterized.expand(Clone.test_parameters) def test_clone_tosa_BI(self, test_tensor: torch.Tensor): self._test_clone_tosa_BI_pipeline(self.Clone(), (test_tensor,)) - - @parameterized.expand(Clone.test_parameters) - @pytest.mark.corstone_fvp - def test_clone_u55_BI(self, test_tensor: torch.Tensor): - self._test_clone_tosa_u55_pipeline(self.Clone(), (test_tensor,)) - - @parameterized.expand(Clone.test_parameters) - @pytest.mark.corstone_fvp - def test_clone_u85_BI(self, test_tensor: torch.Tensor): - self._test_clone_tosa_u85_pipeline(self.Clone(), (test_tensor,)) From b8343a239a696c9960ca2a56012660f28b6ad88a Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Thu, 28 Nov 2024 14:43:46 +0100 Subject: [PATCH 6/7] Remove is_quant_node from NodeVisitor.define_node Signed-off-by: Oscar Andersson Change-Id: Ibb17add461dc79e022a7f4accde29f9f9d61b16d --- backends/arm/operators/node_visitor.py | 1 - backends/arm/operators/op_add.py | 4 +--- backends/arm/operators/op_avg_pool2d.py | 4 +--- backends/arm/operators/op_batch_norm.py | 1 - backends/arm/operators/op_bmm.py | 1 - backends/arm/operators/op_cat.py | 1 - backends/arm/operators/op_conv2d.py | 1 - backends/arm/operators/op_dequant.py | 3 +-- backends/arm/operators/op_exp.py | 1 - backends/arm/operators/op_full.py | 1 - backends/arm/operators/op_get_item.py | 3 +-- backends/arm/operators/op_hardtanh.py | 1 - backends/arm/operators/op_log.py | 1 - backends/arm/operators/op_max.py | 1 - backends/arm/operators/op_max_pool2d.py | 1 - backends/arm/operators/op_min.py | 1 - backends/arm/operators/op_mm.py | 4 +--- backends/arm/operators/op_mul.py | 4 +--- backends/arm/operators/op_permute.py | 1 - backends/arm/operators/op_quant.py | 3 +-- backends/arm/operators/op_reciprocal.py | 1 - backends/arm/operators/op_relu.py | 1 - backends/arm/operators/op_repeat.py | 1 - backends/arm/operators/op_rshift.py | 1 - backends/arm/operators/op_rsqrt.py | 1 - backends/arm/operators/op_select.py | 1 - backends/arm/operators/op_sigmoid.py | 1 - backends/arm/operators/op_slice.py | 1 - backends/arm/operators/op_squeeze.py | 1 - backends/arm/operators/op_sub.py | 4 +--- backends/arm/operators/op_sum.py | 4 +--- backends/arm/operators/op_table.py | 1 - backends/arm/operators/op_tanh.py | 1 - backends/arm/operators/op_to_copy.py | 1 - backends/arm/operators/op_transpose.py | 1 - backends/arm/operators/op_unsqueeze.py | 1 - backends/arm/operators/op_upsample_nearest2d.py | 1 - backends/arm/operators/op_view.py | 1 - backends/arm/process_node.py | 7 +++---- 39 files changed, 12 insertions(+), 57 deletions(-) diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py index 87ef6ed4c6..8609e5e391 100644 --- a/backends/arm/operators/node_visitor.py +++ b/backends/arm/operators/node_visitor.py @@ -39,7 +39,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: raise NotImplementedError("NodeVisitor must be extended.") diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index a81e52c5c6..74f00354ed 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -38,7 +38,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: # Specification (0.80) states that input and output types # should all be the same @@ -96,7 +95,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: # Specification (0.80) states that input and output types # should all be the same @@ -104,7 +102,7 @@ def define_node( if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: # Call the inherited define_node for handling integers - super().define_node(node, tosa_graph, inputs, output, is_quant_node) + super().define_node(node, tosa_graph, inputs, output) else: # FP32 Add lowering assert inputs[0].dtype == ts.DType.FP32 diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 9d8dd13e7e..fecddac659 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -75,7 +75,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: input_tensor = inputs[0] assert input_tensor.dtype == ts.DType.INT8 @@ -107,14 +106,13 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: assert ( inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32 ), "Only FP32 and INT8 supported" if inputs[0].dtype == ts.DType.INT8: - super().define_node(node, tosa_graph, inputs, output, is_quant_node) + super().define_node(node, tosa_graph, inputs, output) if inputs[0].dtype == ts.DType.FP32: accumulator_type = ts.DType.FP32 diff --git a/backends/arm/operators/op_batch_norm.py b/backends/arm/operators/op_batch_norm.py index c3b9bb0c43..ce5998cb72 100644 --- a/backends/arm/operators/op_batch_norm.py +++ b/backends/arm/operators/op_batch_norm.py @@ -42,7 +42,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: # Decompose batch norm into sequence (activations, weights, bias, running_mean, running_var, momentum, epsilon) = ( diff --git a/backends/arm/operators/op_bmm.py b/backends/arm/operators/op_bmm.py index acd74630c2..6ff9c0fd56 100644 --- a/backends/arm/operators/op_bmm.py +++ b/backends/arm/operators/op_bmm.py @@ -35,7 +35,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: assert inputs[0].dtype == inputs[1].dtype, "Both inputs must be of same type" diff --git a/backends/arm/operators/op_cat.py b/backends/arm/operators/op_cat.py index 652eb39737..e249942d0b 100644 --- a/backends/arm/operators/op_cat.py +++ b/backends/arm/operators/op_cat.py @@ -30,7 +30,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: tensors = inputs[0].special diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index 5913cb0c34..42156da013 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -55,7 +55,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: input, weight, bias, stride, pad, dilation, _, _, group = inputs diff --git a/backends/arm/operators/op_dequant.py b/backends/arm/operators/op_dequant.py index afa1dda946..022f4e45ce 100644 --- a/backends/arm/operators/op_dequant.py +++ b/backends/arm/operators/op_dequant.py @@ -1,4 +1,4 @@ -# Copyright 2023 Arm Limited and/or its affiliates. +# Copyright 2023-2024 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -29,7 +29,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: item_name = inputs[0].name ## Simply add an identityOp diff --git a/backends/arm/operators/op_exp.py b/backends/arm/operators/op_exp.py index 26433582d9..46f4980975 100644 --- a/backends/arm/operators/op_exp.py +++ b/backends/arm/operators/op_exp.py @@ -34,7 +34,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: assert len(node.all_input_nodes) == 1 diff --git a/backends/arm/operators/op_full.py b/backends/arm/operators/op_full.py index 23a13dd486..7964e58226 100644 --- a/backends/arm/operators/op_full.py +++ b/backends/arm/operators/op_full.py @@ -31,7 +31,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: shape = tosa_shape(inputs[0].special, output.dim_order) diff --git a/backends/arm/operators/op_get_item.py b/backends/arm/operators/op_get_item.py index a696b33aa7..f7372262c6 100644 --- a/backends/arm/operators/op_get_item.py +++ b/backends/arm/operators/op_get_item.py @@ -1,4 +1,4 @@ -# Copyright 2023 Arm Limited and/or its affiliates. +# Copyright 2023-2024 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -29,7 +29,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: item_name = inputs[0].name ## Simply add an identityOp diff --git a/backends/arm/operators/op_hardtanh.py b/backends/arm/operators/op_hardtanh.py index 544e00c5a2..28f28bd9c6 100644 --- a/backends/arm/operators/op_hardtanh.py +++ b/backends/arm/operators/op_hardtanh.py @@ -34,7 +34,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_log.py b/backends/arm/operators/op_log.py index ffff21c6c8..868eeb9443 100644 --- a/backends/arm/operators/op_log.py +++ b/backends/arm/operators/op_log.py @@ -34,7 +34,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: assert len(node.all_input_nodes) == 1 assert len(node.users) == 1 diff --git a/backends/arm/operators/op_max.py b/backends/arm/operators/op_max.py index 58c0d44821..660a2cf0af 100644 --- a/backends/arm/operators/op_max.py +++ b/backends/arm/operators/op_max.py @@ -38,7 +38,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: assert inputs[0].dtype == inputs[1].dtype diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 9cc40c47df..69a46cc305 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -33,7 +33,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: input_tensor = inputs[0] diff --git a/backends/arm/operators/op_min.py b/backends/arm/operators/op_min.py index 61b9e459ca..2282d9e1cf 100644 --- a/backends/arm/operators/op_min.py +++ b/backends/arm/operators/op_min.py @@ -39,7 +39,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: assert inputs[0].dtype == inputs[1].dtype diff --git a/backends/arm/operators/op_mm.py b/backends/arm/operators/op_mm.py index efdb0d611e..7e2db21426 100644 --- a/backends/arm/operators/op_mm.py +++ b/backends/arm/operators/op_mm.py @@ -41,7 +41,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: # For atem.mm, the two inputs are of rank 2 # For TOSA it needs to be rank 3 @@ -117,10 +116,9 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: if inputs[0].dtype == ts.DType.INT8: - return super().define_node(node, tosa_graph, inputs, output, is_quant_node) + return super().define_node(node, tosa_graph, inputs, output) reshape_dtype = output.dtype # For atem.mm, the two inputs are of rank 2 # For TOSA it needs to be rank 3 diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index 84c489790d..ec0d4b16c2 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -39,7 +39,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: assert inputs[0].dtype == inputs[1].dtype == output.dtype == ts.DType.INT8 input_A = inputs[0] @@ -97,10 +96,9 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: if inputs[0].dtype == ts.DType.INT8: - return super().define_node(node, tosa_graph, inputs, output, is_quant_node) + return super().define_node(node, tosa_graph, inputs, output) attr = ts.TosaSerializerAttribute() attr.MulAttribute(shift=0) tosa_graph.addOperator( diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py index 8142d6d654..16d3d4a04e 100644 --- a/backends/arm/operators/op_permute.py +++ b/backends/arm/operators/op_permute.py @@ -78,7 +78,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: # The permutation vector describes a permutation P in default Pytorch dim_order. # For rank 4, the default dim_order NCHW. diff --git a/backends/arm/operators/op_quant.py b/backends/arm/operators/op_quant.py index 8f83e79442..fcf9372c11 100644 --- a/backends/arm/operators/op_quant.py +++ b/backends/arm/operators/op_quant.py @@ -1,4 +1,4 @@ -# Copyright 2023 Arm Limited and/or its affiliates. +# Copyright 2023-2024 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -29,7 +29,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: item_name = inputs[0].name ## Simply add an identityOp diff --git a/backends/arm/operators/op_reciprocal.py b/backends/arm/operators/op_reciprocal.py index 024fb63a5d..121b78fed6 100644 --- a/backends/arm/operators/op_reciprocal.py +++ b/backends/arm/operators/op_reciprocal.py @@ -33,7 +33,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: assert inputs[0].dtype == output.dtype == ts.DType.FP32 tosa_graph.addOperator(TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_relu.py b/backends/arm/operators/op_relu.py index 0641a5d983..a3bd355197 100644 --- a/backends/arm/operators/op_relu.py +++ b/backends/arm/operators/op_relu.py @@ -32,7 +32,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: list[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_repeat.py b/backends/arm/operators/op_repeat.py index 1e4dc4e23c..fd76a52052 100644 --- a/backends/arm/operators/op_repeat.py +++ b/backends/arm/operators/op_repeat.py @@ -29,7 +29,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: list[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: multiples = inputs[1].special diff --git a/backends/arm/operators/op_rshift.py b/backends/arm/operators/op_rshift.py index 94b3f8b86d..2c1f4d5bbe 100644 --- a/backends/arm/operators/op_rshift.py +++ b/backends/arm/operators/op_rshift.py @@ -28,7 +28,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: input_shape = inputs[0].shape input_0_rank = len(input_shape) diff --git a/backends/arm/operators/op_rsqrt.py b/backends/arm/operators/op_rsqrt.py index 49218645d7..1cc3e8fcff 100644 --- a/backends/arm/operators/op_rsqrt.py +++ b/backends/arm/operators/op_rsqrt.py @@ -33,7 +33,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: assert inputs[0].dtype == output.dtype == ts.DType.FP32 tosa_graph.addOperator(TosaOp.Op().RSQRT, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_select.py b/backends/arm/operators/op_select.py index eddd5c4adf..b047a5dd47 100644 --- a/backends/arm/operators/op_select.py +++ b/backends/arm/operators/op_select.py @@ -33,7 +33,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: assert len(inputs) == 3 diff --git a/backends/arm/operators/op_sigmoid.py b/backends/arm/operators/op_sigmoid.py index d9c93fc7ed..0c28c0ed00 100644 --- a/backends/arm/operators/op_sigmoid.py +++ b/backends/arm/operators/op_sigmoid.py @@ -34,7 +34,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: assert len(node.all_input_nodes) == 1 diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py index 89de8ed2af..9327e005b6 100644 --- a/backends/arm/operators/op_slice.py +++ b/backends/arm/operators/op_slice.py @@ -30,7 +30,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: # aten.slice_copy supports slicing in 1d at a time. diff --git a/backends/arm/operators/op_squeeze.py b/backends/arm/operators/op_squeeze.py index 0429d214ff..e5962fd684 100644 --- a/backends/arm/operators/op_squeeze.py +++ b/backends/arm/operators/op_squeeze.py @@ -27,7 +27,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: shape = inputs[0].shape rank = len(shape) diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index 6125158eb9..0c569a6ffd 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -38,7 +38,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: # Specification (0.80) states that input and output types # should all be the same @@ -96,7 +95,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: # Specification (0.80) states that input and output types # should all be the same @@ -104,7 +102,7 @@ def define_node( if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: # Call the inherited define_node for handling integers - super().define_node(node, tosa_graph, inputs, output, is_quant_node) + super().define_node(node, tosa_graph, inputs, output) else: # FP32 Sub lowering assert inputs[0].dtype == ts.DType.FP32 diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py index a4d2d8f914..dcc194a656 100644 --- a/backends/arm/operators/op_sum.py +++ b/backends/arm/operators/op_sum.py @@ -38,7 +38,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: input_shape = list(inputs[0].shape) dim_list = cast(list[int], inputs[1].special) @@ -94,10 +93,9 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: if inputs[0].dtype == ts.DType.INT8: - return super().define_node(node, tosa_graph, inputs, output, is_quant_node) + return super().define_node(node, tosa_graph, inputs, output) input_name = inputs[0].name reduced_shape = list(inputs[0].shape) dim_list = cast(list[int], inputs[1].special) diff --git a/backends/arm/operators/op_table.py b/backends/arm/operators/op_table.py index 0a5892f067..bfaaf4578e 100644 --- a/backends/arm/operators/op_table.py +++ b/backends/arm/operators/op_table.py @@ -29,7 +29,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: assert node.name in self._exported_program.state_dict.keys() assert inputs[0].dtype == output.dtype == ts.DType.INT8 diff --git a/backends/arm/operators/op_tanh.py b/backends/arm/operators/op_tanh.py index 5fa2a52beb..a1e91be4ff 100644 --- a/backends/arm/operators/op_tanh.py +++ b/backends/arm/operators/op_tanh.py @@ -33,7 +33,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: assert inputs[0].dtype == output.dtype == ts.DType.FP32 tosa_graph.addOperator(TosaOp.Op().TANH, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_to_copy.py b/backends/arm/operators/op_to_copy.py index c0e4f0de4c..256e54f3a2 100644 --- a/backends/arm/operators/op_to_copy.py +++ b/backends/arm/operators/op_to_copy.py @@ -36,6 +36,5 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: tosa_graph.addOperator(TosaOp.Op().CAST, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_transpose.py b/backends/arm/operators/op_transpose.py index 8d08d6f3b9..42675be34b 100644 --- a/backends/arm/operators/op_transpose.py +++ b/backends/arm/operators/op_transpose.py @@ -33,7 +33,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: output_rank = len(output.shape) perms = [dim % output_rank for dim in inputs[1].special] diff --git a/backends/arm/operators/op_unsqueeze.py b/backends/arm/operators/op_unsqueeze.py index c14128fdc8..ddd7bd7957 100644 --- a/backends/arm/operators/op_unsqueeze.py +++ b/backends/arm/operators/op_unsqueeze.py @@ -31,7 +31,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: list[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: dim = inputs[1].number diff --git a/backends/arm/operators/op_upsample_nearest2d.py b/backends/arm/operators/op_upsample_nearest2d.py index 50a2d1d185..68fcb521d9 100644 --- a/backends/arm/operators/op_upsample_nearest2d.py +++ b/backends/arm/operators/op_upsample_nearest2d.py @@ -32,7 +32,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: assert ( inputs[0].shape is not None and output.shape is not None diff --git a/backends/arm/operators/op_view.py b/backends/arm/operators/op_view.py index 8667df590d..3489795ed5 100644 --- a/backends/arm/operators/op_view.py +++ b/backends/arm/operators/op_view.py @@ -31,7 +31,6 @@ def define_node( tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, ) -> None: attr = ts.TosaSerializerAttribute() new_shape = tosa_shape(inputs[1].special, output.dim_order) diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 6a07862682..6aa663b81e 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -19,7 +19,7 @@ from executorch.backends.arm.operators.node_visitor import NodeVisitor from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg from executorch.backends.arm.tosa_quant_utils import ( - dq_q_ops, + dq_op, get_quantized_node_output_dtype, is_node_quantized, ) @@ -44,8 +44,8 @@ def process_call_function( # Convert output (this node itself) output = TosaArg(node) - is_quant_node = output.dtype == ts.DType.INT8 or node.target in dq_q_ops - if is_quant_node: + is_dq_node = node.target == dq_op + if is_dq_node: output_dtype = ts.DType.INT8 else: output_dtype = output.dtype @@ -64,7 +64,6 @@ def process_call_function( tosa_graph, inputs, output, - is_quant_node, ) else: raise RuntimeError(f"Unknown operator {node.target} for TOSA : {tosa_spec}") From cdd6c91dfbec63a1515f62cd5aa14a7306d88f00 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Fri, 20 Dec 2024 11:52:27 +0100 Subject: [PATCH 7/7] Fix pyre issues Address issues from pyre and add similar # pyre-ignores as in https://github.com/pytorch/executorch/pull/7362. Signed-off-by: Oscar Andersson Change-Id: I6feaa611dcd539b3b0d21a6a7dd696ef7db691ef --- .../arm/_passes/annotate_decomposed_matmul.py | 16 ++++++-------- .../fold_qdq_with_annotated_qparams_pass.py | 22 ++++++++++++++----- backends/arm/_passes/insert_table_ops.py | 9 ++++---- backends/arm/operators/op_bmm.py | 8 ++++--- backends/arm/operators/op_hardtanh.py | 4 +++- backends/arm/operators/op_max_pool2d.py | 6 +++-- backends/arm/operators/op_mm.py | 6 +++-- backends/arm/operators/op_mul.py | 4 +++- backends/arm/operators/op_relu.py | 2 ++ backends/arm/tosa_quant_utils.py | 2 +- 10 files changed, 51 insertions(+), 28 deletions(-) diff --git a/backends/arm/_passes/annotate_decomposed_matmul.py b/backends/arm/_passes/annotate_decomposed_matmul.py index 44f99fd1a7..034939945f 100644 --- a/backends/arm/_passes/annotate_decomposed_matmul.py +++ b/backends/arm/_passes/annotate_decomposed_matmul.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import itertools -from typing import Any, Dict, List import torch from executorch.backends.arm._passes.arm_pass_utils import create_node @@ -13,10 +12,7 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule -from torch.fx.passes.utils.source_matcher_utils import ( - get_source_partitions, - SourcePartition, -) +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions class AnnotateDecomposedMatmulPass(ExportPass): @@ -28,8 +24,8 @@ class AnnotateDecomposedMatmulPass(ExportPass): matmul-op (can be mm or bmm). """ - def call(self, graph_module: GraphModule): - matmul_partitions: Dict[Any, List[SourcePartition]] = get_source_partitions( + def call(self, graph_module: GraphModule) -> PassResult: + matmul_partitions = get_source_partitions( graph_module.graph, [ torch.matmul, @@ -56,7 +52,7 @@ def call(self, graph_module: GraphModule): input_node = partition.input_nodes[i] matmul_input_node = matmul_args[i] # Remove partition input dq-node - input_node.replace_all_uses_with(input_node.args[0]) + input_node.replace_all_uses_with(input_node.all_input_nodes[0]) graph_module.graph.erase_node(input_node) input_node_qargs = input_node.args[1:] with graph_module.graph.inserting_before(matmul_node): @@ -81,7 +77,9 @@ def call(self, graph_module: GraphModule): matmul_node.replace_all_uses_with(q_node) q_node.args = (matmul_node, *output_node_qargs) # Remove partition output q-node - partition_output.replace_all_uses_with(partition_output.args[0]) + partition_output.replace_all_uses_with( + partition_output.all_input_nodes[0] + ) graph_module.graph.erase_node(partition_output) # retrace the graph to update the fake tensor types diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index aa5358fe17..045506f19d 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -6,14 +6,20 @@ import copy -from typing import cast, Iterable +from typing import cast, Dict, Iterable, Set, Tuple from executorch.backends.arm.tosa_quant_utils import QuantArgs from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload -from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.pass_base import ( + Argument, + ExportPass, + NodeMetadata, + PassResult, + ProxyValue, +) from torch.fx import GraphModule, Node q_op: EdgeOpOverload = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default @@ -82,7 +88,7 @@ def __init__(self, targeted_ops: Iterable[EdgeOpOverload]) -> None: def fold_and_annotate_arg( self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int - ): + ) -> None: input_qparams = None nodes_to_remove = set() for arg in arg_list: @@ -210,11 +216,17 @@ class RetraceFoldedDtypesPass(ExportPass): the output type of that matches the type of the output_qparams. """ - targeted_ops = { + targeted_ops: Set[EdgeOpOverload] = { exir_ops.edge.aten.sum.dim_IntList, } - def call_operator(self, op, args, kwargs, meta): + def call_operator( + self, + op, # pyre-ignore + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: if op not in self.targeted_ops: return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index 41f1c25924..57a8376d40 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable +from typing import Callable, Dict import torch from executorch.backends.arm._passes.arm_pass_utils import create_node @@ -12,6 +12,7 @@ from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule @@ -22,7 +23,7 @@ @impl(lib, "_table") -def _table_impl(*args, **kwargs): +def _table_impl(*args, **kwargs): # pyre-ignore return args[0] @@ -34,7 +35,7 @@ class InsertTableOpsPass(ExportPass): which will be used to produce the table values in operators/op_table.py. """ - table_ops = { + table_ops: Dict[EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]] = { exir_ops.edge.aten.exp.default: torch.exp, exir_ops.edge.aten.log.default: torch.log, exir_ops.edge.aten.reciprocal.default: torch.reciprocal, @@ -43,7 +44,7 @@ class InsertTableOpsPass(ExportPass): exir_ops.edge.aten.tanh.default: torch.tanh, } - def __init__(self, exported_program: ExportedProgram): + def __init__(self, exported_program: ExportedProgram) -> None: super().__init__() self.exported_program = exported_program diff --git a/backends/arm/operators/op_bmm.py b/backends/arm/operators/op_bmm.py index 6ff9c0fd56..821df84ee3 100644 --- a/backends/arm/operators/op_bmm.py +++ b/backends/arm/operators/op_bmm.py @@ -9,6 +9,8 @@ import serializer.tosa_serializer as ts import torch + +# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, get_output_qparams, @@ -49,7 +51,7 @@ def define_node( # for a later rescale. if inputs[0].dtype == ts.DType.INT8: - input_qparams = get_input_qparams(node) + input_qparams = get_input_qparams(node) # pyre-ingore[16] input0_zp = input_qparams[0].zp input1_zp = input_qparams[1].zp bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) @@ -71,9 +73,9 @@ def define_node( # As INT8 accumulates into INT32, we need to rescale it back to INT8 if output.dtype == ts.DType.INT8: - output_qparams = get_output_qparams(node)[0] + output_qparams = get_output_qparams(node)[0] # pyre-ignore[16] final_output_scale = ( - input_qparams[0].scale * input_qparams[1].scale + input_qparams[0].scale * input_qparams[1].scale # pyre-ignore[61] ) / output_qparams.scale build_rescale( diff --git a/backends/arm/operators/op_hardtanh.py b/backends/arm/operators/op_hardtanh.py index 28f28bd9c6..bfbab55b92 100644 --- a/backends/arm/operators/op_hardtanh.py +++ b/backends/arm/operators/op_hardtanh.py @@ -8,6 +8,8 @@ import serializer.tosa_serializer as ts import torch + +# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, ) @@ -39,7 +41,7 @@ def define_node( if inputs[0].dtype == ts.DType.INT8: # Get quant parameters - input_qparams = get_input_qparams(node) + input_qparams = get_input_qparams(node) # pyre-ignore[16] qargs = input_qparams[0] # Convert to quantized representation clamp_min_qs = quantize_value(inputs[1].number, qargs) diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 69a46cc305..6cb5f0490e 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -8,6 +8,8 @@ import serializer.tosa_serializer as ts import torch + +# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, get_output_qparams, @@ -49,12 +51,12 @@ def define_node( # Initilize zero point to zero. input_zp = 0 if inputs[0].dtype == ts.DType.INT8: - input_qparams = get_input_qparams(node) + input_qparams = get_input_qparams(node) # pyre-ignore[16] input_zp = input_qparams[0].zp output_zp = 0 if output.dtype == ts.DType.INT8: - output_qparams = get_output_qparams(node) + output_qparams = get_output_qparams(node) # pyre-ignore[16] output_zp = output_qparams[0].zp attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_mm.py b/backends/arm/operators/op_mm.py index 7e2db21426..9266b5a03c 100644 --- a/backends/arm/operators/op_mm.py +++ b/backends/arm/operators/op_mm.py @@ -9,6 +9,8 @@ import serializer.tosa_serializer as ts import torch + +# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, get_output_qparams, @@ -52,7 +54,7 @@ def define_node( # The output also needs to be rank 3 output_new_shape = (1, output.shape[0], output.shape[1]) - input_qparams = get_input_qparams(node) + input_qparams = get_input_qparams(node) # pyre-ignore[16] assert len(input_qparams) == 2 input0_qparams = input_qparams[0] input1_qparams = input_qparams[1] @@ -78,7 +80,7 @@ def define_node( ) # As INT8 accumulates into INT32, we need to rescale it back to INT8 - output_qparams = get_output_qparams(node) + output_qparams = get_output_qparams(node) # pyre-ignore[16] assert len(output_qparams) == 1 final_output_scale = ( diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index ec0d4b16c2..c6a315d445 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -12,6 +12,8 @@ import serializer.tosa_serializer as ts import torch + +# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, ) @@ -43,7 +45,7 @@ def define_node( assert inputs[0].dtype == inputs[1].dtype == output.dtype == ts.DType.INT8 input_A = inputs[0] input_B = inputs[1] - input_qparams = get_input_qparams(node) + input_qparams = get_input_qparams(node) # pyre-ignore[16] input_A_qargs = input_qparams[0] input_B_qargs = input_qparams[1] input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order) diff --git a/backends/arm/operators/op_relu.py b/backends/arm/operators/op_relu.py index a3bd355197..267639a432 100644 --- a/backends/arm/operators/op_relu.py +++ b/backends/arm/operators/op_relu.py @@ -8,6 +8,8 @@ import executorch.backends.arm.tosa_quant_utils as tqutils import serializer.tosa_serializer as ts import torch.fx + +# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_output_qparams, ) diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index ab2d8befdc..dff7b12cdd 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -145,7 +145,7 @@ def quantize_value(self, x): self.qmax, ).to(self.dtype) - def dequantize_value(self, qx: int) -> float: + def dequantize_value(self, qx: torch.Tensor) -> torch.Tensor: return (qx - self.zp) * self.scale def __eq__(self, other):