Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arm backend: qdq folding support for remaining operators #7340

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions backends/arm/_passes/annotate_decomposed_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import itertools

import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


class AnnotateDecomposedMatmulPass(ExportPass):
"""
torch.matmul can be decomposed in many ways, for instance:
dq -> matmul -> q can become
dq -> repeat -> view -> bmm -> view -> dq which makes quantization folding
difficult. This helper function find all matmul partitions and annotate its
matmul-op (can be mm or bmm).
"""

def call(self, graph_module: GraphModule) -> PassResult:
matmul_partitions = get_source_partitions(
graph_module.graph,
[
torch.matmul,
],
None,
)
matmul_partitions = list(
itertools.chain.from_iterable(matmul_partitions.values())
)
matmul_targets = {
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.bmm.default,
}
for partition in matmul_partitions:
quantized_input = all(
input_node.target == dq_op for input_node in partition.input_nodes
)
matmul_node = [
node for node in partition.nodes if node.target in matmul_targets
][0]
if quantized_input:
matmul_args = matmul_node.all_input_nodes
for i in range(len(matmul_args)):
input_node = partition.input_nodes[i]
matmul_input_node = matmul_args[i]
# Remove partition input dq-node
input_node.replace_all_uses_with(input_node.all_input_nodes[0])
graph_module.graph.erase_node(input_node)
input_node_qargs = input_node.args[1:]
with graph_module.graph.inserting_before(matmul_node):
# Create new dq-node before matmul
dq_node = create_node(
graph=graph_module.graph,
op_target=dq_op,
)
dq_node.args = (matmul_input_node, *input_node_qargs)
matmul_node.replace_input_with(matmul_input_node, dq_node)

partition_output = list(partition.output_nodes[0].users)[0]
quantized_output = partition_output.target == q_op
if quantized_output:
output_node_qargs = partition_output.args[1:]
with graph_module.graph.inserting_after(matmul_node):
# Create q-node after matmul
q_node = create_node(
graph=graph_module.graph,
op_target=q_op,
)
matmul_node.replace_all_uses_with(q_node)
q_node.args = (matmul_node, *output_node_qargs)
# Remove partition output q-node
partition_output.replace_all_uses_with(
partition_output.all_input_nodes[0]
)
graph_module.graph.erase_node(partition_output)

# retrace the graph to update the fake tensor types
graph_module = super().call(graph_module).graph_module

graph_module.recompile()
return PassResult(graph_module, True)
58 changes: 46 additions & 12 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import (
AnnotateChannelsLastDimOrder,
)
from executorch.backends.arm._passes.annotate_decomposed_matmul import (
AnnotateDecomposedMatmulPass,
)
from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass
from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass
from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
Expand All @@ -32,7 +35,9 @@
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
FoldAndAnnotateQParamsPass,
QuantizeFullArgument,
RetraceFoldedDtypesPass,
)
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
KeepDimsFalseToSqueezePass,
)
Expand Down Expand Up @@ -67,24 +72,15 @@ def transform_to_backend_pipeline(
self, exported_program: ExportedProgram, compile_spec: list[CompileSpec]
):
"""Apply passes before transforming program to backend"""
self.add_pass(CastInt64ToInt32Pass(exported_program))
self.add_pass(DecomposeLinearPass())
self.add_pass(RemoveGetItemPass())
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(SizeAdjustConv2DPass())
self.add_pass(RemoveClonePass())
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(DecomposeLayerNormPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(DecomposeVarPass())
self.add_pass(ConvertMeanDimToAveragePool())
self.add_pass(DecomposeMeanDimPass())
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(DecomposeDivPass())
self.add_pass(KeepDimsFalseToSqueezePass())
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
self.add_pass(DecomposeLinearPass())
# TODO MLETORCH-558
self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeFullArgument())
self.add_pass(
FoldAndAnnotateQParamsPass(
Expand All @@ -93,11 +89,49 @@ def transform_to_backend_pipeline(
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.expand_copy.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.log.default,
exir_ops.edge.aten.max_pool2d.default,
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.reciprocal.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten.rsqrt.default,
exir_ops.edge.aten.select_copy.int,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.upsample_nearest2d.vec,
exir_ops.edge.aten.view_copy.default,
]
)
)
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(InsertTableOpsPass(exported_program))
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(CastInt64ToInt32Pass(exported_program))
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(SizeAdjustConv2DPass())
self.add_pass(RemoveClonePass())
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(DecomposeDivPass())
self.add_pass(KeepDimsFalseToSqueezePass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
for spec in compile_spec:
if spec.key == "permute_memory_format":
memory_format = spec.value.decode()
Expand Down
22 changes: 2 additions & 20 deletions backends/arm/_passes/conv1d_unsqueeze_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_param_tensor,
insert_q_dq_pair,
is_param_node,
)
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand All @@ -27,10 +25,8 @@ class Conv1dUnsqueezePass(ExportPass):
supports 2d and 3d convolution. This is done by modifying the graph to do the
following:
1) unsqueeze the convolution's input from 3d to 4d
2) if the input to unsqueeze is quantized, insert q/dq-pair after unsqueeze
3) perform a conv2d (with a modified version of the original conv1d args)
4) squeeze the output back down to 3d.
5) if all users of squeeze are quantized, insert q/dq-pair before squeeze
2) perform a conv2d (with a modified version of the original conv1d args)
3) squeeze the output back down to 3d.
"""

def __init__(self, exported_program: ExportedProgram) -> None:
Expand Down Expand Up @@ -94,8 +90,6 @@ def call(self, graph_module: torch.fx.GraphModule):
continue

kernel_node = node.args[1]
if kernel_node.target == dq_op:
kernel_node = kernel_node.args[0]

if not is_param_node(self.exported_program, kernel_node):
raise AssertionError(
Expand Down Expand Up @@ -131,11 +125,6 @@ def call(self, graph_module: torch.fx.GraphModule):
)
node.replace_input_with(input_node, unsqueeze_before)

# If Quantized we must insert unsqueeze --> q --> dq --> node
if input_node.target == dq_op:
q_params = input_node.args[1:]
insert_q_dq_pair(graph, unsqueeze_before, q_params)

with graph.inserting_after(node):
squeeze_after = create_node(
graph,
Expand All @@ -151,13 +140,6 @@ def call(self, graph_module: torch.fx.GraphModule):
for user in original_users:
user.replace_input_with(node, squeeze_after)

# If quantized, insert conv2d --> q --> dq --> squeeze
if all(
original_user.target == q_op for original_user in original_users
):
q_params = original_users[0].args[1:]
insert_q_dq_pair(graph, node, q_params)

graph_module.recompile()
# Since we are overriding "call", we need to call the parent's "call"
# to retrace the graph and regenerate metadata
Expand Down
119 changes: 88 additions & 31 deletions backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@

import copy

from typing import cast, Iterable
from typing import cast, Dict, Iterable, Set, Tuple

from executorch.backends.arm.tosa_quant_utils import QuantArgs

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload

from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.pass_base import (
Argument,
ExportPass,
NodeMetadata,
PassResult,
ProxyValue,
)
from torch.fx import GraphModule, Node

q_op: EdgeOpOverload = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
Expand Down Expand Up @@ -80,6 +86,46 @@ def __init__(self, targeted_ops: Iterable[EdgeOpOverload]) -> None:
super().__init__()
self.targeted_ops = targeted_ops

def fold_and_annotate_arg(
self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int
) -> None:
input_qparams = None
nodes_to_remove = set()
for arg in arg_list:
if not isinstance(arg, Node):
return
"""
Make sure arg has requires_grad set to False
For parameters that are not quantized, sometimes (i.e. convolution)
the Parameter(FakeTensor(...)) has requires_grad set to True, which
causes the retracing of the graph to fail with:

