From 373fe8c2a668da4a09066eb4bc9770427afbbb6f Mon Sep 17 00:00:00 2001 From: shewu-quic Date: Wed, 28 Aug 2024 12:27:33 +0800 Subject: [PATCH 1/3] Qualcomm AI Engine Direct - Optimization and fix mutable buffer issue Summary: - Add a pass to convert linear to conv2d: We found the accuracy drop because of QNN Linear op in llama3. And it will be fixed with convert linear to conv2d pass. - Workaround the issue about mutable buffer for index_put op: We add a pass to replace the input of index_put op. Under the workaround, it will result in performance regression. - Insert copy op for int64 inputs to convert int64 to int32 in i64toi32 pass - Support QNN RMS Norm and use native rms norm in llama_transformer - Add a pass to compose rms norm --- backends/qualcomm/builders/__init__.py | 2 + backends/qualcomm/builders/node_visitor.py | 2 +- backends/qualcomm/builders/op_conv2d.py | 86 ++++-------- backends/qualcomm/builders/op_rms_norm.py | 127 ++++++++++++++++++ backends/qualcomm/builders/qnn_constants.py | 7 + .../passes/annotate_and_quant_scalar.py | 1 + backends/qualcomm/passes/i64_to_i32.py | 24 ++++ .../qualcomm/passes/recompose_rms_norm.py | 76 +++++++++++ .../passes/replace_index_put_input.py | 54 ++++++++ .../qualcomm/quantizer/custom_annotation.py | 10 +- backends/qualcomm/quantizer/utils.py | 25 ++++ backends/qualcomm/tests/models.py | 10 ++ backends/qualcomm/tests/test_qnn_delegate.py | 13 ++ backends/qualcomm/utils/utils.py | 7 + examples/models/llama2/export_llama_lib.py | 17 ++- examples/models/llama2/llama_transformer.py | 6 +- .../llama2/source_transformation/sdpa.py | 5 +- extension/llm/export/builder.py | 1 + extension/llm/export/partitioner_lib.py | 4 +- extension/llm/export/quantizer_lib.py | 9 +- 20 files changed, 409 insertions(+), 77 deletions(-) create mode 100644 backends/qualcomm/builders/op_rms_norm.py create mode 100644 backends/qualcomm/passes/recompose_rms_norm.py create mode 100644 backends/qualcomm/passes/replace_index_put_input.py diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index d3bf98bae7..79c02e2207 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -38,6 +38,7 @@ op_quantize, op_relu, op_reshape, + op_rms_norm, op_rsqrt, op_select_copy, op_sigmoid, @@ -92,6 +93,7 @@ op_quantize, op_relu, op_reshape, + op_rms_norm, op_rsqrt, op_select_copy, op_sigmoid, diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index e07a745df5..514bc6efd7 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -202,7 +202,7 @@ def get_quant_tensor_value( dtype = quant_configs[QCOM_DTYPE] - tensor = tensor.div(scale + 1e-6).add(zero_point).round().to(dtype) + tensor = tensor.div(scale).add(zero_point).round().to(dtype) # Make the backends access data correctly if quant_configs.get(QCOM_BITWIDTH) == 4: mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8) diff --git a/backends/qualcomm/builders/op_conv2d.py b/backends/qualcomm/builders/op_conv2d.py index 909cc6a21f..4b58edbac6 100644 --- a/backends/qualcomm/builders/op_conv2d.py +++ b/backends/qualcomm/builders/op_conv2d.py @@ -10,16 +10,7 @@ import numpy as np import torch -from executorch.backends.qualcomm.utils.constants import ( - QCOM_DATA, - QCOM_DTYPE, - QCOM_QUANT_ATTRS, - QCOM_QUANT_MAX, - QCOM_QUANT_MIN, - QCOM_SCALE, - QCOM_ZERO_POINT, -) -from executorch.exir.dialects._ops import ops as exir_ops +from executorch.backends.qualcomm.utils.constants import QCOM_DATA from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import ( @@ -94,52 +85,6 @@ def _add_conv_op_parameter( return conv_op - def _get_bias_tensor( - self, - node: torch.fx.Node, - nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper], - num_output_channel: int, - ) -> PyQnnWrapper.PyQnnOpWrapper: - # build dummy node if bias is not given - bias_node = ( - node.args[2] - if node.args[2] is not None - else torch.fx.Node( - node.graph, - node.name + "_runtime_bias", - "call_function", - exir_ops.edge.aten.full.default, - (), # args - {}, # kwargs - ) - ) - # zeros tensor to meet HTP constraint if bias is not given - bias_tensor = ( - get_parameter(bias_node, self.edge_program) - if node.args[2] is not None - else torch.zeros(num_output_channel) - ) - # insert quant attribute to meet HTP constraint if bias is not given - if ( - node.args[2] is None - and (bias_quant_attrs := node.meta.get(QCOM_QUANT_ATTRS)) is not None - ): - quant_attrs = bias_quant_attrs.copy() - quant_attrs[QCOM_ZERO_POINT] = 0 - quant_attrs[QCOM_SCALE] = 0 - quant_attrs[QCOM_DTYPE] = torch.int32 - quant_attrs[QCOM_QUANT_MAX] = torch.iinfo(torch.int32).max - quant_attrs[QCOM_QUANT_MIN] = torch.iinfo(torch.int32).min + 1 - bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs - - return self.define_tensor( - bias_node, - bias_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, - nodes_to_wrappers, - is_input_tensor=False, - ) - def _define_conv1d( self, node: torch.fx.Node, @@ -204,9 +149,17 @@ def _define_conv1d( is_input_tensor=False, ) conv_input_tensors = [unsqueeze_output_tensor_wrapper, filter_tensor_wrapper] - conv_input_tensors.append( - self._get_bias_tensor(node, nodes_to_wrappers, filter_tensor.shape[-1]) - ) + if node.args[2] is not None: + bias_node = node.args[2] + bias_tensor = get_parameter(bias_node, self.edge_program) + bias_tensor_wrapper = self.define_tensor( + bias_node, + bias_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + is_input_tensor=False, + ) + conv_input_tensors.append(bias_tensor_wrapper) stride = [1] + cast(List[int], node.args[3]) padding = [0] + cast(List[int], node.args[4]) @@ -312,9 +265,18 @@ def define_node( is_input_tensor=False, ) conv_input_tensors = [input_tensor_wrapper, filter_tensor_wrapper] - conv_input_tensors.append( - self._get_bias_tensor(node, nodes_to_wrappers, filter_tensor.shape[-1]) - ) + + if node.args[2] is not None: + bias_node = node.args[2] + bias_tensor = get_parameter(bias_node, self.edge_program) + bias_tensor_wrapper = self.define_tensor( + bias_node, + bias_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + is_input_tensor=False, + ) + conv_input_tensors.append(bias_tensor_wrapper) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( diff --git a/backends/qualcomm/builders/op_rms_norm.py b/backends/qualcomm/builders/op_rms_norm.py new file mode 100644 index 0000000000..e99b1f47ba --- /dev/null +++ b/backends/qualcomm/builders/op_rms_norm.py @@ -0,0 +1,127 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import numpy as np + +import torch +from executorch.backends.qualcomm.builders.utils import get_parameter +from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_QUANT_ATTRS +from executorch.exir.dialects._ops import ops as exir_ops + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpRmsNorm, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class RmsNormVisitor(NodeVisitor): + target = ["aten.rms_norm.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + # args of node : ['input', 'normalized_shape', 'weight', 'eps'] + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + + # should be a immutable list + normalized_shapes = node.args[1] + if ( + len(normalized_shapes) != 1 + and normalized_shapes[0] != input_tensor.shape[-1] + ): + print("Only supports normalization with last input dimension") + return + axes = [node.args[0].meta["val"].dim() - 1] + axes_shape = [len(axes)] + + weight_node = node.args[2] + weight_tensor = get_parameter(weight_node, self.edge_program) + weight_tensor_wrapper = self.define_tensor( + weight_node, + weight_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + is_input_tensor=False, + ) + + # Fake node, nn moudle seems to be inconsistant with document + bias_tensor = torch.zeros(weight_tensor.shape) + bias_node = torch.fx.Node( + node.graph, + node.name + "_runtime_bias", + "call_function", + exir_ops.edge.aten.tensor.default, + (), # args + {}, # kwargs + ) + if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): + bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs + bias_tensor_wrapper = self.define_tensor( + bias_node, + bias_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + is_input_tensor=False, + ) + + epsilon = node.args[3] + if isinstance(epsilon, torch.fx.Node): + epsilon = get_parameter(epsilon, self.edge_program) + epsilon = ( + epsilon + if isinstance(epsilon, float) + else torch.finfo(epsilon.dtype).eps + ) + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=False, + ) + + rms_nrom_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpRmsNorm.op_name, + ) + + rms_nrom_op.AddInputTensors( + [input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper] + ) + rms_nrom_op.AddOutputTensors([output_tensor_wrapper]) + rms_nrom_op.AddScalarParam( + OpRmsNorm.param_epsilon, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + {QCOM_DATA: np.float32(epsilon)}, + ) + rms_nrom_op.AddTensorParam( + OpRmsNorm.param_axes, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(axes_shape), + axes_shape, + np.array(axes, dtype=np.uint32), + True, + ) + + return rms_nrom_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 4a87e5dbbb..8ac702f2ad 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -278,6 +278,13 @@ class OpResizeNearestNeighbor: param_half_pixel_centers: str = "half_pixel_centers" +@dataclass(init=False, frozen=True) +class OpRmsNorm: + op_name: str = "RmsNorm" + param_epsilon: str = "epsilon" + param_axes: str = "axes" + + @dataclass(init=False, frozen=True) class OpScatterNd: op_name: str = "ScatterNd" diff --git a/backends/qualcomm/passes/annotate_and_quant_scalar.py b/backends/qualcomm/passes/annotate_and_quant_scalar.py index 5f111ee9c8..1ec2ac64b5 100644 --- a/backends/qualcomm/passes/annotate_and_quant_scalar.py +++ b/backends/qualcomm/passes/annotate_and_quant_scalar.py @@ -78,6 +78,7 @@ def _annotate_scalar_node( float, torch.float32, torch.int32, + torch.int64, ]: return diff --git a/backends/qualcomm/passes/i64_to_i32.py b/backends/qualcomm/passes/i64_to_i32.py index 7814a3ff0d..1d2171cc37 100644 --- a/backends/qualcomm/passes/i64_to_i32.py +++ b/backends/qualcomm/passes/i64_to_i32.py @@ -5,7 +5,9 @@ # LICENSE file in the root directory of this source tree. import torch from executorch.backends.qualcomm.builders.utils import get_parameter, is_constant +from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult +from torch._subclasses.fake_tensor import FakeTensor class I64toI32(ExportPass): @@ -16,6 +18,8 @@ class I64toI32(ExportPass): def __init__(self, edge_program: torch.export.ExportedProgram): super(I64toI32, self).__init__() self.edge_program = edge_program + # pyre-ignore[4] + self.copy_op = exir_ops.edge.aten._to_copy.default def _update_meta(self, node: torch.fx.node) -> None: meta_val = node.meta["val"] @@ -32,6 +36,10 @@ def _update_meta(self, node: torch.fx.node) -> None: if meta_val.dtype == torch.int64: node.meta["val"] = meta_val.to(torch.float) + # pyre-ignore[2] + def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool: + return isinstance(node_val, FakeTensor) and node_val.dtype == dtype + def _cast_to_int32(self, graph_module: torch.fx.GraphModule): for n in graph_module.graph.nodes: if is_constant(n, self.edge_program): @@ -39,6 +47,22 @@ def _cast_to_int32(self, graph_module: torch.fx.GraphModule): if param.dtype == torch.int64: # QNN does not support int64 self._update_meta(n) + elif n.op == "placeholder": + node_val = n.meta["val"] + if self._is_tensor_of_dtype(node_val, torch.int64): + with graph_module.graph.inserting_after(n): + args = (n,) + to_dst_node = graph_module.graph.create_node( + "call_function", + self.copy_op, + args, + {"dtype": torch.int32}, + ) + to_dst_node.meta["val"] = node_val.to(torch.int32) + + # Replace usage of the src dtype result with the dst dtype result. + n.replace_all_uses_with(to_dst_node) + to_dst_node.args = (n,) def call(self, graph_module: torch.fx.GraphModule): self._cast_to_int32(graph_module) diff --git a/backends/qualcomm/passes/recompose_rms_norm.py b/backends/qualcomm/passes/recompose_rms_norm.py new file mode 100644 index 0000000000..b26de8bd79 --- /dev/null +++ b/backends/qualcomm/passes/recompose_rms_norm.py @@ -0,0 +1,76 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions + +from .utils import dq_ops + + +class RecomposeRmsNorm(ExportPass): + """ + Merge decomposed operators back to one super node. + """ + + def __init__(self): + super().__init__() + + def _get_eps_node(self, nodes): + # eps: one of inputs of add node + add_node = [n for n in nodes if hasattr(n, "name") and "add" in n.name][0] + for a in add_node.args: + if isinstance(a, float) or a.op != "call_function": + return a + + def _get_gamma_node(self, output_node): + # gamma: one of inputs of output node + for a in output_node.args: + if a.op != "call_function" or a.target in dq_ops: + return a + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + partitions = get_source_partitions(graph, [torch.nn.RMSNorm]) + for _, src_partitions in partitions.items(): + for src_partition in src_partitions: + input_len = len(src_partition.input_nodes) + if input_len == 1: + input_node = src_partition.input_nodes[0] + elif input_len == 2: + inp_0, inp_1 = src_partition.input_nodes + input_node = inp_0 if len(inp_0.users) == 2 else inp_1 + else: + raise RuntimeError( + f"Found a edge case of rms_node partitoin {src_partition}, which has {input_len} inputs" + ) + + output_node = src_partition.output_nodes[0] + eps_node = self._get_eps_node(src_partition.nodes) + gamma_node = self._get_gamma_node(output_node) + + with graph.inserting_before(output_node): + # args schema + # (Tensor input, int[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor + rms_node = graph.create_node( + "call_function", + exir_ops.edge.aten.rms_norm.default, + ( + input_node, + list(gamma_node.meta["val"].shape), + gamma_node, + eps_node, + ), + ) + users = output_node.users.copy() + for user in users: + user.replace_input_with(output_node, rms_node) + # copy metadata + rms_node.meta = output_node.meta + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/passes/replace_index_put_input.py b/backends/qualcomm/passes/replace_index_put_input.py new file mode 100644 index 0000000000..1eb210cf67 --- /dev/null +++ b/backends/qualcomm/passes/replace_index_put_input.py @@ -0,0 +1,54 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_ENCODING, QCOM_QUANT_ATTRS +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class ReplaceIndexPutInput(ExportPass): + """ + Index put input workaround for quantized module + """ + + dq_q_map = { + # per tensor + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor: exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, + # per channel + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: exir_ops.edge.quantized_decomposed.quantize_per_channel.default, + } + + def __init__(self, edge_program: torch.export.ExportedProgram): + super(ReplaceIndexPutInput, self).__init__() + self.edge_program = edge_program + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + for node in graph.nodes: + if node.target == exir_ops.edge.aten.index_put.default: + if ( + copy_node := list(node.users)[0] + ) and copy_node.target == exir_ops.edge.aten.copy.default: + m_buffer_node = copy_node.args[0] + bad_frozen_node = node.args[0] + if QCOM_QUANT_ATTRS in bad_frozen_node.meta: + m_buffer_node.meta[QCOM_QUANT_ATTRS] = bad_frozen_node.meta[ + QCOM_QUANT_ATTRS + ] + m_buffer_node.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING] = ( + self.dq_q_map[ + m_buffer_node.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING] + ] + ) + with graph.inserting_after(bad_frozen_node): + node.replace_input_with(bad_frozen_node, m_buffer_node) + else: + continue + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index b2c86e50d3..9cde50b9c7 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -91,15 +91,17 @@ def is_edge_condition(node: Node): def annotate_matmul_input1(node: Node, quantization_config: QuantizationConfig): if is_edge_condition(node): return - if node.target == torch.ops.aten.index_put_.default: + if node.target in [ + torch.ops.aten.index_put.default, + torch.ops.aten.index_put_.default, + ]: annotate_index_put(node, quantization_config) annotate_matmul_input1(node.args[0], quantization_config) elif node.target == torch.ops.aten.cat.default: annotate_cat(node, quantization_config) # Expect that the inputs of the cat op are select ops - for arg in node.args[0][1:]: - annotate_single_in_single_out(arg, quantization_config) - annotate_matmul_input1(node.args[0][0], quantization_config) + for arg in node.args[0]: + annotate_matmul_input1(arg, quantization_config) else: annotate_single_in_single_out(node, quantization_config) annotate_matmul_input1(node.args[0], quantization_config) diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index d31b4753a3..5f299f9bc6 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -684,6 +684,31 @@ def annotate_squeeze(node: Node, quantization_config: QuantizationConfig) -> Non annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.rms_norm.default]) +def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> None: + act_node = node.args[0] + weight_node = node.args[2] + + if _is_annotated([node]): + return + + # TODO current only support 16a16w + _annotate_input_qspec_map( + node, + act_node, + quantization_config.input_activation, + ) + + _annotate_input_qspec_map( + node, + weight_node, + quantization_config.input_activation, + ) + nodes_to_mark_annotated = [node] + _annotate_output_qspec(node, quantization_config.output_activation) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + + @register_annotator([torch.ops.aten.rsqrt.default]) def annotate_rsqrt(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 319cc6092c..127f704e8c 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -734,6 +734,16 @@ def forward(self, x): ) +class RmsNorm(torch.nn.Module): + def __init__(self): + super().__init__() + self.eps = 1e-5 + self.rms = torch.nn.RMSNorm([4], 1e-5) + + def forward(self, x): + return self.rms(x) + + class Rsqrt(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index cba23f935c..71e3b13ff8 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -393,6 +393,11 @@ def test_qnn_backend_reshape(self): sample_input = (torch.randn([3, 4]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_rms_norm(self): + module = RmsNorm() # noqa: F405 + sample_input = (torch.abs(torch.randn([1, 1, 1, 4])),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_rsqrt(self): module = Rsqrt() # noqa: F405 sample_input = (torch.abs(torch.randn([3, 4])),) @@ -1000,6 +1005,14 @@ def test_qnn_backend_reshape(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_rms_norm(self): + module = RmsNorm() # noqa: F405 + sample_input = (torch.abs(torch.randn([1, 1, 1, 4])),) + module = self.get_qdq_module( + module, sample_input, quant_dtype=QuantDtype.use_16a4w + ) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_rsqrt(self): module = Rsqrt() # noqa: F405 sample_input = (torch.abs(torch.randn([3, 4])),) diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 6dc0c4c3c8..3e274a0ce7 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -38,7 +38,11 @@ from executorch.backends.qualcomm.passes.recompose_pixel_unshuffle import ( RecomposePixelUnshuffle, ) +from executorch.backends.qualcomm.passes.recompose_rms_norm import RecomposeRmsNorm from executorch.backends.qualcomm.passes.remove_redundancy import RemoveRedundancy +from executorch.backends.qualcomm.passes.replace_index_put_input import ( + ReplaceIndexPutInput, +) from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( _soc_info_table, QcomChipset, @@ -56,6 +60,7 @@ convert_to_option, ) from executorch.backends.qualcomm.utils.constants import QCOM_QNN_COMPILE_SPEC + from executorch.exir import ExirExportedProgram from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.lowered_backend_module import LoweredBackendModule @@ -201,6 +206,7 @@ def _transform(edge_program: ExportedProgram) -> None: graph_module = edge_program.graph_module RemoveRedundancy()(graph_module) RecomposePixelUnshuffle()(graph_module) + RecomposeRmsNorm()(graph_module) ConvertToLinear()(graph_module) ConvertPReLU(edge_program)(graph_module) ConvertBmmToMatmul()(graph_module) @@ -211,6 +217,7 @@ def _transform(edge_program: ExportedProgram) -> None: AnnotateDecomposed(edge_program)(graph_module) FoldQDQ()(graph_module) LayoutTransform(edge_program)(graph_module) + ReplaceIndexPutInput(edge_program)(graph_module) # Since QDQ nodes are stripped, update graph signature again to validate program edge_program._graph_signature = _get_updated_graph_signature( diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index f6abc3aaf4..401788bea1 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -43,6 +43,7 @@ get_qnn_quantizer, ) from executorch.util.activation_memory_profiler import generate_memory_trace +from torch._export import capture_pre_autograd_graph from ..model_factory import EagerModelFactory from .source_transformation.quantize import ( @@ -406,9 +407,15 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: if args.use_kv_cache: if args.qnn: + # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` + from executorch.backends.qualcomm.utils.utils import ( + convert_linear_to_conv2d, + ) + transforms.append(replace_kv_cache_with_simple_kv_cache) transforms.append(replace_sdpa_with_flex_sdpa) transforms.append(replace_causal_mask) + transforms.append(convert_linear_to_conv2d) elif args.coreml or args.mps: # Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition @@ -552,7 +559,10 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 if args.num_sharding > 0 and args.qnn: from executorch.backends.qualcomm.utils.utils import canonicalize_program - canonicalize_program(builder.edge_manager.exported_program()) + # TODO: Need to remove this once we have better way to handle buffer size + canonicalize_program( + builder.edge_manager.exported_program(), custom_buffer_size=542048256 + ) builder = builder.to_executorch() @@ -569,7 +579,10 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 if args.num_sharding > 0 and args.qnn: from executorch.backends.qualcomm.utils.utils import canonicalize_program - canonicalize_program(builder.edge_manager.exported_program()) + # TODO: Need to remove this once we have better way to handle buffer size + canonicalize_program( + builder.edge_manager.exported_program(), custom_buffer_size=542048256 + ) builder = builder.to_executorch() diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 0c93115ee3..4b67825fa8 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -416,8 +416,8 @@ def __init__(self, layer_id: int, args: ModelArgs): self.block_sparse_moe = MOEFeedForward(args) else: self.feed_forward = FeedForward(args) - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.attention_norm = torch.nn.RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = torch.nn.RMSNorm(args.dim, eps=args.norm_eps) def forward(self, x, freqs_cos, freqs_sin, input_pos=None): # x: 1xN h = self.attention.forward( @@ -443,7 +443,7 @@ def __init__(self, params: ModelArgs): self.layers = torch.nn.ModuleList() for layer_id in range(params.n_layers): self.layers.append(TransformerBlock(layer_id, params)) - self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.norm = torch.nn.RMSNorm(params.dim, eps=params.norm_eps) self.output = nn.Linear(params.dim, params.vocab_size, bias=False) self.use_kv_cache = params.use_kv_cache self.generate_full_logits = params.generate_full_logits diff --git a/examples/models/llama2/source_transformation/sdpa.py b/examples/models/llama2/source_transformation/sdpa.py index 8e5de7d97a..c48fdf0ae5 100644 --- a/examples/models/llama2/source_transformation/sdpa.py +++ b/examples/models/llama2/source_transformation/sdpa.py @@ -118,8 +118,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ - if n_rep == 1: - return hidden_states + # TODO: Encounter the bug about source partition, need to investigate more on it. + # if n_rep == 1: + # return hidden_states new_kv = [] batch, n_heads, seqlen, head_dim = hidden_states.shape diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index bc64ae869f..e7b08aa9be 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -77,6 +77,7 @@ def __init__( verbose: bool = False, metadata: Optional[dict] = None, dynamic_shapes: Optional[Any] = None, + export_fn=capture_pre_autograd_graph, ): self.model = model # graph module returned from capture_pre_autograd_graph diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index e75d5bef3f..2f4c87d6fd 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -130,11 +130,11 @@ def get_qnn_partitioner( ) except ImportError: raise ImportError( - "Please install the Qualcomm backend follwing https://pytorch.org/executorch/main/build-run-qualcomm-ai-engine-direct-backend.html" + "Please install the Qualcomm backend following https://pytorch.org/executorch/main/build-run-qualcomm-ai-engine-direct-backend.html" ) use_fp16 = True - skip_node_op_set = {"llama.fallback.default"} + skip_node_op_set = {"llama.fallback.default", "aten.embedding.default"} if pt2e_quantize is not None: use_fp16 = False diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index 7fc53358c5..45d9932724 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -180,8 +180,9 @@ def get_qnn_quantizer( # Due to the error with 16a16w in Qnn Htp, we need to disable per channel linear quantization when use 16a16w # TODO: enable it after the issue is fixed logging.warning( - "Disable per channel quantization for linear due to the error with QNN HTP 16a16w." + "Disable per channel quantization for linear and conv due to the error with QNN HTP 16a16w." ) + qnn_quantizer.set_per_channel_conv_quant(enable=False) qnn_quantizer.set_per_channel_linear_quant(enable=False) qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS) qnn_quantizer.set_bit16_op_quant_config( @@ -208,6 +209,12 @@ def get_qnn_quantizer( quantization_mode is None ), "Currently qnn backend only supports QnnQuantizer via pt2e flow" qnn_quantizer.add_custom_quant_annotations(custom_annotations) + qnn_quantizer.add_discard_ops( + [ + torch.ops.aten.embedding.default, + ] + ) + return qnn_quantizer, quant_dtype From fcd42fb33073b1d79668c2d6bcf990be1e73c6b1 Mon Sep 17 00:00:00 2001 From: Sheng Feng Wu Date: Thu, 5 Sep 2024 20:46:24 -0700 Subject: [PATCH 2/3] Use transform to replace rms_norm --- examples/models/llama2/TARGETS | 1 + examples/models/llama2/export_llama_lib.py | 3 ++- examples/models/llama2/llama_transformer.py | 7 +++--- .../llama2/source_transformation/rms_norm.py | 23 +++++++++++++++++++ extension/llm/export/builder.py | 1 - 5 files changed, 30 insertions(+), 5 deletions(-) create mode 100644 examples/models/llama2/source_transformation/rms_norm.py diff --git a/examples/models/llama2/TARGETS b/examples/models/llama2/TARGETS index 467949a5eb..18a10fb9fd 100644 --- a/examples/models/llama2/TARGETS +++ b/examples/models/llama2/TARGETS @@ -71,6 +71,7 @@ runtime.python_library( "export_llama_lib.py", "model.py", "source_transformation/quantize.py", + "source_transformation/rms_norm.py", "source_transformation/rope.py", "source_transformation/sdpa.py", ], diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 401788bea1..968117eef2 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -43,13 +43,13 @@ get_qnn_quantizer, ) from executorch.util.activation_memory_profiler import generate_memory_trace -from torch._export import capture_pre_autograd_graph from ..model_factory import EagerModelFactory from .source_transformation.quantize import ( get_quant_embedding_transform, get_quant_weight_transform, ) +from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis from .source_transformation.sdpa import ( replace_causal_mask, @@ -415,6 +415,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: transforms.append(replace_kv_cache_with_simple_kv_cache) transforms.append(replace_sdpa_with_flex_sdpa) transforms.append(replace_causal_mask) + transforms.append(replace_rms_norm_with_native_rms_norm) transforms.append(convert_linear_to_conv2d) elif args.coreml or args.mps: diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 4b67825fa8..534d90c6ed 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -39,6 +39,7 @@ def __init__(self, dim: int, eps: float = 1e-6): """ super().__init__() + self.dim = dim self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) @@ -416,8 +417,8 @@ def __init__(self, layer_id: int, args: ModelArgs): self.block_sparse_moe = MOEFeedForward(args) else: self.feed_forward = FeedForward(args) - self.attention_norm = torch.nn.RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = torch.nn.RMSNorm(args.dim, eps=args.norm_eps) + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) def forward(self, x, freqs_cos, freqs_sin, input_pos=None): # x: 1xN h = self.attention.forward( @@ -443,7 +444,7 @@ def __init__(self, params: ModelArgs): self.layers = torch.nn.ModuleList() for layer_id in range(params.n_layers): self.layers.append(TransformerBlock(layer_id, params)) - self.norm = torch.nn.RMSNorm(params.dim, eps=params.norm_eps) + self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = nn.Linear(params.dim, params.vocab_size, bias=False) self.use_kv_cache = params.use_kv_cache self.generate_full_logits = params.generate_full_logits diff --git a/examples/models/llama2/source_transformation/rms_norm.py b/examples/models/llama2/source_transformation/rms_norm.py new file mode 100644 index 0000000000..ff7e8b6745 --- /dev/null +++ b/examples/models/llama2/source_transformation/rms_norm.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.examples.models.llama2.llama_transformer import RMSNorm + + +def replace_rms_norm_with_native_rms_norm(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, RMSNorm): + rms_norm = torch.nn.RMSNorm(child.dim, eps=child.eps) + rms_norm.weight = child.weight + setattr( + module, + name, + rms_norm, + ) + else: + replace_rms_norm_with_native_rms_norm(child) + return module diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index e7b08aa9be..bc64ae869f 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -77,7 +77,6 @@ def __init__( verbose: bool = False, metadata: Optional[dict] = None, dynamic_shapes: Optional[Any] = None, - export_fn=capture_pre_autograd_graph, ): self.model = model # graph module returned from capture_pre_autograd_graph From 9b98827aaeaca217340c1e0aafcd67a237d4b442 Mon Sep 17 00:00:00 2001 From: shewu-quic Date: Sat, 7 Sep 2024 10:25:49 +0800 Subject: [PATCH 3/3] temporarily remove test-llama-runner-qnn-linux --- .github/workflows/pull.yml | 35 ----------------------------------- 1 file changed, 35 deletions(-) diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index ca13d9bbd2..259ebb1986 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -372,38 +372,3 @@ jobs: # Run pytest with coverage pytest -c /dev/null -v -n auto --cov=./ --cov-report=xml backends/arm/test - - - test-llama-runner-qnn-linux: - name: test-llama-runner-qnn-linux - uses: pytorch/test-infra/.github/workflows/linux_job.yml@main - strategy: - matrix: - dtype: [fp32] - build-tool: [cmake] - mode: [qnn] - fail-fast: false - with: - runner: linux.2xlarge - docker-image: executorch-ubuntu-22.04-clang12-android - submodules: 'true' - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - timeout: 900 - script: | - # The generic Linux job chooses to use base env, not the one setup by the image - CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") - conda activate "${CONDA_ENV}" - - DTYPE=${{ matrix.dtype }} - BUILD_TOOL=${{ matrix.build-tool }} - MODE=${{ matrix.mode }} - - PYTHON_EXECUTABLE=python bash .ci/scripts/setup-qnn-deps.sh - PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh - - # Setup executorch - PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh buck2 - # Install requirements for export_llama - PYTHON_EXECUTABLE=python bash examples/models/llama2/install_requirements.sh - # Test llama2 - PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh stories110M "${BUILD_TOOL}" "${DTYPE}" "${MODE}"