From 1207cd95ba712b7b8ba219e05a1b291e7660735a Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Tue, 27 Aug 2024 11:35:10 +0000 Subject: [PATCH] 2024-08-27 nightly release (dc66414c70dbec763fa37aba7908d50373299435) --- backends/arm/arm_partitioner.py | 3 + backends/arm/operators/__init__.py | 3 + backends/arm/operators/op_cat.py | 45 ++ backends/arm/operators/op_placeholder.py | 25 +- backends/arm/operators/op_relu.py | 55 ++ backends/arm/operators/op_unsqueeze.py | 51 ++ backends/arm/quantizer/arm_quantizer.py | 1 + backends/arm/quantizer/arm_quantizer_utils.py | 10 +- .../quantization_annotation/__init__.py | 1 + .../quantization_annotation/cat_annotator.py | 66 ++ backends/arm/test/misc/test_lifted_tensor.py | 42 ++ backends/arm/test/ops/test_cat.py | 131 ++++ backends/arm/test/ops/test_conv_combos.py | 14 +- backends/arm/test/ops/test_relu.py | 120 ++++ backends/arm/test/ops/test_unsqueeze.py | 103 +++ backends/arm/test/tester/arm_tester.py | 16 +- backends/cadence/aot/compiler.py | 6 +- backends/qualcomm/tests/test_qnn_delegate.py | 49 ++ backends/transforms/addmm_mm_to_linear.py | 4 +- backends/transforms/decompose_sdpa.py | 2 +- .../xnnpack/partition/config/gemm_configs.py | 9 + .../partition/config/generic_node_configs.py | 90 ++- .../xnnpack/partition/config/node_configs.py | 22 +- .../partition/config/xnnpack_config.py | 8 +- .../xnnpack/partition/xnnpack_partitioner.py | 14 + .../channels_last_tagged_reshape_pass.py | 2 +- backends/xnnpack/passes/convert_to_sdpa.py | 2 +- backends/xnnpack/test/ops/mean_dim.py | 13 + .../llama2/source_transformation/quantize.py | 4 +- examples/models/llava/export_llava.py | 35 +- examples/models/llava/test/test_pte.py | 11 + .../models/phi-3-mini/export_phi-3-mini.py | 2 +- examples/qualcomm/CMakeLists.txt | 5 + .../stable_diffusion/CMakeLists.txt | 26 + .../qaihub_scripts/stable_diffusion/README.md | 35 + .../stable_diffusion/install_requirements.sh | 3 + .../qaihub_stable_diffusion.py | 472 +++++++++++++ .../qaihub_stable_diffusion_runner.cpp | 140 ++++ .../stable_diffusion/runner/runner.cpp | 621 ++++++++++++++++++ .../stable_diffusion/runner/runner.h | 141 ++++ .../stable_diffusion/stable_diffusion_lib.py | 22 + exir/backend/utils.py | 31 +- exir/capture/_config.py | 7 +- exir/emit/_emitter.py | 14 +- exir/lowered_backend_module.py | 4 +- exir/pass_base.py | 8 +- exir/passes/__init__.py | 1 + exir/passes/remove_noop_pass.py | 2 +- exir/program/_program.py | 41 +- exir/tests/test_passes.py | 2 +- exir/tests/test_quantization.py | 2 +- exir/tracer.py | 6 +- exir/verification/arg_validator.py | 6 +- extension/llm/export/builder.py | 4 +- extension/llm/export/partitioner_lib.py | 15 +- extension/llm/export/quantizer_lib.py | 8 +- 56 files changed, 2479 insertions(+), 96 deletions(-) create mode 100644 backends/arm/operators/op_cat.py create mode 100644 backends/arm/operators/op_relu.py create mode 100644 backends/arm/operators/op_unsqueeze.py create mode 100644 backends/arm/quantizer/quantization_annotation/cat_annotator.py create mode 100644 backends/arm/test/misc/test_lifted_tensor.py create mode 100644 backends/arm/test/ops/test_cat.py create mode 100644 backends/arm/test/ops/test_relu.py create mode 100644 backends/arm/test/ops/test_unsqueeze.py create mode 100644 examples/qualcomm/qaihub_scripts/stable_diffusion/CMakeLists.txt create mode 100644 examples/qualcomm/qaihub_scripts/stable_diffusion/README.md create mode 100755 examples/qualcomm/qaihub_scripts/stable_diffusion/install_requirements.sh create mode 100644 examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py create mode 100644 examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion_runner.cpp create mode 100644 examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.cpp create mode 100644 examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.h create mode 100644 examples/qualcomm/qaihub_scripts/stable_diffusion/stable_diffusion_lib.py diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index f73d97480b..0dc3d36b5c 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -39,6 +39,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.addmm.default, exir_ops.edge.aten.expand_copy.default, + exir_ops.edge.aten.cat.default, exir_ops.edge.aten.permute_copy.default, exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.convolution.default, @@ -51,12 +52,14 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.mm.default, exir_ops.edge.aten.repeat.default, + exir_ops.edge.aten.relu.default, exir_ops.edge.aten._softmax.default, exir_ops.edge.aten.slice_copy.Tensor, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.view_copy.default, exir_ops.edge.aten.clone.default, exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.unsqueeze_copy.default, operator.getitem, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 94a16d8c94..dc1fcc8e2c 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -9,6 +9,7 @@ op_addmm, op_avg_pool2d, op_batch_norm, + op_cat, op_conv2d, op_dequant, op_div, @@ -20,10 +21,12 @@ op_mul, op_permute, op_quant, + op_relu, op_repeat, op_sigmoid, op_slice, op_softmax, op_sub, + op_unsqueeze, op_view, ) diff --git a/backends/arm/operators/op_cat.py b/backends/arm/operators/op_cat.py new file mode 100644 index 0000000000..f2b4165657 --- /dev/null +++ b/backends/arm/operators/op_cat.py @@ -0,0 +1,45 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +import serializer.tosa_serializer as ts +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from serializer.tosa_serializer import TosaOp +from torch.fx import Node + + +@register_node_visitor +class CatVisitor(NodeVisitor): + target = "aten.cat.default" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + + tensors = inputs[0].special + dim = 0 if len(inputs) < 2 else inputs[1].number + rank = len(output.shape) + dim = (dim + rank) % rank + dim = output.dim_order.index(dim) + + attr = ts.TosaSerializerAttribute() + attr.AxisAttribute(dim) + + tosa_graph.addOperator( + TosaOp.Op().CONCAT, [tensor.name for tensor in tensors], [output.name], attr + ) diff --git a/backends/arm/operators/op_placeholder.py b/backends/arm/operators/op_placeholder.py index 0b2e65f45d..918a270bb0 100644 --- a/backends/arm/operators/op_placeholder.py +++ b/backends/arm/operators/op_placeholder.py @@ -5,7 +5,7 @@ import numpy as np import serializer.tosa_serializer as ts -import torch +import torch.fx from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import ( get_quant_arg_dtype, @@ -130,6 +130,21 @@ def process_inputs_to_buffers( ) +def process_inputs_to_lifted_tensor_constants( + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + edge_program: ExportedProgram, +): + arg = TosaArg(node) + tensor_name = edge_program.graph_signature.inputs_to_lifted_tensor_constants[ + arg.name + ] + tensor = edge_program.tensor_constants[tensor_name] + tensor_data = tensor.detach().numpy() + + tosa_graph.addConst(tensor_data.shape, arg.dtype, tensor_data, name=arg.name) + + def process_placeholder( node: torch.fx.Node, tosa_graph: ts.TosaSerializer, @@ -145,5 +160,11 @@ def process_placeholder( process_inputs_to_parameters(node, tosa_graph, edge_program) elif node.name in edge_program.graph_signature.inputs_to_buffers: process_inputs_to_buffers(node, tosa_graph, edge_program) + elif node.name in edge_program.graph_signature.inputs_to_lifted_tensor_constants: + process_inputs_to_lifted_tensor_constants(node, tosa_graph, edge_program) + elif node.name in edge_program.graph_signature.inputs_to_lifted_custom_objs: + raise NotImplementedError( + "Placeholder is of type 'lifted custom object' which is not supported." + ) else: - raise RuntimeError(f"Unknown placeholder {node.name}") + raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.") diff --git a/backends/arm/operators/op_relu.py b/backends/arm/operators/op_relu.py new file mode 100644 index 0000000000..5afe1ac7bc --- /dev/null +++ b/backends/arm/operators/op_relu.py @@ -0,0 +1,55 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import executorch.backends.arm.tosa_quant_utils as tqutils +import serializer.tosa_serializer as ts +import torch.fx +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from serializer.tosa_serializer import TosaOp + + +@register_node_visitor +class ReluVisitor(NodeVisitor): + target = "aten.relu.default" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: list[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + attr = ts.TosaSerializerAttribute() + + clamp_min_fp = 0.0 + clamp_max_fp = 0.0 + clamp_min_qs = 0 + clamp_max_qs = 0 + if is_quant_node: + out_qargs = tqutils.get_quant_node_args(list(node.users)[0]) + clamp_min_qs = tqutils.quantize_value(0, out_qargs) + clamp_max_qs = tqutils.quantize_value(float("inf"), out_qargs) + + else: + clamp_min_fp = 0 + clamp_max_fp = float("inf") + + attr.ClampAttribute( + tosa_graph.builder, + clamp_min_qs, + clamp_max_qs, + clamp_min_fp, + clamp_max_fp, + ) + + tosa_graph.addOperator(TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr) diff --git a/backends/arm/operators/op_unsqueeze.py b/backends/arm/operators/op_unsqueeze.py new file mode 100644 index 0000000000..a7ff8ce0b4 --- /dev/null +++ b/backends/arm/operators/op_unsqueeze.py @@ -0,0 +1,51 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Follows this specification: https://pytorch.org/docs/stable/generated/torch.unsqueeze.html + +import serializer.tosa_serializer as ts +import torch.fx +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_utils import tosa_shape +from serializer.tosa_serializer import TosaOp + + +@register_node_visitor +class UnsqueezeVisitor(NodeVisitor): + target = "aten.unsqueeze_copy.default" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: list[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + + dim = inputs[1].number + shape = inputs[0].shape + rank = len(shape) + + assert -rank - 1 <= dim < rank + 1 + if dim < 0: + dim = dim + rank + 1 + + new_shape = list(shape) + new_shape.insert(dim, 1) + new_shape = tosa_shape(new_shape, output.dim_order) + + attr = ts.TosaSerializerAttribute() + attr.ReshapeAttribute(new_shape) + tosa_graph.addOperator( + TosaOp.Op().RESHAPE, [inputs[0].name], [output.name], attr + ) diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 8d5edf386a..2692038352 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -267,6 +267,7 @@ class ArmQuantizer(Quantizer): "mul", "sigmoid", "mm", + "cat", ] def __init__(self) -> None: diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index c5da32a40a..417aa454a8 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -102,12 +102,19 @@ def is_input_ok_for_quantization(input_act: Node, gm: GraphModule): ) +def get_node_target(module: torch.nn.Module | GraphModule, target_str: str): + targets = target_str.split(".") + for target in targets[:-1]: + module = module.get_submodule(target) + return getattr(module, targets[-1]) + + def is_input_large_scalar(node: Node, gm: GraphModule): """Check if input is a large scalar value. So that we can skip quantization for the node since histc op (in HistogramObserver) only works for values up to certain upper bound """ if node.op == "get_attr" and isinstance(node.target, str): - tensor = getattr(gm, node.target) + tensor = get_node_target(gm, node.target) # torch.histc works until this upper bound HISTC_UPPER_BOUND = 3.4028235e15 return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND @@ -131,6 +138,7 @@ def is_share_obs_or_fq_op(op: Callable) -> bool: return op in [ torch.ops.aten.hardtanh.default, torch.ops.aten.hardtanh_.default, + torch.ops.aten.relu.default, torch.ops.aten.mean.default, torch.ops.aten.mean.dim, torch.ops.aten.permute.default, diff --git a/backends/arm/quantizer/quantization_annotation/__init__.py b/backends/arm/quantizer/quantization_annotation/__init__.py index 60808d2f23..68ad522fee 100644 --- a/backends/arm/quantizer/quantization_annotation/__init__.py +++ b/backends/arm/quantizer/quantization_annotation/__init__.py @@ -49,6 +49,7 @@ def decorator(annotator: AnnotatorType): from . import ( # noqa adaptive_ang_pool2d_annotator, add_annotator, + cat_annotator, conv_annotator, linear_annotator, max_pool2d_annotator, diff --git a/backends/arm/quantizer/quantization_annotation/cat_annotator.py b/backends/arm/quantizer/quantization_annotation/cat_annotator.py new file mode 100644 index 0000000000..40dd19526b --- /dev/null +++ b/backends/arm/quantizer/quantization_annotation/cat_annotator.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +from typing import Callable, List, Optional + +import torch.fx +from executorch.backends.arm.quantizer import arm_quantizer_utils +from executorch.backends.arm.quantizer.quantization_annotation import register_annotator +from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig +from torch.ao.quantization.quantizer import ( + QuantizationAnnotation, + SharedQuantizationSpec, +) +from torch.fx import Node +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions + + +@register_annotator("cat") +def _annotate_cat( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + cat_partitions = get_source_partitions(gm.graph, [torch.cat], filter_fn) + cat_partitions = list(itertools.chain.from_iterable(cat_partitions.values())) + annotated_partitions = [] + for cat_partition in cat_partitions: + annotated_partitions.append(cat_partition.nodes) + cat_node = cat_partition.output_nodes[0] + if arm_quantizer_utils.is_annotated(cat_node): + continue + + input_acts = cat_node.args[0] + input_act0 = input_acts[0] + + input_act_qspec = quantization_config.get_input_act_qspec() + shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node)) + + input_qspec_map = {} + + # First input is set to input qspec from the quantization config. + if isinstance(input_act0, Node): + if not arm_quantizer_utils.is_input_ok_for_quantization(input_act0, gm): + continue + input_qspec_map[input_act0] = input_act_qspec + + # For the rest of the inputs, share qspec with first. + # If we can't quantize any of the inputs, abort annotation. + for input_act in input_acts[1:]: + if isinstance(input_act, Node): + if not arm_quantizer_utils.is_input_ok_for_quantization(input_act, gm): + continue + if input_act is not input_act0: + input_qspec_map[input_act] = shared_with_input0_qspec + + if input_qspec_map is not None: + cat_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=shared_with_input0_qspec, + _annotated=True, + ) + return annotated_partitions diff --git a/backends/arm/test/misc/test_lifted_tensor.py b/backends/arm/test/misc/test_lifted_tensor.py new file mode 100644 index 0000000000..90aa7e2950 --- /dev/null +++ b/backends/arm/test/misc/test_lifted_tensor.py @@ -0,0 +1,42 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester + + +class LiftedTensor(torch.nn.Module): + + def __init__(self): + super().__init__() + self.lifted_tensor = torch.Tensor([[1, 2], [3, 4]]) + + def forward(self, x: torch.Tensor, length) -> torch.Tensor: + sliced = self.lifted_tensor[:, :length] + return sliced + x + + +class TestLiftedTensor(unittest.TestCase): + """Tests the ArmPartitioner with a placeholder of type lifted tensor.""" + + def test_partition_lifted_tensor(self): + tester = ( + ArmTester( + LiftedTensor(), + example_inputs=(torch.ones(2, 2), 2), + compile_spec=common.get_tosa_compile_spec(), + ) + .export() + .to_edge() + .dump_artifact() + ) + signature = tester.get_artifact().exported_program().graph_signature + assert len(signature.lifted_tensor_constants) > 0 + tester.partition() + tester.to_executorch() + tester.run_method_and_compare_outputs((torch.ones(2, 2), 2)) diff --git a/backends/arm/test/ops/test_cat.py b/backends/arm/test/ops/test_cat.py new file mode 100644 index 0000000000..f677aa5590 --- /dev/null +++ b/backends/arm/test/ops/test_cat.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from parameterized import parameterized + + +class TestCat(unittest.TestCase): + + class Cat(torch.nn.Module): + test_parameters = [ + ((torch.ones(1), torch.ones(1)), 0), + ((torch.ones(1, 2), torch.randn(1, 5), torch.randn(1, 1)), 1), + ( + ( + torch.ones(1, 2, 5), + torch.randn(1, 2, 4), + torch.randn(1, 2, 2), + torch.randn(1, 2, 1), + ), + -1, + ), + ((torch.randn(2, 2, 4, 4), torch.randn(2, 2, 4, 1)), 3), + ( + ( + 10000 * torch.randn(2, 3, 1, 4), + torch.randn(2, 7, 1, 4), + torch.randn(2, 1, 1, 4), + ), + -3, + ), + ] + + def __init__(self): + super().__init__() + + def forward(self, tensors: tuple[torch.Tensor, ...], dim: int) -> torch.Tensor: + return torch.cat(tensors, dim=dim) + + def _test_cat_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[tuple[torch.Tensor, ...], int] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .export() + .check_count({"torch.ops.aten.cat.default": 1}) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_cat_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_cat_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[tuple[torch.Tensor, ...], int] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.cat.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_cat_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data, qtol=1) + ) + + def _test_cat_u55_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[tuple[torch.Tensor, ...], int] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_u55_compile_spec(), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.cat.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_cat_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + ) + + @parameterized.expand(Cat.test_parameters) + def test_cat_tosa_MI(self, operands: tuple[torch.Tensor, ...], dim: int): + test_data = (operands, dim) + self._test_cat_tosa_MI_pipeline(self.Cat(), test_data) + + def test_cat_4d_tosa_MI(self): + square = torch.ones((2, 2, 2, 2)) + for dim in range(-3, 3): + test_data = ((square, square), dim) + self._test_cat_tosa_MI_pipeline(self.Cat(), test_data) + + @parameterized.expand(Cat.test_parameters) + def test_cat_tosa_BI(self, operands: tuple[torch.Tensor, ...], dim: int): + test_data = (operands, dim) + self._test_cat_tosa_BI_pipeline(self.Cat(), test_data) + + @parameterized.expand(Cat.test_parameters) + def test_cat_u55_BI(self, operands: tuple[torch.Tensor, ...], dim: int): + test_data = (operands, dim) + self._test_cat_u55_BI_pipeline(self.Cat(), test_data) diff --git a/backends/arm/test/ops/test_conv_combos.py b/backends/arm/test/ops/test_conv_combos.py index 88006df1a0..31051ef8f7 100644 --- a/backends/arm/test/ops/test_conv_combos.py +++ b/backends/arm/test/ops/test_conv_combos.py @@ -102,7 +102,7 @@ def forward(self, x): return self.adaptive_avg_pool2d(x) -class ComboConvBatchnormRelu(torch.nn.Module): +class ComboConvBatchnormRelu6(torch.nn.Module): edge_op_list = [ "executorch_exir_dialects_edge__ops_aten_convolution_default", "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default", @@ -235,16 +235,16 @@ def test_conv_meandim_u55_BI(self): ############################## ## Conv + batch norm + relu ## ############################## - def test_conv_batchnorm_relu_tosa_MI(self): - model = ComboConvBatchnormRelu() + def test_conv_batchnorm_relu6_tosa_MI(self): + model = ComboConvBatchnormRelu6() self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs()) - def test_conv_batchnorm_relu_tosa_BI(self): - model = ComboConvBatchnormRelu() + def test_conv_batchnorm_relu6_tosa_BI(self): + model = ComboConvBatchnormRelu6() self._test_conv_combo_tosa_BI_pipeline(model, model.get_inputs()) - def test_conv_batchnorm_relu_u55_BI(self): - model = ComboConvBatchnormRelu() + def test_conv_batchnorm_relu6_u55_BI(self): + model = ComboConvBatchnormRelu6() self._test_conv_combo_u55_BI_pipeline(model, model.get_inputs()) ################## diff --git a/backends/arm/test/ops/test_relu.py b/backends/arm/test/ops/test_relu.py new file mode 100644 index 0000000000..d2ca8540f4 --- /dev/null +++ b/backends/arm/test/ops/test_relu.py @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from typing import Tuple + +import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + ArmQuantizer, + get_symmetric_quantization_config, +) +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.backends.xnnpack.test.tester.tester import Quantize +from parameterized import parameterized + + +test_data_suite = [ + # (test_name, test_data) + ("zeros", torch.zeros(1, 10, 10, 10)), + ("ones", torch.ones(10, 10, 10)), + ("rand", torch.rand(10, 10) - 0.5), + ("randn_pos", torch.randn(10) + 10), + ("randn_neg", torch.randn(10) - 10), + ("ramp", torch.arange(-16, 16, 0.2)), +] + + +class TestRelu(unittest.TestCase): + class Relu(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(x) + + def _test_relu_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .export() + .check(["torch.ops.aten.relu.default"]) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_relu_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor] + ): + quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .export() + .check_count({"torch.ops.aten.relu.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_relu_tosa_u55_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor] + ): + quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_u55_compile_spec(), + ) + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .export() + .check_count({"torch.ops.aten.relu.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + ) + + @parameterized.expand(test_data_suite) + def test_relu_tosa_MI( + self, + test_name: str, + test_data: torch.Tensor, + ): + self._test_relu_tosa_MI_pipeline(self.Relu(), (test_data,)) + + @parameterized.expand(test_data_suite) + def test_relu_tosa_BI(self, test_name: str, test_data: torch.Tensor): + self._test_relu_tosa_BI_pipeline(self.Relu(), (test_data,)) + + @parameterized.expand(test_data_suite) + def test_relu_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor): + self._test_relu_tosa_u55_BI_pipeline(self.Relu(), (test_data,)) diff --git a/backends/arm/test/ops/test_unsqueeze.py b/backends/arm/test/ops/test_unsqueeze.py new file mode 100644 index 0000000000..6da6a196c0 --- /dev/null +++ b/backends/arm/test/ops/test_unsqueeze.py @@ -0,0 +1,103 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# +# Tests the unsqueeze op which copies the data of the input tensor (possibly with new data format) +# + +import unittest +from typing import Sequence, Tuple + +import torch + +from executorch.backends.arm.quantizer.arm_quantizer import ( + ArmQuantizer, + get_symmetric_quantization_config, +) +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester + +from executorch.backends.xnnpack.test.tester.tester import Quantize +from parameterized import parameterized + + +class TestSimpleUnsqueeze(unittest.TestCase): + class Unsqueeze(torch.nn.Module): + shapes: list[int | Sequence[int]] = [5, (5, 5), (5, 5), (5, 5, 5)] + test_parameters: list[tuple[torch.Tensor]] = [(torch.ones(n),) for n in shapes] + + def forward(self, x: torch.Tensor, dim): + return x.unsqueeze(dim) + + def _test_unsqueeze_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor, int] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .export() + .check_count({"torch.ops.aten.unsqueeze.default": 1}) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_unsqueeze_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor, int] + ): + quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .export() + .check_count({"torch.ops.aten.unsqueeze.default": 1}) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data, qtol=1) + ) + + def _test_unsqueeze_tosa_u55_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor, int] + ): + quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_u55_compile_spec(), + ) + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .export() + .check_count({"torch.ops.aten.unsqueeze.default": 1}) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + ) + + @parameterized.expand(Unsqueeze.test_parameters) + def test_unsqueeze_tosa_MI(self, test_tensor: torch.Tensor): + for i in range(-test_tensor.dim() - 1, test_tensor.dim() + 1): + self._test_unsqueeze_tosa_MI_pipeline(self.Unsqueeze(), (test_tensor, i)) + + @parameterized.expand(Unsqueeze.test_parameters) + def test_unsqueeze_tosa_BI(self, test_tensor: torch.Tensor): + self._test_unsqueeze_tosa_BI_pipeline(self.Unsqueeze(), (test_tensor, 0)) + + @parameterized.expand(Unsqueeze.test_parameters) + def test_unsqueeze_u55_BI(self, test_tensor: torch.Tensor): + self._test_unsqueeze_tosa_u55_pipeline(self.Unsqueeze(), (test_tensor, 0)) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 41fc907fdf..8a02c63d7a 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -242,16 +242,18 @@ def run_method_and_compare_outputs( # Loop inputs and compare reference stage with the compared stage. for run_iteration in range(num_runs): reference_input = inputs if inputs else next(self.generate_random_inputs()) - if is_nhwc: - test_input = self.transpose_data_format(reference_input, "NHWC") - else: - test_input = reference_input # Test parameters can include constants that are used in eager mode but are already set as attributes # in TOSA. Therefore, only accept torch.Tensor inputs. - test_input = [ - tensor for tensor in test_input if isinstance(tensor, torch.Tensor) - ] + test_input: list[torch.Tensor] = [] + for arg in reference_input: + if isinstance(arg, torch.Tensor): + test_input.append(arg) + if isinstance(arg, tuple) and isinstance(arg[0], torch.Tensor): + test_input.extend(list(arg)) + + if is_nhwc: + test_input = self.transpose_data_format(test_input, "NHWC") input_shapes = [ generated_input.shape if hasattr(generated_input, "shape") else (1,) diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 405f8b5db4..e1494f8d20 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -60,13 +60,13 @@ def convert_pt2( # Export with dynamo model_gm = capture_pre_autograd_graph(model, inputs) - if model_gm_has_SDPA(model_gm): + if model_gm_has_SDPA(model_gm): # pyre-fixme[6] # Decompose SDPA - DecomposeScaledDotProductAttention(False)(model_gm) + DecomposeScaledDotProductAttention(False)(model_gm) # pyre-fixme[6] # Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882 # for details). - result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) + result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) # pyre-fixme[6] assert result is not None model_gm = result.graph_module diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index dd704c35c0..08fd907c40 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1998,6 +1998,55 @@ def test_llama3_8b(self): model_out = msg["result"] self.assertTrue(model_out.startswith(prompt)) + def test_stable_diffusion(self): + if not self.required_envs(): + self.skipTest("missing required envs") + + prompt = "a photo of an astronaut riding a horse on mars" + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py", + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--text_encoder_bin", + f"{self.artifact_dir}/text_encoder.serialized.bin", + "--unet_bin", + f"{self.artifact_dir}/unet.serialized.bin", + "--vae_bin", + f"{self.artifact_dir}/vae.serialized.bin", + "--vocab_json", + f"{self.artifact_dir}/vocab.json", + "--num_time_steps", + "20", + "--ip", + self.ip, + "--port", + str(self.port), + "--prompt", + f"{prompt}", + "--fix_latents", + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + # For the default settings and prompt, the expected results will be {PSNR: 23.258, SSIM: 0.852} + self.assertGreaterEqual(msg["PSNR"], 20) + self.assertGreaterEqual(msg["SSIM"], 0.8) + class TestExampleScript(TestQNN): def required_envs(self, conditions=None) -> bool: diff --git a/backends/transforms/addmm_mm_to_linear.py b/backends/transforms/addmm_mm_to_linear.py index 7855de617b..358cbb7ac1 100644 --- a/backends/transforms/addmm_mm_to_linear.py +++ b/backends/transforms/addmm_mm_to_linear.py @@ -130,7 +130,7 @@ def replace_addmm_mm_with_linear(graph: torch.fx.Graph) -> torch.fx.Graph: "call_function", ops.aten.linear.default, args ) node.replace_all_uses_with(linear_node) - output_val = linear_node.target( + output_val = linear_node.target( # pyre-fixme[29] args[0].meta["val"], args[1].meta["val"], args[2].meta["val"] ) else: @@ -147,7 +147,7 @@ def replace_addmm_mm_with_linear(graph: torch.fx.Graph) -> torch.fx.Graph: "call_function", ops.aten.linear.default, args ) node.replace_all_uses_with(linear_node) - output_val = linear_node.target( + output_val = linear_node.target( # pyre-fixme[29] args[0].meta["val"], args[1].meta["val"] ) linear_node.meta = node.meta diff --git a/backends/transforms/decompose_sdpa.py b/backends/transforms/decompose_sdpa.py index 6dbbf564f5..329dab96df 100644 --- a/backends/transforms/decompose_sdpa.py +++ b/backends/transforms/decompose_sdpa.py @@ -34,7 +34,7 @@ def call( # refer to pytorch/test/test_decomp.py decomposed_module = make_fx( node.target, - decomposition_table=get_decompositions( + decomposition_table=get_decompositions( # pyre-fixme[6] [ torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default, ] diff --git a/backends/xnnpack/partition/config/gemm_configs.py b/backends/xnnpack/partition/config/gemm_configs.py index a20285483b..54c07ad5ab 100644 --- a/backends/xnnpack/partition/config/gemm_configs.py +++ b/backends/xnnpack/partition/config/gemm_configs.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging from itertools import chain from typing import cast, List, Optional, Tuple @@ -31,12 +32,16 @@ from executorch.exir.backend.canonical_partitioners.config_partitioner import ( format_target_name, ) +from executorch.exir.backend.utils import WhyNoPartition from torch.export import ExportedProgram from torch.fx.passes.utils.source_matcher_utils import ( get_source_partitions, SourcePartition, ) +logger = logging.getLogger(__name__) +why = WhyNoPartition(logger=logger) + class GEMMConfig(XNNPartitionerConfig): """ @@ -60,6 +65,8 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: return False is_valid, _ = self.get_deps(node, ep) + if not is_valid: + why(node, "Failed to get valid dependent nodes.") return is_valid def get_node_and_deps( @@ -282,10 +289,12 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: conv_stride = cast(List[int], node.args[3]) if len(conv_stride) > 2: + why(node, "Only support 1D + 2D Conv") return False # Only support 1D + 2D Conv transposed = cast(bool, node.args[6]) if transposed: + why(node, "Transposed Conv is not supported") return False # Currently don't support transposed conv return True diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index e309a3bd03..69defae021 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging from typing import cast, List, Optional import torch @@ -16,8 +17,12 @@ from executorch.exir.backend.canonical_partitioners.config_partitioner import ( format_target_name, ) +from executorch.exir.backend.utils import WhyNoPartition from torch.export import ExportedProgram +logger = logging.getLogger(__name__) +why = WhyNoPartition(logger=logger) + class GenericNodePartitionerConfig(XNNPartitionerConfig): def __init__(self, fused_act: Optional[List[str]] = None): @@ -141,9 +146,22 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: if len(args) >= 7: divisor_override = cast(int, args[6]) - return ( - not (ceil_mode or count_include_pad) and divisor_override == pooling_region - ) + if ceil_mode: + why(node, reason="ceil mode is not supported") + return False + + if count_include_pad: + why( + node, + reason="zero-padding in the averaging calculation is not supported", + ) + return False + + if divisor_override != pooling_region: + why(node, reason="divisor override is not supported") + return False + + return True def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] @@ -160,7 +178,15 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: return False num_tensors = len(node.all_input_nodes) - return num_tensors >= 2 and num_tensors <= 4 + + if not (num_tensors >= 2 and num_tensors <= 4): + why( + node, + reason=f"only support concatenation of 2 - 4 tensors, got {num_tensors} tensors", + ) + return False + + return True def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] @@ -210,7 +236,14 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: dim = cast(int, node.args[1]) node_input = node.all_input_nodes[0] tensor_dims = node_input.meta["val"].dim() - return dim == -1 or dim == tensor_dims - 1 + + if not (dim == -1 or dim == tensor_dims - 1): + why( + node, + reason=f"dim must be the last dim, got dim = {dim} for tensor of rank {tensor_dims}", + ) + return False + return True def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] @@ -255,7 +288,10 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: return False is_ceil_mode = len(node.args) >= 6 and cast(bool, node.args[5]) - return not is_ceil_mode + if is_ceil_mode: + why(node, reason="ceil mode is not supported") + return False + return True def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] @@ -309,7 +345,20 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: dims = node.args[1] output_dims = node.meta["val"].dim() - return dims in ([-2, -1], [-1, -2]) and output_dims == 4 + if dims not in ([-2, -1], [-1, -2]): + why( + node, + reason="mean.dim only supports averaging 4D tensors across the innermost dimensions", + ) + return False + + if output_dims != 4: + why( + node, + reason=f"mean.dim only supports averaging 4D tensors, got tensor of rank {output_dims}", + ) + return False + return True def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] @@ -340,7 +389,15 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: return False power = node.args[1] - return isinstance(power, int) and power == 2 + + if not isinstance(power, int): + why(node, reason=f"only support int powers, got {power}") + return False + + if power != 2: + why(node, reason=f"only support power == 2, got {power}") + return False + return True def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] @@ -372,10 +429,18 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: for dim in input_shape: if not isinstance(dim, int) or dim == 0: + why( + node, + reason=f"input tensor has invalid shape, dim: {dim} of type {type(dim)}. Expecting non-zero, int values.", + ) return False for dim in output_shape: if not isinstance(dim, int) or dim == 0: + why( + node, + reason=f"output tensor has invalid shape, dim: {dim} of type {type(dim)}. Expecting non-zero, int values.", + ) return False return True @@ -431,7 +496,14 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: return False mask_node = node.all_input_nodes[3] mask_rank = mask_node.meta["val"].dim() - return mask_rank == 2 + if mask_rank != 2: + why( + node, + reason=f"mask must have rank 2, got mask of rank {mask_rank}", + ) + return False + + return True def get_original_aten(self) -> Optional[torch._ops.OpOverload]: return torch.ops.aten.scaled_dot_product_attention.default diff --git a/backends/xnnpack/partition/config/node_configs.py b/backends/xnnpack/partition/config/node_configs.py index 501216eaae..1e4d1f05fe 100644 --- a/backends/xnnpack/partition/config/node_configs.py +++ b/backends/xnnpack/partition/config/node_configs.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging import operator from typing import List, Optional @@ -19,8 +20,12 @@ from executorch.exir.backend.canonical_partitioners.config_partitioner import ( format_target_name, ) +from executorch.exir.backend.utils import WhyNoPartition from torch.export import ExportedProgram +logger = logging.getLogger(__name__) +why = WhyNoPartition(logger=logger) + class BatchNormConfig(XNNPartitionerConfig): target_name = "_native_batch_norm_legit_no_training.default" @@ -38,9 +43,15 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: conv_name = format_target_name(conv.target.__name__) # pyre-ignore if conv_name not in ["convolution.default"]: + why(node, f"Invalid conv target {conv_name}") + return False + + can_fuse = FuseBatchNormWithConvPass.can_fuse(conv, bn, ep) + if not can_fuse: + why(node, "BatchNorm cannot be fused with Convolution") return False - return FuseBatchNormWithConvPass.can_fuse(conv, bn, ep) + return True def get_node_and_deps( self, node: torch.fx.Node, ep: ExportedProgram @@ -76,15 +87,18 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: output_0 = node_val[0] # Don't check indicies dtype if output_0.dtype not in supported_dtypes: + why(node, f"Unsupported output dtype {output_0.dtype}") return False max_input = node.all_input_nodes[0] if max_input.meta.get("val").dtype not in supported_dtypes: + why(node, f"Unsupported input dtype {max_input.meta.get('val').dtype}") return False # Make sure that all users are getitems of the first output for user in node.users: if not (user.target == operator.getitem and user.args[1] == 0): + why(node, "Unsupported user of max.dim") return False return True @@ -111,7 +125,11 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: return False weight = node.all_input_nodes[1] - return is_param_node(ep, weight) + is_param = is_param_node(ep, weight) + if not is_param: + why(node, "Prelu weight must be a parameter") + return False + return True def get_original_aten(self) -> Optional[torch._ops.OpOverload]: return torch.ops.aten.prelu.default diff --git a/backends/xnnpack/partition/config/xnnpack_config.py b/backends/xnnpack/partition/config/xnnpack_config.py index 840ffbd43b..f39a651e19 100644 --- a/backends/xnnpack/partition/config/xnnpack_config.py +++ b/backends/xnnpack/partition/config/xnnpack_config.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging from abc import abstractmethod from enum import Enum from typing import List, Optional @@ -13,8 +14,12 @@ format_target_name, PartitionerConfig, ) +from executorch.exir.backend.utils import WhyNoPartition from torch.export import ExportedProgram +logger = logging.getLogger(__name__) +why = WhyNoPartition(logger=logger) + class ConfigPrecisionType(Enum): FP32 = 1 @@ -22,7 +27,6 @@ class ConfigPrecisionType(Enum): DYNAMIC_QUANT = 3 -# TODO: add WhyNotPartition to XNNPartitionerConfig class XNNPartitionerConfig(PartitionerConfig): """ Base partitioner config for XNNPACK Partitioner Configs. Base wrapper class @@ -125,10 +129,12 @@ def check_common_constraints( ) if len(self.enabled_precision_types) == 0: + why(node, reason="not enabled precision types") return False has_valid_dtypes = self._check_node_has_valid_dtype(node) if not has_valid_dtypes: + why(node, reason="invalid dtype") return False return True diff --git a/backends/xnnpack/partition/xnnpack_partitioner.py b/backends/xnnpack/partition/xnnpack_partitioner.py index f582ea753f..9afbefebce 100644 --- a/backends/xnnpack/partition/xnnpack_partitioner.py +++ b/backends/xnnpack/partition/xnnpack_partitioner.py @@ -5,6 +5,8 @@ # LICENSE file in the root directory of this source tree. import itertools + +import logging from typing import List, Optional, Type, Union from executorch.backends.xnnpack.partition.config import ALL_PARTITIONER_CONFIGS @@ -21,6 +23,9 @@ from executorch.exir.backend.partitioner import DelegationSpec from torch.fx.passes.infra.partitioner import Partition +logging.basicConfig(level=logging.WARNING) +logger = logging.getLogger(__name__) + class XnnpackPartitioner(ConfigerationBasedPartitioner): def __init__( @@ -30,7 +35,16 @@ def __init__( Union[ConfigPrecisionType, List[ConfigPrecisionType]] ] = None, per_op_mode=False, + verbose: bool = False, ): + """ + @verbose: if True, print out more information about the partitioner. + Default level is WARNING. If verbose is True, level is set to DEBUG. + """ + if verbose: + logger.setLevel(logging.DEBUG) + logger.debug("Verbose logging enabled for XNNPACK partitioner.") + delegation_spec = DelegationSpec(XnnpackBackend.__name__, []) configs_to_use = configs or ALL_PARTITIONER_CONFIGS # Can do logic and have extra args to filter/delete/select diff --git a/backends/xnnpack/passes/channels_last_tagged_reshape_pass.py b/backends/xnnpack/passes/channels_last_tagged_reshape_pass.py index f1f9a69acc..692f1a9d14 100644 --- a/backends/xnnpack/passes/channels_last_tagged_reshape_pass.py +++ b/backends/xnnpack/passes/channels_last_tagged_reshape_pass.py @@ -124,7 +124,7 @@ def create_call_function_node( "call_function", target=target, args=args, - kwargs=( + kwargs=( # pyre-fixme[6] {"memory_format": memory_format} if memory_format is not None else {} ), ) diff --git a/backends/xnnpack/passes/convert_to_sdpa.py b/backends/xnnpack/passes/convert_to_sdpa.py index 76bb24cc94..97aca5491d 100644 --- a/backends/xnnpack/passes/convert_to_sdpa.py +++ b/backends/xnnpack/passes/convert_to_sdpa.py @@ -83,7 +83,7 @@ def create_sdpa( kwargs={"scale": scale}, ) - sdpa_node.meta["val"] = sdpa_node.target( + sdpa_node.meta["val"] = sdpa_node.target( # pyre-fixme[29] *[n.meta["val"] for n in match.placeholder_nodes], scale=scale, ) diff --git a/backends/xnnpack/test/ops/mean_dim.py b/backends/xnnpack/test/ops/mean_dim.py index e39d3aee08..3bac5f3239 100644 --- a/backends/xnnpack/test/ops/mean_dim.py +++ b/backends/xnnpack/test/ops/mean_dim.py @@ -56,6 +56,19 @@ def test_fp32_mean_dim_unsupported(self): .check_count({"executorch_exir_dialects_edge__ops_aten_mean_dim": 1}) ) + def test_fp32_mean_dim_unsupported_3d(self): + """ + XNNPack mean.dim implementation only supports 4D tensors. + """ + inputs = (torch.randn(1, 5, 4),) + ( + Tester(self.MeanDim((-1, -2)), inputs) + .export() + .check_count({"torch.ops.aten.mean.dim": 1}) + .to_edge_transform_and_lower() + .check_count({"executorch_exir_dialects_edge__ops_aten_mean_dim": 1}) + ) + def test_qs8_mean_dim(self): inputs = (torch.randn(1, 5, 4, 4),) ( diff --git a/examples/models/llama2/source_transformation/quantize.py b/examples/models/llama2/source_transformation/quantize.py index bb014145bd..4f3eaf1125 100644 --- a/examples/models/llama2/source_transformation/quantize.py +++ b/examples/models/llama2/source_transformation/quantize.py @@ -96,7 +96,7 @@ def quantize( try: # torchao 0.3+ - from torchao._eval import InputRecorder + from torchao._eval import InputRecorder # pyre-fixme[21] except ImportError: from torchao.quantization.GPTQ import InputRecorder # pyre-ignore @@ -110,7 +110,7 @@ def quantize( ) inputs = ( - InputRecorder( + InputRecorder( # pyre-fixme[16] tokenizer, calibration_seq_length, None, # input_prep_func diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index 390528844f..4f2aa6576b 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -23,7 +23,15 @@ replace_sdpa_with_custom_op, ) from executorch.examples.models.llava.model import LlavaModel -from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) + +from executorch.exir.passes import MemoryPlanningPass +from executorch.exir.passes.quant_fusion_pass import QuantFusionPass +from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.extension.llm.export.builder import DType, LLMEdgeManager from executorch.extension.llm.tokenizer.tokenizer import Tokenizer @@ -199,7 +207,23 @@ def export_all(llava_model: LlavaModel): compile_config=EdgeCompileConfig(_check_ir_validity=False), ) - executorch_program = lowered_and_edge.to_executorch() + executorch_program = lowered_and_edge.to_executorch( + ExecutorchBackendConfig( + extract_constant_segment=True, + extract_delegate_segments=True, + passes=[ + QuantFusionPass(), + ], + memory_planning_pass=MemoryPlanningPass("greedy", alloc_graph_input=False), + sym_shape_eval_pass={ + "image_encoder": ConstraintBasedSymShapeEvalPass(), + }, + ) + ) + for execution_plan in executorch_program._emitter_output.program.execution_plan: + logging.info( + f"Required memory for activation in bytes: {execution_plan.non_const_buffer_sizes}" + ) return executorch_program @@ -253,13 +277,6 @@ def main(): with open(args.pte_name, "wb") as f: executorch_program.write_to_file(f) - logging.info( - "Required memory for activation in bytes: {}".format( - executorch_program._emitter_output.program.execution_plan[ - 0 - ].non_const_buffer_sizes - ), - ) logging.info(f"Exported ExecuTorch program to {args.pte_name}") # artifacts diff --git a/examples/models/llava/test/test_pte.py b/examples/models/llava/test/test_pte.py index cdf24761c5..d793b2ae22 100644 --- a/examples/models/llava/test/test_pte.py +++ b/examples/models/llava/test/test_pte.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging import sys import torch @@ -17,6 +18,10 @@ from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.DEBUG, format=FORMAT) + + def main(): args = sys.argv[1:] llava_module = _load_for_executorch(args[0]) @@ -41,7 +46,10 @@ def main(): start_pos += pte_prefill_before_img.shape[1] # pte prefill image + logging.warning("Image encoder started") pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0] + logging.warning("Image encoder finished") + logging.warning("Image token prefill started") pte_prefill_img = llava_module.run_method( "text_model", ( @@ -49,11 +57,13 @@ def main(): pte_embeds_img, ), )[0] + logging.warning("Image token prefill finished") print(pte_prefill_img) start_pos += pte_prefill_img.shape[1] # pte prefill prompt after img + logging.warning("Text token prefill started") pte_embeds_after_img = llava_module.run_method( "token_embedding", (prompt_after_image,) )[0] @@ -61,6 +71,7 @@ def main(): "text_model", (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img), )[0] + logging.warning("Text token prefill finished") print(pte_prefill_after_img) # being tested, using llama_transformer diff --git a/examples/models/phi-3-mini/export_phi-3-mini.py b/examples/models/phi-3-mini/export_phi-3-mini.py index ab5e04c307..553fded67f 100644 --- a/examples/models/phi-3-mini/export_phi-3-mini.py +++ b/examples/models/phi-3-mini/export_phi-3-mini.py @@ -67,7 +67,7 @@ def export(args) -> None: model = capture_pre_autograd_graph( model, example_inputs, dynamic_shapes=dynamic_shapes ) - model = prepare_pt2e(model, xnnpack_quantizer) + model = prepare_pt2e(model, xnnpack_quantizer) # pyre-fixme[6] model(*example_inputs) model = convert_pt2e(model, fold_quantize=False) DuplicateDynamicQuantChainPass()(model) diff --git a/examples/qualcomm/CMakeLists.txt b/examples/qualcomm/CMakeLists.txt index fd9c1388b2..94af209cb6 100644 --- a/examples/qualcomm/CMakeLists.txt +++ b/examples/qualcomm/CMakeLists.txt @@ -81,3 +81,8 @@ add_subdirectory( add_subdirectory( ${CMAKE_CURRENT_SOURCE_DIR}/qaihub_scripts/llama ) + +# build qaihub_stable_diffusion_runner +add_subdirectory( + ${CMAKE_CURRENT_SOURCE_DIR}/qaihub_scripts/stable_diffusion +) diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/CMakeLists.txt b/examples/qualcomm/qaihub_scripts/stable_diffusion/CMakeLists.txt new file mode 100644 index 0000000000..c897f5f9f8 --- /dev/null +++ b/examples/qualcomm/qaihub_scripts/stable_diffusion/CMakeLists.txt @@ -0,0 +1,26 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# preprocess qaihub_stable_diffusion_runner_src files +set(_qaihub_stable_diffusion_runner__srcs + ${CMAKE_CURRENT_LIST_DIR}/qaihub_stable_diffusion_runner.cpp + ${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp + ${CMAKE_CURRENT_LIST_DIR}/runner/runner.h +) + +# build qaihub_stable_diffusion_runner +add_executable(qaihub_stable_diffusion_runner ${_qaihub_stable_diffusion_runner__srcs}) +target_include_directories(qaihub_stable_diffusion_runner + PUBLIC ${_common_include_directories} +) +target_link_libraries(qaihub_stable_diffusion_runner + qnn_executorch_backend + executorch_no_prim_ops + extension_data_loader + extension_module + gflags +) +target_compile_options(qaihub_stable_diffusion_runner PUBLIC ${_common_compile_options}) diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/README.md b/examples/qualcomm/qaihub_scripts/stable_diffusion/README.md new file mode 100644 index 0000000000..21b3370df7 --- /dev/null +++ b/examples/qualcomm/qaihub_scripts/stable_diffusion/README.md @@ -0,0 +1,35 @@ +# Summary + +## Overview +This file provides you the instructions to run Stable-Diffusion-v2.1 with different parameters via Qualcomm HTP backend. We will demonstrate how to run Stable Diffusion v2.1 on mobile devices using context binaries from Qualcomm AI Hub’s Stable Diffusion v2.1 + +Please check corresponding section for more information. + +## Stable-Diffusion-v2.1 +The model architecture, scheduler, and time embedding are from the [stabilityai/stable-diffusion-2-1-base](https://huggingface.co/stabilityai/stable-diffusion-2-1-base). + +### Instructions +#### Step 1: Setup +1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch. +2. Follow the [tutorial](https://pytorch.org/executorch/stable/build-run-qualcomm-ai-engine-direct-backend.html) to build Qualcomm AI Engine Direct Backend. + +#### Step2: Prepare Model +1. Download the context binaries for TextEncoder, UNet, and VAEDecoder under https://huggingface.co/qualcomm/Stable-Diffusion-v2.1/tree/main +2. Download vocab.json under https://huggingface.co/openai/clip-vit-base-patch32/tree/main + + +#### Step3: Install Requirements +Before running the code, you need to install the necessary Python packages. + +We have verified the code with `diffusers`==0.29.0 and `piq`==0.8.0. Please follow the instructions here to install the required items: +```bash +sh examples/qualcomm/qaihub_scripts/stable_diffusion/install_requirements.sh +``` + +#### Step4: Run default example +In this example, we execute the script for 20 time steps with the `prompt` 'a photo of an astronaut riding a horse on mars': +```bash +python examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py -a ${ARTIFACTS} -b build_android -m ${SOC_MODEL} --s ${SERIAL_NUM} --text_encoder_bin ${PATH_TO_TEXT_ENCODER_CONTEXT_BINARY} --unet_bin ${PATH_TO_UNET_CONTEXT_BINARY} --vae_bin ${PATH_TO_VAE_CONTEXT_BINARY} --vocab_json ${PATH_TO_VOCAB_JSON_FILE} --num_time_steps 20 --prompt "a photo of an astronaut riding a horse on mars" +``` +- Please replace `${PATH_TO_TEXT_ENCODER_CONTEXT_BINARY}`, `${PATH_TO_UNET_CONTEXT_BINARY}`, and `${PATH_TO_VAE_CONTEXT_BINARY}` with the actual paths to your AI Hub context binary files. +- Please replace `${PATH_TO_VOCAB_JSON_FILE}` with the actual path to your vocab.json file. diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/install_requirements.sh b/examples/qualcomm/qaihub_scripts/stable_diffusion/install_requirements.sh new file mode 100755 index 0000000000..bbb4767bee --- /dev/null +++ b/examples/qualcomm/qaihub_scripts/stable_diffusion/install_requirements.sh @@ -0,0 +1,3 @@ +# For Stable Diffusion V2.1 +pip install diffusers==0.29.0 +pip install piq==0.8.0 diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py b/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py new file mode 100644 index 0000000000..862db31f17 --- /dev/null +++ b/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py @@ -0,0 +1,472 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import gc +import json +import os +from multiprocessing.connection import Client + +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor +import numpy as np +import piq +import torch +from diffusers import EulerDiscreteScheduler, UNet2DConditionModel +from diffusers.models.embeddings import get_timestep_embedding +from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( + QcomChipset, +) +from executorch.backends.qualcomm.utils.utils import ( + canonicalize_program, + from_context_binary, + generate_htp_compiler_spec, + generate_qnn_executorch_compiler_spec, + generate_qnn_executorch_option, +) + +from executorch.examples.qualcomm.qaihub_scripts.stable_diffusion.stable_diffusion_lib import ( + StableDiffusion, +) +from executorch.examples.qualcomm.utils import ( + setup_common_args_and_variables, + SimpleADB, +) +from executorch.exir.backend.backend_api import to_backend +from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass +from PIL import Image +from torchvision.transforms import ToTensor + +target_names = ("text_encoder", "unet", "vae") + + +def get_quant_data( + encoding: dict, data: torch.Tensor, input_model: str, input_index: int +): + scale = encoding[f"{input_model}_input"]["scale"][input_index] + offset = encoding[f"{input_model}_input"]["offset"][input_index] + if offset < 0: + quant_data = data.div(scale).sub(offset).clip(min=0, max=65535).detach() + else: + quant_data = data.div(scale).add(offset).clip(min=0, max=65535).detach() + + return quant_data.to(dtype=torch.uint16) + + +def get_encoding( + path_to_shard: str, + compiler_specs: str, + get_input: bool, + get_output: bool, + num_input: int, + num_output: int, +): + encoding_list = [] + with open(path_to_shard, "rb") as f: + ctx_bin = f.read() + qnn_mgr = PyQnnManagerAdaptor.QnnManager( + generate_qnn_executorch_option(compiler_specs), ctx_bin + ) + assert qnn_mgr.Init().value == 0, "failed to load context binary" + qnn_mgr.AllocateTensor() + if get_input: + encoding_input = {"scale": [], "offset": []} + for i in range(num_input): + inputs = qnn_mgr.GetGraphInputs()[i] + encoding = inputs.GetEncodings() + encoding_input["scale"].append(encoding.data["scale"].item()) + encoding_input["offset"].append(encoding.data["offset"].item()) + encoding_list.append(encoding_input) + if get_output: + encoding_output = {"scale": [], "offset": []} + for i in range(num_output): + outputs = qnn_mgr.GetGraphOutputs()[i] + encoding = outputs.GetEncodings() + encoding_output["scale"].append(encoding.data["scale"].item()) + encoding_output["offset"].append(encoding.data["offset"].item()) + encoding_list.append(encoding_output) + qnn_mgr.Destroy() + return encoding_list + + +def get_encodings( + path_to_shard_encoder: str, + path_to_shard_unet: str, + path_to_shard_vae: str, + compiler_specs, +): + text_encoder_encoding = get_encoding( + path_to_shard=path_to_shard_encoder, + compiler_specs=compiler_specs, + get_input=False, + get_output=True, + num_input=1, + num_output=1, + ) + unet_encoding = get_encoding( + path_to_shard=path_to_shard_unet, + compiler_specs=compiler_specs, + get_input=True, + get_output=True, + num_input=3, + num_output=1, + ) + vae_encoding = get_encoding( + path_to_shard=path_to_shard_vae, + compiler_specs=compiler_specs, + get_input=True, + get_output=True, + num_input=1, + num_output=1, + ) + + return ( + text_encoder_encoding[0], + unet_encoding[0], + unet_encoding[1], + vae_encoding[0], + vae_encoding[1], + ) + + +def get_time_embedding(timestep, time_embedding): + timestep = torch.tensor([timestep]) + t_emb = get_timestep_embedding(timestep, 320, True, 0) + emb = time_embedding(t_emb) + + return emb + + +def build_args_parser(): + parser = setup_common_args_and_variables() + + parser.add_argument( + "-a", + "--artifact", + help="Path for storing generated artifacts by this example. Default ./stable_diffusion_qai_hub", + default="./stable_diffusion_qai_hub", + type=str, + ) + + parser.add_argument( + "--pte_prefix", + help="Prefix of pte files name. Default qaihub_stable_diffusion", + default="qaihub_stable_diffusion", + type=str, + ) + + parser.add_argument( + "--text_encoder_bin", + type=str, + default=None, + help="[For AI hub ctx binary] Path to Text Encoder.", + required=True, + ) + + parser.add_argument( + "--unet_bin", + type=str, + default=None, + help="[For AI hub ctx binary] Path to UNet.", + required=True, + ) + + parser.add_argument( + "--vae_bin", + type=str, + default=None, + help="[For AI hub ctx binary] Path to Vae Decoder.", + required=True, + ) + + parser.add_argument( + "--prompt", + default="a photo of an astronaut riding a horse on mars", + type=str, + help="Prompt to generate image from.", + ) + + parser.add_argument( + "--num_time_steps", + default=20, + type=int, + help="The number of diffusion time steps.", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=7.5, + help="Strength of guidance (higher means more influence from prompt).", + ) + + parser.add_argument( + "--vocab_json", + type=str, + help="Path to tokenizer vocab.json file. Can get vocab.json under https://huggingface.co/openai/clip-vit-base-patch32/tree/main", + required=True, + ) + + parser.add_argument( + "--pre_gen_pte", + help="folder path to pre-compiled ptes", + default=None, + type=str, + ) + + parser.add_argument( + "--fix_latents", + help="Enable this option to fix the latents in the unet diffuse step.", + action="store_true", + ) + + return parser + + +def broadcast_ut_result(output_image, seed): + sd = StableDiffusion(seed) + to_tensor = ToTensor() + target = sd(args.prompt, 512, 512, args.num_time_steps) + target = to_tensor(target).unsqueeze(0) + output_tensor = to_tensor( + Image.fromarray(np.round(output_image[0] * 255).astype(np.uint8)[0]) + ).unsqueeze(0) + + psnr_piq = piq.psnr(target, output_tensor) + ssim_piq = piq.ssim(target, output_tensor) + print(f"PSNR: {round(psnr_piq.item(), 3)}, SSIM: {round(ssim_piq.item(), 3)}") + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"PSNR": psnr_piq.item(), "SSIM": ssim_piq.item()})) + + +def save_result(output_image): + img = Image.fromarray(np.round(output_image[0] * 255).astype(np.uint8)[0]) + save_path = f"{args.artifact}/outputs/output_image.jpg" + img.save(save_path) + print(f"Output image saved at {save_path}") + + +def gen_pte_from_ctx_bin(args, compiler_specs): + # Create custom operators as context loader + bundle_programs = [ + from_context_binary(args.text_encoder_bin, "ctx_loader_0"), + from_context_binary(args.unet_bin, "ctx_loader_1"), + from_context_binary(args.vae_bin, "ctx_loader_2"), + ] + + # Lower with QnnBackend + lowered_modules = [ + to_backend("QnnBackend", prog["edge_program"], compiler_specs) + for prog in bundle_programs + ] + # Setup spill-fill buffer for relieving runtime memory usage + canonicalize_program(lowered_modules) + # export pte files + pte_files = [] + for target_name in target_names: + memory_planning_pass = MemoryPlanningPass( + memory_planning_algo="greedy", + alloc_graph_input=False, + alloc_graph_output=False, + ) + pte_files.append(f"{args.artifact}/{args.pte_prefix}_{target_name}.pte") + with open(pte_files[-1], "wb") as file: + file.write( + lowered_modules[0].buffer( + extract_delegate_segments=True, memory_planning=memory_planning_pass + ) + ) + # GC for reducing host memory consuming + bundle_programs.pop(0) + lowered_modules.pop(0) + gc.collect() + + return pte_files + + +def inference(args, compiler_specs, pte_files): + # Loading a pretrained EulerDiscreteScheduler from the https://huggingface.co/stabilityai/stable-diffusion-2-1-base. + scheduler = EulerDiscreteScheduler.from_pretrained( + "stabilityai/stable-diffusion-2-1-base", subfolder="scheduler", revision="main" + ) + + # Loading a pretrained UNet2DConditionModel (which includes the time embedding) from the https://huggingface.co/stabilityai/stable-diffusion-2-1-base. + time_embedding = UNet2DConditionModel.from_pretrained( + "stabilityai/stable-diffusion-2-1-base", subfolder="unet", revision="main" + ).time_embedding + + scheduler.set_timesteps(args.num_time_steps) + scheduler.config.prediction_type = "epsilon" + # Get encoding of unet and vae + ( + encoder_output, + unet_input, + unet_output, + vae_input, + vae_output, + ) = get_encodings( + args.text_encoder_bin, + args.unet_bin, + args.vae_bin, + compiler_specs, + ) + encoding = { + "encoder_output": encoder_output, + "unet_input": unet_input, + "unet_output": unet_output, + "vae_input": vae_input, + "vae_output": vae_output, + } + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=args.build_folder, + pte_path=pte_files, + workspace=f"/data/local/tmp/executorch/{args.pte_prefix}", + device_id=args.device, + host_id=args.host, + soc_model=args.model, + runner="examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion_runner", + ) + + input_unet = () + input_list_unet = "" + + for i, t in enumerate(scheduler.timesteps): + time_emb = get_quant_data( + encoding, get_time_embedding(t, time_embedding), "unet", 1 + ) + input_list_unet += f"input_{i}_0.raw\n" + input_unet = input_unet + (time_emb,) + + qnn_executor_runner_args = [ + f"--text_encoder_path {adb.workspace}/{args.pte_prefix}_text_encoder.pte", + f"--unet_path {adb.workspace}/{args.pte_prefix}_unet.pte", + f"--vae_path {adb.workspace}/{args.pte_prefix}_vae.pte", + f"--input_list_path {adb.workspace}/input_list.txt", + f"--output_folder_path {adb.output_folder}", + f'--prompt "{args.prompt}"', + f"--guidance_scale {args.guidance_scale}", + f"--num_time_steps {args.num_time_steps}", + f"--vocab_json {adb.workspace}/vocab.json", + ] + if args.fix_latents: + qnn_executor_runner_args.append("--fix_latents") + + text_encoder_output_scale = encoding["encoder_output"]["scale"][0] + text_encoder_output_offset = encoding["encoder_output"]["offset"][0] + unet_input_latent_scale = encoding["unet_input"]["scale"][0] + unet_input_latent_offset = encoding["unet_input"]["offset"][0] + unet_input_text_emb_scale = encoding["unet_input"]["scale"][2] + unet_input_text_emb_offset = encoding["unet_input"]["offset"][2] + unet_output_scale = encoding["unet_output"]["scale"][0] + unet_output_offset = encoding["unet_output"]["offset"][0] + vae_input_scale = encoding["vae_input"]["scale"][0] + vae_input_offset = encoding["vae_input"]["offset"][0] + vae_output_scale = encoding["vae_output"]["scale"][0] + vae_output_offset = encoding["vae_output"]["offset"][0] + + qnn_executor_runner_args = qnn_executor_runner_args + [ + f"--text_encoder_output_scale {text_encoder_output_scale}", + f"--text_encoder_output_offset {text_encoder_output_offset}", + f"--unet_input_latent_scale {unet_input_latent_scale}", + f"--unet_input_latent_offset {unet_input_latent_offset}", + f"--unet_input_text_emb_scale {unet_input_text_emb_scale}", + f"--unet_input_text_emb_offset {unet_input_text_emb_offset}", + f"--unet_output_scale {unet_output_scale}", + f"--unet_output_offset {unet_output_offset}", + f"--vae_input_scale {vae_input_scale}", + f"--vae_input_offset {vae_input_offset}", + f"--vae_output_scale {vae_output_scale}", + f"--vae_output_offset {vae_output_offset}", + ] + + qnn_executor_runner_args = " ".join( + [ + f"cd {adb.workspace} &&", + "export ADSP_LIBRARY_PATH=. &&", + "export LD_LIBRARY_PATH=. &&", + f"./qaihub_stable_diffusion_runner {' '.join(qnn_executor_runner_args)}", + ] + ) + + files = [args.vocab_json] + + if args.fix_latents: + seed = 42 + latents = torch.randn((1, 4, 64, 64), generator=torch.manual_seed(seed)).to( + "cpu" + ) + # We need to explicitly permute after init tensor or else the random value will be different + latents = latents.permute(0, 2, 3, 1).contiguous() + latents = latents * scheduler.init_noise_sigma + flattened_tensor = latents.view(-1) + # Save the flattened tensor to a .raw file + with open(os.path.join(args.artifact, "latents.raw"), "wb") as file: + file.write(flattened_tensor.numpy().tobytes()) + files.append(os.path.join(args.artifact, "latents.raw")) + + adb.push(inputs=input_unet, input_list=input_list_unet, files=files) + adb.execute(custom_runner_cmd=qnn_executor_runner_args) + + output_image = [] + + def post_process_vae(): + with open(f"{args.artifact}/outputs/output_0_0.raw", "rb") as f: + output_image.append( + np.fromfile(f, dtype=np.float32).reshape(1, 512, 512, 3) + ) + + adb.pull(output_path=args.artifact, callback=post_process_vae) + + if args.fix_latents: + broadcast_ut_result(output_image, seed) + else: + save_result(output_image) + + +def main(args): + os.makedirs(args.artifact, exist_ok=True) + + # common part for compile & inference + backend_options = generate_htp_compiler_spec( + use_fp16=False, + use_multi_contexts=True, + ) + compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=getattr(QcomChipset, args.model), + backend_options=backend_options, + is_from_context_binary=True, + ) + + if args.pre_gen_pte is None: + pte_files = gen_pte_from_ctx_bin(args, compiler_specs) + assert ( + len(pte_files) == 3 + ), f"Error: Expected 3 PTE files, but got {len(pte_files)} files." + + else: + pte_files = [ + f"{args.pre_gen_pte}/{args.pte_prefix}_{target_name}.pte" + for target_name in target_names + ] + if args.compile_only: + return + + inference(args, compiler_specs, pte_files) + + +if __name__ == "__main__": # noqa: C901 + parser = build_args_parser() + args = parser.parse_args() + + try: + main(args) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e) diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion_runner.cpp b/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion_runner.cpp new file mode 100644 index 0000000000..687a260c4a --- /dev/null +++ b/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion_runner.cpp @@ -0,0 +1,140 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +DEFINE_string( + text_encoder_path, + "qaihub_stable_diffusion_text_encoder.pte", + "Text Encoder Model serialized in flatbuffer format."); +DEFINE_string( + unet_path, + "qaihub_stable_diffusion_unet.pte", + "Unet Model serialized in flatbuffer format."); +DEFINE_string( + vae_path, + "qaihub_stable_diffusion_vae.pte", + "Vae Model serialized in flatbuffer format."); +DEFINE_string( + output_folder_path, + "outputs", + "Executorch inference data output path."); +DEFINE_string( + input_list_path, + "input_list.txt", + "Input list storing time embedding."); +DEFINE_string( + vocab_json, + "vocab.json", + "Json path to retrieve a list of vocabs."); +DEFINE_string( + prompt, + "a photo of an astronaut riding a horse on mars", + "User input prompt"); +DEFINE_int32(num_time_steps, 20, "Number of time steps."); +DEFINE_double(guidance_scale, 7.5, "Guidance Scale"); + +DEFINE_double(text_encoder_output_scale, 0.0, "Text encoder output scale"); +DEFINE_int32(text_encoder_output_offset, 0, "Text encoder output offset"); +DEFINE_double(unet_input_latent_scale, 0.0, "Unet input latent scale"); +DEFINE_int32(unet_input_latent_offset, 0, "Unet input latent offset"); +DEFINE_double(unet_input_text_emb_scale, 0.0, "Unet input text emb scale"); +DEFINE_int32(unet_input_text_emb_offset, 0, "Unet input text emb offset"); +DEFINE_double(unet_output_scale, 0.0, "Unet output scale"); +DEFINE_int32(unet_output_offset, 0, "Unet output offset"); +DEFINE_double(vae_input_scale, 0.0, "Vae input scale"); +DEFINE_int32(vae_input_offset, 0, "Vae input offset"); +DEFINE_double(vae_output_scale, 0.0, "Vae output scale"); +DEFINE_int32(vae_output_offset, 0, "Vae output offset"); +DEFINE_bool( + fix_latents, + false, + "Enable this option to fix the latents in the unet diffuse step."); + +void usage_message() { + std::string usage_message = + "This is a sample executor runner capable of executing stable diffusion models." + "Users will need binary .pte program files for text_encoder, unet, and vae. Below are the options to retrieve required .pte program files:\n" + "For further information on how to generate the .pte program files and example command to execute this runner, please refer to qaihub_stable_diffsion.py."; + gflags::SetUsageMessage(usage_message); +} + +int main(int argc, char** argv) { + using namespace torch::executor; + runtime_init(); + usage_message(); + gflags::ParseCommandLineFlags(&argc, &argv, true); + bool is_default = + gflags::GetCommandLineFlagInfoOrDie("text_encoder_output_scale") + .is_default || + gflags::GetCommandLineFlagInfoOrDie("text_encoder_output_offset") + .is_default || + gflags::GetCommandLineFlagInfoOrDie("unet_input_latent_scale") + .is_default || + gflags::GetCommandLineFlagInfoOrDie("unet_input_latent_offset") + .is_default || + gflags::GetCommandLineFlagInfoOrDie("unet_input_text_emb_scale") + .is_default || + gflags::GetCommandLineFlagInfoOrDie("unet_input_text_emb_offset") + .is_default || + gflags::GetCommandLineFlagInfoOrDie("unet_output_scale").is_default || + gflags::GetCommandLineFlagInfoOrDie("unet_output_offset").is_default || + gflags::GetCommandLineFlagInfoOrDie("vae_input_scale").is_default || + gflags::GetCommandLineFlagInfoOrDie("vae_input_offset").is_default || + gflags::GetCommandLineFlagInfoOrDie("vae_output_scale").is_default || + gflags::GetCommandLineFlagInfoOrDie("vae_output_offset").is_default; + + ET_CHECK_MSG( + !is_default, + "Please provide scale and offset for unet latent input, unet output, and vae input/output." + "Please refer to qaihub_stable_diffusion.py if you are unsure how to retrieve these values."); + + ET_LOG(Info, "Stable Diffusion runner started"); + std::vector models_path = { + FLAGS_text_encoder_path, FLAGS_unet_path, FLAGS_vae_path}; + + // Create stable_diffusion_runner + Runner runner( + models_path, + FLAGS_num_time_steps, + FLAGS_guidance_scale, + FLAGS_text_encoder_output_scale, + FLAGS_text_encoder_output_offset, + FLAGS_unet_input_latent_scale, + FLAGS_unet_input_latent_offset, + FLAGS_unet_input_text_emb_scale, + FLAGS_unet_input_text_emb_offset, + FLAGS_unet_output_scale, + FLAGS_unet_output_offset, + FLAGS_vae_input_scale, + FLAGS_vae_input_offset, + FLAGS_vae_output_scale, + FLAGS_vae_output_offset, + FLAGS_output_folder_path, + FLAGS_fix_latents); + + ET_CHECK_MSG( + runner.init_tokenizer(FLAGS_vocab_json) == Error::Ok, + "Runner failed to init tokenizer"); + + ET_CHECK_MSG(runner.load() == Error::Ok, "Runner failed to load method"); + + ET_CHECK_MSG( + runner.parse_input_list(FLAGS_input_list_path) == Error::Ok, + "Failed to parse time embedding input list"); + ET_CHECK_MSG( + runner.generate(FLAGS_prompt) == Error::Ok, "Runner failed to generate"); + + ET_CHECK_MSG( + runner.print_performance() == Error::Ok, + "Runner failed to print performance"); + + return 0; +} diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.cpp b/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.cpp new file mode 100644 index 0000000000..a997397855 --- /dev/null +++ b/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.cpp @@ -0,0 +1,621 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// A simple stable diffusion runner that includes preprocessing and post +// processing logic. The module takes in a string as input and emits a tensor as +// output. + +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +namespace torch { +namespace executor { + +Runner::Runner( + const std::vector& models_path, + const int num_time_steps, + const float guidance_scale, + const float text_encoder_output_scale, + const int text_encoder_output_offset, + const float unet_input_latent_scale, + const int unet_input_latent_offset, + const float unet_input_text_emb_scale, + const float unet_input_text_emb_offset, + const float unet_output_scale, + const int unet_output_offset, + const float vae_input_scale, + const int vae_input_offset, + const float vae_output_scale, + const int vae_output_offset, + const std::string output_path, + const bool fix_latents) + : num_time_steps_(num_time_steps), + guidance_scale_(guidance_scale), + text_encoder_output_scale_(text_encoder_output_scale), + text_encoder_output_offset_(text_encoder_output_offset), + unet_input_latent_scale_(unet_input_latent_scale), + unet_input_latent_offset_(unet_input_latent_offset), + unet_input_text_emb_scale_(unet_input_text_emb_scale), + unet_input_text_emb_offset_(unet_input_text_emb_offset), + unet_output_scale_(unet_output_scale), + unet_output_offset_(unet_output_offset), + vae_input_scale_(vae_input_scale), + vae_input_offset_(vae_input_offset), + vae_output_scale_(vae_output_scale), + vae_output_offset_(vae_output_offset), + output_path_(output_path), + fix_latents_(fix_latents) { + for (int i = 0; i < models_path.size(); i++) { + modules_.push_back(std::make_unique( + models_path[i], Module::LoadMode::MmapUseMlockIgnoreErrors)); + ET_LOG(Info, "creating module: model_path=%s", models_path[i].c_str()); + } +} + +std::vector> Runner::get_methods_meta() { + std::vector> methods_meta; + for (std::unique_ptr& module : modules_) { + methods_meta.emplace_back(module->method_meta("forward")); + } + return methods_meta; +} + +bool Runner::is_loaded() const { + bool loaded = true; + for (const std::unique_ptr& module : modules_) { + loaded &= module->is_loaded(); + } + return loaded; +} + +Error Runner::load() { + if (is_loaded()) { + return Error::Ok; + } + stats_.model_load_start_ms = util::time_in_ms(); + for (auto& module : modules_) { + ET_CHECK_OK_OR_RETURN_ERROR(module->load_method("forward")); + } + stats_.model_load_end_ms = util::time_in_ms(); + return Error::Ok; +} + +Error Runner::parse_input_list(std::string& path) { + // Fill in data for input + std::ifstream input_list(path); + time_emb_list_.reserve(num_time_steps_); + ET_CHECK_MSG(input_list.is_open(), "Input list error opening file"); + std::string time_emb_file; + for (int i = 0; i < num_time_steps_; i++) { + std::getline(input_list, time_emb_file); + std::ifstream is; + is.open(time_emb_file, std::ios::binary); + is.seekg(0, std::ios::end); + size_t filesize = is.tellg(); + is.seekg(0, std::ios::beg); + std::vector time_emb; + time_emb.resize(filesize / sizeof(uint16_t)); + is.read(reinterpret_cast(time_emb.data()), filesize); + time_emb_list_.push_back(time_emb); + } + return Error::Ok; +} + +Error Runner::init_tokenizer(const std::string& vocab_json_path) { + ET_LOG(Info, "Loading Tokenizer from json"); + stats_.tokenizer_load_start_ms = util::time_in_ms(); + std::ifstream fin(vocab_json_path); + auto update_map = [this](std::string& target, std::regex& re) { + std::smatch sm; + std::regex_search(target, sm, re); + // replace special character, please extend this if any cornor case found + std::string text = sm[1]; + std::unordered_map post_process = { + {"\"", std::regex(R"(\\\")")}, + {" ", std::regex(R"()")}, + {"\\", std::regex(R"(\\\\)")}}; + for (auto& p : post_process) { + text = std::regex_replace(text, p.second, p.first); + } + vocab_to_token_map_[text] = std::stoi(sm[2]); + }; + + if (fin.is_open()) { + std::string line, text; + while (getline(fin, line)) { + text += line; + } + fin.close(); + + std::regex re_anchor(R"(\d,\")"); + std::regex re_pattern(R"(\{?\"(.*)\":([\d]+)\}?)"); + auto begin = std::sregex_iterator(text.begin(), text.end(), re_anchor); + auto end = std::sregex_iterator(); + size_t pos = 0; + for (std::sregex_iterator iter = begin; iter != end; ++iter) { + std::smatch match; + size_t len = iter->position() - pos + 1; + std::string target = text.substr(pos, len); + update_map(target, re_pattern); + pos = iter->position() + 1; + } + // process last vocabulary + std::string target = text.substr(pos); + update_map(target, re_pattern); + } + stats_.tokenizer_load_end_ms = util::time_in_ms(); + return Error::Ok; +} + +std::vector Runner::tokenize(std::string prompt) { + std::string bos("<|startoftext|>"), eos("<|endoftext|>"); + std::vector vocabs; + vocabs.reserve(max_tokens_); + std::vector tokens(1, vocab_to_token_map_[bos]); + + // pretokenize + // ref: https://github.com/monatis/clip.cpp + // https://huggingface.co/openai/clip-vit-base-patch32 + std::string text; + std::regex re( + R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"); + std::smatch sm; + while (std::regex_search(prompt, sm, re)) { + for (auto& v : sm) { + vocabs.push_back(v); + } + prompt = sm.suffix(); + } + for (std::string& v : vocabs) { + std::string word = (v[0] == ' ') ? v.substr(1) : v; + word += " "; + auto iter = vocab_to_token_map_.find(word); + if (iter != vocab_to_token_map_.end()) { + tokens.push_back(iter->second); + continue; + } + for (int i = 0; i < v.size(); ++i) { + for (int j = v.size() - 1; j >= i; --j) { + std::string token = v.substr(i, j - 1 + 1); + auto iter = vocab_to_token_map_.find(token); + if (iter != vocab_to_token_map_.end()) { + tokens.push_back(iter->second); + i = j + 1; + break; + } else if (j == i) { + ET_LOG(Error, "unknown token found: %s", token.c_str()); + } + } + } + } + tokens.push_back(vocab_to_token_map_[eos]); + return tokens; +} + +std::vector Runner::gen_latent_from_file() { + std::vector tensor_vector; + std::ifstream file("latents.raw", std::ios::binary); + if (!file.is_open()) { + ET_LOG(Error, "Error opening file!"); + return tensor_vector; + } + + // Read the tensor data + float value; + while (file.read(reinterpret_cast(&value), sizeof(float))) { + tensor_vector.push_back(value); + } + file.close(); + return tensor_vector; +} + +std::vector Runner::gen_random_latent(float sigma) { + std::random_device rnd_device; + std::mt19937 mersenne_engine{rnd_device()}; + std::normal_distribution dist{0.0f, 1.0f}; + + constexpr int latent_size = 1 * 64 * 64 * 4; + std::vector random_vector(latent_size); + + for (float& value : random_vector) { + value = dist(mersenne_engine) * sigma; + } + return random_vector; +} + +std::vector Runner::get_time_steps() { + std::vector time_steps(num_time_steps_); + for (int i = 0; i < num_time_steps_; ++i) { + time_steps[i] = (num_train_timesteps_ - 1) * + (1.0f - static_cast(i) / (num_time_steps_ - 1)); + } + return time_steps; +} + +std::vector Runner::get_sigmas(const std::vector& time_steps) { + float start = std::sqrt(beta_start_); + float end = std::sqrt(beta_end_); + std::vector betas(num_train_timesteps_); + float step = (end - start) / (num_train_timesteps_ - 1); + for (int i = 0; i < num_train_timesteps_; ++i) { + float value = start + i * step; + betas[i] = 1 - (value * value); + } + + std::vector alphas_cumprod(num_train_timesteps_); + float cumprod = 1.0; + for (int i = 0; i < num_train_timesteps_; ++i) { + cumprod *= betas[i]; + alphas_cumprod[i] = cumprod; + } + + std::vector sigmas(num_train_timesteps_); + for (int i = 0; i < num_train_timesteps_; ++i) { + sigmas[i] = std::sqrt((1.0 - alphas_cumprod[i]) / alphas_cumprod[i]); + } + + std::vector res(time_steps.size()); + for (size_t i = 0; i < time_steps.size(); ++i) { + float index = + static_cast(i) * (sigmas.size() - 1) / (time_steps.size() - 1); + size_t lower_index = static_cast(std::floor(index)); + size_t upper_index = static_cast(std::ceil(index)); + + float weight = index - lower_index; + res[i] = + (1.0 - weight) * sigmas[lower_index] + weight * sigmas[upper_index]; + } + std::reverse(res.begin(), res.end()); + res.push_back(0); + + return res; +} + +void Runner::scale_model_input( + const std::vector& latents, + std::vector& latent_model_input, + float sigma) { + for (int i = 0; i < latents.size(); i++) { + latent_model_input[i] = (latents[i] / std::sqrt(sigma * sigma + 1)); + } +} + +void Runner::quant_tensor( + const std::vector& fp_vec, + std::vector& quant_vec, + float scale, + int offset) { + offset = abs(offset); + for (int i = 0; i < fp_vec.size(); i++) { + quant_vec[i] = static_cast((fp_vec[i] / scale) + offset); + } +} + +void Runner::dequant_tensor( + const std::vector& quant_vec, + std::vector& fp_vec, + float scale, + int offset) { + offset = abs(offset); + for (int i = 0; i < quant_vec.size(); i++) { + fp_vec[i] = (quant_vec[i] - offset) * scale; + } +} + +// Using the same algorithm as EulerDiscreteScheduler in python. +void Runner::step( + const std::vector& model_output, + const std::vector& sigmas, + std::vector& sample, + std::vector& prev_sample, + int step_index) { + float sigma = sigmas[step_index]; + float dt = sigmas[step_index + 1] - sigma; + + for (int i = 0; i < sample.size(); ++i) { + float sigma_hat = sample[i] - (sigma * model_output[i]); + prev_sample[i] = (sample[i] - sigma_hat) / sigma; + prev_sample[i] = sample[i] + (prev_sample[i] * dt); + } + sample = prev_sample; +} + +Error Runner::generate(std::string prompt) { + ET_LOG(Info, "Start generating"); + stats_.generate_start_ms = util::time_in_ms(); + + // Start tokenize + stats_.tokenizer_parsing_start_ms = util::time_in_ms(); + std::vector cond_tokens = tokenize(prompt); + cond_tokens.resize(max_tokens_); + std::vector uncond_tokens = tokenize(""); + uncond_tokens.resize(max_tokens_); + stats_.tokenizer_parsing_end_ms = util::time_in_ms(); + + std::vector> method_metas = get_methods_meta(); + + MethodMeta encoder_method_meta = method_metas[0].get(); + // Initialize text_encoder input tensors: cond/uncond tokenized_input[1,77] + ManagedTensor managed_cond_tokens( + cond_tokens.data(), + {1, 77}, + encoder_method_meta.input_tensor_meta(0)->scalar_type()); + ManagedTensor managed_uncond_tokens( + uncond_tokens.data(), + {1, 77}, + encoder_method_meta.input_tensor_meta(0)->scalar_type()); + Tensor cond_tokens_tensor = managed_cond_tokens.get_aliasing_tensor(); + Tensor uncond_tokens_tensor = managed_uncond_tokens.get_aliasing_tensor(); + // Initialize text_encoder output tensors: cond/uncond embedding[1, 77, 1024] + constexpr int emb_size = 1 * 77 * 1024; + std::vector cond_emb_vec(emb_size); + std::vector uncond_emb_vec(emb_size); + std::vector fp_emb_vec(emb_size); + ManagedTensor managed_cond_emb( + cond_emb_vec.data(), + {1, 77, 1024}, + encoder_method_meta.output_tensor_meta(0)->scalar_type()); + ManagedTensor managed_uncond_emb( + uncond_emb_vec.data(), + {1, 77, 1024}, + encoder_method_meta.output_tensor_meta(0)->scalar_type()); + Tensor cond_emb_tensor = managed_cond_emb.get_aliasing_tensor(); + Tensor uncond_emb_tensor = managed_uncond_emb.get_aliasing_tensor(); + modules_[0]->set_output_data_ptr(cond_emb_tensor, 0); + long encoder_start = util::time_in_ms(); + auto cond_res = modules_[0]->forward({cond_tokens_tensor}); + stats_.text_encoder_execution_time += (util::time_in_ms() - encoder_start); + modules_[0]->set_output_data_ptr(uncond_emb_tensor, 0); + encoder_start = util::time_in_ms(); + auto uncond_res = modules_[0]->forward({uncond_tokens_tensor}); + stats_.text_encoder_execution_time += (util::time_in_ms() - encoder_start); + + // Initialize unet parameters + MethodMeta unet_method_meta = method_metas[1].get(); + std::vector time_steps = get_time_steps(); + std::vector sigmas = get_sigmas(time_steps); + float max_sigma = *std::max_element(sigmas.begin(), sigmas.end()); + std::vector latent; + if (fix_latents_) { + latent = gen_latent_from_file(); + } else { + latent = gen_random_latent(max_sigma); + } + std::vector prev_sample(latent.size()); + + // Initialize unet input tensors + // 1. latent[1,64,64,4] + // 2. time_embedding[1,1280] + // 3. cond/uncond embedding[1,77,1024] + std::vector latent_model_input(latent.size()); + std::vector fp_latent_model_input(latent.size()); + ManagedTensor managed_latent( + latent_model_input.data(), + {1, 64, 64, 4}, + unet_method_meta.input_tensor_meta(0)->scalar_type()); + Tensor latent_tensor = managed_latent.get_aliasing_tensor(); + std::vector managed_time_emb_tensors; + std::vector time_emb_tensors; + managed_time_emb_tensors.reserve(num_time_steps_); + time_emb_tensors.reserve(num_time_steps_); + for (int step_index = 0; step_index < num_time_steps_; step_index++) { + managed_time_emb_tensors.emplace_back(ManagedTensor( + time_emb_list_[step_index].data(), + {1, 1280}, + unet_method_meta.input_tensor_meta(1)->scalar_type())); + time_emb_tensors.emplace_back( + managed_time_emb_tensors.back().get_aliasing_tensor()); + } + // requantize text encoders output + dequant_tensor( + cond_emb_vec, + fp_emb_vec, + text_encoder_output_scale_, + text_encoder_output_offset_); + quant_tensor( + fp_emb_vec, + cond_emb_vec, + unet_input_text_emb_scale_, + unet_input_text_emb_offset_); + dequant_tensor( + uncond_emb_vec, + fp_emb_vec, + text_encoder_output_scale_, + text_encoder_output_offset_); + quant_tensor( + fp_emb_vec, + uncond_emb_vec, + unet_input_text_emb_scale_, + unet_input_text_emb_offset_); + + // Initialize unet output tensors: text/uncond noise_pred[1,64,64,4] + std::vector noise_pred_text(latent.size()); + std::vector noise_pred_uncond(latent.size()); + std::vector fp_noise_pred_text(noise_pred_text.size()); + std::vector fp_noise_pred_uncond(noise_pred_uncond.size()); + ManagedTensor managed_noise_pred_text( + noise_pred_text.data(), + {1, 64, 64, 4}, + unet_method_meta.output_tensor_meta(0)->scalar_type()); + Tensor noise_pred_text_tensor = managed_noise_pred_text.get_aliasing_tensor(); + ManagedTensor managed_noise_pred_uncond( + noise_pred_uncond.data(), + {1, 64, 64, 4}, + unet_method_meta.output_tensor_meta(0)->scalar_type()); + Tensor noise_pred_uncond_tensor = + managed_noise_pred_uncond.get_aliasing_tensor(); + + // Execute unet + for (int step_index = 0; step_index < num_time_steps_; step_index++) { + long start_post_process = util::time_in_ms(); + scale_model_input(latent, fp_latent_model_input, sigmas[step_index]); + + quant_tensor( + fp_latent_model_input, + latent_model_input, + unet_input_latent_scale_, + unet_input_latent_offset_); + + stats_.unet_aggregate_post_processing_time += + (util::time_in_ms() - start_post_process); + modules_[1]->set_output_data_ptr(noise_pred_text_tensor, 0); + long start_unet_execution = util::time_in_ms(); + auto cond_res = modules_[1]->forward( + {latent_tensor, time_emb_tensors[step_index], cond_emb_tensor}); + stats_.unet_aggregate_execution_time += + (util::time_in_ms() - start_unet_execution); + modules_[1]->set_output_data_ptr(noise_pred_uncond_tensor, 0); + start_unet_execution = util::time_in_ms(); + auto uncond_res = modules_[1]->forward( + {latent_tensor, + time_emb_tensors[step_index], + uncond_emb_tensor}); // results in noise_pred_uncond_vec + stats_.unet_aggregate_execution_time += + (util::time_in_ms() - start_unet_execution); + + // start unet post processing + start_post_process = util::time_in_ms(); + + dequant_tensor( + noise_pred_text, + fp_noise_pred_text, + unet_output_scale_, + unet_output_offset_); + dequant_tensor( + noise_pred_uncond, + fp_noise_pred_uncond, + unet_output_scale_, + unet_output_offset_); + + for (int i = 0; i < fp_noise_pred_text.size(); i++) { + fp_noise_pred_text[i] = fp_noise_pred_uncond[i] + + guidance_scale_ * (fp_noise_pred_text[i] - fp_noise_pred_uncond[i]); + } + step(fp_noise_pred_text, sigmas, latent, prev_sample, step_index); + stats_.unet_aggregate_post_processing_time += + (util::time_in_ms() - start_post_process); + } + + // Start VAE + MethodMeta vae_method_meta = method_metas[2].get(); + // Initialize vae input tensor : latent[1,64,64,4] + std::vector vae_input(latent.size()); + ManagedTensor managed_vae_input( + vae_input.data(), + {1, 64, 64, 4}, + vae_method_meta.input_tensor_meta(0)->scalar_type()); + Tensor vae_input_tensor = managed_vae_input.get_aliasing_tensor(); + // Intialize vae output tensor: output[1,512,512,3] + constexpr int image_size = 1 * 512 * 512 * 3; + std::vector q_out(image_size); + std::vector out(image_size); + ManagedTensor managed_output( + q_out.data(), + {1, 512, 512, 3}, + vae_method_meta.output_tensor_meta(0)->scalar_type()); + Tensor output_tensor = managed_output.get_aliasing_tensor(); + + quant_tensor(latent, vae_input, vae_input_scale_, vae_input_offset_); + + modules_[2]->set_output_data_ptr(output_tensor, 0); + long start_vae_execution = util::time_in_ms(); + auto vae_res = modules_[2]->forward({vae_input_tensor}); + stats_.vae_execution_time = (util::time_in_ms() - start_vae_execution); + stats_.generate_end_ms = util::time_in_ms(); + + // Dequant uint16 output to fp32 output + dequant_tensor(q_out, out, vae_output_scale_, vae_output_offset_); + + // Saving outputs + auto output_file_name = output_path_ + "/output_0_0.raw"; + std::ofstream fout(output_file_name.c_str(), std::ios::binary); + fout.write( + reinterpret_cast(out.data()), out.size() * sizeof(float)); + fout.close(); + + return Error::Ok; +} + +Error Runner::print_performance() { + ET_LOG(Info, "\tTotal Number of steps:\t\t\t\t%d", num_time_steps_); + + ET_LOG( + Info, + "\tTokenizer Load Time:\t\t\t\t%f (seconds)", + ((double)(stats_.tokenizer_load_end_ms - stats_.tokenizer_load_start_ms) / + stats_.SCALING_FACTOR_UNITS_PER_SECOND)); + + ET_LOG( + Info, + "\tModel Load Time:\t\t\t\t%f (seconds)", + ((double)(stats_.model_load_end_ms - stats_.model_load_start_ms) / + stats_.SCALING_FACTOR_UNITS_PER_SECOND)); + + ET_LOG( + Info, + "\tGenerate Time(Tokenize + Encoder + UNet + VAE):\t%f (seconds)", + ((double)(stats_.generate_end_ms - stats_.generate_start_ms) / + stats_.SCALING_FACTOR_UNITS_PER_SECOND)); + + ET_LOG( + Info, + "\tTokenize Time:\t\t\t\t\t%f (seconds)", + ((double)(stats_.tokenizer_parsing_end_ms - + stats_.tokenizer_parsing_start_ms) / + stats_.SCALING_FACTOR_UNITS_PER_SECOND)); + + ET_LOG( + Info, + "\tText Encoder Execution Time:\t\t\t%f (seconds)", + ((double)(stats_.text_encoder_execution_time) / + stats_.SCALING_FACTOR_UNITS_PER_SECOND)); + + ET_LOG( + Info, + "\tUnet Aggregate (Cond + Uncond) Execution Time:\t%f (seconds)", + ((double)stats_.unet_aggregate_execution_time / + (stats_.SCALING_FACTOR_UNITS_PER_SECOND))); + + ET_LOG( + Info, + "\tUnet Average Execution Time:\t\t\t%f (seconds)", + ((double)(stats_.unet_aggregate_execution_time / (num_time_steps_ * 2)) / + (stats_.SCALING_FACTOR_UNITS_PER_SECOND))); + + ET_LOG( + Info, + "\tUnet Aggregate Post-Processing Time:\t\t%f (seconds)", + ((double)(stats_.unet_aggregate_post_processing_time) / + stats_.SCALING_FACTOR_UNITS_PER_SECOND)); + + ET_LOG( + Info, + "\tUnet Average Post-Processing Time:\t\t%f (seconds)", + ((double)(stats_.unet_aggregate_post_processing_time / + (num_time_steps_ * 2)) / + (stats_.SCALING_FACTOR_UNITS_PER_SECOND))); + + ET_LOG( + Info, + "\tVAE Execution Time:\t\t\t\t%f (seconds)", + ((double)(stats_.vae_execution_time) / + stats_.SCALING_FACTOR_UNITS_PER_SECOND)); + return Error::Ok; +} + +} // namespace executor +} // namespace torch diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.h b/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.h new file mode 100644 index 0000000000..e081ab80cc --- /dev/null +++ b/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.h @@ -0,0 +1,141 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// A simple diffusion runner that includes preprocessing and post processing +// logic. The module takes in a string as input and emites a tensor as output. + +#pragma once + +#include +#include +#include + +#include + +namespace torch { +namespace executor { + +class Runner { + public: + explicit Runner( + const std::vector& models_path, + const int num_time_steps, + const float guidance_scale, + const float text_encoder_output_scale, + const int text_encoder_output_offset, + const float unet_input_latent_scale, + const int unet_input_latent_offset, + const float unet_input_text_emb_scale, + const float unet_input_text_emb_offset, + const float unet_output_scale, + const int unet_output_offset, + const float vae_input_scale, + const int vae_input_offset, + const float vae_output_scale, + const int vae_output_offset, + const std::string output_path, + const bool fix_latents); + + struct Stats { + // Scaling factor for timestamps - in this case, we use ms. + const long SCALING_FACTOR_UNITS_PER_SECOND = 1000; + // Time stamps for the different stages of the execution + // model_load_start_ms: Model loading time + long model_load_start_ms; + long model_load_end_ms; + + // tokenizer loading time + long tokenizer_load_start_ms = 0; + long tokenizer_load_end_ms = 0; + + // tokenizer parsing time + long tokenizer_parsing_start_ms = 0; + long tokenizer_parsing_end_ms = 0; + + // Total time to run generate + long generate_start_ms = 0; + long generate_end_ms = 0; + + // text encoder execution time + long text_encoder_execution_time = 0; + + // Unet aggregation execution time over n steps for cond + uncond + long unet_aggregate_execution_time = 0; + + // UNet aggregation post processing time over n steps for cond + uncond. + // This is the time from processing unet's output until feeding it into the + // next iteration. + long unet_aggregate_post_processing_time = 0; + + // VAE execution time + long vae_execution_time = 0; + }; + + bool is_loaded() const; + Error load(); + Error init_tokenizer(const std::string& vocab_json_path); + Error print_performance(); + std::vector tokenize(std::string prompt); + std::vector gen_latent_from_file(); + std::vector gen_random_latent(float sigma); + void step( + const std::vector& model_output, + const std::vector& sigmas, + std::vector& sample, + std::vector& prev_sample, + int step_index); + std::vector> get_methods_meta(); + std::vector get_time_steps(); + std::vector get_sigmas(const std::vector& time_steps); + void scale_model_input( + const std::vector& vec, + std::vector& latent_model_input, + float sigma); + Error parse_input_list(std::string& path); + Error generate(std::string prompt); + void quant_tensor( + const std::vector& fp_vec, + std::vector& quant_vec, + float scale, + int offset); + void dequant_tensor( + const std::vector& quant_vec, + std::vector& fp_vec, + float scale, + int offset); + + private: + Stats stats_; + std::vector> modules_; + std::vector> time_emb_list_; + std::unordered_map vocab_to_token_map_; + + std::string output_path_; + int num_time_steps_; + float guidance_scale_; + float text_encoder_output_scale_; + int text_encoder_output_offset_; + float unet_input_latent_scale_; + int unet_input_latent_offset_; + float unet_input_text_emb_scale_; + int unet_input_text_emb_offset_; + float unet_output_scale_; + int unet_output_offset_; + float vae_input_scale_; + int vae_input_offset_; + float vae_output_scale_; + int vae_output_offset_; + const float beta_start_ = 0.00085; + const float beta_end_ = 0.012; + const int num_train_timesteps_ = 1000; + const int max_tokens_ = 77; + const bool fix_latents_ = false; +}; + +} // namespace executor +} // namespace torch diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/stable_diffusion_lib.py b/examples/qualcomm/qaihub_scripts/stable_diffusion/stable_diffusion_lib.py new file mode 100644 index 0000000000..8ec5783131 --- /dev/null +++ b/examples/qualcomm/qaihub_scripts/stable_diffusion/stable_diffusion_lib.py @@ -0,0 +1,22 @@ +import torch +from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline + + +class StableDiffusion: + def __init__(self, seed=42): + self.model_id: str = "stabilityai/stable-diffusion-2-1-base" + self.generator = torch.manual_seed(seed) + self.scheduler = EulerDiscreteScheduler.from_pretrained( + self.model_id, subfolder="scheduler" + ) + + self.pipe = StableDiffusionPipeline.from_pretrained( + self.model_id, scheduler=self.scheduler, torch_dtype=torch.float32 + ) + self.pipe = self.pipe.to("cpu") + + def __call__(self, prompt, height, width, num_time_steps): + image = self.pipe( + prompt, height, width, num_time_steps, generator=self.generator + ).images[0] + return image diff --git a/exir/backend/utils.py b/exir/backend/utils.py index b5072604d2..2b768fe7c2 100644 --- a/exir/backend/utils.py +++ b/exir/backend/utils.py @@ -28,9 +28,6 @@ T_DQuantPerTensor = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default -log: logging.Logger = logging.getLogger(__name__) - - # NB: Set this to None to handle validation from MobileBert @lru_cache(maxsize=None) def is_same_node( @@ -499,3 +496,31 @@ def insert_delegate_mapping_entry( # pyre-ignore Warning from Union[int, st] keys self._debug_handle_map[identifier] = filtered_debug_handles return identifier + + +class WhyNoPartition: + """ + Simple helper class for partitioners to log why a node was not lowered. + + Example usage: + + # In your backend partitioner file(s) + why = WhyNoPartition(logger=your_backend_logger) + + # hypothetical function that checks if a node can be lowered + if not can_be_lowered(node): + why(node, "This node was not lowered because ...") + """ + + def __init__(self, logger: logging.Logger): + self.logger = logger + self.node: Optional[torch.fx.Node] = None + self.reason: str = "" + + def __call__(self, node: torch.fx.Node, reason: str) -> None: + self.node = node + self.reason = reason + self.logger.debug(self) + + def __str__(self) -> str: + return f"WhyNoPartition: Node {self.node} was not partitioned because {self.reason}." diff --git a/exir/capture/_config.py b/exir/capture/_config.py index d959f10403..42dc170c19 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -82,7 +82,12 @@ class ExecutorchBackendConfig: # If provided, the minimum alignment of delegate data in the program. Must # be a power of 2. If not provided, uses the value in the schema file. delegate_alignment: Optional[int] = None - sym_shape_eval_pass: PassType = HintBasedSymShapeEvalPass() + + # A single sym shape eval pass can be defined for all the programs in the + # EdgeProgramManager or can be defined per program. + sym_shape_eval_pass: Union[PassType, Dict[str, PassType]] = ( + HintBasedSymShapeEvalPass() + ) # If set to true, view_copy operations will be converted to lightweight # view operations in the ET runtime diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index f51b4113c8..2d2cc0f3f1 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -1270,7 +1270,7 @@ def _emit_prim_getters(self, prim_getters: Dict[str, Any]) -> List[ExecutionPlan def fetch_attr(self, target: _Target) -> _AbstractValue: """Fetch weights and other module parameters. If the attribute is a tensor, emit it.""" - attr = super().fetch_attr(target) + attr = super().fetch_attr(target) # pyre-fixme[6] if isinstance(attr, torch.Tensor): return self._emit_evalue( @@ -1286,7 +1286,7 @@ def fetch_attr(self, target: _Target) -> _AbstractValue: else: return attr - def call_module( + def call_module( # pyre-fixme[14] self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] ) -> None: """Unsupported in execution IR, so unhandled by the emitter.""" @@ -1294,7 +1294,7 @@ def call_module( self._emit_node_specific_error(self.node, "call_module is not supported") ) - def call_method( + def call_method( # pyre-fixme[14] self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] ) -> _EmitterValue: """Unsupported in execution IR, so unhandled by the emitter.""" @@ -1302,7 +1302,7 @@ def call_method( self._emit_node_specific_error(self.node, "call_method is not supported") ) - def placeholder( + def placeholder( # pyre-fixme[14] self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] ) -> _AbstractValue: """Performs actions for the placeholder node of a graph module. @@ -1324,7 +1324,7 @@ def placeholder( self.placeholder_count += 1 return value - def output( + def output( # pyre-fixme[14] self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] ) -> None: """Performs actions for the output node of a graph module. @@ -1354,7 +1354,7 @@ def output( ) self.chain.instructions.append(instruction) - def call_function( + def call_function( # pyre-fixme[14] self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] ) -> _EmitterValue: """Performs actions for the call_function node of a graph module. @@ -1412,7 +1412,7 @@ def call_function( ) ) - def run( + def run( # pyre-fixme[14] self, *args: _Argument, initial_env: Optional[Dict[torch.fx.Node, _Argument]] = None, diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index 2c2cd8eb0d..4d07fdcdf0 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -139,7 +139,7 @@ def buffer( segment_alignment: int = 4096, constant_tensor_alignment: Optional[int] = None, delegate_alignment: Optional[int] = None, - memory_planning: MemoryPlanningPass = None, + memory_planning: MemoryPlanningPass = None, # pyre-fixme[9] ) -> bytes: """ Returns a buffer containing the serialized ExecuTorch binary. @@ -161,7 +161,7 @@ def buffer( def program( self, emit_stacktrace: bool = False, - memory_planning: MemoryPlanningPass = None, + memory_planning: MemoryPlanningPass = None, # pyre-fixme[9] ) -> Program: # Fix autodpes introuces cyclic dependencies: # program -> verifier -> lowered_backend_module -> program diff --git a/exir/pass_base.py b/exir/pass_base.py index dd55641f25..3b1a2928e2 100644 --- a/exir/pass_base.py +++ b/exir/pass_base.py @@ -177,7 +177,7 @@ def __init__(self, callback: "_ExportPassBase", codegen: CodeGen) -> None: self.fake_tensor_mode: Optional[FakeTensorMode] = None self.submodules: Dict[torch.nn.Module, str] = {} - def trace(self) -> None: + def trace(self) -> None: # pyre-fixme[14,15] raise ExportPassBaseError("ExportTracer doesn't support trace().") def create_arg(self, a: Argument) -> torch.fx.Node: @@ -290,7 +290,7 @@ def __init__(self, callback: "_ExportPassBase", gm: fx.GraphModule) -> None: self.callback = callback self.node: torch.fx.Node = next(iter(gm.graph.nodes)) - def placeholder( + def placeholder( # pyre-fixme[14] self, target: str, args: Tuple[Argument, ...], @@ -351,7 +351,7 @@ def call_function( else: raise ExportPassBaseError(f"Unsupported target type: {target}") - def get_attr( + def get_attr( # pyre-fixme[14] self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] ) -> Argument: return super().get_attr(target, args, kwargs) @@ -364,7 +364,7 @@ def call_module( ) -> None: raise ExportPassBaseError("call_module is not supported.") - def call_method( + def call_method( # pyre-fixme[14] self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] ) -> None: raise ExportPassBaseError("call_method is not supported.") diff --git a/exir/passes/__init__.py b/exir/passes/__init__.py index 99507ccdc9..7a0623040f 100644 --- a/exir/passes/__init__.py +++ b/exir/passes/__init__.py @@ -302,6 +302,7 @@ def make_alloc_node( "Memory allocator node needs FakeTensor val or TensorMetadata to proceed" ) + # pyre-fixme[6] alloc = graph_module.graph.call_function(memory.alloc, (alloc_spec,)) alloc.meta["val"] = val alloc.meta["tensor_meta"] = tensor_meta diff --git a/exir/passes/remove_noop_pass.py b/exir/passes/remove_noop_pass.py index c834ca9294..d9b9955663 100644 --- a/exir/passes/remove_noop_pass.py +++ b/exir/passes/remove_noop_pass.py @@ -40,7 +40,7 @@ def eliminate_dq_q( qparams_q = list(user.args)[1:] if qparams_dq != qparams_q: continue - user.replace_all_uses_with(node.args[0]) + user.replace_all_uses_with(node.args[0]) # pyre-fixme[6] class RemoveNoopPass(ExportPass): diff --git a/exir/program/_program.py b/exir/program/_program.py index 9031ce39e6..849eae4f6f 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -13,7 +13,6 @@ import torch import torch._export - from executorch.exir._serialize import _serialize_pte_binary from executorch.exir._serialize._cord import Cord from executorch.exir.backend.backend_api import to_backend @@ -23,6 +22,7 @@ from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap from executorch.exir.error import ExportError from executorch.exir.graph_module import get_control_flow_submodules +from executorch.exir.pass_base import PassBase from executorch.exir.pass_manager import PassType from executorch.exir.passes import ( base_post_op_replace_passes, @@ -641,25 +641,48 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram": return new_ep -def pre_memory_planning_passes(config: ExecutorchBackendConfig) -> List[PassType]: +def pre_memory_planning_passes( + config: ExecutorchBackendConfig, name: Optional[str] = None +) -> List[PassType]: + """ + Returns a list of passes to run before memory planning. + Get the sym shape eval pass based on the method name, if the pass is not in the dict, use the default pass. + """ + # Handle symbolic shape eval pass + if isinstance(config.sym_shape_eval_pass, dict): + default_pass = ExecutorchBackendConfig().sym_shape_eval_pass + if not name: + sym_shape_eval_pass = default_pass + # pyre-ignore: Undefined attribute [16] + sym_shape_eval_pass = config.sym_shape_eval_pass.get(name, default_pass) + elif isinstance(config.sym_shape_eval_pass, PassBase): + sym_shape_eval_pass = config.sym_shape_eval_pass + else: + raise RuntimeError( + f"sym_shape_eval_pass must be a dict or a PassBase, got {config.sym_shape_eval_pass}" + ) if config.remove_view_copy: - # pyre-ignore return [ NormalizeViewCopyBasePass(), dead_code_elimination_pass, ReplaceViewCopyWithViewPass(), - config.sym_shape_eval_pass, + sym_shape_eval_pass, config.to_out_var_pass, ] else: - # pyre-ignore return [ - config.sym_shape_eval_pass, + sym_shape_eval_pass, config.to_out_var_pass, ] -def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType]: +def edge_to_executorch_passes( + config: ExecutorchBackendConfig, name: Optional[str] = None +) -> List[PassType]: + """ + Returns a list of passes to lower from edge to executorch. + Get the pre memory planning passes based on the method name, if the pass is not in the dict, use the default pass. + """ passes: List[PassType] = [ *config.passes, SpecPropPass(), @@ -668,7 +691,7 @@ def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType] # there exists an unbacked symint operation. EdgeToBackendOpsPass(), RemoveGraphAssertsPass(), - ] + pre_memory_planning_passes(config) + ] + pre_memory_planning_passes(config, name) return passes @@ -1234,7 +1257,7 @@ def to_executorch( program = unsafe_remove_auto_functionalized_pass(program) gm, new_signature = insert_write_back_for_buffers_pass(program) new_gm = program.graph_module - for p in edge_to_executorch_passes(config): + for p in edge_to_executorch_passes(config, name): new_gm_res = p(new_gm) assert new_gm_res is not None new_gm = new_gm_res.graph_module diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 99ec648145..a167a67dd9 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -1421,7 +1421,7 @@ def quantize_model( quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config() quantizer.set_global(quantization_config) - m = prepare_pt2e(m, quantizer) + m = prepare_pt2e(m, quantizer) # pyre-fixme[6] m = convert_pt2e(m, fold_quantize=True) ep = torch.export.export(m, example_inputs) dq_nodes_pre = count_dq_nodes(ep.graph_module) diff --git a/exir/tests/test_quantization.py b/exir/tests/test_quantization.py index ca85386db6..ebe9477522 100644 --- a/exir/tests/test_quantization.py +++ b/exir/tests/test_quantization.py @@ -58,7 +58,7 @@ def test_resnet(self) -> None: quantizer = XNNPACKQuantizer() operator_config = get_symmetric_quantization_config(is_per_channel=True) quantizer.set_global(operator_config) - m = prepare_pt2e(m, quantizer) + m = prepare_pt2e(m, quantizer) # pyre-fixme[6] self.assertEqual( id(m.activation_post_process_3), id(m.activation_post_process_2) ) diff --git a/exir/tracer.py b/exir/tracer.py index 1a8709a237..c4593cca8e 100644 --- a/exir/tracer.py +++ b/exir/tracer.py @@ -272,7 +272,7 @@ def __torch_function__( kwargs = {} if torch.is_inference_mode_enabled(): if func is torch.nn.functional.layer_norm: - args, kwargs = normalize_function(func, args, kwargs) + args, kwargs = normalize_function(func, args, kwargs) # pyre-fixme[23] input, normalized_shape = args normalized_shape = list(normalized_shape) return cls.__torch_dispatch__( @@ -470,13 +470,13 @@ def create_arg(self, a: Value) -> torch.fx.Node: # noqa: C901 self.submodules[a] = name_submodule return self.create_node("get_attr", self.submodules[a], (), {}) - return super().create_arg(a) + return super().create_arg(a) # pyre-fixme[7] @staticmethod def get() -> "DispatchTracer": return TRACER - def trace( + def trace( # pyre-fixme[14,15] self, root: Callable[..., Value], concrete_args: Tuple[Value, ...] = (), diff --git a/exir/verification/arg_validator.py b/exir/verification/arg_validator.py index 65ab146782..c087944b12 100644 --- a/exir/verification/arg_validator.py +++ b/exir/verification/arg_validator.py @@ -62,7 +62,7 @@ def _get_kernel_arg(self, schema_arg, schema_arg_idx, args, kwargs): return kernel_arg - def call_function( # noqa: C901 + def call_function( # noqa: C901 # pyre-fixme[14] self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument] ) -> Any: """ @@ -73,7 +73,7 @@ def call_function( # noqa: C901 ): if isinstance(target, HigherOrderOperator): raise RunHigherOrderOperatorError("Can't run delegate") - return super().call_function(target, args, kwargs) + return super().call_function(target, args, kwargs) # pyre-fixme[6] # TODO(gasoonjia): Update Optional[torch.dtype] to a concrete class to support mixed dtypes in tensorlist. tensor_arg_types: Dict[str, Optional[torch.dtype]] = {} @@ -126,4 +126,4 @@ def call_function( # noqa: C901 valid = target._schema.dtype_constraint.validate(tensor_arg_types) if not valid: self.violating_ops[target] = tensor_arg_types - return super().call_function(target, args, kwargs) + return super().call_function(target, args, kwargs) # pyre-fixme[6] diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 28afef20d0..eccb3317e7 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -161,6 +161,7 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager": # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): + # pyre-fixme[8] self.pre_autograd_graph_module = capture_pre_autograd_graph( self.model, self.example_inputs, dynamic_shapes=dynamic_shape ) @@ -209,11 +210,12 @@ def export_to_edge(self) -> "LLMEdgeManager": # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): if self.pre_autograd_graph_module is None: + # pyre-fixme[8] self.pre_autograd_graph_module = capture_pre_autograd_graph( self.model, self.example_inputs, dynamic_shapes=dynamic_shape ) self.edge_manager = export_to_edge( - self.pre_autograd_graph_module, + self.pre_autograd_graph_module, # pyre-fixme[6] self.example_inputs, dynamic_shapes=dynamic_shape, edge_constant_methods=self.metadata, diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index 501ef6fa6b..ab98f2543f 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -52,7 +52,7 @@ def get_mps_partitioner(use_kv_cache: bool = False): ) compile_specs = [CompileSpec("use_fp16", bytes([True]))] - return MPSPartitioner(compile_specs) + return MPSPartitioner(compile_specs) # pyre-fixme[16] def get_coreml_partitioner( @@ -92,14 +92,14 @@ def get_coreml_partitioner( # if use_kv_cache: # minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18) - compile_specs = CoreMLBackend.generate_compile_specs( + compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16] minimum_deployment_target=minimum_deployment_target, compute_precision=ct.precision(ct.precision.FLOAT16.value), # using `ComputeUnit.ALL` can increase the model load time, default to `ComputeUnit.CPU_AND_GPU` compute_unit=ct.ComputeUnit[ct.ComputeUnit.CPU_AND_GPU.name.upper()], - model_type=CoreMLBackend.MODEL_TYPE.MODEL, + model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16] ) - return CoreMLPartitioner( + return CoreMLPartitioner( # pyre-fixme[16] compile_specs=compile_specs, ) @@ -136,9 +136,10 @@ def get_qnn_partitioner( if pt2e_quantize is not None: use_fp16 = False - return QnnPartitioner( - generate_qnn_executorch_compiler_spec( - soc_model=QcomChipset.SM8650, # default to SM8650 + return QnnPartitioner( # pyre-fixme[16] + generate_qnn_executorch_compiler_spec( # pyre-fixme[16] + soc_model=QcomChipset.SM8650, # default to SM8650 # pyre-fixme[16] + # pyre-fixme[16] backend_options=generate_htp_compiler_spec(use_fp16=use_fp16), debug=False, saver=False, diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index 8514e5d255..36d2f630b0 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -146,7 +146,7 @@ def get_qnn_quantizer( quantization_mode: Optional[str] = None, ): try: - from executorch.backends.qualcomm.quantizer.custom_annotation import ( + from executorch.backends.qualcomm.quantizer.custom_annotation import ( # pyre-fixme[21] custom_annotate_llama_matmul_16a8w, ) @@ -168,15 +168,15 @@ def get_qnn_quantizer( assert ( backend == "qnn" ), f"The quantization config is for backend {backend} instead of qnn." - qnn_quantizer = QnnQuantizer() + qnn_quantizer = QnnQuantizer() # pyre-fixme[16] qnn_quantizer.set_per_channel_conv_quant(enable=True) qnn_quantizer.set_per_channel_linear_quant(enable=True) # more custom quantization are supported including 16a4w etc. default to 8bit quantized custom_annotations = () if quant_config == "8a8w": - quant_dtype = QuantDtype.use_8a8w + quant_dtype = QuantDtype.use_8a8w # pyre-fixme[16] elif quant_config == "16a16w": - quant_dtype = QuantDtype.use_16a16w + quant_dtype = QuantDtype.use_16a16w # pyre-fixme[16] qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS) qnn_quantizer.set_bit16_op_quant_config( # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.