E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch.
E
E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
E Original traceback:
E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward
E x = conv(x)
"""
if arg.op == "placeholder":
arg.meta["val"].requires_grad = False

arg_quant_params = None
if arg.target == dq_op:
arg_quant_params = QuantArgs.from_operator(arg.target, arg.args)
# add arg to nodes_to_remove to fold the dq-node
nodes_to_remove.add(arg)
if input_qparams is not None and input_qparams != arg_quant_params:
# Two args are quantized differently
raise RuntimeError("Input qparams does not match!")
input_qparams = arg_quant_params
if input_qparams is not None:
node.meta["input_qparams"][i] = input_qparams
for n in nodes_to_remove:
assert n.target == dq_op
n.replace_all_uses_with(n.args[0])
graph_module.graph.erase_node(n)

def call(self, graph_module: GraphModule) -> PassResult:

# Loop over the graph nodes and find any node in the 'targeted_ops' list.
Expand All @@ -98,36 +144,11 @@ def call(self, graph_module: GraphModule) -> PassResult:
n.meta["input_qparams"] = {}
n.meta["output_qparams"] = {}
for i, arg in enumerate(n.args):
if not isinstance(arg, Node):
continue

# Make sure arg has requires_grad set to False
# For parameters that are not quantized, sometimes (i.e. convolution)
# the Parameter(FakeTensor(...)) has requires_grad set to True, which
# causes the retracing of the graph to fail with:
#
# E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch.
# E
# E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
# E Original traceback:
# E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward
# E x = conv(x)
#
if arg.op == "placeholder":
arg.meta["val"].requires_grad = False

if arg.target != dq_op:
continue

# arg.target for argument i is a dequant node, extract the information
n.meta["input_qparams"][i] = QuantArgs.from_operator(
arg.target, arg.args
)
if isinstance(arg, list):
self.fold_and_annotate_arg(graph_module, n, arg, i)

# arg.args[0] is the tensor input, replace the input usage
tensor_input = cast(Node, arg.args[0])
n.replace_input_with(arg, tensor_input)
graph_module.graph.erase_node(arg)
elif isinstance(arg, Node):
self.fold_and_annotate_arg(graph_module, n, [arg], i)

# Copy the users, since we are modifying it.
users_copy = copy.copy(n.users)
Expand Down Expand Up @@ -181,3 +202,39 @@ def call(self, graph_module: GraphModule) -> PassResult:
modified = True

return PassResult(graph_module, modified)


class RetraceFoldedDtypesPass(ExportPass):
"""
FoldAndAnnotateQParamsPass folds dq and q nodes. When the graph is retraced
some operators are retraced to types that cannot be handled by TOSA. One
such example is sum.dim_IntList:
q (int8) -> dq (fp32) -> sum (fp32) -> q (int8) ...
After folding it becomes:
q (int8) -> sum (int64) -> ...
This pass changes types of ops in self.targeted_ops, such as sum, so that
the output type of that matches the type of the output_qparams.
"""

targeted_ops: Set[EdgeOpOverload] = {
exir_ops.edge.aten.sum.dim_IntList,
}

def call_operator(
self,
op, # pyre-ignore
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op not in self.targeted_ops:
return super().call_operator(op, args, kwargs, meta)

node_kwargs = kwargs.copy()
output_qparams = meta["output_qparams"]
if len(output_qparams) == 0:
return super().call_operator(op, args, kwargs, meta)

output_dtype = output_qparams[0].dtype
node_kwargs["dtype"] = output_dtype
return super().call_operator(op, args, node_kwargs, meta)
Loading
Loading