From 01d4038de50534adf40da38af66023e283971834 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Tue, 6 Aug 2024 11:35:11 +0000 Subject: [PATCH] 2024-08-06 nightly release (de300e0ca12627f83ac31a4341fac7f01a55f077) --- .github/workflows/android-perf.yml | 68 ++- backends/cadence/aot/compiler.py | 65 ++- backends/cadence/aot/functions.yaml | 2 +- backends/cadence/aot/ops_registrations.py | 9 +- backends/cadence/aot/quantizer/fusion_pass.py | 22 + backends/cadence/aot/quantizer/patterns.py | 4 +- .../operators/quantized_relu_out.cpp | 36 +- backends/qualcomm/aot/ir/qcir_utils.cpp | 18 +- backends/qualcomm/aot/ir/qcir_utils.h | 4 +- .../qualcomm/runtime/QnnExecuTorchBackend.cpp | 4 +- backends/qualcomm/runtime/SharedBuffer.cpp | 7 + backends/qualcomm/scripts/build.sh | 24 +- backends/qualcomm/tests/test_qnn_delegate.py | 15 +- backends/qualcomm/tests/utils.py | 86 ++- .../vulkan/runtime/graph/ComputeGraph.cpp | 18 +- backends/vulkan/runtime/graph/ComputeGraph.h | 28 +- .../runtime/graph/ops/glsl/indexing_utils.h | 2 +- .../ops/glsl/int8_tensor_to_nchw_noint8.glsl | 54 ++ .../ops/glsl/nchw_to_int8_tensor_noint8.glsl | 74 +++ .../graph/ops/glsl/nchw_to_tensor.glsl | 2 +- .../runtime/graph/ops/glsl/q_8w_linear.glsl | 2 +- .../graph/ops/glsl/tensor_to_nchw.glsl | 2 +- .../vulkan/runtime/graph/ops/impl/Staging.cpp | 31 +- .../runtime/graph/ops/utils/StagingUtils.cpp | 18 +- .../runtime/graph/ops/utils/StagingUtils.h | 8 +- backends/vulkan/test/glsl/all_shaders.yaml | 17 +- .../vulkan/test/glsl/idx_fill_texture.glsl | 14 +- backends/vulkan/test/utils/test_utils.cpp | 24 +- backends/vulkan/test/utils/test_utils.h | 5 + .../vulkan/test/vulkan_compute_api_test.cpp | 94 ++- backends/xnnpack/CMakeLists.txt | 16 +- backends/xnnpack/operators/op_conv2d.py | 15 +- backends/xnnpack/partition/TARGETS | 2 + backends/xnnpack/partition/config/TARGETS | 20 + backends/xnnpack/partition/config/__init__.py | 72 +++ .../xnnpack/partition/config/gemm_configs.py | 282 +++++++++ .../partition/config/generic_node_configs.py | 263 +++++++++ .../xnnpack/partition/config/node_configs.py | 91 +++ .../partition/config/xnnpack_config.py | 203 +++++++ .../xnnpack/partition/graphs/bilinear_2d.py | 15 +- .../xnnpack/partition/xnnpack_partitioner2.py | 87 +++ backends/xnnpack/passes/__init__.py | 3 + .../xnnpack/passes/tag_implicit_q_dq_pass.py | 9 +- backends/xnnpack/runtime/XNNCompiler.cpp | 163 ++---- .../xnnpack/serialization/runtime_schema.fbs | 354 ------------ .../serialization/schema_version_history.txt | 0 backends/xnnpack/serialization/targets.bzl | 4 +- backends/xnnpack/test/ops/abs.py | 4 +- backends/xnnpack/test/ops/add.py | 39 +- backends/xnnpack/test/ops/avgpool2d.py | 24 +- backends/xnnpack/test/ops/bilinear2d.py | 7 +- backends/xnnpack/test/ops/cat.py | 10 +- backends/xnnpack/test/ops/ceil.py | 4 +- backends/xnnpack/test/ops/clamp.py | 8 +- backends/xnnpack/test/ops/conv1d.py | 27 +- backends/xnnpack/test/ops/conv2d.py | 13 +- backends/xnnpack/test/ops/div.py | 8 +- backends/xnnpack/test/ops/elu.py | 20 +- backends/xnnpack/test/ops/hardtanh.py | 18 +- backends/xnnpack/test/ops/linear.py | 536 ++++++++++-------- backends/xnnpack/test/ops/max_dim.py | 43 +- backends/xnnpack/test/ops/maximum.py | 8 +- backends/xnnpack/test/ops/maxpool2d.py | 33 +- backends/xnnpack/test/ops/multiply.py | 20 +- backends/xnnpack/test/ops/permute.py | 24 +- .../xnnpack/test/ops/quantize_per_tensor.py | 16 +- backends/xnnpack/test/ops/relu.py | 4 +- backends/xnnpack/test/ops/sigmoid.py | 4 +- backends/xnnpack/test/ops/softmax.py | 12 +- backends/xnnpack/test/tester/tester.py | 108 +++- backends/xnnpack/utils/TARGETS | 1 + backends/xnnpack/utils/configs.py | 8 +- backends/xnnpack/utils/quant_utils.py | 82 ++- backends/xnnpack/utils/utils.py | 16 +- examples/cadence/models/wav2vec2.py | 65 +++ examples/models/llama2/runner/runner.cpp | 87 +-- examples/models/llama2/runner/runner.h | 29 +- examples/models/llama2/runner/targets.bzl | 1 + examples/models/llama2/targets.bzl | 1 + examples/models/phi-3-mini/CMakeLists.txt | 45 +- examples/models/phi-3-mini/eager.py | 6 +- .../models/phi-3-mini/export_phi-3-mini.py | 35 +- examples/models/phi-3-mini/main.cpp | 92 +-- examples/models/phi-3-mini/runner.cpp | 109 ++++ examples/models/phi-3-mini/runner.h | 50 ++ .../phi-3-mini/sentence_piece_tokenizer.h | 43 -- examples/qualcomm/CMakeLists.txt | 3 - .../executor_runner/qnn_executor_runner.cpp | 5 +- .../qualcomm/llama2/qaihub_runner/runner.cpp | 13 +- examples/qualcomm/scripts/utils.py | 42 +- exir/_serialize/_program.py | 24 +- exir/backend/backend_api.py | 6 +- exir/backend/canonical_partitioners/TARGETS | 17 + .../config_partitioner.py | 204 +++++++ exir/backend/test/TARGETS | 1 - .../demos/rpc/executor_backend_preprocess.py | 10 +- exir/emit/_emit_program.py | 11 +- exir/emit/_emitter.py | 124 ++-- exir/emit/test/test_emit.py | 37 ++ exir/memory_planning.py | 16 +- exir/passes/TARGETS | 11 + exir/passes/__init__.py | 2 + exir/passes/dim_order_ops_registry.py | 12 + exir/passes/memory_format_ops_pass.py | 60 +- exir/passes/remove_mixed_type_operators.py | 1 + exir/passes/weights_to_outputs_pass.py | 91 +++ exir/program/TARGETS | 1 + exir/program/_program.py | 5 + exir/schema.py | 1 + exir/tensor.py | 5 - exir/tests/TARGETS | 11 + exir/tests/test_joint_graph.py | 91 +++ exir/tests/test_passes.py | 100 +++- extension/android/jni/jni_layer_llama.cpp | 4 +- extension/llm/README.md | 47 ++ extension/llm/runner/TARGETS | 8 + extension/llm/runner/stats.h | 123 ++++ extension/llm/runner/targets.bzl | 10 + extension/llm/tokenizer/bpe_tokenizer.cpp | 2 +- extension/llm/tokenizer/bpe_tokenizer.h | 2 + extension/llm/tokenizer/tokenizer.h | 5 +- extension/training/optimizer/targets.bzl | 2 +- extension/training/optimizer/test/targets.bzl | 2 +- kernels/prim_ops/et_copy_index.cpp | 3 +- runtime/executor/program.h | 2 +- test/end2end/exported_module.py | 38 +- test/models/export_program.py | 25 + test/models/targets.bzl | 1 + 128 files changed, 3617 insertions(+), 1577 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/int8_tensor_to_nchw_noint8.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_tensor_noint8.glsl create mode 100644 backends/xnnpack/partition/config/TARGETS create mode 100644 backends/xnnpack/partition/config/__init__.py create mode 100644 backends/xnnpack/partition/config/gemm_configs.py create mode 100644 backends/xnnpack/partition/config/generic_node_configs.py create mode 100644 backends/xnnpack/partition/config/node_configs.py create mode 100644 backends/xnnpack/partition/config/xnnpack_config.py create mode 100644 backends/xnnpack/partition/xnnpack_partitioner2.py delete mode 100644 backends/xnnpack/serialization/runtime_schema.fbs create mode 100644 backends/xnnpack/serialization/schema_version_history.txt create mode 100644 examples/cadence/models/wav2vec2.py create mode 100644 examples/models/phi-3-mini/runner.cpp create mode 100644 examples/models/phi-3-mini/runner.h delete mode 100644 examples/models/phi-3-mini/sentence_piece_tokenizer.h create mode 100644 exir/backend/canonical_partitioners/config_partitioner.py create mode 100644 exir/passes/weights_to_outputs_pass.py create mode 100644 exir/tests/test_joint_graph.py create mode 100644 extension/llm/README.md create mode 100644 extension/llm/runner/TARGETS create mode 100644 extension/llm/runner/stats.h create mode 100644 extension/llm/runner/targets.bzl diff --git a/.github/workflows/android-perf.yml b/.github/workflows/android-perf.yml index 78f41ada20..a8223eef2c 100644 --- a/.github/workflows/android-perf.yml +++ b/.github/workflows/android-perf.yml @@ -38,26 +38,52 @@ concurrency: permissions: read-all jobs: - set-models: + set-parameters: runs-on: linux.2xlarge outputs: - models: ${{ steps.set-models.outputs.models }} + models: ${{ steps.set-parameters.outputs.models }} + devices: ${{ steps.set-parameters.outputs.devices }} + delegates: ${{ steps.set-parameters.outputs.delegates }} steps: - - name: Set models - id: set-models + - name: Set parameters + id: set-parameters shell: bash run: | set -ex MODELS="${{ inputs.models }}" + DEVICES="${{ inputs.devices }}" + DELEGATES="${{ inputs.delegates }}" + + # Mapping devices to their corresponding device-pool-arn + declare -A DEVICE_POOL_ARNS + DEVICE_POOL_ARNS[samsung_galaxy_s2x]="arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/e59f866a-30aa-4aa1-87b7-4510e5820dfa" + + # Resolve device names with their corresponding ARNs + if [[ ! $(echo "$DEVICES" | jq empty 2>/dev/null) ]]; then + DEVICES=$(echo "$DEVICES" | jq -Rc 'split(",")') + fi + declare -a MAPPED_ARNS=() + for DEVICE in $(echo "$DEVICES" | jq -r '.[]'); do + if [[ -z "${DEVICE_POOL_ARNS[$DEVICE]}" ]]; then + echo "Error: No ARN found for device '$DEVICE'. Abort." >&2 + exit 1 + fi + MAPPED_ARNS+=("${DEVICE_POOL_ARNS[$DEVICE]}") + done + echo "models=$(echo $MODELS | jq -Rc 'split(",")')" >> $GITHUB_OUTPUT + MAPPED_ARNS_JSON=$(printf '%s\n' "${MAPPED_ARNS[@]}" | jq -R . | jq -s .) + echo "devices=$(echo "$MAPPED_ARNS_JSON" | jq -c .)" >> $GITHUB_OUTPUT + echo "delegates=$(echo $DELEGATES | jq -Rc 'split(",")')" >> $GITHUB_OUTPUT export-models: name: export-models uses: pytorch/test-infra/.github/workflows/linux_job.yml@main - needs: set-models + needs: set-parameters strategy: matrix: - model: ${{ fromJson(needs.set-models.outputs.models) }} + model: ${{ fromJson(needs.set-parameters.outputs.models) }} + delegate: ${{ fromJson(needs.set-parameters.outputs.delegates) }} fail-fast: false with: runner: linux.2xlarge @@ -72,32 +98,33 @@ jobs: PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "cmake" echo "Exporting model: ${{ matrix.model }}" - export ARTIFACTS_DIR_NAME=artifacts-to-be-uploaded/${{ matrix.model }} + export ARTIFACTS_DIR_NAME=artifacts-to-be-uploaded/${{ matrix.model }}_${{ matrix.delegate }} + # TODO(T197546696): Note that the following scripts/steps only work for llama. It's expected to fail for other models+delegates. # Install requirements for export_llama PYTHON_EXECUTABLE=python bash examples/models/llama2/install_requirements.sh # Test llama2 PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh "${{ matrix.model }}.pt" "cmake" "fp32" "xnnpack+custom+qe" "${ARTIFACTS_DIR_NAME}"\ - # Upload artifacts to S3. The artifacts are needed not only by the device farm but also TorchChat + # Upload models to S3. The artifacts are needed not only by the device farm but also TorchChat upload-models: needs: export-models runs-on: linux.2xlarge steps: - - name: Download the artifacts from GitHub + - name: Download the models from GitHub uses: actions/download-artifact@v3 with: # The name here needs to match the name of the upload-artifact parameter name: android-models path: ${{ runner.temp }}/artifacts/ - - name: Verify the artifacts + - name: Verify the models shell: bash working-directory: ${{ runner.temp }}/artifacts/ run: | ls -lah ./ - - name: Upload the artifacts to S3 + - name: Upload the models to S3 uses: seemethere/upload-artifact-s3@v5 with: s3-bucket: gha-artifacts @@ -110,7 +137,7 @@ jobs: build-llm-demo: name: build-llm-demo uses: pytorch/test-infra/.github/workflows/linux_job.yml@main - needs: set-models + needs: set-parameters strategy: matrix: tokenizer: [bpe] @@ -139,20 +166,20 @@ jobs: needs: build-llm-demo runs-on: linux.2xlarge steps: - - name: Download the artifacts from GitHub + - name: Download the apps from GitHub uses: actions/download-artifact@v3 with: # The name here needs to match the name of the upload-artifact parameter name: android-apps path: ${{ runner.temp }}/artifacts/ - - name: Verify the artifacts + - name: Verify the apps shell: bash working-directory: ${{ runner.temp }}/artifacts/ run: | ls -lah ./ - - name: Upload the artifacts to S3 + - name: Upload the apps to S3 uses: seemethere/upload-artifact-s3@v5 with: s3-bucket: gha-artifacts @@ -169,20 +196,21 @@ jobs: contents: read uses: pytorch/test-infra/.github/workflows/mobile_job.yml@main needs: - - set-models + - set-parameters - upload-models - upload-android-apps strategy: matrix: - model: ${{ fromJson(needs.set-models.outputs.models) }} + model: ${{ fromJson(needs.set-parameters.outputs.models) }} + delegate: ${{ fromJson(needs.set-parameters.outputs.delegates) }} + device: ${{ fromJson(needs.set-parameters.outputs.devices) }} with: device-type: android runner: linux.2xlarge test-infra-ref: '' # This is the ARN of ExecuTorch project on AWS project-arn: arn:aws:devicefarm:us-west-2:308535385114:project:02a2cf0f-6d9b-45ee-ba1a-a086587469e6 - # This is the custom Android device pool that only includes Samsung Galaxy S2x - device-pool-arn: arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/e59f866a-30aa-4aa1-87b7-4510e5820dfa + device-pool-arn: ${{ matrix.device }} # Uploaded to S3 from the previous job, the name of the app comes from the project itself. # Unlike models there are limited numbers of build flavor for apps, and the model controls whether it should build with bpe/tiktoken tokenizer. # It's okay to build all possible apps with all possible flavors in job "build-llm-demo". However, in this job, once a model is given, there is only @@ -193,4 +221,4 @@ jobs: # The test spec can be downloaded from https://ossci-assets.s3.amazonaws.com/android-llama2-device-farm-test-spec.yml test-spec: arn:aws:devicefarm:us-west-2:308535385114:upload:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/abd86868-fa63-467e-a5c7-218194665a77 # Uploaded to S3 from the previous job - extra-data: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/${{ matrix.model }}/model.zip + extra-data: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/${{ matrix.model }}_${{ matrix.delegate }}/model.zip diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 39511ae917..509e254b55 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -7,6 +7,7 @@ # pyre-strict import logging +from typing import Optional import torch @@ -36,16 +37,24 @@ from torch.export.exported_program import ExportedProgram -def quantize_pt2( +# Note: this is not meant as a primary API since it can create inconsistencies +# if the quantizer here is different from the quantizer used to convert. It is +# however useful for unit tests to separate the converted model from the fused +# model, to be able to get reference numerics. +# If this does not apply, please use quantize_and_fuse_pt2 instead. +def convert_pt2( model: torch.nn.Module, inputs: tuple[object, ...], + quantizer: CadenceQuantizer, ) -> torch.fx.GraphModule: """ - Instantiate the CadenceQuantizer (PTQ), prepare, convert and fuse the model. - Returns a GraphModule with the quantized model. + Prepare and convert a model using the given quantizer. + The quantizer must be supplied and be the same as the one used to + fuse the model later, if applicable. If you do not expect that behavior, + please use quantize_and_fuse_pt2 instead, which will instantiate a + default quantizer for you if needed. + Returns a GraphModule with the converted model. """ - # Quantizer - quantizer = CadenceQuantizer() # Export with dynamo model_exp = capture_pre_autograd_graph(model, inputs) @@ -62,12 +71,54 @@ def quantize_pt2( # Convert converted_model = convert_pt2e(prepared_model) + return converted_model + + +# Note: this is not meant as a primary API since it can create inconsistencies +# if the quantizer here is different from the quantizer used to convert. It is +# however useful for unit tests to separate the converted model from the fused +# model, to be able to get reference numerics. +# If this does not apply, please use quantize_and_fuse_pt2 instead. +def fuse_pt2( + converted_graph_module: torch.fx.GraphModule, + quantizer: CadenceQuantizer, +) -> torch.fx.GraphModule: + """ + Fuse a converted graph module using the given quantizer. + The quantizer must be the same as the one used to convert the model. + If you do not expect that behavior, please use quantize_and_fuse_pt2 instead, + which will instantiate a default quantizer for you if needed. + Returns a GraphModule with the fused model. + """ # Get patterns and apply fusion of dq -> op -> q to qop # pyre-ignore[16]: no attribute patterns = [q.pattern for q in quantizer.quantizers] - QuantFusion(patterns)(converted_model) + QuantFusion(patterns)(converted_graph_module) - return converted_model + return converted_graph_module + + +# Note: this is the one-liner API to quantize and fuse a model. +def quantize_pt2( + model: torch.nn.Module, + inputs: tuple[object, ...], + quantizer: Optional[CadenceQuantizer] = None, +) -> torch.fx.GraphModule: + """ + Prepare, convert and fuse the model using the given quantizer. + Returns a GraphModule with the quantized model. + """ + # Quantizer + if not quantizer: + quantizer = CadenceQuantizer() + + # Get converted graph module + converted_gm = convert_pt2(model, inputs, quantizer) + + # Get fused model + fused_gm = fuse_pt2(converted_gm, quantizer) + + return fused_gm # Export the model and lower it to an ExportedProgram (in aten IR) diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index dbfe1e3639..b31bb20549 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -145,7 +145,7 @@ - arg_meta: null kernel_name: impl::reference::quantized_linear_out -- func: cadence::quantized_relu.out(Tensor X, Tensor X_zero_point, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_relu.out(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null kernel_name: impl::reference::quantized_relu_out diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index adcf086873..a4d856ebed 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -43,9 +43,11 @@ "quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" ) -lib.define("quantized_relu(Tensor X, Tensor X_zero_point) -> (Tensor Y)") lib.define( - "quantized_relu.out(Tensor X, Tensor X_zero_point, *, Tensor(a!) out) -> Tensor (a!)" + "quantized_relu(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Y)" +) +lib.define( + "quantized_relu.out(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor (a!)" ) lib.define( @@ -168,6 +170,9 @@ def quantized_layer_norm_meta( def quantized_relu_meta( X: torch.Tensor, X_zero_point: torch.Tensor, + out_zero_point: int, + out_multiplier: torch.Tensor, + out_shift: torch.Tensor, ): return X.new_empty(X.size(), dtype=torch.uint8) diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 4c43172a92..7c05e9b867 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -287,7 +287,15 @@ def get_args_and_kwargs_relu( graph_module: GraphModule, inputs_inputs: List[fx.Node], dequants_inputs: List[fx.Node], + quant_node: fx.Node, ) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]: + input_scale = dequants_inputs[0].args[1] + # pyre-fixme[58]: Unsupported operand types + requantize_scale = input_scale / quant_node.args[1] + requantize_scale_t = torch.tensor([requantize_scale]) + + (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t) + # Make the args and kwargs for the replacement op args = tuple(inputs_inputs) @@ -296,9 +304,22 @@ def get_args_and_kwargs_relu( ([1], dequants_inputs[0].args[2]), {"dtype": torch.int32}, ) + out_multiplier_ = graph_module.graph.call_function( + torch.ops.aten.full.default, + ([1], out_multiplier[0].item()), + {"dtype": torch.int32}, + ) + out_shift_ = graph_module.graph.call_function( + torch.ops.aten.full.default, + ([1], out_shift[0].item()), + {"dtype": torch.int32}, + ) kwargs = { "X_zero_point": X_zero_point, + "out_zero_point": quant_node.args[2], + "out_multiplier": out_multiplier_, + "out_shift": out_shift_, } return args, kwargs @@ -420,6 +441,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 graph_module, inputs_inputs, dequants_inputs, + quant_node, ) fused = graph_module.graph.call_function( pattern.replacement_op(), diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 7043bae571..c5eb3b964d 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -303,9 +303,7 @@ def get_anchors( inputs=[(relu_node, 0)], weights=[], biases=[], - output=[ - (relu_node, SharedQuantizationSpec((relu_node.args[0], relu_node))) - ], + output=[(relu_node,)], ) def replacement_op(self) -> OpOverload: diff --git a/backends/cadence/reference/operators/quantized_relu_out.cpp b/backends/cadence/reference/operators/quantized_relu_out.cpp index 54f6b723c6..bcfd28b5bc 100644 --- a/backends/cadence/reference/operators/quantized_relu_out.cpp +++ b/backends/cadence/reference/operators/quantized_relu_out.cpp @@ -16,19 +16,30 @@ namespace native { using Tensor = exec_aten::Tensor; using RuntimeContext = torch::executor::RuntimeContext; -// Note: this kernel assumes that the input and output share quantization -// parameters. If that is not the case, it will produce incorrect results. template void quantized_relu_( const Tensor& input, const Tensor& in_zero_point, + const int64_t out_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, Tensor& output) { T q_zero_point = in_zero_point.const_data_ptr()[0]; const T* __restrict__ in = input.const_data_ptr(); T* __restrict__ out = output.mutable_data_ptr(); + const int32_t* __restrict__ out_multiplier_data = + out_multiplier.const_data_ptr(); + const int32_t* __restrict__ out_shift_data = + out_shift.const_data_ptr(); + + // Compute the out_scale from out_multiplier and out_shift + const float out_scale = + -out_multiplier_data[0] * 1.0 / (1 << 31) * pow(2, out_shift_data[0]); + for (size_t i = 0, e = input.numel(); i < e; ++i) { - out[i] = in[i] > q_zero_point ? in[i] : q_zero_point; + const T temp = in[i] > q_zero_point ? (in[i] - q_zero_point) : 0; + out[i] = kernels::quantize(temp, out_scale, out_zero_point); } } @@ -36,11 +47,26 @@ void quantized_relu_out( RuntimeContext& ctx, const Tensor& input, const Tensor& in_zero_point, + const int64_t out_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, Tensor& output) { if (input.scalar_type() == exec_aten::ScalarType::Byte) { - quantized_relu_(input, in_zero_point, output); + quantized_relu_( + input, + in_zero_point, + out_zero_point, + out_multiplier, + out_shift, + output); } else if (input.scalar_type() == exec_aten::ScalarType::Char) { - quantized_relu_(input, in_zero_point, output); + quantized_relu_( + input, + in_zero_point, + out_zero_point, + out_multiplier, + out_shift, + output); } else { ET_CHECK_MSG(false, "Unhandled input dtype %hhd", input.scalar_type()); } diff --git a/backends/qualcomm/aot/ir/qcir_utils.cpp b/backends/qualcomm/aot/ir/qcir_utils.cpp index e025b8667a..75446bb733 100755 --- a/backends/qualcomm/aot/ir/qcir_utils.cpp +++ b/backends/qualcomm/aot/ir/qcir_utils.cpp @@ -100,7 +100,7 @@ Qnn_DataType_t ToDataType(qcir::DataType type) { } flatbuffers::Offset ToQuantizeParam( - const Qnn_QuantizeParams_t& param, + const Qnn_Tensor_t& tensor, flatbuffers::FlatBufferBuilder* builder) { static const std::unordered_map def_map{ {QNN_DEFINITION_IMPL_GENERATED, qcir::QuantizeDef::IMPL_GENERATED}, @@ -124,6 +124,7 @@ flatbuffers::Offset ToQuantizeParam( int32_t axis = 0; uint32_t bitwidth = 0; + auto param = QNN_VER_PTR(tensor)->quantizeParams; auto quant_type = type_map.at(param.quantizationEncoding); std::vector data; std::vector scales; @@ -160,7 +161,9 @@ flatbuffers::Offset ToQuantizeParam( } } break; default: - QNN_EXECUTORCH_LOG_ERROR("QNN_QUANTIZATION_ENCODING_UNDEFINED detected"); + QNN_EXECUTORCH_LOG_WARN( + "QNN_QUANTIZATION_ENCODING_UNDEFINED detected: %s", + QNN_VER_PTR(tensor)->name); break; } return CreateQuantizeParamDirect( @@ -174,7 +177,7 @@ flatbuffers::Offset ToQuantizeParam( &data); } -Qnn_QuantizeParams_t ToQuantizeParam(const qparam_type& param) { +Qnn_QuantizeParams_t ToQuantizeParam(const tensor_type& tensor) { static const std::unordered_map def_map{ {qcir::QuantizeDef::IMPL_GENERATED, QNN_DEFINITION_IMPL_GENERATED}, {qcir::QuantizeDef::DEFINED, QNN_DEFINITION_DEFINED}, @@ -196,6 +199,7 @@ Qnn_QuantizeParams_t ToQuantizeParam(const qparam_type& param) { }; Qnn_QuantizeParams_t p = QNN_QUANTIZE_PARAMS_INIT; + auto param = tensor->qparam(); p.encodingDefinition = def_map.at(param->def()); p.quantizationEncoding = type_map.at(param->type()); switch (p.quantizationEncoding) { @@ -225,7 +229,9 @@ Qnn_QuantizeParams_t ToQuantizeParam(const qparam_type& param) { const_cast(param->offsets()->data()); } break; default: - QNN_EXECUTORCH_LOG_ERROR("qcir::QuantizeType::UNDEFINED detected"); + QNN_EXECUTORCH_LOG_WARN( + "qcir::QuantizeType::UNDEFINED detected: %s", + tensor->name()->c_str()); break; } return p; @@ -248,7 +254,7 @@ flatbuffers::Offset ToTensor( &shape, ToTensorType(QNN_VER_PTR(tensor)->type), ToDataType(QNN_VER_PTR(tensor)->dataType), - ToQuantizeParam(QNN_VER_PTR(tensor)->quantizeParams, builder), + ToQuantizeParam(tensor, builder), &buffer); } @@ -261,7 +267,7 @@ Qnn_Tensor_t ToTensor(const tensor_type& tensor) { QNN_VER_PTR(t)->name = tensor->name()->c_str(); QNN_VER_PTR(t)->type = ToTensorType(tensor->type()); QNN_VER_PTR(t)->dataType = ToDataType(tensor->dtype()); - QNN_VER_PTR(t)->quantizeParams = ToQuantizeParam(tensor->qparam()); + QNN_VER_PTR(t)->quantizeParams = ToQuantizeParam(tensor); QNN_VER_PTR(t)->rank = tensor->shape()->size(); QNN_VER_PTR(t)->dimensions = const_cast(tensor->shape()->data()); QNN_VER_PTR(t)->clientBuf.dataSize = tensor->data()->size(); diff --git a/backends/qualcomm/aot/ir/qcir_utils.h b/backends/qualcomm/aot/ir/qcir_utils.h index 30a5481f9f..890dfa33ca 100755 --- a/backends/qualcomm/aot/ir/qcir_utils.h +++ b/backends/qualcomm/aot/ir/qcir_utils.h @@ -26,9 +26,9 @@ qcir::DataType ToDataType(Qnn_DataType_t type); Qnn_DataType_t ToDataType(qcir::DataType type); flatbuffers::Offset ToQuantizeParam( - const Qnn_QuantizeParams_t& param, + const Qnn_Tensor_t& tensor, flatbuffers::FlatBufferBuilder* builder); -Qnn_QuantizeParams_t ToQuantizeParam(const qparam_type& type); +Qnn_QuantizeParams_t ToQuantizeParam(const tensor_type& tensor); flatbuffers::Offset ToTensor( const Qnn_Tensor_t& tensor, diff --git a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp index f08f688cf9..36512c4ff2 100644 --- a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp +++ b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp @@ -23,12 +23,12 @@ Result QnnExecuTorchBackend::init( ArrayRef compile_specs) const { // covert SizedBuffer to qnn ExecuTorch option QnnExecuTorchContextBinary qnn_context_blob; - const qnn_delegate::QnnExecuTorchOptions* qnn_executorch_options; + const qnn_delegate::QnnExecuTorchOptions* qnn_executorch_options = nullptr; qnn_context_blob.buffer = const_cast(processed->data()); qnn_context_blob.nbytes = processed->size(); - // covert CompileSpec to qnn ExecuTorch option + // convert CompileSpec to qnn ExecuTorch option for (auto& compile_spec : compile_specs) { if (std::strcmp(compile_spec.key, QNN_COMPILE_SPEC) == 0) qnn_executorch_options = diff --git a/backends/qualcomm/runtime/SharedBuffer.cpp b/backends/qualcomm/runtime/SharedBuffer.cpp index 430c8f757a..3fa62d09cd 100644 --- a/backends/qualcomm/runtime/SharedBuffer.cpp +++ b/backends/qualcomm/runtime/SharedBuffer.cpp @@ -87,7 +87,12 @@ SharedBuffer& SharedBuffer::GetSharedBufferManager() { std::lock_guard lk(init_mutex_); static SharedBuffer shared_buffer_manager; if (!shared_buffer_manager.GetInitialize()) { +#if defined(__aarch64__) Error status = shared_buffer_manager.Load(); +#else + // For x86_64 platform + Error status = Error::Ok; +#endif if (status == Error::Ok) { shared_buffer_manager.SetInitialize(true); } @@ -96,9 +101,11 @@ SharedBuffer& SharedBuffer::GetSharedBufferManager() { } SharedBuffer::~SharedBuffer() { +#if defined(__aarch64__) if (initialize_) { SharedBuffer::GetSharedBufferManager().UnLoad(); } +#endif }; void* SharedBuffer::AllocMem(size_t bytes, size_t alignment) { diff --git a/backends/qualcomm/scripts/build.sh b/backends/qualcomm/scripts/build.sh index 3712a83fde..d6b1da62fc 100755 --- a/backends/qualcomm/scripts/build.sh +++ b/backends/qualcomm/scripts/build.sh @@ -107,19 +107,33 @@ if [ "$BUILD_X86_64" = true ]; then rm -rf $BUILD_ROOT && mkdir $BUILD_ROOT fi cd $BUILD_ROOT + # TODO: Use CMAKE_BUILD_TYPE=RelWithDebInfo, and handle flatcc issues cmake \ - -DCMAKE_BUILD_TYPE=RelWithDebInfo \ + -DCMAKE_BUILD_TYPE=Debug \ + -DCMAKE_INSTALL_PREFIX=$BUILD_ROOT \ -DQNN_SDK_ROOT=${QNN_SDK_ROOT} \ -DEXECUTORCH_BUILD_QNN=ON \ + -DEXECUTORCH_BUILD_SDK=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_ENABLE_EVENT_TRACER=ON \ -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ - -DBUCK2=$BUCK2 \ -S $PRJ_ROOT \ -B $BUILD_ROOT \ - cmake \ - --build $BUILD_ROOT \ - -t "PyQnnManagerAdaptor" "PyQnnWrapperAdaptor" -j16 + cmake --build $BUILD_ROOT -j16 --target install rm -f $PRJ_ROOT/backends/qualcomm/python/* cp -fv $BUILD_ROOT/backends/qualcomm/Py* "$PRJ_ROOT/backends/qualcomm/python" + + EXAMPLE_ROOT=examples/qualcomm + CMAKE_PREFIX_PATH="${BUILD_ROOT}/lib/cmake/ExecuTorch;${BUILD_ROOT}/third-party/gflags;" + + cmake $PRJ_ROOT/$EXAMPLE_ROOT \ + -DCMAKE_BUILD_TYPE=Debug \ + -DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \ + -DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \ + -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ + -B$EXAMPLE_ROOT + + cmake --build $EXAMPLE_ROOT -j16 fi diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 80fc71ef7c..c1c070ca3c 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -147,6 +147,7 @@ def test_qnn_backend_element_wise_ceil(self): def test_qnn_backend_element_wise_div(self): eps = 1e-03 + torch.manual_seed(8) test_comb = [ { QCOM_MODULE: [Div()], # noqa: F405 @@ -721,6 +722,7 @@ def test_qnn_backend_element_wise_ceil(self): def test_qnn_backend_element_wise_div(self): eps = 1e-03 + torch.manual_seed(8) test_comb = [ { QCOM_MODULE: [Div()], # noqa: F405 @@ -1323,7 +1325,6 @@ def test_qnn_backend_multi_contexts_composite(self): exec_prog = edge_prog.to_executorch() self.verify_output(module.get_reference_module(), sample_input, exec_prog) - @unittest.expectedFailure def test_qnn_backend_profile_op(self): TestQNN.enable_profile = True backend_options = generate_htp_compiler_spec(use_fp16=True) @@ -1338,7 +1339,7 @@ def test_qnn_backend_profile_op(self): module, sample_input, expected_partitions=1, - expected_profile_events=25, + expected_profile_events=24, ) def test_qnn_backend_shared_buffer(self): @@ -1488,7 +1489,6 @@ def test_qnn_backend_multi_contexts_composite(self): exec_prog = edge_prog.to_executorch() self.verify_output(module.get_reference_module(), sample_input, exec_prog) - @unittest.expectedFailure def test_qnn_backend_profile_op(self): TestQNN.enable_profile = True backend_options = generate_htp_compiler_spec(use_fp16=False) @@ -1504,7 +1504,7 @@ def test_qnn_backend_profile_op(self): module, sample_input, expected_partitions=1, - expected_profile_events=26, + expected_profile_events=25, ) def test_qnn_backend_shared_buffer(self): @@ -2288,6 +2288,12 @@ def setup_environment(): help="Path to open source software model repository", type=str, ) + parser.add_argument( + "-x", + "--enable_x86_64", + help="Enable unittest to be executed on x86_64 platform", + action="store_true", + ) args, ns_args = parser.parse_known_args(namespace=unittest) TestQNN.host = args.host @@ -2304,6 +2310,7 @@ def setup_environment(): TestQNN.error_only = args.error_only TestQNN.oss_repo = args.oss_repo TestQNN.shared_buffer = args.shared_buffer + TestQNN.enable_x86_64 = args.enable_x86_64 return sys.argv[:1] + ns_args diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index f31f07562b..ef0ac0f202 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -27,7 +27,11 @@ QcomChipset, ) from executorch.backends.qualcomm.utils.utils import capture_program -from executorch.examples.qualcomm.scripts.utils import SimpleADB +from executorch.examples.qualcomm.scripts.utils import ( + generate_inputs, + make_output_dir, + SimpleADB, +) from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.compile_spec_schema import CompileSpec @@ -133,6 +137,7 @@ class TestQNN(unittest.TestCase): use_16a16w: str = "16a16w" use_16a4w: str = "16a4w" shared_buffer: bool = False + enable_x86_64: bool = False def _assert_outputs_equal(self, model_output, ref_output): self.assertTrue(len(ref_output) == len(model_output)) @@ -201,16 +206,16 @@ def verify_output( tmp_dir, ) - device_output_dir = f"{tmp_dir}/outputs" - device_outputs = [] + output_dir = f"{tmp_dir}/outputs" + outputs = [] etdump_path = f"{tmp_dir}/etdump.etdp" def post_process(): - for i, f in enumerate(sorted(os.listdir(device_output_dir))): - filename = os.path.join(device_output_dir, f) + for i, f in enumerate(sorted(os.listdir(output_dir))): + filename = os.path.join(output_dir, f) output = np.fromfile(filename, dtype=ref_outputs[i].numpy().dtype) output = torch.from_numpy(output).reshape(ref_outputs[i].shape) - device_outputs.append(output) + outputs.append(output) def validate_profile(): inspector = Inspector(etdump_path=etdump_path, etrecord=etrecord_path) @@ -218,23 +223,58 @@ def validate_profile(): len(inspector.to_dataframe().index) == expected_profile_events ) - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=self.build_folder, - pte_path=pte_fname, - workspace="/data/local/tmp/qnn_executorch_test", - device_id=self.device, - host_id=self.host, - soc_model=self.model, - error_only=self.error_only, - ) - adb.push(inputs=[sample_inputs], input_list=input_list) - adb.execute() - adb.pull(output_path=tmp_dir, callback=post_process) - self._assert_outputs_equal(device_outputs, ref_outputs) + if self.enable_x86_64: + generate_inputs(tmp_dir, "input_list.txt", [sample_inputs], input_list) + make_output_dir(output_dir) + + target = "x86_64-linux-clang" + qnn_sdk = os.environ.get("QNN_SDK_ROOT", None) + assert qnn_sdk, "QNN_SDK_ROOT was not found in environment variable" + + build_path = "build_x86_64" + cmds = [ + # export LD_LIBRARY_PATH to QNN_SDK_ROOT + f"export LD_LIBRARY_PATH={qnn_sdk}/lib/{target}/:{self.executorch_root}/{build_path}/lib && " + # qnn_executor_runner + f"{self.executorch_root}/{build_path}/examples/qualcomm/qnn_executor_runner", + f"--model_path {pte_fname}", + f"--input_list_path {tmp_dir}/input_list.txt", + f"--output_folder_path {output_dir}", + ] + + subprocess.run( + " ".join(cmds), + shell=True, + executable="/bin/bash", + capture_output=True, + cwd=tmp_dir, + ) + + # Verify the outputs + post_process() + self._assert_outputs_equal(outputs, ref_outputs) + + # Verify the etdump + if expected_profile_events != -1: + validate_profile() + else: + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=self.build_folder, + pte_path=pte_fname, + workspace="/data/local/tmp/qnn_executorch_test", + device_id=self.device, + host_id=self.host, + soc_model=self.model, + error_only=self.error_only, + ) + adb.push(inputs=[sample_inputs], input_list=input_list) + adb.execute() + adb.pull(output_path=tmp_dir, callback=post_process) + self._assert_outputs_equal(outputs, ref_outputs) - if expected_profile_events != -1: - adb.pull_etdump(etdump_path, callback=validate_profile) + if expected_profile_events != -1: + adb.pull_etdump(etdump_path, callback=validate_profile) def lower_module_and_test_output( self, @@ -362,6 +402,8 @@ def _insert_clone( (node,), ) inserted_node.meta["val"] = node.meta["val"] + if "quant_attrs" in node.meta: + inserted_node.meta["quant_attrs"] = node.meta["quant_attrs"] for user in users: user.replace_input_with(node, inserted_node) diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 2046e78e88..fb2c379c1b 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -319,24 +319,20 @@ utils::uvec3 ComputeGraph::create_global_wg_size(const ValueRef idx) { return image_extents_of(idx); } -utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) { +utils::uvec3 ComputeGraph::create_local_wg_size( + const utils::uvec3 global_wg_size) { if (config_.enable_local_wg_size_override) { return config_.local_wg_size_override; } - if (is_buffer_storage(idx)) { - return {64u, 1u, 1u}; - } - - const utils::uvec3 image_extents = image_extents_of(idx); utils::uvec3 local_group_size = {4, 4, 4}; - if (image_extents.data[2u] == 1) { - if (image_extents.data[1u] == 1) { + if (global_wg_size.data[2u] == 1) { + if (global_wg_size.data[1u] == 1) { local_group_size.data[0u] = 64; local_group_size.data[1u] = 1; local_group_size.data[2u] = 1; - } else if (image_extents.data[1u] < 8) { + } else if (global_wg_size.data[1u] < 8) { local_group_size.data[0u] = 16; local_group_size.data[1u] = 4; local_group_size.data[2u] = 1; @@ -349,6 +345,10 @@ utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) { return local_group_size; } +utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) { + return create_local_wg_size(image_extents_of(idx)); +} + void ComputeGraph::copy_into_staging( const ValueRef idx, const void* data, diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 5237a7746d..898a856291 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -180,7 +180,9 @@ class ComputeGraph final { return values_.at(idx).type(); } - // Get Tensor Property + // + // Tensor Properties Accessors + // std::vector sizes_of(const ValueRef idx) const; @@ -226,7 +228,9 @@ class ComputeGraph final { return values_.at(idx).toTensor().ntexels_ubo(); } + // // Scalar Value Extraction + // template T extract_scalar(const ValueRef idx) { @@ -459,9 +463,7 @@ class ComputeGraph final { utils::uvec3 create_global_wg_size(const ValueRef idx); /* - * Suggest a local workgroup size for a given `api::vTensor` value, assuming - * that every shader invocation calculates one texel element of the output - * tensor. + * Suggest a local workgroup size for a given global workgroup size. * * The local workgroup size will be formed to try and minimize the number of * inactive invocations. @@ -469,6 +471,13 @@ class ComputeGraph final { * Currently, the local workgroup size is hard-coded to contain a total of 64 * shader invocations. In the future, this value can be configured. */ + utils::uvec3 create_local_wg_size(const utils::uvec3 global_wg_size); + + /* + * Convenience function to suggest a local workgroup size for a given + * `api::vTensor` value, assuming that every shader invocation calculates one + * texel element of the output tensor. + */ utils::uvec3 create_local_wg_size(const ValueRef idx); // @@ -500,6 +509,17 @@ class ComputeGraph final { void resize_input(const int64_t idx, const std::vector& new_sizes); void propagate_resize(); + // + // Miscellaneous Utilities + // + + /* + * Check whether the GPU supports 8 bit buffers. + */ + inline bool int8_buffers_enabled() const { + return context_->adapter_ptr()->has_full_int8_buffers_support(); + } + // // Debug support (implemented in Logging.cpp) // diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h index 0ecfb83eac..d3264e43a2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -80,7 +80,7 @@ ivec4 from_nchw_buffer_i(int buf_i, ivec4 sizes) { * Returns: The (x, y, z, n) texel position corresponding to the first element * of the texel at the specified buffer index */ -ivec4 to_texel_pos(int buf_i, ivec4 strides, int packed_dim) { +ivec4 to_tensor_idx(int buf_i, ivec4 strides, int packed_dim) { ivec4 idx; for (int i = 3; i >= 0; i--) { if (i != packed_dim) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/int8_tensor_to_nchw_noint8.glsl b/backends/vulkan/runtime/graph/ops/glsl/int8_tensor_to_nchw_noint8.glsl new file mode 100644 index 0000000000..21290d0ce8 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/int8_tensor_to_nchw_noint8.glsl @@ -0,0 +1,54 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#include "indexing_utils.h" + +layout(std430) buffer; + +#extension GL_EXT_control_flow_attributes : require + +${layout_declare_tensor(0, "r", "t_in", "int8", "texture3d")} +${layout_declare_buffer(1, "w", "nchw_out", "int")} +${layout_declare_ubo(2, "ivec4", "tensor_sizes")} +${layout_declare_ubo(3, "int", "out_ntexels")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int packed_dim = C_DIM; + +void main() { + const int out_buf_idx = int(gl_GlobalInvocationID.x); + if (out_buf_idx >= out_ntexels) { + return; + } + + ivec4 values; + int in_buf_idx = 4 * out_buf_idx; + + [[unroll]] for (int i = 0; i < 4; ++i) { + const ivec4 tensor_idx = from_nchw_buffer_i(in_buf_idx, tensor_sizes); + const ivec4 texture_pos = to_texture_elem_pos( + tensor_idx, tensor_sizes, packed_dim); + values[i] = load_texel(t_in, texture_pos.xyz)[texture_pos.w]; + in_buf_idx++; + } + + // Manually pack 4x 8-bit integers into a 32 bit integer. Note that little + // endian is assumed, since most processors use little endian. Thus the + // "later" values are placed in most significant bytes. + int packed = ((values[3] & 0xFF) << 24) + | ((values[2] & 0xFF) << 16) + | ((values[1] & 0xFF) << 8) + | ((values[0] & 0xFF)); + + nchw_out[out_buf_idx] = packed; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_tensor_noint8.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_tensor_noint8.glsl new file mode 100644 index 0000000000..378cf09d12 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_tensor_noint8.glsl @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#include "indexing_utils.h" + +layout(std430) buffer; + +#extension GL_EXT_control_flow_attributes : require + +${layout_declare_tensor(0, "w", "t_out", "int8", "texture3d")} +${layout_declare_buffer(1, "r", "nchw_in", "int")} +${layout_declare_ubo(2, "ivec4", "tensor_sizes")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int packed_dim = C_DIM; + +/* + * Extends sign of int8 + */ +int extend_sign(int x) { + if (x >> 7 == 1) { + return x | 0xFFFFFF00; + } + return x; +} + +ivec4 read_texel(ivec4 tensor_idx) { + const ivec4 buf_indices = get_texel_nchw_buffer_ixs( + tensor_idx, tensor_sizes, packed_dim); + + int shift = (1 << 8) - 1; + ivec4 masks; + // Masks used to unpack 4x 8-bit values from a 32 bit integer. Note that + // little endian is assumed, as most processors use little endian. Thus the + // most significant bytes correspond to the "latter" packed values. + masks.x = shift << (8 * (buf_indices.x % 4)); + masks.y = shift << (8 * (buf_indices.y % 4)); + masks.z = shift << (8 * (buf_indices.z % 4)); + masks.w = shift << (8 * (buf_indices.w % 4)); + + ivec4 out_tex = ivec4(0); + + [[unroll]] for (int i = 0; i < 4; ++i) { + if (tensor_idx[packed_dim] + i < tensor_sizes[packed_dim]) { + int in_texel = nchw_in[buf_indices[i] / 4]; + int extracted_val = (in_texel & masks[i]) >> (8 * (buf_indices[i] % 4)); + extracted_val = extend_sign(extracted_val); + out_tex[i] = extracted_val; + } + } + + return out_tex; +} + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 tensor_idx = to_tensor_idx(pos, tensor_sizes, packed_dim); + + if (any(greaterThanEqual(tensor_idx, tensor_sizes))) { + return; + } + + write_texel(t_out, pos, read_texel(tensor_idx)); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_tensor.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_tensor.glsl index c0bbc5183a..c218482b09 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_tensor.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_tensor.glsl @@ -62,7 +62,7 @@ void main() { return; } - ivec4 tensor_idx = to_texel_pos(t_id, gpu_strides, packed_dim); + ivec4 tensor_idx = to_tensor_idx(t_id, gpu_strides, packed_dim); tensor_idx[packed_dim] *= 4; t_out[t_id] = read_texel(tensor_idx); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl index 139c82866f..37988f21ec 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl @@ -53,7 +53,7 @@ void main() { return; } - const ivec4 out_pos = to_texel_pos(t_id, out_strides, 0); + const ivec4 out_pos = to_tensor_idx(t_id, out_strides, 0); VEC4_T outtex = q_8w_linear(out_pos, mat1_sizes.x); write_texel(t_out, t_id, outtex); diff --git a/backends/vulkan/runtime/graph/ops/glsl/tensor_to_nchw.glsl b/backends/vulkan/runtime/graph/ops/glsl/tensor_to_nchw.glsl index 78d8346428..d545e5d86e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/tensor_to_nchw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/tensor_to_nchw.glsl @@ -61,7 +61,7 @@ void main() { } const VEC4_T intex = t_in[t_id]; - ivec4 tensor_idx = to_texel_pos(t_id, gpu_strides, packed_dim); + ivec4 tensor_idx = to_tensor_idx(t_id, gpu_strides, packed_dim); tensor_idx[packed_dim] *= 4; write_out_texel(intex, tensor_idx); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 2e5e9addfb..79b463d7ef 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -21,8 +21,8 @@ void add_staging_to_tensor_node( const ValueRef out_tensor) { VK_CHECK_COND(graph.val_is_staging(in_staging)); - vkapi::ShaderInfo shader = - get_nchw_to_tensor_shader(*graph.get_tensor(out_tensor)); + vkapi::ShaderInfo shader = get_nchw_to_tensor_shader( + *graph.get_tensor(out_tensor), graph.int8_buffers_enabled()); vkapi::ParamsBindList ubos({graph.sizes_ubo(out_tensor)}); if (graph.is_buffer_storage(out_tensor)) { @@ -55,10 +55,26 @@ void add_tensor_to_staging_node( const ValueRef out_staging) { VK_CHECK_COND(graph.val_is_staging(out_staging)); - vkapi::ShaderInfo shader = - get_tensor_to_nchw_shader(*graph.get_tensor(in_tensor)); + vkapi::ShaderInfo shader = get_tensor_to_nchw_shader( + *graph.get_tensor(in_tensor), graph.int8_buffers_enabled()); + utils::uvec3 global_wg_size = graph.create_global_wg_size(in_tensor); vkapi::ParamsBindList ubos({graph.sizes_ubo(in_tensor)}); + + // Normally, the tensor_to_nchw shader is structured so that each thread reads + // one texel from the input texture and writes each component of the texel + // into the corresponding location in the output buffer. However, this shader + // is structured slightly differently in that each thread writes out a + // complete 32 bit integer (containing 4 packed 8-bit integers) into the + // output buffer. Therefore, the global work group size for this shader will + // be the number of elements in the output buffer divided by 4, as opposed to + // the extents of the input texture. + if (shader.kernel_name == "int8_tensor_to_nchw_noint8") { + uint32_t buffer_len = graph.get_staging(out_staging)->numel() / 4; + global_wg_size = {buffer_len, 1, 1}; + ubos.append({graph.ntexels_ubo(in_tensor)}); + } + if (graph.is_buffer_storage(in_tensor)) { ubos.append({ graph.texel_strides_ubo(in_tensor), @@ -69,8 +85,8 @@ void add_tensor_to_staging_node( graph.execute_nodes().emplace_back(new ExecuteNode( graph, shader, - graph.create_global_wg_size(in_tensor), - graph.create_local_wg_size(in_tensor), + global_wg_size, + graph.create_local_wg_size(global_wg_size), // Input and Outputs {{in_tensor, vkapi::MemoryAccessType::READ}, {out_staging, vkapi::MemoryAccessType::WRITE}}, @@ -86,7 +102,8 @@ ValueRef prepack( const utils::GPUMemoryLayout layout) { ValueRef v = graph.add_tensor_like(vref, layout); - vkapi::ShaderInfo shader = get_nchw_to_tensor_shader(*graph.get_tensor(v)); + vkapi::ShaderInfo shader = get_nchw_to_tensor_shader( + *graph.get_tensor(v), graph.int8_buffers_enabled()); vkapi::ParamsBindList ubos({graph.sizes_ubo(v)}); if (graph.is_buffer_storage(v)) { diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp index d681618d9d..2ade34e425 100644 --- a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp @@ -95,10 +95,17 @@ void set_staging_zeros(api::StorageBuffer& staging, const size_t nbytes) { memset(data_ptr, 0, staging.nbytes()); } -vkapi::ShaderInfo get_nchw_to_tensor_shader(const api::vTensor& v_dst) { +vkapi::ShaderInfo get_nchw_to_tensor_shader( + const api::vTensor& v_dst, + const bool int8_buffer_enabled) { std::string kernel_name; kernel_name.reserve(kShaderNameReserve); + if (v_dst.dtype() == vkapi::kChar && + v_dst.storage_type() == utils::kTexture3D && !int8_buffer_enabled) { + return VK_KERNEL(nchw_to_int8_tensor_noint8); + } + kernel_name = "nchw_to_tensor"; add_dtype_suffix(kernel_name, v_dst); add_storage_type_suffix(kernel_name, v_dst); @@ -106,10 +113,17 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader(const api::vTensor& v_dst) { return VK_KERNEL_FROM_STR(kernel_name); } -vkapi::ShaderInfo get_tensor_to_nchw_shader(const api::vTensor& v_src) { +vkapi::ShaderInfo get_tensor_to_nchw_shader( + const api::vTensor& v_src, + bool int8_buffer_enabled) { std::string kernel_name; kernel_name.reserve(kShaderNameReserve); + if (v_src.dtype() == vkapi::kChar && + v_src.storage_type() == utils::kTexture3D && !int8_buffer_enabled) { + return VK_KERNEL(int8_tensor_to_nchw_noint8); + } + kernel_name = "tensor_to_nchw"; add_dtype_suffix(kernel_name, v_src); add_storage_type_suffix(kernel_name, v_src); diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h index dfe86a9e26..cabc17f30e 100644 --- a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h @@ -31,7 +31,11 @@ void set_staging_zeros(api::StorageBuffer& staging, const size_t nbytes); // Functions to get shaders // -vkapi::ShaderInfo get_nchw_to_tensor_shader(const api::vTensor& v_dst); -vkapi::ShaderInfo get_tensor_to_nchw_shader(const api::vTensor& v_src); +vkapi::ShaderInfo get_nchw_to_tensor_shader( + const api::vTensor& v_dst, + bool int8_buffer_enabled = true); +vkapi::ShaderInfo get_tensor_to_nchw_shader( + const api::vTensor& v_src, + bool int8_buffer_enabled = true); } // namespace vkcompute diff --git a/backends/vulkan/test/glsl/all_shaders.yaml b/backends/vulkan/test/glsl/all_shaders.yaml index edba41b7ea..37403c97ac 100644 --- a/backends/vulkan/test/glsl/all_shaders.yaml +++ b/backends/vulkan/test/glsl/all_shaders.yaml @@ -47,21 +47,12 @@ idx_fill_buffer: idx_fill_texture: parameter_names_with_default_values: DTYPE: float - NDIM: 3 - PACKING: CHANNELS_PACKED generate_variant_forall: - PACKING: - - VALUE: "CHANNELS_PACKED" - SUFFIX: "C_packed" - - VALUE: "WIDTH_PACKED" - SUFFIX: "W_packed" - - VALUE: "HEIGHT_PACKED" - SUFFIX: "H_packed" DTYPE: - - VALUE: "half" - SUFFIX: "half" - - VALUE: "float" - SUFFIX: "float" + - VALUE: half + - VALUE: float + - VALUE: int + - VALUE: int8 shader_variants: - NAME: idx_fill_texture diff --git a/backends/vulkan/test/glsl/idx_fill_texture.glsl b/backends/vulkan/test/glsl/idx_fill_texture.glsl index 1f75cadf49..8914d2b892 100644 --- a/backends/vulkan/test/glsl/idx_fill_texture.glsl +++ b/backends/vulkan/test/glsl/idx_fill_texture.glsl @@ -12,21 +12,17 @@ #define VEC4_T ${texel_type(DTYPE)} -#define POS ${get_pos[NDIM]("pos")} - #include "indexing_utils.h" layout(std430) buffer; -layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; - -layout(set = 0, binding = 1) uniform PRECISION restrict Sizes { - ivec4 sizes; -}; +${layout_declare_tensor(0, "w", "image_out", DTYPE, "texture3d")} +${layout_declare_ubo(1, "ivec4", "sizes")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; layout(constant_id = 3) const int packed_dim = C_DIM; +layout(constant_id = 4) const int offset = 10; void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); @@ -37,6 +33,6 @@ void main() { } const ivec4 buf_indices = get_texel_nchw_buffer_ixs(idx, sizes, packed_dim); - VEC4_T texel = VEC4_T(buf_indices); - imageStore(image_out, POS, texel); + VEC4_T texel = VEC4_T(buf_indices) + offset; + imageStore(image_out, pos, texel); } diff --git a/backends/vulkan/test/utils/test_utils.cpp b/backends/vulkan/test/utils/test_utils.cpp index c55c286acc..649c0c82d6 100644 --- a/backends/vulkan/test/utils/test_utils.cpp +++ b/backends/vulkan/test/utils/test_utils.cpp @@ -76,7 +76,8 @@ void record_nchw_to_image_op( SV(v_dst.packed_dim_whcn_idx())}; context->submit_compute_job( - get_nchw_to_tensor_shader(v_dst), + get_nchw_to_tensor_shader( + v_dst, context->adapter_ptr()->has_full_int8_buffers_support()), pipeline_barrier, v_dst.image_extents(), adaptive_work_group_size(v_dst.image_extents()), @@ -112,6 +113,27 @@ void record_image_to_nchw_op( v_src.sizes_ubo()); } +void record_int8_image_to_nchw_noint8_op( + api::Context* const context, + api::vTensor& v_src, + api::StorageBuffer& dst_buffer) { + vkapi::PipelineBarrier pipeline_barrier{}; + uint32_t buffer_len = utils::safe_downcast(dst_buffer.numel() / 4); + utils::uvec3 global_wg_size = {buffer_len, 1, 1}; + context->submit_compute_job( + VK_KERNEL(int8_tensor_to_nchw_noint8), + pipeline_barrier, + global_wg_size, + adaptive_work_group_size(global_wg_size), + {v_src.packed_dim_whcn_idx()}, + VK_NULL_HANDLE, + 0, + v_src.image(pipeline_barrier, vkapi::PipelineStage::COMPUTE), + dst_buffer.buffer(), + v_src.sizes_ubo(), + v_src.ntexels_ubo()); +} + void record_conv2d_prepack_weights_op( api::Context* const context, vkapi::VulkanBuffer& src_buffer, diff --git a/backends/vulkan/test/utils/test_utils.h b/backends/vulkan/test/utils/test_utils.h index 89e16131c9..3dd9497e69 100644 --- a/backends/vulkan/test/utils/test_utils.h +++ b/backends/vulkan/test/utils/test_utils.h @@ -82,6 +82,11 @@ void record_image_to_nchw_op( api::vTensor& v_src, vkapi::VulkanBuffer& dst_buffer); +void record_int8_image_to_nchw_noint8_op( + api::Context* const context, + api::vTensor& v_src, + api::StorageBuffer& dst_buffer); + void record_conv2d_prepack_weights_op( api::Context* const context, vkapi::VulkanBuffer& src_buffer, diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 5f4fa519ca..6f0879c422 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -1692,25 +1693,21 @@ void run_from_gpu_test( if (dtype == vkapi::kHalf && !context()->adapter_ptr()->has_16bit_storage()) { return; } - if ((dtype == vkapi::kChar || dtype == vkapi::kQInt8) && - !context()->adapter_ptr()->has_full_int8_buffers_support()) { - return; - } vTensor vten = vTensor(context(), sizes, dtype, storage_type, memory_layout); std::string kernel_name("idx_fill_texture"); - add_memory_layout_suffix(kernel_name, vten); add_dtype_suffix(kernel_name, vten); + int32_t offset = -50; + { vkapi::PipelineBarrier pipeline_barrier{}; - vkapi::SpecVarList specialization_constants = {vten.packed_dim_whcn_idx()}; context()->submit_compute_job( VK_KERNEL_FROM_STR(kernel_name), pipeline_barrier, vten.image_extents(), {4, 4, 4}, - specialization_constants, + {vten.packed_dim_whcn_idx(), offset}, VK_NULL_HANDLE, 0, vten.image( @@ -1722,7 +1719,12 @@ void run_from_gpu_test( StorageBuffer staging_buffer(context(), dtype, vten.gpu_numel()); - record_image_to_nchw_op(context(), vten, staging_buffer.buffer()); + if (dtype == vkapi::kChar && + !context()->adapter_ptr()->has_full_int8_buffers_support()) { + record_int8_image_to_nchw_noint8_op(context(), vten, staging_buffer); + } else { + record_image_to_nchw_op(context(), vten, staging_buffer.buffer()); + } submit_to_gpu(); @@ -1730,12 +1732,12 @@ void run_from_gpu_test( copy_staging_to_ptr(staging_buffer, data_out.data(), staging_buffer.nbytes()); for (int i = 0; i < vten.numel(); i++) { - CHECK_VALUE(data_out, i, i); + CHECK_VALUE(data_out, i, i + offset); } } template -void run_to_gpu_test( +void round_trip_test( std::vector& sizes, utils::GPUMemoryLayout memory_layout = utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, @@ -1744,10 +1746,6 @@ void run_to_gpu_test( if (dtype == vkapi::kHalf && !context()->adapter_ptr()->has_16bit_storage()) { return; } - if ((dtype == vkapi::kChar || dtype == vkapi::kQInt8) && - !context()->adapter_ptr()->has_full_int8_buffers_support()) { - return; - } vTensor vten = vTensor(context(), sizes, dtype, storage_type, memory_layout); @@ -1756,16 +1754,22 @@ void run_to_gpu_test( std::vector data_in(staging_buffer_in.numel()); for (int i = 0; i < staging_buffer_in.numel(); i++) { - data_in[i] = i; + data_in[i] = T(i * -1); } copy_ptr_to_staging(data_in.data(), staging_buffer_in, vten.gpu_nbytes()); // Output staging buffer StorageBuffer staging_buffer_out(context(), dtype, vten.gpu_numel()); - // Copy data in and out of the tensor record_nchw_to_image_op(context(), staging_buffer_in.buffer(), vten); - record_image_to_nchw_op(context(), vten, staging_buffer_out.buffer()); + + // Copy data in and out of the tensor + if (dtype == vkapi::kChar && + !context()->adapter_ptr()->has_full_int8_buffers_support()) { + record_int8_image_to_nchw_noint8_op(context(), vten, staging_buffer_out); + } else { + record_image_to_nchw_op(context(), vten, staging_buffer_out.buffer()); + } // Execute command buffer submit_to_gpu(); @@ -1777,11 +1781,51 @@ void run_to_gpu_test( // All indices should be equal to the input data for (int i = 0; i < vten.numel(); i++) { - CHECK_VALUE(data_out, i, i); + CHECK_VALUE(data_out, i, data_in[i]); + } +} + +template +void compute_graph_round_trip_test( + std::vector& sizes, + utils::GPUMemoryLayout memory_layout = + utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, + vkapi::ScalarType dtype = vkapi::kFloat, + utils::StorageType storage_type = utils::StorageType::TEXTURE_3D) { + if (dtype == vkapi::kHalf && !context()->adapter_ptr()->has_16bit_storage()) { + return; + } + + GraphConfig config; + ComputeGraph graph(config); + + ValueRef r_tensor = + graph.add_tensor(sizes, dtype, storage_type, memory_layout); + ValueRef r_staging_in = graph.set_input_tensor(r_tensor); + ValueRef r_staging_out = graph.set_output_tensor(r_tensor); + + graph.prepare(); + graph.encode_execute(); + + vTensorPtr tensor = graph.get_tensor(r_tensor); + + std::vector data_in(tensor->numel()); + for (int i = 0; i < data_in.size(); i++) { + data_in[i] = T(i * -1); + } + graph.copy_into_staging(r_staging_in, data_in.data(), data_in.size()); + + graph.execute(); + + std::vector data_out(tensor->gpu_numel()); + graph.copy_from_staging(r_staging_out, data_out.data(), data_out.size()); + + for (int i = 0; i < data_in.size(); i++) { + CHECK_VALUE(data_out, i, data_in[i]); } } -TEST(VulkanToFromGPUShaderTest, to_gpu_and_from_gpu_test_texture) { +TEST(VulkanToFromGPUShaderTest, round_trip_tests) { // The below tests will fill each texel element with the value of the linear // buffer index that corresponds to it. The texel at position (0, 0, 0) will // be filled with the values [0, 1, 2, 3], the texel at position (1, 0, 0) @@ -1824,11 +1868,17 @@ TEST(VulkanToFromGPUShaderTest, to_gpu_and_from_gpu_test_texture) { }; #define RUN_TESTS(ctype, dtype) \ - run_to_gpu_test( \ + round_trip_test( \ + sizes, utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, dtype); \ + round_trip_test( \ + sizes, utils::GPUMemoryLayout::TENSOR_WIDTH_PACKED, dtype); \ + round_trip_test( \ + sizes, utils::GPUMemoryLayout::TENSOR_HEIGHT_PACKED, dtype); \ + compute_graph_round_trip_test( \ sizes, utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, dtype); \ - run_to_gpu_test( \ + compute_graph_round_trip_test( \ sizes, utils::GPUMemoryLayout::TENSOR_WIDTH_PACKED, dtype); \ - run_to_gpu_test( \ + compute_graph_round_trip_test( \ sizes, utils::GPUMemoryLayout::TENSOR_HEIGHT_PACKED, dtype); for (auto& sizes : to_test) { diff --git a/backends/xnnpack/CMakeLists.txt b/backends/xnnpack/CMakeLists.txt index 1ac7867f3c..b0b80d8633 100644 --- a/backends/xnnpack/CMakeLists.txt +++ b/backends/xnnpack/CMakeLists.txt @@ -37,21 +37,10 @@ set(_common_compile_options -Wno-deprecated-declarations -fPIC) set(_xnnpack_schema__include_dir "${CMAKE_BINARY_DIR}/schema/include") # Paths to headers generated from the .fbs files. -set(_xnnpack_flatbuffer__outputs) -foreach(fbs_file ${_xnnpack_schema__srcs}) - string(REGEX REPLACE "([^/]+)[.]fbs$" "\\1_generated.h" generated - "${fbs_file}" - ) - list(APPEND _xnnpack_flatbuffer__outputs - "${_xnnpack_schema__include_dir}/executorch/${generated}" - ) -endforeach() - set(_xnnpack_schema__outputs) foreach(fbs_file ${_xnnpack_schema__srcs}) - string(REGEX REPLACE "runtime_([^/]+)[.]fbs$" "\\1_generated.h" generated - "${fbs_file}" - ) + string(REGEX REPLACE "([^/]+)[.]fbs$" "\\1_generated.h" + generated "${fbs_file}") list(APPEND _xnnpack_schema__outputs "${_xnnpack_schema__include_dir}/executorch/${generated}" ) @@ -64,7 +53,6 @@ add_custom_command( ${FLATC_EXECUTABLE} --cpp --cpp-std c++11 --scoped-enums -o "${_xnnpack_schema__include_dir}/executorch/backends/xnnpack/serialization" ${_xnnpack_schema__srcs} - COMMAND mv ${_xnnpack_flatbuffer__outputs} ${_xnnpack_schema__outputs} WORKING_DIRECTORY ${EXECUTORCH_ROOT} COMMENT "Generating xnnpack_schema headers" VERBATIM diff --git a/backends/xnnpack/operators/op_conv2d.py b/backends/xnnpack/operators/op_conv2d.py index 5661a9a4d3..28da480574 100644 --- a/backends/xnnpack/operators/op_conv2d.py +++ b/backends/xnnpack/operators/op_conv2d.py @@ -52,6 +52,9 @@ def define_node( ) # NHWC input kwargs["input1_id"] = vals_to_ids[get_input_node(node, 0)] + # filter shape for pytorch convolution is (oc, inc/groups, height, width) + # shape for xnnpack convolution is (oc, height, width, inc/groups), to convert + # to the proper shape, this is essentially a NCHW to NHWC conversion kernel_node = get_input_node(node, 1) kernel_shape = get_shape(kernel_node) groups = cast(int, node.args[8]) @@ -65,19 +68,13 @@ def define_node( is_depthwise_conv = (group_input_channels == 1) and ( group_output_channels % group_input_channels == 0 ) - # filter - # filter shape for pytorch convolution is (oc, inc/groups, height, width) - # shape for xnnpack convolution is (oc, height, width, inc/groups), to convert - # to the proper shape, this is essentially a NCHW to NHWC conversion - weight_node = get_input_node(node, 1) weight_quant_params = QuantParams.from_weights( - weight_node, self._exported_program + kernel_node, self._exported_program ) - - fp32_static_weights = weight_node.meta["val"].dtype == torch.float16 + fp32_static_weights = kernel_node.meta["val"].dtype == torch.float16 self.define_tensor( - weight_node, + kernel_node, xnn_graph, vals_to_ids, convert_to_nhwc=True, diff --git a/backends/xnnpack/partition/TARGETS b/backends/xnnpack/partition/TARGETS index f11695460f..13cf15bfc5 100644 --- a/backends/xnnpack/partition/TARGETS +++ b/backends/xnnpack/partition/TARGETS @@ -6,6 +6,7 @@ runtime.python_library( name = "xnnpack_partitioner", srcs = [ "xnnpack_partitioner.py", + "xnnpack_partitioner2.py", ], visibility = [ "//executorch/...", @@ -15,6 +16,7 @@ runtime.python_library( ":configs", ":partitioner_graphs", "//executorch/backends/xnnpack:xnnpack_preprocess", + "//executorch/backends/xnnpack/partition/config:xnnpack_partitioner_configs", "//executorch/exir:delegate", "//executorch/exir:lib", "//executorch/exir/backend:partitioner", diff --git a/backends/xnnpack/partition/config/TARGETS b/backends/xnnpack/partition/config/TARGETS new file mode 100644 index 0000000000..adfbf95a72 --- /dev/null +++ b/backends/xnnpack/partition/config/TARGETS @@ -0,0 +1,20 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.python_library( + name = "xnnpack_partitioner_configs", + srcs = glob([ + "*.py", + ]), + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//executorch/exir:lib", + "//executorch/exir/backend:partitioner", + "//executorch/exir/backend:utils", + "//executorch/exir/backend/canonical_partitioners:config_partitioner_lib", + ], +) diff --git a/backends/xnnpack/partition/config/__init__.py b/backends/xnnpack/partition/config/__init__.py new file mode 100644 index 0000000000..f1f51b27b6 --- /dev/null +++ b/backends/xnnpack/partition/config/__init__.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List, Type + +from executorch.backends.xnnpack.partition.config.gemm_configs import ( + AddmmConfig, + ConvolutionConfig, + LinearConfig, +) + +from executorch.backends.xnnpack.partition.config.generic_node_configs import ( + AbsConfig, + AddConfig, + AvgPoolingConfig, + CatConfig, + CeilConfig, + ClampConfig, + DeQuantizedPerTensorConfig, + DivConfig, + # EluConfig, + HardtanhConfig, + MaximumConfig, + MaxPool2dConfig, + MulConfig, + PermuteConfig, + QuantizedPerTensorConfig, + ReLUConfig, + SigmoidConfig, + SoftmaxConfig, +) +from executorch.backends.xnnpack.partition.config.node_configs import ( + BatchNormConfig, + MaxDimConfig, +) +from executorch.backends.xnnpack.partition.config.xnnpack_config import ( + XNNPartitionerConfig, +) + +ALL_PARTITIONER_CONFIGS: List[Type[XNNPartitionerConfig]] = [ + # GEMM-like Configs + AddmmConfig, + LinearConfig, + ConvolutionConfig, + # BatchNorm Config + BatchNormConfig, + # Single Node Configs + HardtanhConfig, + AbsConfig, + AvgPoolingConfig, + AddConfig, + CatConfig, + CeilConfig, + ClampConfig, + DivConfig, + MaxDimConfig, + MaxPool2dConfig, + MaximumConfig, + MulConfig, + SoftmaxConfig, + SigmoidConfig, + PermuteConfig, + # EluConfig, # Waiting for PyTorch Pin Update + ReLUConfig, + # Quantization Op Configs + QuantizedPerTensorConfig, + DeQuantizedPerTensorConfig, +] diff --git a/backends/xnnpack/partition/config/gemm_configs.py b/backends/xnnpack/partition/config/gemm_configs.py new file mode 100644 index 0000000000..fc9562d33a --- /dev/null +++ b/backends/xnnpack/partition/config/gemm_configs.py @@ -0,0 +1,282 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import chain +from typing import List, Optional, Tuple + +import torch +from executorch.backends.xnnpack.partition.config.xnnpack_config import ( + ConfigPrecisionType, + XNNPartitionerConfig, +) +from executorch.backends.xnnpack.utils.quant_utils import ( + is_dequant, + is_dynamic_qdq, + is_per_channel, + is_qparam, + is_quant, +) +from executorch.backends.xnnpack.utils.utils import ( + get_input_node, + is_getitem, + is_node, + is_param_node, +) +from executorch.exir.backend.canonical_partitioners.config_partitioner import ( + format_target_name, +) +from torch.export import ExportedProgram + + +class GEMMConfig(XNNPartitionerConfig): + """ + GEMM-like ops like Convolution, Addmm, Linear, mostly behave in the same way, in which we + have some weight, bias, and activation node. The only difference between these types + of ops are that the weight, bias, and activations are in different indicies of the + nodes arguments, this class helps to generalize the logic needed to partition these + different ops + """ + + def __init__(self, weight_idx, bias_idx, act_idx, fused_acts): + super().__init__() + self.weight_idx = weight_idx + self.bias_idx = bias_idx + self.act_idx = act_idx + self.fused_acts = fused_acts + + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + if not self.check_common_constraints(node, ep): + # short circuit if we don't pass common constraints + return False + + precision = self._detect_precision(node) + if precision not in self.enabled_precision_types: + # detected precision but it is either disabled or not supported + return False + + is_valid, _ = self.get_deps(node, ep, precision) + return is_valid + + def get_node_and_deps( + self, node: torch.fx.Node, ep: ExportedProgram + ) -> List[torch.fx.Node]: + partition = [node] + precision = self._detect_precision(node) + _, deps = self.get_deps(node, ep, precision) + partition.extend(deps) + + return partition + + def get_original_aten(self) -> Optional[torch._ops.OpOverload]: + return None + + def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType: + weight = get_input_node(node, self.weight_idx) + + if not is_dequant(weight): + return ConfigPrecisionType.FP32 + + activation = get_input_node(node, self.act_idx) + if is_dynamic_qdq(activation): + return ConfigPrecisionType.DYNAMIC_QUANT + + return ConfigPrecisionType.STATIC_QUANT + + def get_deps( + self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType + ) -> Tuple[bool, List[torch.fx.Node]]: + """ + Gets all dependencies for this gemm partition. Returns a tuple of + a bool indicating if the deps are valid and a list of all the + dep nodes + """ + valid_bias, bias_deps = self._get_bias_deps(node, ep, precision) + valid_weight, weight_deps = self._get_weight_deps(node, ep, precision) + valid_act, act_deps = self._get_act_deps(node, ep, precision) + valid_output, output_deps = self._get_output_deps(node, ep, precision) + + valid_deps = valid_bias and valid_weight and valid_act and valid_output + deps = list(chain(bias_deps, weight_deps, act_deps, output_deps)) + + return valid_deps, deps + + def _get_weight_deps( + self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType + ) -> Tuple[bool, List[torch.fx.Node]]: + gemm_deps = [] + if precision == ConfigPrecisionType.FP32: + # First find the weight + weight_node = get_input_node(node, self.weight_idx) + if not is_param_node(ep, weight_node): + return (False, []) # weight must be a static param + gemm_deps.append(weight_node) + + return (True, gemm_deps) + else: + # Quantized Weight deps + dequant_node = get_input_node(node, self.weight_idx) + if not is_dequant(dequant_node): + return False, [] + gemm_deps.append(dequant_node) + weight = get_input_node(dequant_node, 0) + if not is_param_node(ep, weight): + return False, [] + gemm_deps.append(weight) + + if is_per_channel(dequant_node): + if len(dequant_node.all_input_nodes) < 2: + # Expected channel quantized to have scale/zp nodes + return False, [] + + gemm_deps.extend(dequant_node.all_input_nodes[1:3]) + return (True, gemm_deps) + + def _get_output_deps( + self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType + ) -> Tuple[bool, List[torch.fx.Node]]: + gemm_deps = [] + if precision == ConfigPrecisionType.STATIC_QUANT: + # Look for fused activations and tail end quant node + node_users = list(node.users.keys()) + if len(node_users) != 1: + # Expect quantized node to have a single output (fused act or dequant) + return False, [] + + # Check if the quantized pattern has a fused activation + n_output = node_users[0] + if ( + n_output.op == "call_function" + and format_target_name(n_output.target.__name__) in self.fused_acts + ): + gemm_deps.append(n_output) + fused_out_users = list(n_output.users.keys()) + if len(fused_out_users) == 1: + n_output = fused_out_users[0] + + if not is_quant(n_output): + # Expected gemm_node --> fused_act (optional) --> dequant + return (False, []) + gemm_deps.append(n_output) + elif precision == ConfigPrecisionType.FP32: + # Look for fused activations only, and partition with fp32 op + node_users = list(node.users.keys()) + if len(node_users) == 1: + n_output = node_users[0] + if ( + n_output.op == "call_function" + and format_target_name(n_output.target.__name__) in self.fused_acts + ): + gemm_deps.append(n_output) + + # FP32 and Dynamic Quant have no output dependencies + return (True, gemm_deps) + + def _get_bias_deps( + self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType + ) -> Tuple[bool, List[torch.fx.Node]]: + gemm_deps = [] + if len(node.all_input_nodes) > 2: + bias_node = get_input_node(node, self.bias_idx) + if bias_node: + if not is_param_node(ep, bias_node): + return (False, []) # bias node must be a static param + gemm_deps.append(bias_node) + + return (True, gemm_deps) + + def _get_act_deps( + self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType + ) -> Tuple[bool, List[torch.fx.Node]]: + gemm_deps = [] + if precision == ConfigPrecisionType.FP32: + return (True, []) + else: + dq_input = get_input_node(node, self.act_idx) + if not is_dequant(dq_input): + # Expected static quant input to be dequant node + return False, [] + gemm_deps.append(dq_input) + if precision == ConfigPrecisionType.STATIC_QUANT: + # if static quant we are done after finding first dq_input + return (True, gemm_deps) + + # q input node + q_input = get_input_node(dq_input, 0) + if not is_quant(q_input): + return (False, []) + + gemm_deps.append(q_input) + if not (is_node(q_input.args[1]) and is_node(q_input.args[2])): + # expected to find getitem node from choose qparam + return (False, []) + + getitem1 = get_input_node(q_input, 1) + getitem2 = get_input_node(q_input, 2) + + if not (is_getitem(getitem1) and is_getitem(getitem2)): + # expected getitem node from choose qparam + return (False, []) + + gemm_deps.extend([getitem1, getitem2]) + choose_qparam = get_input_node(getitem1, 0) + if not is_qparam(choose_qparam): + # expected to find choose_qparam node + return (False, []) + gemm_deps.append(choose_qparam) + return (True, gemm_deps) + + +class LinearConfig(GEMMConfig): + target_name = "linear.default" + + def __init__(self): + super().__init__( + weight_idx=1, + bias_idx=2, + act_idx=0, + fused_acts=["relu.default", "hardtanh.default"], + ) + + def get_original_aten(self) -> Optional[torch._ops.OpOverload]: + return torch.ops.aten.linear.default + + def supported_precision_types(self): + return [ + ConfigPrecisionType.DYNAMIC_QUANT, + ConfigPrecisionType.FP32, + ConfigPrecisionType.STATIC_QUANT, + ] + + +class AddmmConfig(GEMMConfig): + target_name = "addmm.default" + + def __init__(self): + super().__init__(weight_idx=2, bias_idx=0, act_idx=1, fused_acts=[]) + + def supported_precision_types(self): + return [ + ConfigPrecisionType.FP32, + ConfigPrecisionType.STATIC_QUANT, + ] + + +class ConvolutionConfig(GEMMConfig): + target_name = "convolution.default" + + def __init__(self): + super().__init__( + weight_idx=1, + bias_idx=2, + act_idx=0, + fused_acts=["relu.default", "hardtanh.default"], + ) + + def supported_precision_types(self): + return [ + ConfigPrecisionType.FP32, + ConfigPrecisionType.STATIC_QUANT, + ] diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py new file mode 100644 index 0000000000..caf61f33c9 --- /dev/null +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -0,0 +1,263 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast, List, Optional + +import torch +from executorch.backends.xnnpack.partition.config.xnnpack_config import ( + ConfigPrecisionType, + XNNPartitionerConfig, +) +from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant +from executorch.exir.backend.canonical_partitioners.config_partitioner import ( + format_target_name, +) +from torch.export import ExportedProgram + + +class GenericNodePartitionerConfig(XNNPartitionerConfig): + def __init__(self, fused_act: Optional[List[str]] = None): + """ + fused_act is a list of node target names that can be fused with this + node under quantization + """ + self.fused_acts = fused_act or [] + super().__init__() + + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + return self.check_common_constraints(node, ep) + + def get_node_and_deps( + self, node: torch.fx.Node, ep: ExportedProgram + ) -> List[torch.fx.Node]: + deps = [node] + quantized_deps = [] + if ConfigPrecisionType.STATIC_QUANT in self.enabled_precision_types: + # try to partition dequant inputs and quant outputs if static quant is enabled + if [(is_dequant(dq_input)) for dq_input in node.all_input_nodes].count( + False + ): + # if not all inputs are dequant nodes then it isn't quantized + return deps + + quantized_deps.extend(node.all_input_nodes) + + # check if quantized pattern has fused activation + if len(node.users) != 1: + return deps + + node_output = list(node.users)[0] + if ( + node_output.op == "call_function" + and format_target_name(node_output.target.__name__) in self.fused_acts + ): + quantized_deps.append(node_output) + fused_out_users = list(node_output.users.keys()) + if len(fused_out_users) == 1: + node_output = fused_out_users[0] + + if not is_quant(node_output): + # Expected node --> fused_act (optional) --> dequant + return deps + + quantized_deps.append(node_output) + + return deps + quantized_deps + + +class QuantizedPerTensorConfig(GenericNodePartitionerConfig): + target_name = "quantize_per_tensor.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.STATIC_QUANT] + + +class DeQuantizedPerTensorConfig(GenericNodePartitionerConfig): + target_name = "dequantize_per_tensor.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.STATIC_QUANT] + + +class HardtanhConfig(GenericNodePartitionerConfig): + target_name = "hardtanh.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] + + +class AddConfig(GenericNodePartitionerConfig): + target_name = "add.Tensor" + + def __init__(self): + super().__init__(fused_act=["relu.default"]) + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] + + +class ReLUConfig(GenericNodePartitionerConfig): + target_name = "relu.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] + + +class AbsConfig(GenericNodePartitionerConfig): + target_name = "abs.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + +class AvgPoolingConfig(GenericNodePartitionerConfig): + target_name = "avg_pool2d.default" + + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + """ + XNNPACK does not support ceil_mode = True and count_include_pad = True + Additionally, we only support divisor_override if divisor_override = pooling region + """ + if not self.check_common_constraints(node, ep): + return False + + args = node.args + + ceil_mode = False # default is False + if len(args) >= 5: + ceil_mode = cast(bool, args[4]) + + count_include_pad = True # default is True + if len(args) >= 6: + count_include_pad = cast(bool, args[5]) + + kernel_size = cast(List[int], args[1]) + pooling_region = kernel_size[0] * kernel_size[1] + divisor_override = pooling_region # Default divisor is pooling_region + if len(args) >= 7: + divisor_override = cast(int, args[6]) + + return ( + not (ceil_mode or count_include_pad) and divisor_override == pooling_region + ) + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + +class CatConfig(GenericNodePartitionerConfig): + target_name = "cat.default" + + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + """ + Only support concatenation of 2 - 4 tensors + """ + if not self.check_common_constraints(node, ep): + return False + + num_tensors = len(node.all_input_nodes) + return num_tensors >= 2 and num_tensors <= 4 + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] + + +class CeilConfig(GenericNodePartitionerConfig): + target_name = "ceil.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + +class ClampConfig(GenericNodePartitionerConfig): + target_name = "clamp.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] + + +class DivConfig(GenericNodePartitionerConfig): + target_name = "div.Tensor" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + +class EluConfig(GenericNodePartitionerConfig): + target_name = "elu.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] + + def get_original_aten(self) -> Optional[torch._ops.OpOverload]: + return torch.ops.aten.elu.default + + +class SoftmaxConfig(GenericNodePartitionerConfig): + target_name = "_softmax.default" + + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + """ + Check that dim is always the last dim + """ + if not self.check_common_constraints(node, ep): + return False + + dim = cast(int, node.args[1]) + node_input = node.all_input_nodes[0] + tensor_dims = node_input.meta["val"].dim() + return dim == -1 or dim == tensor_dims - 1 + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + +class PermuteConfig(GenericNodePartitionerConfig): + target_name = "permute_copy.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] + + +class SigmoidConfig(GenericNodePartitionerConfig): + target_name = "sigmoid.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + +class MulConfig(GenericNodePartitionerConfig): + target_name = "mul.Tensor" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] + + +class MaximumConfig(GenericNodePartitionerConfig): + target_name = "maximum.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + +class MaxPool2dConfig(GenericNodePartitionerConfig): + target_name = "max_pool2d.default" + + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + """ + XNNPACK's maxpool2d does not support ceil mode + """ + if not self.check_common_constraints(node, ep): + return False + + is_ceil_mode = len(node.args) >= 6 and cast(bool, node.args[5]) + return not is_ceil_mode + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] + + def get_original_aten(self) -> Optional[torch._ops.OpOverload]: + return torch.ops.aten.max_pool2d.default diff --git a/backends/xnnpack/partition/config/node_configs.py b/backends/xnnpack/partition/config/node_configs.py new file mode 100644 index 0000000000..11a18543b8 --- /dev/null +++ b/backends/xnnpack/partition/config/node_configs.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import operator +from typing import List, Optional + +import torch +from executorch.backends.xnnpack.partition.config.xnnpack_config import ( + ConfigPrecisionType, + XNNPartitionerConfig, +) +from executorch.backends.xnnpack.passes.fuse_batch_norm_with_conv import ( + FuseBatchNormWithConvPass, +) +from torch.export import ExportedProgram + + +class BatchNormConfig(XNNPartitionerConfig): + target_name = "_native_batch_norm_legit_no_training.default" + + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + if not self.check_common_constraints(node, ep): + return False + + bn = node + conv = node.all_input_nodes[0] + + return FuseBatchNormWithConvPass.can_fuse(conv, bn, ep) + + def get_node_and_deps( + self, node: torch.fx.Node, ep: ExportedProgram + ) -> List[torch.fx.Node]: + deps = [node] + + # weight, bias, running_mean, running_var + deps.extend(node.all_input_nodes[1:5]) + + # All the users of batchnorm node must be getitem ops. batchnorm + # returns a 3-element tuple. Each user must only access the first + # element of the tuple. + if [ + (user.target == operator.getitem and user.args[1] == 0) + for user in node.users + ].count(False): + return [] + + deps.extend(list(node.users.keys())) + return deps + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + +class MaxDimConfig(XNNPartitionerConfig): + target_name = "max.dim" + + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + # We support max_dim as long as we don't return indices + supported_dtypes = {torch.float32, torch.float16, torch.int8, torch.qint8} + node_val = node.meta.get("val") + output_0 = node_val[0] + # Don't check indicies dtype + if output_0.dtype not in supported_dtypes: + return False + + max_input = node.all_input_nodes[0] + if max_input.meta.get("val").dtype not in supported_dtypes: + return False + + # Make sure that all users are getitems of the first output + for user in node.users: + if not (user.target == operator.getitem and user.args[1] == 0): + return False + + return True + + def get_node_and_deps( + self, node: torch.fx.Node, ep: ExportedProgram + ) -> List[torch.fx.Node]: + getitems = list(node.users) + + return [node] + getitems + + def get_original_aten(self) -> Optional[torch._ops.OpOverload]: + return None + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] diff --git a/backends/xnnpack/partition/config/xnnpack_config.py b/backends/xnnpack/partition/config/xnnpack_config.py new file mode 100644 index 0000000000..840ffbd43b --- /dev/null +++ b/backends/xnnpack/partition/config/xnnpack_config.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import abstractmethod +from enum import Enum +from typing import List, Optional + +import torch +from executorch.exir.backend.canonical_partitioners.config_partitioner import ( + format_target_name, + PartitionerConfig, +) +from torch.export import ExportedProgram + + +class ConfigPrecisionType(Enum): + FP32 = 1 + STATIC_QUANT = 2 + DYNAMIC_QUANT = 3 + + +# TODO: add WhyNotPartition to XNNPartitionerConfig +class XNNPartitionerConfig(PartitionerConfig): + """ + Base partitioner config for XNNPACK Partitioner Configs. Base wrapper class + for all XNNPACK Partitioner Configs allows us to apply control over + all PartitionerConfigs. XNNPACK Partitioner config also sets a property + for supported precision types. This allows partitioner configs to set + the precision types they support, and let users toggle which precision + types they want to enable + """ + + def __init__(self): + super().__init__() + self.enabled_precision_types = self.supported_precision_types() + + def get_partition( + self, node: torch.fx.Node, ep: ExportedProgram + ) -> List[torch.fx.Node]: + """ + Overriding abstract method get_partition. + + Returns the partitioned nodes from get_node_and_deps, but also labels them + with the name of the XNNPartitionerConfig class which return this set of nodes. + This enforces that all partitions returned by XNNPartitioner configs are labeled + with the partitioner config which returned them + """ + partitioned_nodes = self.get_node_and_deps(node, ep) + # label partitioned nodes with the name of the partitioner config + for node in partitioned_nodes: + if "xnn_partitioner_config" in node.meta: + node.meta["xnn_partitioner_config"].append(self.__class__.__name__) + else: + node.meta["xnn_partitioner_config"] = [self.__class__.__name__] + + return partitioned_nodes + + def get_original_aten(self) -> Optional[torch._ops.OpOverload]: + # By default if not specified, we do not halt decomposition for those configs + return None + + @abstractmethod + def supported_precision_types(self) -> List[ConfigPrecisionType]: + """ + Returns the supported PrecisionType of this partitioner config + """ + pass + + @abstractmethod + def get_node_and_deps( + self, node: torch.fx.Node, ep: ExportedProgram + ) -> List[torch.fx.Node]: + """ + Takes in a node and its exported program and returns a list of nodes + and its dependencies that need to be partitioned together + + Args: + node: Node to be partitioned + ep: Exported program of the graph module + Returns: + List of nodes that can be partitioned + """ + pass + + def set_enabled_precision_types( + self, precision_types: Optional[List[ConfigPrecisionType]] + ): + """ + Set the enabled precisions. + + We take the intersection of the precision_types we wish to enable with + the precision types that this config supports. If enabled_precisions is empty, i.e. + the config does not support any of the precision types we want to enable, + then we will not partition nothing and return false at the common constraints + """ + + if precision_types: + enabled_precisions = [] + for precision in precision_types: + if precision in self.supported_precision_types(): + enabled_precisions.append(precision) + + self.enabled_precision_types = enabled_precisions + + def check_common_constraints( + self, node: torch.fx.Node, ep: ExportedProgram + ) -> bool: + """ + Checks common xnnpack constraints + + Args: + node (torch.fx.Node): Node to check common constraints against + ep (ExportedProgram): Exported Program to check constraints against + + Returns: + True or False whether this node is partitionable + """ + assert ( + node.op == "call_function" + and format_target_name(node.target.__name__) # pyre-ignore + == self.target_name + ) + + if len(self.enabled_precision_types) == 0: + return False + + has_valid_dtypes = self._check_node_has_valid_dtype(node) + if not has_valid_dtypes: + return False + + return True + + def _check_inputs_are_valid_dtypes(self, node, valid_dtypes): + # Check inputs are valid dtypes + # Gather all args which are nodes + args_to_check = [] + for arg in node.args: + if isinstance(arg, list) or isinstance(arg, tuple): + for item in arg: + if isinstance(item, torch.fx.Node): + args_to_check.append(item) + + if isinstance(arg, torch.fx.Node): + args_to_check.append(arg) + + for arg in args_to_check: + arg_val = arg.meta.get("val", None) + + if arg_val is None or isinstance(arg_val, tuple): + continue + + # Being conservative for now, UX >> Perf + # TODO: We need a pass to scrub these out. + if not isinstance(arg_val, torch.Tensor): + return False + + # XNNPACK does not support empty tensors + if arg_val.numel() == 0: + return False + + if arg_val.dtype not in valid_dtypes: + return False + + return True + + def _check_outputs_are_valid_dtypes(self, node, valid_dtypes): + # Check outputs are valid dtype + node_val = node.meta.get("val", None) + if node_val is None: + return True + + if not isinstance(node_val, tuple): + node_val = (node_val,) + + for val in node_val: + if not isinstance(val, torch.Tensor): + return False + + if val.dtype not in valid_dtypes: + return False + + return True + + def _check_node_has_valid_dtype(self, node): + valid_dtypes = { + torch.float32, + torch.float16, + torch.int8, + torch.qint8, + } + if ( + node.op != "placeholder" + and node.op != "call_function" + and node.op != "get_attr" + ): + return False + + return self._check_inputs_are_valid_dtypes( + node, valid_dtypes + ) and self._check_outputs_are_valid_dtypes(node, valid_dtypes) diff --git a/backends/xnnpack/partition/graphs/bilinear_2d.py b/backends/xnnpack/partition/graphs/bilinear_2d.py index a971cb9244..0040439f84 100644 --- a/backends/xnnpack/partition/graphs/bilinear_2d.py +++ b/backends/xnnpack/partition/graphs/bilinear_2d.py @@ -37,12 +37,15 @@ def forward(self, x): ] for align_corners in [True, False]: for config in capture_configs: - edge = exir.capture( - bilinear2d(align_corners), sample_inputs, config - ).to_edge( - config=get_xnnpack_edge_compile_config(), - ) - _bilinear2d_graphs[edge.exported_program.graph_module] = align_corners + for skip_dim_order_flag in [True, False]: + edge = exir.capture( + bilinear2d(align_corners), sample_inputs, config + ).to_edge( + config=get_xnnpack_edge_compile_config( + skip_dim_order=skip_dim_order_flag + ) + ) + _bilinear2d_graphs[edge.exported_program.graph_module] = align_corners return _bilinear2d_graphs diff --git a/backends/xnnpack/partition/xnnpack_partitioner2.py b/backends/xnnpack/partition/xnnpack_partitioner2.py new file mode 100644 index 0000000000..c7e792b1c3 --- /dev/null +++ b/backends/xnnpack/partition/xnnpack_partitioner2.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +from typing import List, Optional, Type, Union + +from executorch.backends.xnnpack.partition.config import ALL_PARTITIONER_CONFIGS +from executorch.backends.xnnpack.partition.config.xnnpack_config import ( + ConfigPrecisionType, + XNNPartitionerConfig, +) + +from executorch.backends.xnnpack.xnnpack_preprocess import XnnpackBackend +from executorch.exir.backend.backend_details import ExportedProgram +from executorch.exir.backend.canonical_partitioners.config_partitioner import ( + ConfigerationBasedPartitioner, +) +from executorch.exir.backend.partitioner import DelegationSpec +from torch.fx.passes.infra.partitioner import Partition + + +class XnnpackPartitioner(ConfigerationBasedPartitioner): + def __init__( + self, + configs: Optional[List[Type[XNNPartitionerConfig]]] = None, + config_precisions: Optional[ + Union[ConfigPrecisionType, List[ConfigPrecisionType]] + ] = None, + per_op_mode=False, + ): + delegation_spec = DelegationSpec(XnnpackBackend.__name__, []) + configs_to_use = configs or ALL_PARTITIONER_CONFIGS + # Can do logic and have extra args to filter/delete/select + # Certain configs based on user specification + initialized_configs = [] + if isinstance(config_precisions, ConfigPrecisionType): + config_precisions = [config_precisions] + + for config in configs_to_use: + # Config Classes given to XnnpackPartitioner should no longer be abstract + initialized = config() # pyre-ignore + initialized.set_enabled_precision_types(config_precisions) + initialized_configs.append(initialized) + + # per_op_mode takes the first match from a partitioner config, any + # subsequent matches that overlap with the first match are not partitioned + self.per_op_mode = per_op_mode + super().__init__(delegation_spec, initialized_configs) + + def generate_partitions(self, ep: ExportedProgram) -> List[Partition]: + """ + generate_partitions is different if partitioner is set to per_op_mode + for per_op_mode we only need to generate unmerged partitions instead + of using the default generate_partitions method. + """ + if self.per_op_mode: + return self.generate_per_op_partitions(ep) + else: + return super().generate_partitions(ep) + + def generate_per_op_partitions(self, ep: ExportedProgram) -> List[Partition]: + """ + Uses configs to generate per_op_partitions. That is no partitions are + merged together. All partitions (node + deps) returned by PartitionerConfigs + are put into their own partition. + """ + partitions = [] + matched_nodes = self.get_matched_nodes_from_configs(ep) + partition_id = itertools.count() + nodes_seen = set() + for match in matched_nodes: + match_set = set(match) + # We only create partitions from the first PartitionerConfig match + # if a subsequent partitioner match contains the same node, we do + # not create a partition for it + if match_set.isdisjoint(nodes_seen): + partitions.append( + Partition( + id=next(partition_id), + nodes=match_set, + ) + ) + nodes_seen.update(match_set) + return partitions diff --git a/backends/xnnpack/passes/__init__.py b/backends/xnnpack/passes/__init__.py index 1ca4fe307f..c3a85e4aa8 100644 --- a/backends/xnnpack/passes/__init__.py +++ b/backends/xnnpack/passes/__init__.py @@ -27,6 +27,7 @@ from executorch.exir.pass_base import ExportPass from executorch.exir.passes.const_prop_pass import ConstPropPass +from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass from executorch.exir.program._program import _transform from torch._export.pass_base import PassType @@ -50,6 +51,8 @@ def __init__( if not passes: # All the XNNPACK passes self.passes = [ + # TODO - remove this pass once we have a better support for dim_order ops lowering + DimOrderOpsRevertPass, ConvertToUpsampleBilinear2d, ConvertToLinearPass, ConvertToSDPAPass, diff --git a/backends/xnnpack/passes/tag_implicit_q_dq_pass.py b/backends/xnnpack/passes/tag_implicit_q_dq_pass.py index 2d41429eb1..0aa2e1291e 100644 --- a/backends/xnnpack/passes/tag_implicit_q_dq_pass.py +++ b/backends/xnnpack/passes/tag_implicit_q_dq_pass.py @@ -139,10 +139,8 @@ def get_ending_implicit_q_nodes( ): return [next_node] elif self.is_output_node(next_node): - # Check if second_node (which is between dq and output nodes) - # is aten.linear.default - if self.is_dynamically_quantized(start_node): - return [] + # if node following dq is output node + return None else: # Check if nodes between the dq node and the next q match # a supported quant chain @@ -193,6 +191,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: ending_implicit_q_nodes = [] for user in first_node.users: + if self.is_dynamically_quantized(user): + # if the dq is a dynamic dq, then it is implicit + break user_end_nodes = self.get_ending_implicit_q_nodes(user) if user_end_nodes is None: # This user isn't part of a "supported" group diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 8c8db60065..722dabdfbe 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -124,17 +124,9 @@ const uint8_t* getConstantDataPtr( const uint8_t* constant_data_ptr) { auto buffer_idx = tensor_value->constant_buffer_idx(); if (buffer_idx) { - if (!constant_data_ptr) { - // TODO(T172265611): Remove constant_buffer in flatbuffer path after BC - // window - const auto& constant_buffer = *flatbuffer_graph->constant_buffer(); - return constant_buffer[buffer_idx]->storage()->data(); - } else { - const auto& constant_data_offsets = *flatbuffer_graph->constant_data(); - uint64_t constant_data_offset = - constant_data_offsets[buffer_idx]->offset(); - return constant_data_ptr + constant_data_offset; - } + const auto& constant_data_offsets = *flatbuffer_graph->constant_data(); + uint64_t constant_data_offset = constant_data_offsets[buffer_idx]->offset(); + return constant_data_ptr + constant_data_offset; } return nullptr; @@ -194,105 +186,29 @@ Error defineTensor( xnn_status status; // The type we might have to convert to - auto dq_datatype = getDataType(tensor_value->dq_datatype()); - - if (dq_datatype != xnn_datatype::xnn_datatype_invalid) { - if (dq_datatype != xnn_datatype::xnn_datatype_qint8) { - ET_CHECK_OR_RETURN_ERROR( - false, - Internal, - "Only int8_t is supported for dq_datatype for now, got: %d", - dq_datatype); - } else { - ET_CHECK_OR_RETURN_ERROR( - (tensor_value->flags() & XNN_VALUE_FLAG_EXTERNAL_INPUT), - Internal, - "Dynamic quantization of tensor is only allowed for the external input tensor value for now! got flags: %u", - tensor_value->flags()); - } - } + auto datatype = getDataType(tensor_value->datatype()); if (qtensor_value == nullptr) { // FP32 tensor - if (!isQuantizedDataType(dq_datatype)) { - // Define non-quantied tensor - status = xnn_define_tensor_value( - /*subgraph=*/subgraph_ptr, - /*datatype=*/getDataType(tensor_value->datatype()), - /*num_dims=*/tensor_value->num_dims(), - /*dims=*/dims_data.data(), - /*data=*/buffer_ptr, - /*external_id=*/tensor_value->external_id(), - /*flags=*/tensor_value->flags(), - /*id_out=*/&id); - } else if (dq_datatype != xnn_datatype::xnn_datatype_invalid) { - ET_CHECK_OR_RETURN_ERROR( - isQuantizedDataType(dq_datatype), - Internal, - "Dynamic quantization can only produce supported quantized dtypes"); - ET_CHECK_OR_RETURN_ERROR( - tensor_value->external_id() != XNN_INVALID_VALUE_ID, - Internal, - "Dynamic quantization can only work with external inputs for now, got an internal ID"); - ET_CHECK_OR_RETURN_ERROR( - buffer_ptr == nullptr, - Internal, - "Dynamic quantization can only work with external inputs for now, got const data"); - - switch (dq_datatype) { - case xnn_datatype::xnn_datatype_qint8: { - // HACK TO Maintain FC/BC for ASR this will be removed after 01/2024 - - // When encountering a dynamically quantized tensor via dq_datatype, - // which is the old flow for serializing dynamically quantized linear. - // We replace the definition of a single tensor with a new dynamic - // Quantization pattern. We change the pattern from: - // serialized_qd_input - // to - // (fp32_input --> convert --> qdint8_input) - - status = xnn_define_dynamically_quantized_tensor_value( - /*subgraph=*/subgraph_ptr, - /*datatype=*/xnn_datatype_qdint8, - /*num_dims=*/tensor_value->num_dims(), - /*num_nonbatch_dims=*/1, // always do per token quantization - /*dims=*/dims_data.data(), - /*external_id=*/XNN_INVALID_VALUE_ID, // always internal value id - /*flags=*/0, // this is netiher external input or output - /*id_out=*/&id); - - // this is the FP16 or FP32 external value that is being dynamically - // quantized - uint32_t float_id; - enum xnn_datatype fp_datatype = getDataType(tensor_value->datatype()); - status = xnn_define_tensor_value( - /*subgraph=*/subgraph_ptr, - /*datatype=*/fp_datatype, - /*num_dims=*/tensor_value->num_dims(), - /*dims=*/dims_data.data(), - /*data=*/buffer_ptr, - /*external_id=*/tensor_value->external_id(), - /*flags=*/tensor_value->flags(), - /*id_out=*/&float_id); - - // Define dynamic conversion from float to qdint8 - status = xnn_define_convert( - /*subgraph=*/subgraph_ptr, - /*input_id=*/float_id, - /*output_id=*/id, - /*flags=*/0); - break; - } - default: - ET_CHECK_OR_RETURN_ERROR( - false, - NotImplemented, - "Unhandled Dyanmic Quantization dtype: %d", - dq_datatype); - } - } else { - ET_CHECK_OR_RETURN_ERROR(false, NotImplemented, "Unhandled fp32 tensor"); - } + ET_CHECK_OR_RETURN_ERROR( + !isQuantizedDataType(datatype), + Internal, + "xnn_datatype is quantized, but is not quantized tensor value"); + + status = xnn_define_tensor_value( + /*subgraph=*/subgraph_ptr, + /*datatype=*/datatype, + /*num_dims=*/tensor_value->num_dims(), + /*dims=*/dims_data.data(), + /*data=*/buffer_ptr, + /*external_id=*/tensor_value->external_id(), + /*flags=*/tensor_value->flags(), + /*id_out=*/&id); + ET_CHECK_OR_RETURN_ERROR( + xnn_status_success == status, + Internal, + "Failed to define tensor with id %i", + id); } else { // define tensor for quantized switch (qtensor_value->quant_params_type()) { @@ -306,7 +222,7 @@ Error defineTensor( qparams->zero_point()); status = xnn_define_quantized_tensor_value( /*subgraph=*/subgraph_ptr, - /*datatype=*/getDataType(tensor_value->datatype()), + /*datatype=*/datatype, /*zero_point=*/qparams->zero_point(), /*scale=*/qparams->scale(), /*num_dims=*/tensor_value->num_dims(), @@ -319,9 +235,8 @@ Error defineTensor( } case fb_xnnpack::XNNQuantParams::PerChannelQuant: { auto qparams = qtensor_value->quant_params_as_PerChannelQuant(); - enum xnn_datatype dtype = getDataType(tensor_value->datatype()); int32_t zero_point = - (dtype == xnn_datatype::xnn_datatype_qcint4 ? 8 : 0); + (datatype == xnn_datatype::xnn_datatype_qcint4 ? 8 : 0); ET_LOG( Debug, @@ -329,11 +244,11 @@ Error defineTensor( buffer_ptr, qparams->scale()->size(), qparams->channel_dim(), - dtype, + datatype, zero_point); status = xnn_define_channelwise_quantized_tensor_value_v2( /*subgraph=*/subgraph_ptr, - /*datatype=*/dtype, + /*datatype=*/datatype, /*zero_point=*/zero_point, /*scale=*/qparams->scale()->data(), /*num_dims=*/tensor_value->num_dims(), @@ -346,7 +261,6 @@ Error defineTensor( break; } case fb_xnnpack::XNNQuantParams::PerChannelGroupQuant: { - xnn_datatype datatype = getDataType(tensor_value->datatype()); ET_CHECK_OR_RETURN_ERROR( datatype == xnn_datatype::xnn_datatype_qbint4, Internal, @@ -410,7 +324,7 @@ Error defineTensor( "Dynamically Quantized Tensors currently only support per token quantization"); status = xnn_define_dynamically_quantized_tensor_value( /*subgraph=*/subgraph_ptr, - /*datatype=*/getDataType(tensor_value->datatype()), + /*datatype=*/datatype, /*num_dims=*/tensor_value->num_dims(), /*num_nonbatch_dims*/ qparams->num_nonbatch_dims(), /*dims=*/dims_data.data(), @@ -1594,23 +1508,24 @@ __ET_NODISCARD Error XNNCompiler::compileModel( constant_data = reinterpret_cast(buffer_pointer) + header->constant_data_offset; } else if (header.error() == Error::NotFound) { - flatbuffer_data = reinterpret_cast(buffer_pointer); + ET_LOG( + Error, + "XNNHeader version mismatch: '%.4s' != expected '%.4s'", + // Header Magic and FlatbufferIdentifier are same offset and size + flatbuffers::GetBufferIdentifier(buffer_pointer), + XNNHeader::kMagic); + return header.error(); } else { ET_LOG(Error, "XNNHeader may be corrupt"); return header.error(); } - // Temporarily support identifier XN00 and XN01 - bool is_supported_version = - strncmp(flatbuffers::GetBufferIdentifier(flatbuffer_data), "XN00", 4) == - 0 || - strncmp(flatbuffers::GetBufferIdentifier(flatbuffer_data), "XN01", 4) == - 0; ET_CHECK_OR_RETURN_ERROR( - is_supported_version, + fb_xnnpack::XNNGraphBufferHasIdentifier(flatbuffer_data), DelegateInvalidCompatibility, - "XNNPACK Delegate Serialization Format version identifier '%.4s' != expected XN00 or XN01'", - flatbuffers::GetBufferIdentifier(flatbuffer_data)); + "XNNPACK Delegate flatbuffer version mismatch: '%.4s' != expected '%.4s'", + flatbuffers::GetBufferIdentifier(flatbuffer_data), + fb_xnnpack::XNNGraphIdentifier()); auto flatbuffer_graph = fb_xnnpack::GetXNNGraph(flatbuffer_data); // initialize xnnpack diff --git a/backends/xnnpack/serialization/runtime_schema.fbs b/backends/xnnpack/serialization/runtime_schema.fbs deleted file mode 100644 index 5ace211149..0000000000 --- a/backends/xnnpack/serialization/runtime_schema.fbs +++ /dev/null @@ -1,354 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. - -namespace fb_xnnpack; - -// Update after any BC breaking changes -file_identifier "XN00"; - -// datatype for xnn-values -enum XNNDatatype : short { - /// Invalid data type. Valid Values never have this datatype. - xnn_datatype_invalid = 0, - /// IEEE754 single-precision floating-point. - xnn_datatype_fp32 = 1, - /// IEEE754 half-precision floating-point. - xnn_datatype_fp16 = 2, - /// Quantized 8-bit signed integer with shared per-Value quantization parameters. - xnn_datatype_qint8 = 3, - /// Quantized 8-bit unsigned integer with shared per-Value quantization parameters. - xnn_datatype_quint8 = 4, - /// Quantized 32-bit signed integer with shared per-Value quantization parameters. - xnn_datatype_qint32 = 5, - /// Quantized 8-bit signed integer with shared per-channel quantization parameters. - xnn_datatype_qcint8 = 6, - /// Quantized 32-bit signed integer with shared per-channel quantization parameters. - xnn_datatype_qcint32 = 7, - /// Quantized 4-bit signed integer with shared per-channel quantization parameters. - xnn_datatype_qcint4 = 8, - /// Dynamically quantized 8-bit signed integer with per-batch quantization parameters. - xnn_datatype_qdint8 = 9, - /// Quantized 4-bit signed integer with shared blockwise quantization parameters. - xnn_datatype_qbint4 = 10, -} - -// type of quantization -union XNNQuantParams { - PerChannelQuant, - PerTensorQuant, - PerTokenDynamicQuant, - PerChannelGroupQuant, -} - -// taken from executorch -// Data buffer abstraction. -table Buffer { - storage:[ubyte] (force_align: 16); -} - -table PerChannelQuant { - scale:[float]; - channel_dim:int; -} - -table PerTokenDynamicQuant { - num_nonbatch_dims:int; -} - -table PerTensorQuant { - scale:float; - zero_point:int; -} - -table PerChannelGroupQuant { - scale:[float]; - channel_dim:int; - group_size:int; -} - -table XNNTensorValue { - // type of the tensor elements. - datatype:XNNDatatype; - // number of dimensions in the shape. - num_dims:uint; - // pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL. - // XNNPACK does not keep any pointers to this array after the function returns. - dims:[uint]; - // Index to the program's constant buffer table, value 0 is reserved to indicate non constant - constant_buffer_idx:uint; - // external ID for the Value. The ID must be within the range of reserved Value IDs specified on - // the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be - // created for the Value. - external_id:uint; - // binary features of the Value. Supported values are any combination of XNN_VALUE_FLAG_EXTERNAL_INPUT - // and XNN_VALUE_FLAG_EXTERNAL_OUTPUT. - flags:uint; - // pointer to the variable that will be initialized with the Value ID upon successful return. If a - // valid @a external_id was provided, the variable will be initialized with the @a external_id value. - id_out:uint; - // does this value need to be quantized dynamically at runtime? - // if we are quantizing at runtime, this field points to a target dtype - dq_datatype:XNNDatatype = xnn_datatype_invalid; -} - -table XNNQuantizedTensorValue { - // Base Tensor Value - tensor_value:XNNTensorValue; - // Quantization parameters - quant_params:XNNQuantParams; -} - -union XNodeUnion { - XNNAdd: _XNNNode2x1, - XNNFullyConnected, - XNNSoftmax: _XNNNode1x1, - XNNSigmoid: _XNNNode1x1, - XNNStaticTranspose, - XNNClamp: _XNNNode1x1, - XNNConv2d: _XNNNodeConv, - XNNDiv: _XNNNode2x1, - XNNStaticResizeBilinear2D, - XNNStaticConstantPad, - XNNAvgPooling2d: _XNNPooling2D, - XNNMinimum: _XNNNode2x1, - XNNDepthwiseConv2d: _XNNNodeConv, - XNNMaxPooling2d: _XNNPooling2D, - XNNMultiply: _XNNNode2x1, - XNNSubtract: _XNNNode2x1, - XNNFloor: _XNNNode1x1, - XNNConvert: _XNNNode1x1, - XNNGlobalAvgPooling2d: _XNNNode1x1, - XNNStaticReshape, - XNNArgMaxPooling2d, - XNNSquareRoot: _XNNNode1x1, - XNNCeiling: _XNNNode1x1, - XNNHardswish: _XNNNode1x1, - XNNLeakyReLU, - XNNMaximum: _XNNNode2x1, - XNNNegate: _XNNNode1x1, - XNNSquare: _XNNNode1x1, - XNNELU, - XNNAbs: _XNNNode1x1, - XNNPReLU: _XNNNode2x1, - XNNConcatenate2: _XNNCat, - XNNConcatenate3: _XNNCat, - XNNConcatenate4: _XNNCat, - XNNStaticSlice, - XNNScaledDotProductAttention, -} - -union XValueUnion { - XNNTensorValue, - XNNQuantizedTensorValue, -} - -table OutputMinMax { - output_min:float; - output_max:float; -} - -table XNode { - xnode_union:XNodeUnion; - // An int which can be linked back to the node in the origin graph - debug_handle:uint; - output_min_max:OutputMinMax; -} - -table XValue { - xvalue_union:XValueUnion; -} - -table XNNStaticTranspose { - num_dims:uint; - perm:[uint]; - input_id:uint; - output_id:uint; - flags:uint; -} - -table XNNStaticResizeBilinear2D { - new_height:uint; - new_width:uint; - input_id:uint; - output_id:uint; - flags:uint; -} - -table XNNStaticConstantPad { - pre_paddings:[uint]; - post_paddings:[uint]; - padding_value:float; - input_id:uint; - output_id:uint; - flags:uint; -} - -// A node with two input and one output -// Not meant to be used directly -table _XNNNode2x1 { - input1_id:uint; - input2_id:uint; - output_id:uint; - flags:uint; -} - -// A node with one input and one output -// Not meant to be used directly -table _XNNNode1x1 { - input_id:uint; - output_id:uint; - flags:uint; -} - -table _XNNCat { - axis: uint; - input1_id: uint; - input2_id: uint; - input3_id: uint; - input4_id: uint; - output_id: uint; - flags: uint; -} - -table XNNELU { - alpha:float; - input_id:uint; - output_id:uint; - flags:uint; -} - -table XNNFullyConnected { - input1_id:uint; - filter_id:uint; - bias_id:uint; - output_id:uint; - flags:uint; -} - -table _XNNNodeConv { - padding_top:uint; - padding_right:uint; - padding_bottom:uint; - padding_left:uint; - kernel_height:uint; - kernel_width:uint; - subsampling_height:uint; - subsampling_width:uint; - dilation_height:uint; - dilation_width:uint; - group_input_channels:uint; - group_output_channels:uint; - groups:uint; - adjustment_height:uint; - adjustment_width:uint; - input1_id:uint; - filter_id:uint; - bias_id:uint; - output_id:uint; - flags:uint; -} - -table _XNNPooling2D { - padding_top: uint; - padding_right: uint; - padding_bottom: uint; - padding_left: uint; - pooling_height: uint; - pooling_width: uint; - stride_height: uint; - stride_width: uint; - dilation_height: uint; - dilation_width: uint; - input_id: uint; - output_id: uint; - flags: uint; -} - -table XNNStaticReshape { - num_dims:uint; - new_shape:[uint]; - input_id: uint; - output_id: uint; - flags: uint; -} - -table XNNStaticSlice { - num_dims:uint; - offsets:[uint]; - sizes:[uint]; - input_id:uint; - output_id:uint; - flags:uint; -} - -table XNNScaledDotProductAttention { - query_id:uint; - key_id:uint; - value_id:uint; - scale_id:uint; - mask_id:uint; - output_id:uint; - flags:uint; -} - -table XNNArgMaxPooling2d { - padding_top: uint; - padding_right: uint; - padding_bottom: uint; - padding_left: uint; - pooling_height: uint; - pooling_width: uint; - input_id: uint; - output_value_id: uint; - output_index_id: uint; - flags: uint; -} - -table XNNLeakyReLU { - negative_slope: float; - input_id: uint; - output_id: uint; - flags: uint; -} - -// Describes data offsets for constant data -table ConstantDataOffset { - // Constant data offsets are relative to the constant data base offset provided - // in the XNNPACKHeader. - offset: uint64; - - // The size in bytes of valid data starting at the offset. The constant data - // may be followed by padding before the next piece of constant data - size: uint64; -} - -table XNNGraph { - // Schema version. - version:string; - xnodes:[XNode]; - xvalues:[XValue]; - - // Number of external inputs/outputs - num_externs:uint; - - // Ids of external inputs - input_ids:[uint]; - - // Ids of external outputs - output_ids:[uint]; - - // Tables of constant data, used for constant Values (e.g. - // data field of weight tensors). Each constant is assigned an index into the table - // which are each individually aligned. 0 index is reserved to be pointed to by non-constant - // Tensors. Exactly one of constant_buffer and constant_data must be non-empty - constant_buffer:[Buffer]; - - // the list index is memory buffer id, the value is the memory buffer size. - mem_buffer_sizes: [uint]; - - // List of the constant data that follows the XNNGraph in this file. Each constant data is assigned an index into - // the table. 0 index is reserved to be pointed to by non-constant Tensor. Exactly one of constant_buffer and - // constant_data must be non-empty - constant_data:[ConstantDataOffset]; -} - -root_type XNNGraph; diff --git a/backends/xnnpack/serialization/schema_version_history.txt b/backends/xnnpack/serialization/schema_version_history.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backends/xnnpack/serialization/targets.bzl b/backends/xnnpack/serialization/targets.bzl index 5eeab3de2b..05e6b89f12 100644 --- a/backends/xnnpack/serialization/targets.bzl +++ b/backends/xnnpack/serialization/targets.bzl @@ -4,14 +4,14 @@ def define_common_targets(): runtime.genrule( name = "gen_xnnpack_schema", srcs = [ - "runtime_schema.fbs", + "schema.fbs", ], # We're only generating a single file, so it seems like we could use # `out`, but `flatc` takes a directory as a parameter, not a single # file. Use `outs` so that `${OUT}` is expanded as the containing # directory instead of the file itself. outs = { - "schema_generated.h": ["runtime_schema_generated.h"], + "schema_generated.h": ["schema_generated.h"], }, cmd = " ".join([ "$(exe {})".format(runtime.external_dep_location("flatc")), diff --git a/backends/xnnpack/test/ops/abs.py b/backends/xnnpack/test/ops/abs.py index 2906654dfb..fba91db05c 100644 --- a/backends/xnnpack/test/ops/abs.py +++ b/backends/xnnpack/test/ops/abs.py @@ -24,9 +24,7 @@ def _test_abs(self, inputs): Tester(self.Abs(), inputs) .export() .check_count({"torch.ops.aten.abs.default": 1}) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_abs_default": 1}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not(["executorch_exir_dialects_edge__ops_aten_abs_default"]) .to_executorch() diff --git a/backends/xnnpack/test/ops/add.py b/backends/xnnpack/test/ops/add.py index f8c202bd7f..784a9d3bbf 100644 --- a/backends/xnnpack/test/ops/add.py +++ b/backends/xnnpack/test/ops/add.py @@ -47,9 +47,7 @@ def _test_add(self, inputs): Tester(self.Add(), inputs) .export() .check_count({"torch.ops.aten.add.Tensor": 4}) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_add_Tensor": 4}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"]) .to_executorch() @@ -71,9 +69,7 @@ def test_fp32_add_constant(self): Tester(self.AddConstant(torch.randn(4, 4, 4)), inputs) .export() .check_count({"torch.ops.aten.add.Tensor": 4}) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_add_Tensor": 4}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"]) .to_executorch() @@ -88,9 +84,7 @@ def test_qs8_add_constant(self): .quantize() .export() .check_count({"torch.ops.aten.add.Tensor": 4}) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_add_Tensor": 4}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"]) .to_executorch() @@ -106,9 +100,7 @@ def test_qs8_add(self): .export() .check_count({"torch.ops.aten.add.Tensor": 4}) .check(["torch.ops.quantized_decomposed"]) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_add_Tensor": 4}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not( [ @@ -129,9 +121,7 @@ def test_qs8_add2(self): .export() .check_count({"torch.ops.aten.add.Tensor": 1}) .check(["torch.ops.quantized_decomposed"]) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not( [ @@ -152,9 +142,7 @@ def test_qs8_add3(self): .export() .check_count({"torch.ops.aten.add.Tensor": 4}) .check(["torch.ops.quantized_decomposed"]) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_add_Tensor": 4}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not( [ @@ -179,10 +167,7 @@ def test_fp32_add_relu(self): .export() .check_count({"torch.ops.aten.add.Tensor": 1}) .check_count({"torch.ops.aten.relu.default": 1}) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1}) - .check_count({"executorch_exir_dialects_edge__ops_aten_relu_default": 1}) - .partition() + .to_edge_transform_and_lower() .check_not(["executorch_exir_dialects_edge__ops_aten_add_Tensor"]) .check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) @@ -200,10 +185,7 @@ def test_qs8_add_relu(self): .check_count({"torch.ops.aten.add.Tensor": 1}) .check_count({"torch.ops.aten.relu.default": 1}) .check(["torch.ops.quantized_decomposed"]) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1}) - .check_count({"executorch_exir_dialects_edge__ops_aten_relu_default": 1}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .serialize() @@ -246,10 +228,7 @@ def forward(self, x, z): {"torch.ops.aten.add.Tensor": 1, "torch.ops.aten.relu.default": 1} ) .check(["torch.ops.quantized_decomposed"]) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1}) - .check_count({"executorch_exir_dialects_edge__ops_aten_relu_default": 1}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .serialize() diff --git a/backends/xnnpack/test/ops/avgpool2d.py b/backends/xnnpack/test/ops/avgpool2d.py index edb92d09a3..b471fd914c 100644 --- a/backends/xnnpack/test/ops/avgpool2d.py +++ b/backends/xnnpack/test/ops/avgpool2d.py @@ -33,11 +33,7 @@ def _test_argpool2d(self, inputs): Tester(self.AvgPool2d(), inputs) .export() .check_count({"torch.ops.aten.avg_pool2d.default": 1}) - .to_edge() - .check_count( - {"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1} - ) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"]) .to_executorch() @@ -62,11 +58,7 @@ def test_fp32_avgpool2d_ceil_mode_unsupported(self): Tester(self.AvgPool2d(ceil_mode=True), inputs) .export() .check_count({"torch.ops.aten.avg_pool2d.default": 1}) - .to_edge() - .check_count( - {"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1} - ) - .partition() + .to_edge_transform_and_lower() .check_not(["torch.ops.higher_order.executorch_call_delegate"]) ) @@ -79,11 +71,7 @@ def test_fp32_avgpool2d_count_include_pad_unsupported(self): Tester(self.AvgPool2d(count_include_pad=True), inputs) .export() .check_count({"torch.ops.aten.avg_pool2d.default": 1}) - .to_edge() - .check_count( - {"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1} - ) - .partition() + .to_edge_transform_and_lower() .check_not(["torch.ops.higher_order.executorch_call_delegate"]) ) @@ -96,10 +84,6 @@ def test_fp32_avgpool2d_divisor_override(self): Tester(self.AvgPool2d(divisor_override=5), inputs) .export() .check_count({"torch.ops.aten.avg_pool2d.default": 1}) - .to_edge() - .check_count( - {"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1} - ) - .partition() + .to_edge_transform_and_lower() .check_not(["torch.ops.higher_order.executorch_call_delegate"]) ) diff --git a/backends/xnnpack/test/ops/bilinear2d.py b/backends/xnnpack/test/ops/bilinear2d.py index ab9d3d3c11..d3c8535069 100644 --- a/backends/xnnpack/test/ops/bilinear2d.py +++ b/backends/xnnpack/test/ops/bilinear2d.py @@ -65,12 +65,15 @@ def forward(self, x): ) return a + # Since we may or may not enable dim order, use these ops only for + # check_not since we have `to_copy` and `to_dim_order_copy` in the list. ops = { "executorch_exir_dialects_edge__ops_aten_sub_Tensor", "executorch_exir_dialects_edge__ops_aten_mul_Tensor", "executorch_exir_dialects_edge__ops_aten_index_Tensor", "executorch_exir_dialects_edge__ops_aten_arange_start_step", "executorch_exir_dialects_edge__ops_aten__to_copy_default", + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default", "executorch_exir_dialects_edge__ops_aten_add_Tensor", "executorch_exir_dialects_edge__ops_aten_clamp_default", } @@ -81,7 +84,6 @@ def test_fp32_static_resize_bilinear2d(self): Tester(self.StaticResizeBilinear2dModule(), example_inputs) .export() .to_edge() - .check(self.ops) .partition() .check_not(self.ops) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) @@ -90,13 +92,12 @@ def test_fp32_static_resize_bilinear2d(self): .run_method_and_compare_outputs() ) - def test_fp32_static_resize_bilinear2d_with_align_cornesr(self): + def test_fp32_static_resize_bilinear2d_with_align_corners(self): example_inputs = (torch.randn(2, 3, 4, 5),) ( Tester(self.StaticResizeBilinear2dModuleWithAlignCorners(), example_inputs) .export() .to_edge() - .check(self.ops) .partition() .check_not(self.ops) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) diff --git a/backends/xnnpack/test/ops/cat.py b/backends/xnnpack/test/ops/cat.py index 15524c0134..23fca91f5b 100644 --- a/backends/xnnpack/test/ops/cat.py +++ b/backends/xnnpack/test/ops/cat.py @@ -56,11 +56,7 @@ def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2): } ) - ( - tester.to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1}) - .partition() - ) + tester.to_edge_transform_and_lower() if quant: tester.check_not(["torch.ops.quantized_decomposed"]) @@ -155,9 +151,7 @@ def test_fp32_cat_unsupported(self): Tester(self.Cat5(), inputs) .export() .check_count({"torch.ops.aten.cat": 1}) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1}) - .partition() + .to_edge_transform_and_lower() .check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1}) ) diff --git a/backends/xnnpack/test/ops/ceil.py b/backends/xnnpack/test/ops/ceil.py index 8d59f3b35d..6dbebf3650 100644 --- a/backends/xnnpack/test/ops/ceil.py +++ b/backends/xnnpack/test/ops/ceil.py @@ -24,9 +24,7 @@ def _test_ceil(self, inputs): Tester(self.Ceil(), inputs) .export() .check_count({"torch.ops.aten.ceil.default": 1}) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_ceil_default": 1}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not(["executorch_exir_dialects_edge__ops_aten_ceil_default"]) .to_executorch() diff --git a/backends/xnnpack/test/ops/clamp.py b/backends/xnnpack/test/ops/clamp.py index c52fd011f8..9fb8935553 100644 --- a/backends/xnnpack/test/ops/clamp.py +++ b/backends/xnnpack/test/ops/clamp.py @@ -26,9 +26,7 @@ def _test_clamp(self, module, inputs): Tester(module, inputs) .export() .check_count({"torch.ops.aten.clamp.default": 1}) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_clamp_default": 1}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"]) .to_executorch() @@ -64,9 +62,7 @@ def test_qs8_clamp(self): .export() .check_count({"torch.ops.aten.clamp.default": 1}) .check(["torch.ops.quantized_decomposed"]) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_clamp_default": 1}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not( [ diff --git a/backends/xnnpack/test/ops/conv1d.py b/backends/xnnpack/test/ops/conv1d.py index 1759b1452d..ae4ca6884f 100644 --- a/backends/xnnpack/test/ops/conv1d.py +++ b/backends/xnnpack/test/ops/conv1d.py @@ -7,13 +7,16 @@ import unittest import torch -from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( - XnnpackFloatingPointPartitioner, +from executorch.backends.xnnpack.partition.config.xnnpack_config import ( + ConfigPrecisionType, +) +from executorch.backends.xnnpack.partition.xnnpack_partitioner2 import ( + XnnpackPartitioner, ) from executorch.backends.xnnpack.test.test_xnnpack_utils import randomize_bn from executorch.backends.xnnpack.test.tester import RunPasses, Tester -from executorch.backends.xnnpack.test.tester.tester import Partition +from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower from executorch.exir.passes.constant_prop_pass import constant_prop_pass @@ -93,8 +96,8 @@ def _test_conv1d( conv_count, quantized=False, dynamic_shape=None, - partition=None, passes=None, + stage=None, skip_to_executorch=False, ): tester = ( @@ -104,15 +107,9 @@ def _test_conv1d( else Tester(module, inputs) ) .export() - .check_count({"torch.ops.aten.conv1d.default": conv_count}) - .to_edge() - .check_count( - { - "executorch_exir_dialects_edge__ops_aten_convolution_default": conv_count - } - ) .run_passes(passes) - .partition(partition) + .check_count({"torch.ops.aten.conv1d.default": conv_count}) + .to_edge_transform_and_lower(stage) .check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) ) @@ -170,7 +167,11 @@ def test_qs8_conv1d_with_floating_point_partitioner(self): 1, quantized=True, dynamic_shape=dynamic_shapes, - partition=Partition(XnnpackFloatingPointPartitioner()), + stage=ToEdgeTransformAndLower( + partitioners=[ + XnnpackPartitioner(config_precisions=ConfigPrecisionType.FP32) + ] + ), passes=RunPasses(pass_functions=[constant_prop_pass]), skip_to_executorch=True, ) diff --git a/backends/xnnpack/test/ops/conv2d.py b/backends/xnnpack/test/ops/conv2d.py index 4a281e2265..95b22bb3f8 100644 --- a/backends/xnnpack/test/ops/conv2d.py +++ b/backends/xnnpack/test/ops/conv2d.py @@ -164,14 +164,13 @@ def _test( ( tester.export() .check_count({"torch.ops.aten.conv2d": conv_count}) - .to_edge() - .check_count( - { - "executorch_exir_dialects_edge__ops_aten_convolution_default": conv_count - } - ) - .partition() + .to_edge_transform_and_lower() .check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"]) + .check_not( + [ + "executorch_exir_dialects_edge__ops__native_batch_norm_legit_no_training_default" + ] + ) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() .serialize() diff --git a/backends/xnnpack/test/ops/div.py b/backends/xnnpack/test/ops/div.py index 3815b2f084..9bca5feed4 100644 --- a/backends/xnnpack/test/ops/div.py +++ b/backends/xnnpack/test/ops/div.py @@ -32,9 +32,7 @@ def _test_div(self, inputs): Tester(self.Div(), inputs) .export() .check_count({"torch.ops.aten.div.Tensor": 1}) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_div_Tensor": 1}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not(["executorch_exir_dialects_edge__ops_aten_div_Tensor"]) .to_executorch() @@ -62,9 +60,7 @@ def test_fp32_div_single_input(self): Tester(self.DivSingleInput(), inputs) .export() .check_count({"torch.ops.aten.div.Tensor": 1}) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_div_Tensor": 1}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not(["executorch_exir_dialects_edge__ops_aten_div_Tensor"]) .to_executorch() diff --git a/backends/xnnpack/test/ops/elu.py b/backends/xnnpack/test/ops/elu.py index 03bdfc508d..f976c29d79 100644 --- a/backends/xnnpack/test/ops/elu.py +++ b/backends/xnnpack/test/ops/elu.py @@ -28,9 +28,7 @@ def _test_elu(self, inputs): Tester(self.ELU(), inputs) .export() .check_count({"torch.ops.aten.elu.default": 1}) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_elu_default": 1}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not( [ @@ -42,17 +40,17 @@ def _test_elu(self, inputs): .run_method_and_compare_outputs() ) - @unittest.skip("T171810227 - Missing recomposition for ELU") + @unittest.skip("PyTorch Pin Update Required") def _test_fp16_elu(self): inputs = (torch.randn(1, 3, 3).to(torch.float16),) self._test_elu(inputs) - @unittest.skip("T171810227 - Missing recomposition for ELU") + @unittest.skip("PyTorch Pin Update Required") def _test_fp32_elu(self): inputs = (torch.randn(1, 3, 3),) self._test_elu(inputs) - @unittest.skip("T171810227 - Missing recomposition for ELU") + @unittest.skip("Update Quantizer to quantize Elu") def _test_qs8_elu(self): inputs = (torch.randn(1, 3, 4, 4),) ( @@ -61,9 +59,7 @@ def _test_qs8_elu(self): .export() .check_count({"torch.ops.aten.elu.default": 1}) .check(["torch.ops.quantized_decomposed"]) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_elu_default": 1}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not( [ @@ -76,7 +72,7 @@ def _test_qs8_elu(self): .run_method_and_compare_outputs() ) - @unittest.skip("T171810227 - Missing recomposition for ELU") + @unittest.skip("Update Quantizer to quantize Elu") def _test_qs8_elu_functional(self): inputs = (torch.randn(1, 3, 4, 4),) ( @@ -85,9 +81,7 @@ def _test_qs8_elu_functional(self): .export() .check_count({"torch.ops.aten.elu.default": 1}) .check(["torch.ops.quantized_decomposed"]) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_elu_default": 1}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not( [ diff --git a/backends/xnnpack/test/ops/hardtanh.py b/backends/xnnpack/test/ops/hardtanh.py index d13624663c..e35e840e3c 100644 --- a/backends/xnnpack/test/ops/hardtanh.py +++ b/backends/xnnpack/test/ops/hardtanh.py @@ -29,11 +29,7 @@ def test_fp32_hardtanh(self): Tester(self.HardTanh(), (input,)) .export() .check_count({"torch.ops.aten.hardtanh.default": 1}) - .to_edge() - .check_count( - {"executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1} - ) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"]) .to_executorch() @@ -48,11 +44,7 @@ def test_fp32_hardtanh_bound(self): Tester(self.HardTanh(-2.0, 2.0), (input,)) .export() .check_count({"torch.ops.aten.hardtanh.default": 1}) - .to_edge() - .check_count( - {"executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1} - ) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"]) .to_executorch() @@ -74,11 +66,7 @@ def test_qs8_hardtanh(self): torch.ops.aten.hardtanh.default: 1, } ) - .to_edge() - .check_count( - {"executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1} - ) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not( [ diff --git a/backends/xnnpack/test/ops/linear.py b/backends/xnnpack/test/ops/linear.py index 2ce1c2d3c3..52f1334a6b 100644 --- a/backends/xnnpack/test/ops/linear.py +++ b/backends/xnnpack/test/ops/linear.py @@ -10,12 +10,21 @@ from typing import Optional, Tuple import torch +from executorch.backends.xnnpack.partition.config.xnnpack_config import ( + ConfigPrecisionType, +) from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( XnnpackDynamicallyQuantizedPartitioner, ) +from executorch.backends.xnnpack.partition.xnnpack_partitioner2 import ( + XnnpackPartitioner, +) from executorch.backends.xnnpack.test.tester import Quantize, Tester -from executorch.backends.xnnpack.test.tester.tester import Partition +from executorch.backends.xnnpack.test.tester.tester import ( + Partition, + ToEdgeTransformAndLower, +) from torch.ao.quantization.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, @@ -68,11 +77,11 @@ def test_fp32_addmm(self): class AddMMModule(torch.nn.Module): def __init__(self, in_size, out_size): super().__init__() - self.mat = torch.nn.Parameter(torch.randn(out_size, in_size)) + self.mat = torch.nn.Parameter(torch.randn(in_size, out_size)) self.bias = torch.nn.Parameter(torch.randn(1, out_size)) def forward(self, x): - return torch.addmm(self.bias, x, torch.transpose(self.mat, 0, 1)) + return torch.addmm(self.bias, x, self.mat) self._test_linear( lambda in_size, out_size: AddMMModule(in_size, out_size), @@ -358,6 +367,223 @@ def forward(self, x, y): atol=1e-1, ) + def test_qd8_fp32_per_token_weight_per_channel_int8(self): + self._run_manual_dqlinear_tests(8, torch.float) + + def test_qd8_fp32_per_token_weight_per_channel_int4(self): + self._run_manual_dqlinear_tests(4, torch.float) + + # This fails because the output tensor dtype is different, but if you squint and ignore that and look at the values, + # it is not too bad. + # Difference: max: 0.042601585388183594, abs: 0.042601585388183594. + # -- Model vs. Reference -- + # Numel: 68, 68 + # Median: -0.7754800915718079, -0.7755751013755798 + # Mean: -0.6128872036933899, -0.6143574714660645 + # Max: 12.518657684326172, 12.516003608703613 + # Min: -20.070953369140625, -20.077701568603516 + @unittest.skip("Need to fix the dq_per_channel output dtype") + def _test_qd8_fp16_per_token_weight_per_channel_int8(self): + self._run_manual_dqlinear_tests(8, torch.float16) + + @unittest.skip("Need to fix the dq_per_channel output dtype") + def _test_qd8_fp16_per_token_weight_per_channel_int4(self): + self._run_manual_dqlinear_tests(4, torch.float16) + + def test_qd8_fp32_per_token_weight_per_channel_group_int4(self): + M_sizes = [1, 2, 17, 31] + K_sizes = [8, 32, 64, 128] + bl_sizes = [8, 16, 16, 32] + N_sizes = [2, 17, 92, 128] + + for use_bias in [True, False]: + for i, _ in enumerate(M_sizes): + M = int(M_sizes[i]) + K = int(K_sizes[i]) + N = int(N_sizes[i]) + bl = int(bl_sizes[i]) + mod = self.ManualDQLinear( + input_channels=K, + output_channels=N, + weight_n_bit=4, + dtype=torch.float, + group_size=bl, + force_groupwise_quant=True, + use_bias=use_bias, + ) + + inputs = (torch.randn(1, M, K),) + self._test_manual_dq_linear( + mod, + inputs, + weight_groupwise=True, + use_bias=use_bias, + ) + + @unittest.skip("Need to fix the dq_per_channel_group output dtype") + def _test_qd8_fp16_per_token_weight_per_channel_group_int4(self): + M_sizes = [1, 2, 17, 31] + K_sizes = [8, 32, 64, 128] + bl_sizes = [8, 16, 16, 32] + N_sizes = [2, 17, 92, 128] + + for use_bias in [True, False]: + for i, _ in enumerate(M_sizes): + M = int(M_sizes[i]) + K = int(K_sizes[i]) + N = int(N_sizes[i]) + bl = int(bl_sizes[i]) + mod = self.ManualDQLinear( + input_channels=K, + output_channels=N, + weight_n_bit=4, + dtype=torch.float16, + group_size=bl, + force_groupwise_quant=True, + use_bias=use_bias, + ) + + inputs = (torch.randn(1, M, K, dtype=torch.float16),) + self._test_manual_dq_linear( + mod, + inputs, + weight_groupwise=True, + use_bias=use_bias, + atol=0.1, + rtol=0.1, + ) + + def _test_linear( + self, + make_module, + uses_bias, + num_batch_dims=1, + quant_type=None, + dtype: torch.dtype = torch.float, + atol=1e-03, + ): + edge_op = ( + "executorch_exir_dialects_edge__ops_aten_addmm_default" + if uses_bias + else "executorch_exir_dialects_edge__ops_aten_mm_default" + ) + + in_sizes = [3, 4, 4] + input_sizes = [4, 37, 17] + output_sizes = [4, 17, 37] + + quant = quant_type is not None + + """ + Note that torch.nn.Linear maps to aten.mm.default (no bias) or aten.addmm.default (bias), + which ares then transformed into aten.linear.default by the ConvertToLinear pass. + """ + for i, _ in enumerate(in_sizes): + in_size = int(in_sizes[i]) + input_size = int(input_sizes[i]) + output_size = int(output_sizes[i]) + input_shape = [in_size] * num_batch_dims + [input_size] + print(f"Testing input_shape {input_shape} with {output_size} out_channels") + + module = make_module(input_size, output_size).eval().to(dtype) + inputs = (torch.randn(input_shape).to(dtype),) + dynamic_shape = {} + for i in range(num_batch_dims): + dynamic_shape[i] = torch.export.Dim(f"batch{i}", min=2, max=in_size) + + dynamic_shape = (dynamic_shape,) + print(dynamic_shape) + + tester = Tester(module, inputs, dynamic_shapes=dynamic_shape) + + if quant: + if quant_type == "per_channel": + quant_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=False, + ) + elif quant_type == "per_tensor": + quant_config = get_symmetric_quantization_config( + is_per_channel=False, + is_dynamic=False, + ) + else: + raise ValueError(f"Unsupported quant type {quant_type}") + tester.quantize(Quantize(quantization_config=quant_config)) + + tester.export() + if quant: + tester.check(["torch.ops.quantized_decomposed"]) + + tester.to_edge_transform_and_lower() + tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + tester.check_not([edge_op]) + + if quant: + tester.check_not([edge_op, "torch.ops.quantized_decomposed"]) + + tester.to_executorch() + tester.serialize() + tester.run_method_and_compare_outputs(qtol=quant, atol=atol) + + def _test_dqlinear( + self, + module, + inputs, + dynamic_shapes, + linear_count=1, + is_per_channel=False, + uses_bias=False, + qconfig: Optional[QuantizationConfig] = None, + atol=5e-02, + ): + edge_op = ( + "executorch_exir_dialects_edge__ops_aten_addmm_default" + if uses_bias + else "executorch_exir_dialects_edge__ops_aten_mm_default" + ) + + quant_config = qconfig or get_symmetric_quantization_config( + is_per_channel=is_per_channel, + is_dynamic=True, + ) + for legacy_partitioner in (True, False): + for per_op_mode in (True, False): + tester = Tester(module, inputs, dynamic_shapes=dynamic_shapes) + tester.quantize(Quantize(quantization_config=quant_config)) + + tester.export() + + if legacy_partitioner: + tester.to_edge() + tester.partition( + Partition(XnnpackDynamicallyQuantizedPartitioner()) + ) + else: + tester.to_edge_transform_and_lower( + ToEdgeTransformAndLower( + [ + XnnpackPartitioner( + config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, + per_op_mode=per_op_mode, + ) + ] + ) + ) + num_call_delegates = ( + linear_count if legacy_partitioner or per_op_mode else 1 + ) + tester.check_count( + { + "torch.ops.higher_order.executorch_call_delegate": num_call_delegates + } + ) + tester.check_not([edge_op]) + + tester.to_executorch() + tester.serialize() + tester.run_method_and_compare_outputs(atol=atol) + class ManualDQLinear(torch.nn.Module): def __init__( self, @@ -507,7 +733,6 @@ def group_quantize_tensor_symmetric( def fwd_input_per_token(self, input: torch.Tensor) -> torch.Tensor: ip_quant_min = -128 ip_quant_max = 127 - input = input.to(self.op_dtype) ( ip_scales, ip_zero_points, @@ -532,7 +757,6 @@ def fwd_input_per_token(self, input: torch.Tensor) -> torch.Tensor: torch.int8, self.op_dtype, ) - input = input.to(self.op_dtype) return input def quant_weight_per_channel(self): @@ -596,14 +820,14 @@ def fwd_weight_per_channel_group(self) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor: # Input - input = self.fwd_input_per_token(input).to(self.op_dtype) + input = self.fwd_input_per_token(input) # Weights w = ( self.fwd_weight_per_channel_group() if self.w_scales.ndim == 2 else self.fwd_weight_per_channel() - ).to(self.op_dtype) + ) assert isinstance(w, torch.Tensor) return torch.nn.functional.linear(input, w, self.bias) @@ -625,41 +849,65 @@ def _test_manual_dq_linear( weight_dq_edge_op = ( "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_group_default" if weight_groupwise - else "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default" + else "torch.ops.quantized_decomposed.dequantize_per_channel.default" ) - ( - Tester(mod, inputs) - .export() - .to_edge() - .check_count( - { - "executorch_exir_dialects_edge__ops_quantized_decomposed_choose_qparams_per_token_asymmetric_default": 1, - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_token_default": 1, - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_token_default": 1, - weight_dq_edge_op: 1, - linear_edge_op: 1, - } - ) - .partition(Partition(partitioner=XnnpackDynamicallyQuantizedPartitioner())) - .check_count( - { - "torch.ops.higher_order.executorch_call_delegate": 1, - } + weight_dq_aten_op = ( + "torch.ops.quantized_decomposed.dequantize_per_channel_group.default" + if weight_groupwise + else "torch.ops.quantized_decomposed.dequantize_per_channel.default" + ) + for legacy_partitioner in (True, False): + tester = ( + Tester(mod, inputs) + .export() + .check_count( + { + "torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default": 1, + "torch.ops.quantized_decomposed.quantize_per_token.default": 1, + "torch.ops.quantized_decomposed.dequantize_per_token.default": 1, + weight_dq_aten_op: 1, + "torch.ops.aten.linear.default": 1, + } + ) ) - .check_not( - [ - "executorch_exir_dialects_edge__ops_quantized_decomposed_choose_qparams_per_token_asymmetric_default", - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_token_default", - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_token_default", - weight_dq_edge_op, - linear_edge_op, - ] + + if legacy_partitioner: + tester.to_edge() + tester.partition(Partition(XnnpackDynamicallyQuantizedPartitioner())) + else: + ( + tester.to_edge_transform_and_lower( + ToEdgeTransformAndLower( + [ + XnnpackPartitioner( + config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, + per_op_mode=True, + ) + ] + ) + ) + ) + + ( + tester.check_count( + { + "torch.ops.higher_order.executorch_call_delegate": 1, + } + ) + .check_not( + [ + "executorch_exir_dialects_edge__ops_quantized_decomposed_choose_qparams_per_token_asymmetric_default", + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_token_default", + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_token_default", + weight_dq_edge_op, + linear_edge_op, + ] + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs(atol=atol, rtol=rtol) ) - .to_executorch() - .serialize() - .run_method_and_compare_outputs(atol=atol, rtol=rtol) - ) def _run_manual_dqlinear_tests(self, weight_n_bit: int, op_dtype: torch.dtype): in_sizes = [1, 4, 4] @@ -681,215 +929,3 @@ def _run_manual_dqlinear_tests(self, weight_n_bit: int, op_dtype: torch.dtype): inputs = (torch.randn(1, in_size, input_size).to(op_dtype),) self._test_manual_dq_linear(mod, inputs, use_bias=use_bias) - - def test_qd8_fp32_per_token_weight_per_channel_int8(self): - self._run_manual_dqlinear_tests(8, torch.float) - - def test_qd8_fp32_per_token_weight_per_channel_int4(self): - self._run_manual_dqlinear_tests(4, torch.float) - - # This fails because the output tensor dtype is different, but if you squint and ignore that and look at the values, - # it is not too bad. - # Difference: max: 0.042601585388183594, abs: 0.042601585388183594. - # -- Model vs. Reference -- - # Numel: 68, 68 - # Median: -0.7754800915718079, -0.7755751013755798 - # Mean: -0.6128872036933899, -0.6143574714660645 - # Max: 12.518657684326172, 12.516003608703613 - # Min: -20.070953369140625, -20.077701568603516 - @unittest.skip("Need to fix the dq_per_channel output dtype") - def _test_qd8_fp16_per_token_weight_per_channel_int8(self): - self._run_manual_dqlinear_tests(8, torch.float16) - - @unittest.skip("Need to fix the dq_per_channel output dtype") - def _test_qd8_fp16_per_token_weight_per_channel_int4(self): - self._run_manual_dqlinear_tests(4, torch.float16) - - def test_qd8_fp32_per_token_weight_per_channel_group_int4(self): - M_sizes = [1, 2, 17, 31] - K_sizes = [8, 32, 64, 128] - bl_sizes = [8, 16, 16, 32] - N_sizes = [2, 17, 92, 128] - - for use_bias in [True, False]: - for i, _ in enumerate(M_sizes): - M = int(M_sizes[i]) - K = int(K_sizes[i]) - N = int(N_sizes[i]) - bl = int(bl_sizes[i]) - mod = self.ManualDQLinear( - input_channels=K, - output_channels=N, - weight_n_bit=4, - dtype=torch.float, - group_size=bl, - force_groupwise_quant=True, - use_bias=use_bias, - ) - - inputs = (torch.randn(1, M, K),) - self._test_manual_dq_linear( - mod, - inputs, - weight_groupwise=True, - use_bias=use_bias, - ) - - def test_qd8_fp16_per_token_weight_per_channel_group_int4(self): - M_sizes = [1, 2, 17, 31] - K_sizes = [8, 32, 64, 128] - bl_sizes = [8, 16, 16, 32] - N_sizes = [2, 17, 92, 128] - - for use_bias in [True, False]: - for i, _ in enumerate(M_sizes): - M = int(M_sizes[i]) - K = int(K_sizes[i]) - N = int(N_sizes[i]) - bl = int(bl_sizes[i]) - mod = self.ManualDQLinear( - input_channels=K, - output_channels=N, - weight_n_bit=4, - dtype=torch.float16, - group_size=bl, - force_groupwise_quant=True, - use_bias=use_bias, - ) - - inputs = (torch.randn(1, M, K, dtype=torch.float16),) - self._test_manual_dq_linear( - mod, - inputs, - weight_groupwise=True, - use_bias=use_bias, - atol=0.1, - rtol=0.1, - ) - - def _test_linear( - self, - make_module, - uses_bias, - num_batch_dims=1, - quant_type=None, - dtype: torch.dtype = torch.float, - atol=1e-03, - ): - aten_op, edge_op = ( - ( - "aten.addmm.default", - "executorch_exir_dialects_edge__ops_aten_addmm_default", - ) - if uses_bias - else ( - "aten.mm.default", - "executorch_exir_dialects_edge__ops_aten_mm_default", - ) - ) - - in_sizes = [3, 4, 4] - input_sizes = [4, 37, 17] - output_sizes = [4, 17, 37] - - quant = quant_type is not None - - """ - Note that torch.nn.Linear maps to aten.mm.default (no bias) or aten.addmm.default (bias), - which ares then transformed into aten.linear.default by the ConvertToLinear pass. - """ - for i, _ in enumerate(in_sizes): - in_size = int(in_sizes[i]) - input_size = int(input_sizes[i]) - output_size = int(output_sizes[i]) - input_shape = [in_size] * num_batch_dims + [input_size] - print(f"Testing input_shape {input_shape} with {output_size} out_channels") - - module = make_module(input_size, output_size).eval().to(dtype) - inputs = (torch.randn(input_shape).to(dtype),) - dynamic_shape = {} - for i in range(num_batch_dims): - dynamic_shape[i] = torch.export.Dim(f"batch{i}", min=2, max=in_size) - - dynamic_shape = (dynamic_shape,) - print(dynamic_shape) - - tester = Tester(module, inputs, dynamic_shapes=dynamic_shape) - - if quant: - if quant_type == "per_channel": - quant_config = get_symmetric_quantization_config( - is_per_channel=True, - is_dynamic=False, - ) - elif quant_type == "per_tensor": - quant_config = get_symmetric_quantization_config( - is_per_channel=False, - is_dynamic=False, - ) - else: - raise ValueError(f"Unsupported quant type {quant_type}") - tester.quantize(Quantize(quantization_config=quant_config)) - - tester.export() - if quant: - tester.check(["torch.ops.quantized_decomposed"]) - - tester.to_edge() - tester.check_count({edge_op: 1}) - - tester.partition() - tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - tester.check_not([edge_op]) - - if quant: - tester.check_not([edge_op, "torch.ops.quantized_decomposed"]) - - tester.to_executorch() - tester.serialize() - tester.run_method_and_compare_outputs(qtol=quant, atol=atol) - - def _test_dqlinear( - self, - module, - inputs, - dynamic_shapes, - linear_count=1, - is_per_channel=False, - uses_bias=False, - qconfig: Optional[QuantizationConfig] = None, - atol=5e-02, - ): - aten_op, edge_op = ( - ( - "aten.addmm.default", - "executorch_exir_dialects_edge__ops_aten_addmm_default", - ) - if uses_bias - else ( - "aten.mm.default", - "executorch_exir_dialects_edge__ops_aten_mm_default", - ) - ) - - quant_config = qconfig or get_symmetric_quantization_config( - is_per_channel=is_per_channel, - is_dynamic=True, - ) - - tester = Tester(module, inputs, dynamic_shapes=dynamic_shapes) - tester.quantize(Quantize(quantization_config=quant_config)) - - tester.export() - tester.to_edge() - tester.check_count({edge_op: linear_count}) - - tester.partition( - Partition(partitioner=XnnpackDynamicallyQuantizedPartitioner()) - ) - tester.check(["torch.ops.higher_order.executorch_call_delegate"]) - tester.check_not([edge_op]) - - tester.to_executorch() - tester.serialize() - tester.run_method_and_compare_outputs(atol=atol) diff --git a/backends/xnnpack/test/ops/max_dim.py b/backends/xnnpack/test/ops/max_dim.py index e16a4f8b15..c660a5a6d2 100644 --- a/backends/xnnpack/test/ops/max_dim.py +++ b/backends/xnnpack/test/ops/max_dim.py @@ -13,14 +13,12 @@ class TestMaxDim(unittest.TestCase): class Max(torch.nn.Module): def forward(self, x): - x = torch.add(x, x) max_values_1, max_indices_1 = torch.max(x, dim=2, keepdim=True) max_values_2, max_indices_2 = torch.max(x, dim=3, keepdim=True) return (max_values_1, max_indices_1, max_values_2, max_indices_2) class MaxNoIndices(torch.nn.Module): def forward(self, x): - x = torch.add(x, x) max_values_1, _ = torch.max(x, dim=2, keepdim=True) max_values_2, _ = torch.max(x, dim=3, keepdim=True) return (max_values_1, max_values_2) @@ -30,39 +28,36 @@ def _test_max_dim(self, inputs): Tester(self.Max(), inputs) .export() .check_count({"torch.ops.aten.max.dim": 2}) - .to_edge() + .to_edge_transform_and_lower() + .check_not(["torch.ops.higher_order.executorch_call_delegate"]) .check_count({"executorch_exir_dialects_edge__ops_aten_max_dim": 2}) - .partition() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 2}) + ) + + def _test_max_dim_no_indicies(self, inputs): + ( + Tester(self.MaxNoIndices(), inputs) + .export() + .check_count({"torch.ops.aten.max.dim": 2}) + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not(["executorch_exir_dialects_edge__ops_aten_max_dim"]) .to_executorch() .serialize() .run_method_and_compare_outputs() ) - @unittest.skip("T171468483 - Fails to partition due to index output dtype.") - def _test_fp16_max_dim(self): + def test_fp16_max_dim_with_indicies(self): inputs = (torch.randn(16, 3, 12, 12).to(torch.float16),) self._test_max_dim(inputs) - @unittest.skip("T171468483 - Fails to partition due to index output dtype.") - def _test_fp32_max_dim(self): + def test_fp32_max_dim_with_indices(self): inputs = (torch.randn(16, 3, 12, 12),) self._test_max_dim(inputs) - @unittest.skip("T171468483 - Fails to partition due to index output dtype.") - def _test_fp32_max_dim_no_indices(self): + def test_fp32_max_dim_no_indices(self): inputs = (torch.randn(16, 3, 12, 12),) - ( - Tester(self.MaxNoIndices(), inputs) - .export() - .check_count({"torch.ops.aten.max.dim": 2}) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_max_dim": 2}) - .partition() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 2}) - .check_not(["executorch_exir_dialects_edge__ops_aten_max_dim"]) - .to_executorch() - .serialize() - .run_method_and_compare_outputs() - ) + self._test_max_dim_no_indicies(inputs) + + def test_fp16_max_dim_no_indices(self): + inputs = (torch.randn(16, 3, 12, 12).to(torch.float16),) + self._test_max_dim_no_indicies(inputs) diff --git a/backends/xnnpack/test/ops/maximum.py b/backends/xnnpack/test/ops/maximum.py index feff02744d..30dfa5503a 100644 --- a/backends/xnnpack/test/ops/maximum.py +++ b/backends/xnnpack/test/ops/maximum.py @@ -23,9 +23,7 @@ def _test_maximum(self, inputs): Tester(self.Maximum(), inputs) .export() .check_count({"torch.ops.aten.maximum.default": 1}) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_maximum_default": 1}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not(["executorch_exir_dialects_edge__ops_aten_maximum_default"]) .to_executorch() @@ -56,9 +54,7 @@ def test_fp32_maximum_broadcast(self): Tester(self.Maximum(), inputs) .export() .check_count({"torch.ops.aten.maximum.default": 1}) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_maximum_default": 1}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not(["executorch_exir_dialects_edge__ops_aten_maximum_default"]) .to_executorch() diff --git a/backends/xnnpack/test/ops/maxpool2d.py b/backends/xnnpack/test/ops/maxpool2d.py index bbc76743b0..1031852176 100644 --- a/backends/xnnpack/test/ops/maxpool2d.py +++ b/backends/xnnpack/test/ops/maxpool2d.py @@ -55,14 +55,7 @@ def _test_maxpool2d(self, inputs): Tester(self.MaxPool2d(3, 1, 0, 1), inputs) .export() .check_count({"torch.ops.aten.max_pool2d.default": 1}) - .to_edge() - .check_count( - { - "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1, - } - ) - .check(["getitem"]) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not( [ @@ -91,13 +84,7 @@ def test_fp32_maxpool2d_unsupported(self): Tester(self.MaxPool2dUnsupported(), inputs) .export() .check_count({"torch.ops.aten.max_pool2d_with_indices.default": 1}) - .to_edge() - .check_count( - { - "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1 - } - ) - .partition() + .to_edge_transform_and_lower() # We expect it not be be delegated. .check_count( { @@ -115,13 +102,7 @@ def test_fp32_maxpool2d_unsupported_ceilmode(self): Tester(self.MaxPool2dUnsupportedCeilMode(), inputs) .export() .check_count({"torch.ops.aten.max_pool2d.default": 1}) - .to_edge() - .check_count( - { - "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1 - } - ) - .partition() + .to_edge_transform_and_lower() # We expect it not be be delegated. .check_count({"torch.ops.higher_order.executorch_call_delegate": 0}) .check_count( @@ -153,13 +134,7 @@ def forward(self, x): .export() .check_count({"torch.ops.aten.max_pool2d.default": 1}) .check(["torch.ops.quantized_decomposed"]) - .to_edge() - .check_count( - { - "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1 - } - ) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not( [ diff --git a/backends/xnnpack/test/ops/multiply.py b/backends/xnnpack/test/ops/multiply.py index d151f58bd6..db50bc5dd4 100644 --- a/backends/xnnpack/test/ops/multiply.py +++ b/backends/xnnpack/test/ops/multiply.py @@ -36,9 +36,7 @@ def _test_mul(self, inputs): Tester(self.Mul(), inputs) .export() .check_count({"torch.ops.aten.mul.Tensor": 1}) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not(["executorch_exir_dialects_edge__ops_aten_mul_Tensor"]) .to_executorch() @@ -65,9 +63,7 @@ def test_qs8_mul(self): .export() .check_count({"torch.ops.aten.mul.Tensor": 1}) .check(["torch.ops.quantized_decomposed"]) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not( [ @@ -88,9 +84,7 @@ def test_qs8_mul2(self): .export() .check_count({"torch.ops.aten.mul.Tensor": 1}) .check(["torch.ops.quantized_decomposed"]) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not( [ @@ -111,9 +105,7 @@ def test_qs8_mul_functional(self): .export() .check_count({"torch.ops.aten.mul.Tensor": 3}) .check(["torch.ops.quantized_decomposed"]) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 3}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not( [ @@ -139,9 +131,7 @@ def test_qs8_mul_relu(self): } ) .check(["torch.ops.quantized_decomposed"]) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not( [ diff --git a/backends/xnnpack/test/ops/permute.py b/backends/xnnpack/test/ops/permute.py index 2c99537675..b348fc8af6 100644 --- a/backends/xnnpack/test/ops/permute.py +++ b/backends/xnnpack/test/ops/permute.py @@ -36,11 +36,7 @@ def _test_permute(self, inputs): Tester(self.Permute([0, 2, 3, 1]), inputs) .export() .check_count({"torch.ops.aten.permute.default": 1}) - .to_edge() - .check_count( - {"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1} - ) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not(["executorch_exir_dialects_edge__ops_aten_permute_copy_default"]) .to_executorch() @@ -62,11 +58,7 @@ def test_fp32_permute_copy(self): Tester(self.PermuteCopy([0, 2, 3, 1]), inputs) .export() .check_count({"torch.ops.aten.permute_copy.default": 1}) - .to_edge() - .check_count( - {"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1} - ) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not(["executorch_exir_dialects_edge__ops_aten_permute_copy_default"]) .to_executorch() @@ -86,11 +78,7 @@ def test_qs8_permute(self): torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, } ) - .to_edge() - .check_count( - {"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1} - ) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not( [ @@ -115,11 +103,7 @@ def test_qs8_permute_copy(self): torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, } ) - .to_edge() - .check_count( - {"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1} - ) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not( [ diff --git a/backends/xnnpack/test/ops/quantize_per_tensor.py b/backends/xnnpack/test/ops/quantize_per_tensor.py index f912428a8a..c211798753 100644 --- a/backends/xnnpack/test/ops/quantize_per_tensor.py +++ b/backends/xnnpack/test/ops/quantize_per_tensor.py @@ -24,13 +24,7 @@ def forward(self, x): ( Tester(Quant(), inputs) .export() - .to_edge() - .check_count( - { - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 1 - } - ) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not( [ @@ -60,13 +54,7 @@ def forward(self, x): ( Tester(Dequant(), inputs) .export() - .to_edge() - .check_count( - { - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1 - } - ) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not( [ diff --git a/backends/xnnpack/test/ops/relu.py b/backends/xnnpack/test/ops/relu.py index 3ab1c72b57..8672b1d3e4 100644 --- a/backends/xnnpack/test/ops/relu.py +++ b/backends/xnnpack/test/ops/relu.py @@ -26,9 +26,7 @@ def test_fp32_relu(self): Tester(self.Relu(), inputs) .export() .check_count({"torch.ops.aten.relu.default": 1}) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_relu_default": 1}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"]) .to_executorch() diff --git a/backends/xnnpack/test/ops/sigmoid.py b/backends/xnnpack/test/ops/sigmoid.py index 3dde395922..a9acd4df6d 100644 --- a/backends/xnnpack/test/ops/sigmoid.py +++ b/backends/xnnpack/test/ops/sigmoid.py @@ -25,9 +25,7 @@ def _test_sigmoid(self, inputs): Tester(self.Sigmoid(), inputs) .export() .check_count({"torch.ops.aten.sigmoid.default": 1}) - .to_edge() - .check_count({"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1}) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"]) .to_executorch() diff --git a/backends/xnnpack/test/ops/softmax.py b/backends/xnnpack/test/ops/softmax.py index 697b6f9294..cc544a28a2 100644 --- a/backends/xnnpack/test/ops/softmax.py +++ b/backends/xnnpack/test/ops/softmax.py @@ -29,11 +29,7 @@ def _test_softmax(self, inputs): Tester(self.Softmax(dim), inputs) .export() .check_count({"torch.ops.aten.softmax": 1}) - .to_edge() - .check_count( - {"executorch_exir_dialects_edge__ops_aten__softmax_default": 1} - ) - .partition() + .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not(["executorch_exir_dialects_edge__ops_aten__softmax_default"]) .to_executorch() @@ -63,11 +59,7 @@ def test_fp32_softmax_unsupported(self): Tester(self.Softmax(dim), inputs) .export() .check_count({"torch.ops.aten.softmax": 1}) - .to_edge() - .check_count( - {"executorch_exir_dialects_edge__ops_aten__softmax_default": 1} - ) - .partition() + .to_edge_transform_and_lower() # Should not be delegated .check(["executorch_exir_dialects_edge__ops_aten__softmax_default"]) ) diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index a93c20a6b1..bd3971523d 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -11,11 +11,10 @@ import sys from abc import ABC, abstractmethod from collections import Counter, OrderedDict -from typing import Any, Callable, Dict, List, Optional, Tuple, Type +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import torch import torch.export._trace as export_trace -from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner from executorch.backends.xnnpack.passes import XNNPACKPassManager from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config from executorch.exir import ( @@ -29,6 +28,7 @@ from executorch.exir.backend.partitioner import Partitioner from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.exir.print_program import pretty_print, print_program +from executorch.exir.program._program import _to_edge_transform_and_lower logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -40,6 +40,7 @@ logger.warning(f"{e=}") pass +from executorch.exir.program._program import _transform from torch._export.pass_base import PassType from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer.quantizer import Quantizer @@ -234,23 +235,76 @@ def __init__( ): self.pass_list = pass_list self.pass_functions = pass_functions - self.edge_dialect_program = None + self.edge_or_aten_program = None - def run(self, artifact: EdgeProgramManager, inputs=None) -> None: - self.edge_dialect_program = artifact - if self.pass_list: - pass_manager = XNNPACKPassManager( - artifact.exported_program(), self.pass_list - ) - self.edge_dialect_program._edge_programs["forward"] = ( - pass_manager.transform() - ) - if self.pass_functions: - assert isinstance(self.pass_functions, list) - for pass_function in self.pass_functions: - self.edge_dialect_program._edge_programs["forward"] = pass_function( - self.edge_dialect_program.exported_program() + def run( + self, artifact: Union[EdgeProgramManager, ExportedProgram], inputs=None + ) -> None: + if isinstance(artifact, EdgeProgramManager): + self.edge_or_aten_program = artifact + if self.pass_list: + pass_manager = XNNPACKPassManager( + artifact.exported_program(), self.pass_list ) + self.edge_or_aten_program._edge_programs["forward"] = ( + pass_manager.transform() + ) + if self.pass_functions: + assert isinstance(self.pass_functions, list) + for pass_function in self.pass_functions: + self.edge_or_aten_program._edge_programs["forward"] = pass_function( + self.edge_or_aten_program.exported_program() + ) + else: + transformed_ep = artifact + if self.pass_list: + assert isinstance(self.pass_list, list) + for pass_ in self.pass_list: + transformed_ep = _transform(transformed_ep, pass_()) + + if self.pass_functions: + assert isinstance(self.pass_functions, list) + for pass_function in self.pass_functions: + transformed_ep = pass_function(transformed_ep) + + self.edge_or_aten_program = transformed_ep + + @property + def artifact(self) -> Union[EdgeProgramManager, ExportedProgram]: + return self.edge_or_aten_program + + @property + def graph_module(self) -> str: + if isinstance(self.edge_or_aten_program, EdgeProgramManager): + return self.edge_or_aten_program.exported_program().graph_module + else: + return self.edge_or_aten_program.graph_module + + +@register_stage +class ToEdgeTransformAndLower(Stage): + def __init__( + self, + partitioners: Optional[List[Partitioner]] = None, + edge_compile_config: Optional[EdgeCompileConfig] = None, + ): + from executorch.backends.xnnpack.partition.xnnpack_partitioner2 import ( + XnnpackPartitioner, + ) + + self.partitioners = partitioners or [XnnpackPartitioner()] + self.edge_compile_conf = ( + edge_compile_config or get_xnnpack_edge_compile_config() + ) + self.edge_dialect_program = None + + def run(self, artifact: ExportedProgram, inputs=None) -> None: + artifact_to_run = copy.deepcopy(artifact) + self.edge_dialect_program = _to_edge_transform_and_lower( + artifact_to_run, + compile_config=self.edge_compile_conf, + partitioner=self.partitioners, + ) @property def artifact(self) -> EdgeProgramManager: @@ -264,6 +318,10 @@ def graph_module(self) -> str: @register_stage class Partition(Stage): def __init__(self, partitioner: Optional[Partitioner] = None): + from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackPartitioner, + ) + self.partitioner = partitioner or XnnpackPartitioner() self.delegate_module = None @@ -372,13 +430,22 @@ def __init__( self.pipeline = { self.stage_name(Quantize): [self.stage_name(Export)], self.stage_name(Export): [ + self.stage_name(RunPasses), self.stage_name(ToEdge), + self.stage_name(ToEdgeTransformAndLower), + ], + self.stage_name(ToEdgeTransformAndLower): [ + self.stage_name(RunPasses), + self.stage_name(ToExecutorch), ], self.stage_name(ToEdge): [ self.stage_name(Partition), self.stage_name(RunPasses), ], - self.stage_name(RunPasses): [self.stage_name(Partition)], + self.stage_name(RunPasses): [ + self.stage_name(Partition), + self.stage_name(ToEdgeTransformAndLower), + ], # TODO Make this Stage optional self.stage_name(Partition): [self.stage_name(ToExecutorch)], self.stage_name(ToExecutorch): [self.stage_name(Serialize)], @@ -502,6 +569,11 @@ def to_edge(self, to_edge_stage: Optional[ToEdge] = None): to_edge_stage.edge_compile_conf._skip_dim_order = True return self._run_stage(to_edge_stage) + def to_edge_transform_and_lower( + self, to_edge_and_transform_stage: Optional[ToEdgeTransformAndLower] = None + ): + return self._run_stage(to_edge_and_transform_stage or ToEdgeTransformAndLower()) + def run_passes(self, run_passes_stage: Optional[RunPasses] = None): return self._run_stage(run_passes_stage or RunPasses()) diff --git a/backends/xnnpack/utils/TARGETS b/backends/xnnpack/utils/TARGETS index b542006e3b..55615e1106 100644 --- a/backends/xnnpack/utils/TARGETS +++ b/backends/xnnpack/utils/TARGETS @@ -9,6 +9,7 @@ python_library( "//caffe2:torch", "//executorch/exir:lib", "//executorch/exir:pass_manager", + "//executorch/exir/backend/canonical_partitioners:config_partitioner_lib", "//executorch/exir/dialects:lib", "//pytorch/ao:torchao", # @manual ], diff --git a/backends/xnnpack/utils/configs.py b/backends/xnnpack/utils/configs.py index 3fe290606c..9dda84c5e5 100644 --- a/backends/xnnpack/utils/configs.py +++ b/backends/xnnpack/utils/configs.py @@ -12,8 +12,12 @@ ### XNNPACK Configs ### -def get_xnnpack_edge_compile_config() -> exir.EdgeCompileConfig: - return exir.EdgeCompileConfig(_check_ir_validity=False, _skip_dim_order=True) +def get_xnnpack_edge_compile_config( + skip_dim_order: bool = True, +) -> exir.EdgeCompileConfig: + return exir.EdgeCompileConfig( + _check_ir_validity=False, _skip_dim_order=skip_dim_order + ) def get_transform_passes(additional_passes=None) -> List[PassType]: diff --git a/backends/xnnpack/utils/quant_utils.py b/backends/xnnpack/utils/quant_utils.py index 0b6e7e496a..d5a7ec7fd0 100644 --- a/backends/xnnpack/utils/quant_utils.py +++ b/backends/xnnpack/utils/quant_utils.py @@ -5,28 +5,74 @@ # LICENSE file in the root directory of this source tree. import torch -from executorch.exir.dialects._ops import ops as exir_ops - -DQ_TARGETS = { - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, - exir_ops.edge.quantized_decomposed.dequantize_per_channel_group.default, - exir_ops.edge.quantized_decomposed.dequantize_per_token.default, +from executorch.exir.backend.canonical_partitioners.config_partitioner import ( + format_target_name, +) + +_Q_OPS = { + "quantize_per_tensor.tensor", + "quantize_per_tensor.default", + "quantize_per_channel.default", + "quantize_per_channel_group.default", + "quantize_per_token.default", } -Q_TARGETS = { - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, - exir_ops.edge.quantized_decomposed.quantize_per_channel.default, - exir_ops.edge.quantized_decomposed.quantize_per_channel_group.default, - exir_ops.edge.quantized_decomposed.quantize_per_token.default, +_DQ_OPS = { + "dequantize_per_tensor.tensor", + "dequantize_per_tensor.default", + "dequantize_per_channel.default", + "dequantize_per_channel_group.default", + "dequantize_per_token.default", } -def is_quant(tensor: torch.fx.Node) -> bool: - return tensor.target in Q_TARGETS +_QPARAM_OPS = { + "choose_qparams.tensor", + "choose_qparams_per_token_asymmetric.default", +} + +_DYNAMIC_OPS = { + "quantize_per_tensor.tensor", + "quantize_per_token.default", + "dequantize_per_tensor.tensor", + "dequantize_per_token.default", +} + + +def is_dynamic_qdq(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + node_name = format_target_name(node.target.__name__) # pyre-ignore + + return node_name in _DYNAMIC_OPS + + +def is_qparam(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + node_name = format_target_name(node.target.__name__) # pyre-ignore + + return node_name in _QPARAM_OPS + + +def is_quant(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + node_name = format_target_name(node.target.__name__) # pyre-ignore + + return node_name in _Q_OPS + + +def is_dequant(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + node_name = format_target_name(node.target.__name__) # pyre-ignore + + return node_name in _DQ_OPS + +def is_per_channel(node: torch.fx.Node) -> bool: + if not (is_quant(node) or is_dequant(node)): + return False -def is_dequant(tensor: torch.fx.Node) -> bool: - return tensor.target in DQ_TARGETS + return "per_channel" in node.target.__name__ # pyre-ignore diff --git a/backends/xnnpack/utils/utils.py b/backends/xnnpack/utils/utils.py index 5c76922472..b802d73c16 100644 --- a/backends/xnnpack/utils/utils.py +++ b/backends/xnnpack/utils/utils.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import cast, Optional, Tuple +from typing import Any, cast, Optional, Tuple import executorch.exir as exir import torch @@ -62,6 +62,20 @@ def check_or_raise(condition: bool, err: str) -> None: raise RuntimeError(err) +def is_node(node: Any) -> bool: + """ + returns true if node is a torch.fx.Node, otherwise false + """ + return isinstance(node, torch.fx.Node) + + +def is_getitem(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + + return node.target.__name__ == "getitem" # pyre-ignore + + def get_input_node(node: torch.fx.Node, input_index: int) -> torch.fx.Node: return cast(torch.fx.Node, node.args[input_index]) diff --git a/examples/cadence/models/wav2vec2.py b/examples/cadence/models/wav2vec2.py new file mode 100644 index 0000000000..5db9ea2a6d --- /dev/null +++ b/examples/cadence/models/wav2vec2.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Example script for exporting simple models to flatbuffer + +import logging + +from executorch.backends.cadence.aot.ops_registrations import * # noqa + +import torch + +from executorch.backends.cadence.aot.export_example import export_model +from torchaudio.models.wav2vec2.model import wav2vec2_model, Wav2Vec2Model + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +def main() -> None: + # The wrapper is needed to avoid issues with the optional second arguments + # of Wav2Vec2Models. + class Wav2Vec2ModelWrapper(torch.nn.Module): + def __init__(self, model: Wav2Vec2Model): + super().__init__() + self.model = model + + def forward(self, x): + out, _ = self.model(x) + return out + + _model = wav2vec2_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=768, + encoder_projection_dropout=0.1, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=12, + encoder_num_heads=12, + encoder_attention_dropout=0.1, + encoder_ff_interm_features=3072, + encoder_ff_interm_dropout=0.0, + encoder_dropout=0.1, + encoder_layer_norm_first=False, + encoder_layer_drop=0.1, + aux_num_out=None, + ) + _model.eval() + + model = Wav2Vec2ModelWrapper(_model) + model.eval() + + # test input + audio_len = 1680 + example_inputs = (torch.rand(1, audio_len),) + + export_model(model, example_inputs) + + +if __name__ == "__main__": + main() diff --git a/examples/models/llama2/runner/runner.cpp b/examples/models/llama2/runner/runner.cpp index b624e6eb43..cd5346bacd 100644 --- a/examples/models/llama2/runner/runner.cpp +++ b/examples/models/llama2/runner/runner.cpp @@ -32,11 +32,6 @@ #include namespace torch::executor { -namespace { -static constexpr auto kTopp = 0.9f; -void printReport(const Runner::Stats& stats); -std::string statsToJsonString(const Runner::Stats& stats); -} // namespace Runner::Runner( const std::string& model_path, @@ -96,7 +91,7 @@ Error Runner::load() { sampler_ = std::make_unique( vocab_size_, temperature_, - kTopp, + ::executorch::llm::kTopp, static_cast(std::time(nullptr))); return Error::Ok; @@ -479,7 +474,7 @@ Error Runner::generate( stats_.num_prompt_tokens = num_prompt_tokens; stats_.num_generated_tokens = pos - num_prompt_tokens; - printReport(stats_); + ::executorch::llm::print_report(stats_); if (stats_callback) { stats_callback(stats_); } @@ -487,84 +482,6 @@ Error Runner::generate( return Error::Ok; } -namespace { -void printReport(const Runner::Stats& stats) { - printf("PyTorchObserver %s\n", statsToJsonString(stats).c_str()); - - ET_LOG( - Info, - "\tPrompt Tokens: %" PRIu64 " Generated Tokens: %" PRIu64, - stats.num_prompt_tokens, - stats.num_generated_tokens); - - ET_LOG( - Info, - "\tModel Load Time:\t\t%f (seconds)", - ((double)(stats.model_load_end_ms - stats.model_load_start_ms) / - stats.SCALING_FACTOR_UNITS_PER_SECOND)); - double inference_time_ms = - (double)(stats.inference_end_ms - stats.inference_start_ms); - ET_LOG( - Info, - "\tTotal inference time:\t\t%f (seconds)\t\t Rate: \t%f (tokens/second)", - inference_time_ms / stats.SCALING_FACTOR_UNITS_PER_SECOND, - - (stats.num_generated_tokens) / - (double)(stats.inference_end_ms - stats.inference_start_ms) * - stats.SCALING_FACTOR_UNITS_PER_SECOND); - double prompt_eval_time = - (double)(stats.prompt_eval_end_ms - stats.inference_start_ms); - ET_LOG( - Info, - "\t\tPrompt evaluation:\t%f (seconds)\t\t Rate: \t%f (tokens/second)", - prompt_eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND, - (stats.num_prompt_tokens) / prompt_eval_time * - stats.SCALING_FACTOR_UNITS_PER_SECOND); - - double eval_time = - (double)(stats.inference_end_ms - stats.prompt_eval_end_ms); - ET_LOG( - Info, - "\t\tGenerated %" PRIu64 - " tokens:\t%f (seconds)\t\t Rate: \t%f (tokens/second)", - stats.num_generated_tokens, - eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND, - stats.num_generated_tokens / eval_time * - stats.SCALING_FACTOR_UNITS_PER_SECOND); - - // Time to first token is measured from the start of inference, excluding - // model load time. - ET_LOG( - Info, - "\tTime to first generated token:\t%f (seconds)", - ((double)(stats.first_token_ms - stats.inference_start_ms) / - stats.SCALING_FACTOR_UNITS_PER_SECOND)); - - ET_LOG( - Info, - "\tSampling time over %" PRIu64 " tokens:\t%f (seconds)", - stats.num_prompt_tokens + stats.num_generated_tokens, - (double)stats.aggregate_sampling_time_ms / - stats.SCALING_FACTOR_UNITS_PER_SECOND); -} - -std::string statsToJsonString(const Runner::Stats& stats) { - std::stringstream ss; - ss << "{\"prompt_tokens\":" << stats.num_prompt_tokens << "," - << "\"generated_tokens\":" << stats.num_generated_tokens << "," - << "\"model_load_start_ms\":" << stats.model_load_start_ms << "," - << "\"model_load_end_ms\":" << stats.model_load_end_ms << "," - << "\"inference_start_ms\":" << stats.inference_start_ms << "," - << "\"inference_end_ms\":" << stats.inference_end_ms << "," - << "\"prompt_eval_end_ms\":" << stats.prompt_eval_end_ms << "," - << "\"first_token_ms\":" << stats.first_token_ms << "," - << "\"aggregate_sampling_time_ms\":" << stats.aggregate_sampling_time_ms - << "," << "\"SCALING_FACTOR_UNITS_PER_SECOND\":" - << stats.SCALING_FACTOR_UNITS_PER_SECOND << "}"; - return ss.str(); -} -} // namespace - void Runner::stop() { shouldStop_ = true; } diff --git a/examples/models/llama2/runner/runner.h b/examples/models/llama2/runner/runner.h index 7b9d2763fc..407527531d 100644 --- a/examples/models/llama2/runner/runner.h +++ b/examples/models/llama2/runner/runner.h @@ -15,14 +15,17 @@ #include #include #include +#include #include +#include #include #include #include #include namespace torch::executor { +using Stats = ::executorch::llm::Stats; class Runner { public: @@ -31,32 +34,6 @@ class Runner { const std::string& tokenizer_path, const float temperature = 0.8f); - struct Stats { - // Scaling factor for timestamps - in this case, we use ms. - const long SCALING_FACTOR_UNITS_PER_SECOND = 1000; - // Time stamps for the different stages of the execution - // model_load_start_ms: Start of model loading. - long model_load_start_ms; - // model_load_end_ms: End of model loading. - long model_load_end_ms; - // inference_start_ms: Immediately after the model is loaded (or we check - // for model load), measure the inference time. - long inference_start_ms; - // prompt_eval_end_ms: Prompt array allocation and tokenization. Ends right - // before the inference loop starts - long prompt_eval_end_ms; - // first_token: Timestamp when the first generated token is emitted - long first_token_ms; - // inference_end_ms: End of inference/generation. - long inference_end_ms; - // Keep a running total of the time spent in sampling. - long aggregate_sampling_time_ms; - // Token count from prompt - int64_t num_prompt_tokens; - // Token count from generated (total - prompt) - int64_t num_generated_tokens; - }; - bool is_loaded() const; Error load(); Error generate( diff --git a/examples/models/llama2/runner/targets.bzl b/examples/models/llama2/runner/targets.bzl index c8b63b6a54..c0a892e14d 100644 --- a/examples/models/llama2/runner/targets.bzl +++ b/examples/models/llama2/runner/targets.bzl @@ -33,6 +33,7 @@ def define_common_targets(): ], exported_deps = [ "//executorch/backends/xnnpack:xnnpack_backend", + "//executorch/extension/llm/runner:stats", "//executorch/extension/llm/sampler:sampler" + aten_suffix, "//executorch/extension/evalue_util:print_evalue" + aten_suffix, "//executorch/extension/runner_util:managed_tensor" + aten_suffix, diff --git a/examples/models/llama2/targets.bzl b/examples/models/llama2/targets.bzl index a0c41d03e0..6cf398097d 100644 --- a/examples/models/llama2/targets.bzl +++ b/examples/models/llama2/targets.bzl @@ -10,6 +10,7 @@ def define_common_targets(): srcs = [ "main.cpp", ], + compiler_flags = ["-Wno-global-constructors"], preprocessor_flags = [ "-DUSE_ATEN_LIB", ] if aten else [], diff --git a/examples/models/phi-3-mini/CMakeLists.txt b/examples/models/phi-3-mini/CMakeLists.txt index 5dddf7eb71..39358e088e 100644 --- a/examples/models/phi-3-mini/CMakeLists.txt +++ b/examples/models/phi-3-mini/CMakeLists.txt @@ -4,6 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# ### Editing this file ### +# +# This file should be formatted with +# ~~~ +# cmake-format -i CMakeLists.txt +# ~~~ +# It should also be cmake-lint clean. +# + cmake_minimum_required(VERSION 3.19) project(phi_3_mini_runner) @@ -18,22 +27,26 @@ option(EXECUTORCH_BUILD_KERNELS_OPTIMIZED "" ON) option(EXECUTORCH_BUILD_XNNPACK "" ON) add_subdirectory( - ${CMAKE_CURRENT_SOURCE_DIR}/../../.. - ${CMAKE_BINARY_DIR}/../../..) -add_subdirectory( - ${CMAKE_CURRENT_SOURCE_DIR}/../../../extension/llm/third-party/sentencepiece - ${CMAKE_BINARY_DIR}/sentencepiece) + ${CMAKE_CURRENT_SOURCE_DIR}/../../.. ${CMAKE_BINARY_DIR}/../../.. +) +if(NOT TARGET gflags) + add_subdirectory( + ${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/gflags + ${CMAKE_BINARY_DIR}/gflags + ) +endif() -add_executable(phi_3_mini_runner main.cpp) +add_executable( + phi_3_mini_runner + main.cpp runner.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../../../extension/llm/sampler/sampler.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../../../extension/llm/tokenizer/bpe_tokenizer.cpp +) target_include_directories( - phi_3_mini_runner - PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR}/../../../extension/llm/third-party/sentencepiece/src) + phi_3_mini_runner + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/gflags/src +) target_link_libraries( - phi_3_mini_runner - PRIVATE - executorch - extension_module_static - optimized_native_cpu_ops_lib - xnnpack_backend - sentencepiece) + phi_3_mini_runner PRIVATE executorch extension_module_static + optimized_native_cpu_ops_lib xnnpack_backend gflags +) diff --git a/examples/models/phi-3-mini/eager.py b/examples/models/phi-3-mini/eager.py index f3d6f9a224..8b57b5a24c 100644 --- a/examples/models/phi-3-mini/eager.py +++ b/examples/models/phi-3-mini/eager.py @@ -43,11 +43,7 @@ def _generate_token_with_kv_cache(args, model, prompt_tokens): print("Generating tokens:", end="", flush=True) model = Phi3Mini(model, 1, args.seq_len + prompt_tokens.shape[-1]) - - for input_pos in range(prompt_tokens.shape[-1]): - result = model.forward( - input_ids=prompt_tokens[:, input_pos : input_pos + 1], - ) + result = model.forward(input_ids=prompt_tokens) current_token = torch.argmax(result, dim=-1).item() print(f" {current_token}", end="", flush=True) diff --git a/examples/models/phi-3-mini/export_phi-3-mini.py b/examples/models/phi-3-mini/export_phi-3-mini.py index af933ff292..a6d202a42b 100644 --- a/examples/models/phi-3-mini/export_phi-3-mini.py +++ b/examples/models/phi-3-mini/export_phi-3-mini.py @@ -23,7 +23,7 @@ XNNPACKQuantizer, ) -from transformers import AutoTokenizer, Phi3ForCausalLM +from transformers import Phi3ForCausalLM from .phi_3_mini import Phi3Mini @@ -40,18 +40,16 @@ def main(args) -> None: max_batch_size=1, max_seq_len=args.seq_len, ) - tokenizer = AutoTokenizer.from_pretrained(model_name) - - tokens = tokenizer.encode("Tell me a story", return_tensors="pt") - for input_pos in range(tokens.shape[-1]): - result = model.forward( - input_ids=tokens[:, input_pos : input_pos + 1], - ) - current_token = torch.argmax(result, dim=-1).item() - example_inputs = ( - torch.tensor([[current_token]], dtype=torch.long, requires_grad=False), + torch.tensor( + [[1048, 263, 931, 746]], dtype=torch.long, requires_grad=False + ), ) + dynamic_shapes = { + "input_ids": { + 1: torch.export.Dim("sequence_length", min=1, max=args.seq_len) + } + } xnnpack_quant_config = get_symmetric_quantization_config( is_per_channel=True, is_dynamic=True @@ -59,7 +57,9 @@ def main(args) -> None: xnnpack_quantizer = XNNPACKQuantizer() xnnpack_quantizer.set_global(xnnpack_quant_config) - model = capture_pre_autograd_graph(model, example_inputs) + model = capture_pre_autograd_graph( + model, example_inputs, dynamic_shapes=dynamic_shapes + ) model = prepare_pt2e(model, xnnpack_quantizer) model(*example_inputs) model = convert_pt2e(model, fold_quantize=False) @@ -69,16 +69,17 @@ def main(args) -> None: model = torch.export._trace._export( model, example_inputs, + dynamic_shapes=dynamic_shapes, strict=False, pre_dispatch=False, ) edge_config = get_xnnpack_edge_compile_config() edge_manager = to_edge(model, compile_config=edge_config) - edge_manager = edge_manager.to_backend(XnnpackPartitioner()) + edge_manager = edge_manager.to_backend(XnnpackPartitioner(has_dynamic_shapes=True)) et_program = edge_manager.to_executorch() - with open("phi-3-mini.pte", "wb") as file: + with open(args.output_name, "wb") as file: file.write(et_program.buffer) @@ -91,4 +92,10 @@ def main(args) -> None: default=128, help="Maximum number of tokens including prompt to generate", ) + parser.add_argument( + "-o", + "--output_name", + default="phi-3-mini.pte", + help="Override the output filename of the saved pte model file.", + ) main(parser.parse_args()) diff --git a/examples/models/phi-3-mini/main.cpp b/examples/models/phi-3-mini/main.cpp index 0158c95ad8..7aedcb75b2 100644 --- a/examples/models/phi-3-mini/main.cpp +++ b/examples/models/phi-3-mini/main.cpp @@ -6,85 +6,45 @@ * LICENSE file in the root directory of this source tree. */ -// main.cpp +#include -#include +#include -#include -#include +DEFINE_string( + model_path, + "phi-3-mini.pte", + "File path for model serialized in flatbuffer format."); -#include "sentence_piece_tokenizer.h" +DEFINE_string(tokenizer_path, "tokenizer.bin", "File path for tokenizer."); -using namespace torch::executor; +DEFINE_string(prompt, "Tell me a story", "Prompt."); -// The value of the phi-3-mini `<|endoftext|>` token. -#define ENDOFTEXT_TOKEN 32000 -#define VOCABULARY_SIZE 32064 +DEFINE_double( + temperature, + 0.8f, + "Temperature; Default is 0.8f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic"); -// TODO(lunwenh): refactor and share with llama -void generate( - Module& llm_model, - std::string& prompt, - SentencePieceTokenizer& tokenizer, - size_t max_output_length) { - // Convert the input text into a list of integers (tokens) that represents - // it, using the string-to-token mapping that the model was trained on. - // Each token is an integer that represents a word or part of a word. - std::vector input_tokens = tokenizer.encode(prompt); +DEFINE_int32( + seq_len, + 128, + "Total number of tokens to generate (prompt + output)."); - std::cout << "Generating tokens ..." << std::endl; +int main(int32_t argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); - std::vector output_tokens; + const char* model_path = FLAGS_model_path.c_str(); - for (size_t i = 0; i < max_output_length; i++) { - ManagedTensor tensor_tokens( - input_tokens.data(), - {1, static_cast(input_tokens.size())}, - ScalarType::Long); - std::vector inputs = {tensor_tokens.get_aliasing_tensor()}; + const char* tokenizer_path = FLAGS_tokenizer_path.c_str(); - Result> result_evalue = llm_model.forward(inputs); + const char* prompt = FLAGS_prompt.c_str(); - const auto error = result_evalue.error(); - Tensor logits_tensor = result_evalue.get()[0].toTensor(); - const auto sentence_length = logits_tensor.size(1); - std::vector logits( - logits_tensor.data_ptr() + - (sentence_length - 1) * VOCABULARY_SIZE, - logits_tensor.data_ptr() + sentence_length * VOCABULARY_SIZE); + double temperature = FLAGS_temperature; - // Sample the next token from the logits. - int64_t next_token = - std::max_element(logits.begin(), logits.end()) - logits.begin(); + int32_t seq_len = FLAGS_seq_len; - std::cout << next_token << "\t"; - std::cout.flush(); + ::torch::executor::Runner runner(model_path, tokenizer_path, temperature); - // Break if we reached the end of the text. - if (next_token == ENDOFTEXT_TOKEN) { - break; - } + runner.generate(prompt, seq_len); - output_tokens.push_back(next_token); - - // Update next input. - input_tokens.push_back(next_token); - } - - std::cout << std::endl; - std::cout << tokenizer.decode(output_tokens) << std::endl; -} - -int main() { - // Set up the prompt. This provides the seed text for the model to elaborate. - std::cout << "Enter model prompt: "; - std::string prompt; - std::getline(std::cin, prompt); - - SentencePieceTokenizer tokenizer("tokenizer.model"); - - Module model("phi-3-mini.pte", Module::LoadMode::MmapUseMlockIgnoreErrors); - - const auto max_output_tokens = 128; - generate(model, prompt, tokenizer, max_output_tokens); + return 0; } diff --git a/examples/models/phi-3-mini/runner.cpp b/examples/models/phi-3-mini/runner.cpp new file mode 100644 index 0000000000..6d365bfe36 --- /dev/null +++ b/examples/models/phi-3-mini/runner.cpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include +#include + +namespace torch::executor { + +#define SAMPLER_TOP 0.9f +#define ENDOFTEXT_TOKEN 32000 +#define VOCABULARY_SIZE 32064 + +Runner::Runner( + const std::string& model_path, + const std::string& tokenizer_path, + const float temperature) + : module_(std::make_unique(model_path, Module::LoadMode::File)), + tokenizer_(std::make_unique()), + sampler_(std::make_unique( + VOCABULARY_SIZE, + temperature, + SAMPLER_TOP, + static_cast(std::time(nullptr)))) { + ET_CHECK_MSG( + tokenizer_->load(tokenizer_path) == Error::Ok, + "Failed to load tokenizer at %s", + tokenizer_path.c_str()); + ET_LOG( + Info, + "Created Phi-3-mini runner: model_path=%s, tokenizer_path=%s", + model_path.c_str(), + tokenizer_path.c_str()); +} + +void Runner::generate(const std::string& prompt, std::size_t max_seq_len) { + auto encode_res = tokenizer_->encode(prompt, 0, 0); + ET_CHECK_MSG( + encode_res.error() == Error::Ok, "Failed to encode %", prompt.c_str()); + auto input_tokens = encode_res.get(); + + std::cout << "Prefilling tokens ..." << std::endl; + for (auto token : input_tokens) { + std::cout << token << " "; + } + std::cout << std::endl; + std::cout.flush(); + auto prev_token = input_tokens.back(); + auto current_token = prefill(input_tokens); + + std::cout << "Generating tokens ..." << std::endl; + std::cout << tokenizer_->decode(prev_token, current_token).get(); + std::cout.flush(); + + std::size_t seq_len = input_tokens.size() + 1; + + while (current_token != ENDOFTEXT_TOKEN && seq_len < max_seq_len) { + prev_token = current_token; + current_token = run_model_step(current_token); + std::cout << tokenizer_->decode(prev_token, current_token).get(); + std::cout.flush(); + + ++seq_len; + } + + std::cout << std::endl; +} + +uint64_t Runner::logits_to_token(const exec_aten::Tensor& logits_tensor) { + return sampler_->sample(logits_tensor.data_ptr()); +} + +uint64_t Runner::prefill(std::vector& tokens) { + ManagedTensor input_tokens( + tokens.data(), + {1, static_cast(tokens.size())}, + ScalarType::Long); + std::vector inputs = {input_tokens.get_aliasing_tensor()}; + + auto result = module_->forward(inputs); + ET_CHECK_MSG(result.error() == Error::Ok, "Failed to prefill tokens"); + + return logits_to_token(result.get()[0].toTensor()); +} + +uint64_t Runner::run_model_step(uint64_t token) { + ManagedTensor input_token(&token, {1, 1}, ScalarType::Long); + std::vector inputs = {input_token.get_aliasing_tensor()}; + + auto result = module_->forward(inputs); + ET_CHECK_MSG( + result.error() == Error::Ok, + "Failed to run forward() for token %" PRIu64, + token); + + return logits_to_token(result.get()[0].toTensor()); +} + +} // namespace torch::executor diff --git a/examples/models/phi-3-mini/runner.h b/examples/models/phi-3-mini/runner.h new file mode 100644 index 0000000000..15022751a8 --- /dev/null +++ b/examples/models/phi-3-mini/runner.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// A simple phi-3-mini runner that includes preprocessing and post processing +// logic. The module takes in a string as input and emits a string as output. + +#pragma once + +#include +#include + +#include +#include +#include +#include + +namespace torch::executor { + +class Runner { + public: + explicit Runner( + const std::string& model_path, + const std::string& tokenizer_path, + const float temperature = 0.8f); + + /** + * Generates response for a given prompt. + * + * @param[in] prompt The prompt to generate a response for. + * @param[in] max_seq_len The maximum length of the sequence to generate, + * including prompt. + */ + void generate(const std::string& prompt, std::size_t max_seq_len); + + private: + uint64_t logits_to_token(const exec_aten::Tensor& logits_tensor); + uint64_t prefill(std::vector& tokens); + uint64_t run_model_step(uint64_t token); + + std::unique_ptr module_; + std::unique_ptr tokenizer_; + std::unique_ptr sampler_; +}; + +} // namespace torch::executor diff --git a/examples/models/phi-3-mini/sentence_piece_tokenizer.h b/examples/models/phi-3-mini/sentence_piece_tokenizer.h deleted file mode 100644 index 3428a30c83..0000000000 --- a/examples/models/phi-3-mini/sentence_piece_tokenizer.h +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include - -// TODO(lunwenh): Add unit tests -class SentencePieceTokenizer { - public: - SentencePieceTokenizer(const std::string& filePath) { - const auto status = processor_.Load(filePath); - if (!status.ok()) { - std::ostringstream errorMessageStream; - errorMessageStream << "Failed to load SentencePiece model from " - << filePath << " with error " << status.ToString(); - throw std::runtime_error(errorMessageStream.str()); - } - processor_.SetEncodeExtraOptions("bos"); - } - - std::vector encode(const std::string& piece) { - std::vector ids; - processor_.Encode(piece, &ids); - std::vector idsLong(ids.begin(), ids.end()); - return idsLong; - } - - std::string decode(const std::vector& ids) { - std::vector idsInt(ids.begin(), ids.end()); - std::string piece; - processor_.Decode(idsInt, &piece); - return piece; - } - - private: - sentencepiece::SentencePieceProcessor processor_; -}; diff --git a/examples/qualcomm/CMakeLists.txt b/examples/qualcomm/CMakeLists.txt index 94aae08de8..6bfbdea058 100644 --- a/examples/qualcomm/CMakeLists.txt +++ b/examples/qualcomm/CMakeLists.txt @@ -7,9 +7,6 @@ set(CMAKE_CXX_STANDARD 17) # qnn_executor_runner: Like executor_runner but with QNN -if(NOT ${ANDROID}) - message(FATAL_ERROR "Not building Android, quitting...") -endif() cmake_minimum_required(VERSION 3.19) project(qualcomm_runner_example) diff --git a/examples/qualcomm/executor_runner/qnn_executor_runner.cpp b/examples/qualcomm/executor_runner/qnn_executor_runner.cpp index 0ae6e4e6e4..7871cafc24 100644 --- a/examples/qualcomm/executor_runner/qnn_executor_runner.cpp +++ b/examples/qualcomm/executor_runner/qnn_executor_runner.cpp @@ -31,6 +31,7 @@ #include +#include #include #include @@ -202,10 +203,8 @@ int main(int argc, char** argv) { // be used by a single thread at at time, but it can be reused. // torch::executor::ETDumpGen etdump_gen = torch::executor::ETDumpGen(); - // TODO: So far we have issues with etdump_gen during load_method. Enable it - // after the issues are fixed. Result method = - program->load_method(method_name, &memory_manager, nullptr); + program->load_method(method_name, &memory_manager, &etdump_gen); ET_CHECK_MSG( method.ok(), "Loading of method %s failed with status 0x%" PRIx32, diff --git a/examples/qualcomm/llama2/qaihub_runner/runner.cpp b/examples/qualcomm/llama2/qaihub_runner/runner.cpp index 32a89c9700..2f8a01f4e9 100644 --- a/examples/qualcomm/llama2/qaihub_runner/runner.cpp +++ b/examples/qualcomm/llama2/qaihub_runner/runner.cpp @@ -23,7 +23,9 @@ #include #include +#if defined(__aarch64__) #include "arm_neon.h" +#endif namespace torch { namespace executor { @@ -108,9 +110,11 @@ Error Runner::load() { int32_t Runner::logitsToToken(const Tensor& logits_tensor) { static std::vector logits_f(vocab_size_); + const uint16_t* logits = logits_tensor.data_ptr(); + +#if defined(__aarch64__) static int32x4_t offset = vmovq_n_s32(logits_offset_); static float32x4_t scale = vmovq_n_f32(logits_scale_); - const uint16_t* logits = logits_tensor.data_ptr(); // dequantize for (int i = 0; i < vocab_size_; i += 4) { const uint16_t* in = logits + i; @@ -121,6 +125,13 @@ int32_t Runner::logitsToToken(const Tensor& logits_tensor) { float32x4_t shifted_f = vcvtq_f32_s32(shifted); vst1q_f32(out, vmulq_f32(shifted_f, scale)); } +#else + // dequantize + for (int i = 0; i < vocab_size_; i++) { + logits_f[i] = (logits[i] - logits_offset_) * logits_scale_; + } +#endif + return sampler_->sample(logits_f.data()); } diff --git a/examples/qualcomm/scripts/utils.py b/examples/qualcomm/scripts/utils.py index 1e4b1c6968..8211dc4581 100755 --- a/examples/qualcomm/scripts/utils.py +++ b/examples/qualcomm/scripts/utils.py @@ -106,24 +106,18 @@ def push(self, inputs=None, input_list=None, files=None): f"{self.build_path}/backends/qualcomm/libqnn_executorch_backend.so", ] - # prepare input list - if input_list is not None: - input_list_file = f"{self.working_dir}/{self.input_list_filename}" - with open(input_list_file, "w") as f: - f.write(input_list) - f.flush() - artifacts.append(input_list_file) + input_list_file, input_files = generate_inputs( + self.working_dir, self.input_list_filename, inputs, input_list + ) + # prepare input list + artifacts.append(input_list_file) for artifact in artifacts: self._adb(["push", artifact, self.workspace]) # input data - if inputs is not None: - for idx, data in enumerate(inputs): - for i, d in enumerate(data): - file_name = f"{self.working_dir}/input_{idx}_{i}.raw" - d.detach().numpy().tofile(file_name) - self._adb(["push", file_name, self.workspace]) + for file_name in input_files: + self._adb(["push", file_name, self.workspace]) # custom files if files is not None: @@ -437,3 +431,25 @@ def parse_skip_delegation_node(args): print("Skipping following node ops: ", skip_node_op_set) return skip_node_id_set, skip_node_op_set + + +def generate_inputs(dest_path: str, file_name: str, inputs=None, input_list=None): + input_list_file = "" + input_files = [] + + # Prepare input list + if input_list is not None: + input_list_file = f"{dest_path}/{file_name}" + with open(input_list_file, "w") as f: + f.write(input_list) + f.flush() + + # Prepare input data + if inputs is not None: + for idx, data in enumerate(inputs): + for i, d in enumerate(data): + file_name = f"{dest_path}/input_{idx}_{i}.raw" + d.detach().numpy().tofile(file_name) + input_files.append(file_name) + + return input_list_file, input_files diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py index bb5bdc9aa7..a82b947cec 100644 --- a/exir/_serialize/_program.py +++ b/exir/_serialize/_program.py @@ -309,7 +309,7 @@ def _extract_delegate_segments( def _extract_constant_segment( constant_buffer: List[Buffer], - tensor_alignment: int, + tensor_alignment: Optional[int] = None, ) -> Tuple[Cord, List[int]]: """Copies the tensors from the provided list into a Cord and tracks the offsets of each tensor. @@ -329,7 +329,11 @@ def _extract_constant_segment( buffer = constant_buffer[i] constant_segment_data.append(buffer.storage) buffer_length = len(buffer.storage) - pad_length = _padding_required(buffer_length, tensor_alignment) + pad_length = ( + _padding_required(buffer_length, tensor_alignment) + if tensor_alignment is not None + else 0 + ) if i < len(constant_buffer) - 1: constant_segment_data.append(b"\x00" * pad_length) constant_segment_offsets.append(current_offset) @@ -341,6 +345,7 @@ def _extract_constant_segment( def serialize_pte_binary( program: Program, *, + mutable_data: Optional[List[Buffer]] = None, extract_delegate_segments: bool = False, extract_constant_segment: bool = False, segment_alignment: int = 4096, @@ -396,6 +401,21 @@ def serialize_pte_binary( # Add to the aggregate segments cord. segments.append(constant_segment_data) + if mutable_data is not None: + mutable_segment_data, mutable_segment_offsets = _extract_constant_segment( + mutable_data, + tensor_alignment=None, # data is copied at Method load so no need to align. + ) + if len(mutable_segment_data) > 0: + # Update program.mutable_segment_data with constant subsegment offset information. + program.mutable_data_segments = [ + SubsegmentOffsets( + segment_index=len(segments), offsets=mutable_segment_offsets + ), + ] + # Add to the aggregate segments cord. + segments.append(mutable_segment_data) + if extract_delegate_segments: _extract_delegate_segments(program, segments) diff --git a/exir/backend/backend_api.py b/exir/backend/backend_api.py index 2e7f1b3cdf..25c793287d 100644 --- a/exir/backend/backend_api.py +++ b/exir/backend/backend_api.py @@ -269,8 +269,7 @@ def _partition_and_lower_one_graph_module( if node.name in toplevel_signature.inputs_to_buffers: # Delete the consumed buffers - buffer_name = toplevel_signature.inputs_to_buffers.pop(node.name) - toplevel_signature.buffers.remove(buffer_name) + buffer_name = toplevel_signature.inputs_to_buffers.get(node.name) if buffer_name in owning_program.state_dict: owning_program.state_dict.pop(buffer_name) else: @@ -278,8 +277,7 @@ def _partition_and_lower_one_graph_module( tagged_graph_module.graph.erase_node(node) elif node.name in toplevel_signature.inputs_to_parameters: # Delete the consumed parameters - param_name = toplevel_signature.inputs_to_parameters.pop(node.name) - toplevel_signature.parameters.remove(param_name) + param_name = toplevel_signature.inputs_to_parameters.get(node.name) owning_program.state_dict.pop(param_name) tagged_graph_module.graph.erase_node(node) diff --git a/exir/backend/canonical_partitioners/TARGETS b/exir/backend/canonical_partitioners/TARGETS index 06b028e259..22a6e2c51b 100644 --- a/exir/backend/canonical_partitioners/TARGETS +++ b/exir/backend/canonical_partitioners/TARGETS @@ -36,3 +36,20 @@ runtime.python_library( "//executorch/exir/backend:partitioner", ], ) + +runtime.python_library( + name = "config_partitioner_lib", + srcs = [ + "config_partitioner.py", + ], + visibility = [ + "//executorch/...", + "//executorch/exir/backend/...", + "//executorch/test/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir/backend:partitioner", + ], +) diff --git a/exir/backend/canonical_partitioners/config_partitioner.py b/exir/backend/canonical_partitioners/config_partitioner.py new file mode 100644 index 0000000000..1a9bcc33e8 --- /dev/null +++ b/exir/backend/canonical_partitioners/config_partitioner.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from typing import Callable, Dict, Iterable, List, Optional, Tuple + +import torch +from executorch.exir.backend.backend_details import ExportedProgram +from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( + generate_partitions_from_list_of_nodes, +) +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from torch.fx.passes.infra.partitioner import Partition + + +def format_target_name(target_name: str) -> str: + """ + We remove the dialect name space from the target name. We generally + do not care for the op dialect specific name space ("aten.", "quantized_decomposed.") + but rather the op itself. Se remove the dialect-specific name space from the + name and return the op name itself + """ + names = target_name.split(".") + if len(names) > 2: + names.pop(0) + + return ".".join(names) + + +class PartitionerConfig(ABC): + """ + Class used to represent a PartitionerConfig. + + PartitionerConfig is used by config-based partitioner to partition identify + nodes to be delegated. User overrides the methods: + - target_name + - check_constraints + - get_partition + - get_original_aten + + The Config-Based Partitioner then uses these overridden methods to find nodes + which match target_name, check_constraints, and if true, returns the partition + (list of nodes) which represent the node and its dependencies. get_original_aten + is used to halt decomposition to edge_dialect if the node can be delegated by + the specified backend. + """ + + @classmethod + @property + @abstractmethod + def target_name(cls) -> str: + """ + Target name for this partitioner config. When the Config-Based Partitioner + encounters a node with a matching target name, it uses this config's methods to + checks the constraints of this node and get all of its dependencies. + the target name is formatted to remove the dialect-specific name space. + i.e. linear.default + """ + pass + + @abstractmethod + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + """ + Takes in a node and returns true if the node is partitionable. + + Args: + node: Node to be partitioned + ep: Exported program of the graph module + Returns: + True or False whether this node is partitionable + """ + pass + + @abstractmethod + def get_original_aten(self) -> Optional[torch._ops.OpOverload]: + """ + Returns the original aten dialect op, this is for to_edge_transform_and_lower + API, so that this config can be used to stop decomposition of this original + aten op + """ + pass + + @abstractmethod + def get_partition( + self, node: torch.fx.Node, ep: ExportedProgram + ) -> List[torch.fx.Node]: + """ + Returns the partitioned nodes from get_node_and_deps, but also labels them + with the name of the PartitionerConfig class which return this set of nodes. + + Returns an empty list of the node and deps do not satisfy the checked constraints + """ + pass + + +class ConfigerationBasedPartitioner(Partitioner): + def __init__( + self, + delegation_spec: DelegationSpec, + partitioner_configs: Iterable[PartitionerConfig], + ): + """ + Configeration based partitioner. We supply the partitioner with a set of configerations + which describe the node type, constraints, and any dependencies required to be partitioned + with the node. We use the configerations to partition the graph module. + """ + super().__init__() + # Initialize partitioner configs map {"target_name": PartitionerConfig} + self.target_partitioner_configs: Dict[str, PartitionerConfig] = {} + for config in partitioner_configs: + target_name = config.target_name + if target_name in self.target_partitioner_configs: + other_config = self.target_partitioner_configs[target_name] + raise RuntimeError( + f"PartitionerConfig: {config} and {other_config} have the same target_name: {target_name}" + ) + else: + self.target_partitioner_configs[target_name] = config + + self.delegation_spec = delegation_spec + + def ops_to_not_decompose( + self, + ep: ExportedProgram, + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + def filter_fn(node: torch.fx.Node) -> bool: + """ + The partitioner configs we initialize with have check_constraints function, + to determine if this op is indeed partitionable. We grab the check_constraint + function of this op from the config and use it to filter. + """ + if node.op != "call_function": + return False + target_name = format_target_name(node.target.__name__) # pyre-ignore + + if target_name in self.target_partitioner_configs: + config = self.target_partitioner_configs[target_name] + # only filter_fn if config has original_aten + if config.get_original_aten(): + return self.target_partitioner_configs[ + target_name + ].check_constraints(node, ep) + + return False + + # Get list of original aten targets which we do not want to decomp + do_not_decomp = [] + for node_config in self.target_partitioner_configs.values(): + original_aten = node_config.get_original_aten() + if original_aten is not None: + do_not_decomp.append(original_aten) + + return (do_not_decomp, filter_fn) + + def get_matched_nodes_from_configs( + self, ep: ExportedProgram + ) -> List[List[torch.fx.Node]]: + # gather supported nodes + matched_nodes = [] + gm = ep.graph_module + for node in gm.graph.nodes: + if node.op == "call_function": + target = format_target_name(node.target.__name__) + if target in self.target_partitioner_configs: + node_config = self.target_partitioner_configs[target] + if node_config.check_constraints(node, ep): + matched_nodes.append(node_config.get_partition(node, ep)) + + return matched_nodes + + def generate_partitions(self, ep: ExportedProgram) -> List[Partition]: + matched_nodes = self.get_matched_nodes_from_configs(ep) + # create partitions + partitions = generate_partitions_from_list_of_nodes( + ep.graph_module, + matched_nodes, + ) + return partitions + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + partitions = self.generate_partitions(exported_program) + + # tag nodes + partition_tags: Dict[str, DelegationSpec] = {} + for partition in partitions: + for node in partition.nodes: + delegation_tag = f"tag{partition.id}" + if "delegation_tag" in node.meta: + raise RuntimeError( + f"Partitioner Erro found node {node} in partition {node.meta['delegation_tag']} and partition {delegation_tag}" + ) + node.meta["delegation_tag"] = delegation_tag + partition_tags[delegation_tag] = self.delegation_spec + + return PartitionResult( + tagged_exported_program=exported_program, partition_tags=partition_tags + ) diff --git a/exir/backend/test/TARGETS b/exir/backend/test/TARGETS index 7a6d68aef1..ed58b06b3d 100644 --- a/exir/backend/test/TARGETS +++ b/exir/backend/test/TARGETS @@ -269,7 +269,6 @@ python_unittest( "test_utils.py", ], deps = [ - "fbsource//third-party/pypi/pandas:pandas", ":op_partitioner_demo", "//caffe2:torch", "//executorch/exir:lib", diff --git a/exir/backend/test/demos/rpc/executor_backend_preprocess.py b/exir/backend/test/demos/rpc/executor_backend_preprocess.py index aa286af300..0e5b8a8d3d 100644 --- a/exir/backend/test/demos/rpc/executor_backend_preprocess.py +++ b/exir/backend/test/demos/rpc/executor_backend_preprocess.py @@ -4,15 +4,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + from typing import final, List -from executorch.exir import ExirExportedProgram from executorch.exir.backend.backend_details import ( BackendDetails, ExportedProgram, PreprocessResult, ) from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.program._program import EdgeProgramManager @final @@ -23,10 +25,8 @@ def preprocess( compile_specs: List[CompileSpec], ) -> PreprocessResult: return PreprocessResult( - processed_bytes=ExirExportedProgram( - exported_program=edge_program, - # Indicates that edge_program is already in edge dialect. - after_to_edge_passes=True, + processed_bytes=EdgeProgramManager( + edge_programs=edge_program, ) .to_executorch() .buffer, diff --git a/exir/emit/_emit_program.py b/exir/emit/_emit_program.py index 6b545e0a7d..0aebab649e 100644 --- a/exir/emit/_emit_program.py +++ b/exir/emit/_emit_program.py @@ -17,7 +17,8 @@ _TopLevelEmitter, ) from executorch.exir.error import ExportError, ExportErrorType -from executorch.exir.schema import Program, SubsegmentOffsets + +from executorch.exir.schema import Buffer, Program, SubsegmentOffsets from executorch.exir.version import EXECUTORCH_SCHEMA_VERSION from torch.export.exported_program import ExportedProgram, OutputKind from torch.utils import _pytree as pytree @@ -44,6 +45,8 @@ class EmitterOutput: str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]] ] + mutable_data: Optional[List[Buffer]] + def _remove_non_user_outputs(exported_program: ExportedProgram) -> torch.fx.GraphModule: gm = exported_program.graph_module @@ -156,5 +159,11 @@ def emit_program( segments=[], # Subsegment offsets may be added at serialization time. constant_segment=SubsegmentOffsets(segment_index=0, offsets=[]), + mutable_data_segments=None, # Will be filled in during serialization + ), + mutable_data=( + program_state.mutable_buffer + if len(program_state.mutable_buffer) > 1 + else None ), ) diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index f57ed15d10..f51b4113c8 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -110,8 +110,11 @@ class _ProgramState: # emitted graph modules, not any weights emitted from itself. This should speed up the lookup, # from O(N) to O(1) cached_spec_hash_values: Dict[str, int] = field(default_factory=dict) + cached_spec_mutable_hash_values: Dict[str, int] = field(default_factory=dict) # The 0 index is reserved to be pointed to by non-constant tensors, so add an empty placeholder. constant_buffer: List[Buffer] = field(default_factory=lambda: [Buffer(storage=b"")]) + # The 0 index is reserved to be pointed to by non-constant tensors, so add an empty placeholder. + mutable_buffer: List[Buffer] = field(default_factory=lambda: [Buffer(storage=b"")]) # Delegate data stored directly in the flatbuffer. Pointed to by BackendDelegateDataReference, # and should be copied to Program.backend_delegate_data. backend_delegate_data: List[BackendDelegateInlineData] = field(default_factory=list) @@ -326,68 +329,83 @@ def _emit_list(self, val: List[_Argument], val_type: _SchemaType) -> EValue: def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue: """Constructs an EValue from the given TensorSpec.""" - if not spec.const: - if spec.mem_id is not None: - # Tensor is an activation. - self._internal_assert_emitter( - isinstance(spec.mem_id, int) and spec.mem_id >= 0, - self.node, - "Non-const tensor should be an activation tensor", - ) - self._internal_assert_emitter( - isinstance(spec.mem_offset, int) and spec.mem_offset >= 0, - self.node, - "Non-const tensor should be an activation tensor", - ) - allocation_info = make_allocation_info(spec.mem_id, spec.mem_offset) - else: - # Tensor is an input/placeholder. - allocation_info = None + allocation_info = None + buffer_idx = 0 - # For non-constant tensors, constant_buffer = 0. - return EValue(make_tensor_value(0, allocation_info, spec)) + # Need to memory plan + # Some users set mem_id on all tensors and then rely on the + # default algos to set offsets, so need to check both. + if spec.mem_id is not None and spec.mem_offset is not None: + # Tensor is an activation. + self._internal_assert_emitter( + isinstance(spec.mem_id, int) and spec.mem_id >= 0, + self.node, + f"Non-const tensor should be an activation tensor: mem_id {spec.mem_id}", + ) - # Constant tensor. Reserve a buffer for the constant tensor. - spec_array_type = ( - ctypes.c_char * typing.cast(torch.UntypedStorage, spec.storage).nbytes() - ) + self._internal_assert_emitter( + isinstance(spec.mem_offset, int) and spec.mem_offset >= 0, + self.node, + f"Non-const tensor should be an activation tensor: mem_offset {spec.mem_offset}", + ) + allocation_info = make_allocation_info(spec.mem_id, spec.mem_offset) - buffer_data = ( - bytes( - ctypes.cast( - typing.cast(torch.UntypedStorage, spec.storage).data_ptr(), - ctypes.POINTER(spec_array_type), - ).contents + if spec.const: + # Tensor with a blob we need to serialize. May not actually be constant at runtime + # if it's a weight with an associated gradient + spec_array_type = ( + ctypes.c_char * typing.cast(torch.UntypedStorage, spec.storage).nbytes() ) - if spec.allocated_memory != 0 - else b"" - ) - hashed = hashlib.sha256(buffer_data).hexdigest() + buffer_data = ( + bytes( + ctypes.cast( + typing.cast(torch.UntypedStorage, spec.storage).data_ptr(), + ctypes.POINTER(spec_array_type), + ).contents + ) + if spec.allocated_memory != 0 + else b"" + ) - buffer_idx = self.program_state.cached_spec_hash_values.get(hashed, -1) + hashed = hashlib.sha256(buffer_data).hexdigest() - # Haven't seen this constant before - if buffer_idx == -1: - # Update buffer_idx to point to the end of the list where we are adding the new buffer. - buffer = Buffer(storage=buffer_data) - buffer_idx = len(self.program_state.constant_buffer) - self.program_state.allocated_specs.append(spec) - # +1 because the first buffer location is reserved - self.program_state.cached_spec_hash_values[hashed] = buffer_idx - self.program_state.constant_buffer.append(buffer) + if allocation_info: + buffer_idx = self.program_state.cached_spec_mutable_hash_values.get( + hashed, -1 + ) + else: + buffer_idx = self.program_state.cached_spec_hash_values.get(hashed, -1) + + # Haven't seen this constant before + if buffer_idx == -1: + # Update buffer_idx to point to the end of the list where we are adding the new buffer. + buffer = Buffer(storage=buffer_data) + self.program_state.allocated_specs.append(spec) + # +1 because the first buffer location is reserved + + if allocation_info: + buffer_idx = len(self.program_state.mutable_buffer) + self.program_state.cached_spec_mutable_hash_values[hashed] = ( + buffer_idx + ) + self.program_state.mutable_buffer.append(buffer) + else: + buffer_idx = len(self.program_state.constant_buffer) + self.program_state.cached_spec_hash_values[hashed] = buffer_idx + self.program_state.constant_buffer.append(buffer) - if spec.const and spec.nbytes() != len(buffer_data): - raise InternalError( - self._emit_node_specific_error( - self.node, - f"Tensor spec has buffer of size {len(buffer_data)}, but expected nbytes of {spec.nbytes()}", + if spec.const and spec.nbytes() != len(buffer_data): + raise InternalError( + self._emit_node_specific_error( + self.node, + f"Tensor spec has buffer of size {len(buffer_data)}, but expected nbytes of {spec.nbytes()}", + ) ) - ) # For constant tensors, allocation_info = None. - return EValue(make_tensor_value(buffer_idx, None, spec)) + return EValue(make_tensor_value(buffer_idx, allocation_info, spec)) def _get_list_tuple_jit_type( self, val: Union[Tuple[_Argument], List[_Argument]] @@ -770,7 +788,7 @@ def forward(self, x,y): # Increment iter_idx to mark that we have completed an iteration. op_index, op = self._get_operator( name="executorch_prim::add", - overload="int", + overload="Scalar", ) kernel = Instruction( KernelCall( @@ -787,7 +805,7 @@ def forward(self, x,y): # section. op_index, op = self._get_operator( name="executorch_prim::eq", - overload="int", + overload="Scalar", ) kernel = Instruction( KernelCall( @@ -809,7 +827,7 @@ def forward(self, x,y): # Reset iter_idx in case we plan to run the model again. op_index, op = self._get_operator( name="executorch_prim::sub", - overload="int", + overload="Scalar", ) kernel = Instruction( KernelCall( diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index bb670e2ad6..f1b980a9ae 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -722,6 +722,43 @@ def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: "executorch_prim::sub", ) + def test_load_emit_map(self) -> None: + class Foo(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + return control_flow.map(map_fn, x, y) + + f = Foo() + + inputs = (torch.ones(4, 4), torch.ones(4)) + module = to_edge( + export(f, inputs), + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), + ) + _load_for_executorch_from_buffer(module.to_executorch().buffer) + + def test_run_emit_map(self) -> None: + class Foo(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + return control_flow.map(map_fn, x, y) + + f = Foo() + + inputs = (torch.ones(4, 4), torch.ones(4)) + module = to_edge( + export(f, inputs), + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), + ) + buffer = module.to_executorch().buffer + loaded_model = _load_for_executorch_from_buffer(buffer) + outputs = loaded_model(inputs)[0] + torch.allclose(outputs, f(*inputs)) + def test_dim_order(self) -> None: class SimpleLinear(torch.nn.Module): def __init__(self) -> None: diff --git a/exir/memory_planning.py b/exir/memory_planning.py index 9e722f48d5..859bd06901 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -427,7 +427,11 @@ def collect_specs_from_nodes( # noqa: C901 continue if ignore_graph_output and spec in graph_output_tensors: continue - if ignore_const and spec.const: + if ( + ignore_const + and spec.const + and not node.meta.get("weight_has_gradient", False) + ): continue if dedup: if spec in unique_spec: @@ -768,7 +772,9 @@ def apply_algo( ) insert_calls_to_free(graph_module, specs) - def handle_submodule(submodule_nd: torch.fx.Node) -> None: + def handle_submodule( + submodule_nd: torch.fx.Node, alloc_graph_input: bool = False + ) -> None: nonlocal bufsizes assert submodule_nd.op == "get_attr" submodule = getattr(graph_module, submodule_nd.target) @@ -780,7 +786,7 @@ def handle_submodule(submodule_nd: torch.fx.Node) -> None: submodule, alignment, graph_signature, - alloc_graph_input=False, + alloc_graph_input=alloc_graph_input, alloc_graph_output=True, ) submodule.meta.update({"non_const_buffer_sizes": bufsizes}) @@ -795,7 +801,9 @@ def handle_submodule(submodule_nd: torch.fx.Node) -> None: # TODO: Add test coverage for map operator once dynamo tracing is # fully supported for this. T142287208 for map_node in get_map_nodes(graph_module): - handle_submodule(typing.cast(torch.fx.Node, map_node.args[0])) + handle_submodule( + typing.cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True + ) graph_module.meta.update({"non_const_buffer_sizes": bufsizes}) diff --git a/exir/passes/TARGETS b/exir/passes/TARGETS index 7dcde950b9..4e59af26ea 100644 --- a/exir/passes/TARGETS +++ b/exir/passes/TARGETS @@ -25,6 +25,7 @@ python_library( ":spec_prop_pass", ":sym_shape_eval_pass", ":sym_to_tensor_pass", + ":weights_to_outputs_pass", "//caffe2:torch", "//executorch/exir:common", "//executorch/exir:control_flow", @@ -62,6 +63,16 @@ python_library( ], ) +python_library( + name = "weights_to_outputs_pass", + srcs = [ + "weights_to_outputs_pass.py", + ], + deps = [ + "//caffe2:torch", + ], +) + python_library( name = "const_prop_pass", srcs = [ diff --git a/exir/passes/__init__.py b/exir/passes/__init__.py index 594de3f79a..99507ccdc9 100644 --- a/exir/passes/__init__.py +++ b/exir/passes/__init__.py @@ -54,6 +54,7 @@ from executorch.exir.passes.spec_prop_pass import SpecPropPass from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass from executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass +from executorch.exir.passes.weights_to_outputs_pass import weights_to_outputs_pass from torch import fx from torch._subclasses import FakeTensor from torch.fx.passes.infra.pass_base import PassBase, PassResult @@ -69,6 +70,7 @@ "MemoryPlanningPass", "HintBasedSymShapeEvalPass", "insert_write_back_for_buffers_pass", + "weights_to_outputs_pass", ] Argument = Optional[ diff --git a/exir/passes/dim_order_ops_registry.py b/exir/passes/dim_order_ops_registry.py index 7fed005b3c..27fc03f941 100644 --- a/exir/passes/dim_order_ops_registry.py +++ b/exir/passes/dim_order_ops_registry.py @@ -45,3 +45,15 @@ def _to_dim_order_copy_out_impl(*args, **kwargs): DimOrderOpsMap = { "aten._to_copy.default": exir_ops.edge.dim_order_ops._to_dim_order_copy.default, } + +""" +Defines a map of aten or edge ops to the corresponding memory format ops for quick lookup +""" +MemoryFormatOpsMap = { + "dim_order_ops._to_dim_order_copy.default": exir_ops.edge.aten._to_copy.default, +} + +# If we are replacing an aten op with a dim_order op, we must have a 1:1 mapping through these dicts. +assert len(DimOrderOpsMap) == len(MemoryFormatOpsMap) + +# TODO stricter check for 1:1 mapping diff --git a/exir/passes/memory_format_ops_pass.py b/exir/passes/memory_format_ops_pass.py index 5a3c0f3a91..32678bf408 100644 --- a/exir/passes/memory_format_ops_pass.py +++ b/exir/passes/memory_format_ops_pass.py @@ -9,13 +9,19 @@ import torch from executorch.exir.dialects.edge._ops import EdgeOpOverload -from executorch.exir.dim_order_utils import get_dim_order +from executorch.exir.dim_order_utils import get_dim_order, get_memory_format from executorch.exir.pass_base import ExportPass, ProxyValue -from executorch.exir.passes.dim_order_ops_registry import DimOrderOpsMap +from executorch.exir.passes.dim_order_ops_registry import ( + DimOrderOpsMap, + MemoryFormatOpsMap, +) logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) +# TODO - these passes are too specialized on a single to_copy op. +# We should be able to replace (or revert) any of the dim_order ops in the future. + class MemoryFormatOpsPass(ExportPass): """ @@ -53,7 +59,55 @@ def call_operator(self, op, args, kwargs, meta): f" _to_dim_order_copy = dim_order: {nkwargs['dim_order']}" ) - t = DimOrderOpsMap[op.__name__] + t = DimOrderOpsMap.get(op.__name__, None) + assert t is not None, f"{op.__name__} not found in DimOrderOpsMap" + + return super().call_operator( + t, + args, + nkwargs, + meta, + ) + + +class DimOrderOpsRevertPass(ExportPass): + """ + This pass is to revert the dim_order ops back to the memory format ops. + """ + + def call_operator(self, op, args, kwargs, meta): + if not (isinstance(op, EdgeOpOverload) and op.__name__ in MemoryFormatOpsMap): + return super().call_operator( + op, + args, + kwargs, + meta, + ) + + # new kwargs with dim_order, and no memory_format for the new op + nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable + + # can always get the shape, assuming rank is specialized + if isinstance(args[0], ProxyValue) and args[0].is_tensor(): + ndim = args[0].to_tensor().dim() + elif isinstance(args[0], torch.Tensor): + ndim = args[0].dim() + else: + assert 0, f"Expecting a Tensor or a ProxyValue buy got {type(args[0])}" + + # get the "to" memory format for the EdgeOp + default_dim_order = list(range(ndim)) + dim_order = nkwargs.pop("dim_order", default_dim_order) + + nkwargs["memory_format"] = get_memory_format(dim_order) + + logger.debug( + f" _to_dim_order_copy = dim_order: {dim_order}." + f"_to_copy = rank: {ndim}, memory_format: {nkwargs['memory_format']}." + ) + + t = MemoryFormatOpsMap.get(op.__name__, None) + assert t is not None, f"{op.__name__} not found in MemoryFormatOpsMap" return super().call_operator( t, diff --git a/exir/passes/remove_mixed_type_operators.py b/exir/passes/remove_mixed_type_operators.py index 93d6689ae1..701a8269f1 100644 --- a/exir/passes/remove_mixed_type_operators.py +++ b/exir/passes/remove_mixed_type_operators.py @@ -23,6 +23,7 @@ def call_operator(self, op, args, kwargs, meta: NodeMetadata): # noqa: C901 promotion_type_allow_list = { torch.ops.aten.add.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, torch.ops.aten.mul.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + torch.ops.aten.div.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, torch.ops.aten.minimum.default: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, } diff --git a/exir/passes/weights_to_outputs_pass.py b/exir/passes/weights_to_outputs_pass.py new file mode 100644 index 0000000000..216830c2e6 --- /dev/null +++ b/exir/passes/weights_to_outputs_pass.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from torch.export import ExportedProgram +from torch.export.exported_program import OutputKind, OutputSpec, TensorArgument + + +def weights_to_outputs_pass( + exported_program: ExportedProgram, +) -> ExportedProgram: + """ + This pass is for training graphs with gradients returned. It flags the weights as having a gradient attached, + and appends them to the outputs in order to make the weights easier to handle in memory planning and the emitter. + + Args: + exported_program: The ExportedProgram to update. + + Returns: + The modified ExportedProgram. + """ + if ( + len([node for node in exported_program.graph.nodes if node.op == "placeholder"]) + == 0 + ): + return exported_program + + gs = exported_program.graph_signature + gm = exported_program.graph_module + + # Check for/ get gradients + grad_targets = [ + spec.target + for spec in gs.output_specs + if spec.kind == OutputKind.GRADIENT_TO_PARAMETER + ] + + # If no gradients, return + if len(grad_targets) == 0: + return exported_program + + inputs_to_params = gs.inputs_to_parameters + + # Get output node + output_node = None + for node in gm.graph.nodes: + if node.op == "output": + output_node = node + break + assert output_node is not None + + # Get place holder nodes with gradients + placeholder_nodes = [ + node + for node in gm.graph.nodes + if node.op == "placeholder" and node.target in inputs_to_params.keys() + ] + + # Flag these placeholder nodes as having a gradient attached so that memory planning will operate on them. + for node in placeholder_nodes: + node.meta["weight_has_gradient"] = True + + # add to output node + new_output_nodes = [] + new_output_nodes.extend(output_node.args[0]) + new_output_nodes.extend(placeholder_nodes) + # Remove old outputs + new_output = gm.graph.output(tuple(new_output_nodes)) + output_node.replace_all_uses_with(new_output) + gm.graph.erase_node(output_node) + + # add to output signature + for node in placeholder_nodes: + gs.output_specs.append( + OutputSpec( + OutputKind.TOKEN, # This is a hack. We are returning the raw weights here to make it easier for memory + # planning and the emitter. There is no outputkind.Parameter so I am using TOKEN which is currently unused in Edge. + TensorArgument(node.target), + None, + ) + ) + + # Cleanup the graph. + exported_program.graph.eliminate_dead_code() + exported_program.graph_module.recompile() + + return exported_program diff --git a/exir/program/TARGETS b/exir/program/TARGETS index ef4e619e1e..730c9e93ae 100644 --- a/exir/program/TARGETS +++ b/exir/program/TARGETS @@ -40,6 +40,7 @@ python_library( "//executorch/exir/passes:replace_aten_with_edge_pass", "//executorch/exir/passes:replace_view_copy_with_view_pass", "//executorch/exir/passes:spec_prop_pass", + "//executorch/exir/passes:weights_to_outputs_pass", "//executorch/exir/verification:verifier", ], ) diff --git a/exir/program/_program.py b/exir/program/_program.py index fd6253a8aa..dda2da7fa7 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import copy import io import logging @@ -43,6 +45,7 @@ ReplaceViewCopyWithViewPass, ) from executorch.exir.passes.spec_prop_pass import SpecPropPass +from executorch.exir.passes.weights_to_outputs_pass import weights_to_outputs_pass from executorch.exir.print_program import pretty_print, print_program from executorch.exir.schema import Program from executorch.exir.tracer import _default_decomposition_table @@ -1227,6 +1230,7 @@ def to_executorch( execution_programs: Dict[str, ExportedProgram] = {} for name, program in self._edge_programs.items(): + program = weights_to_outputs_pass(program) program = unsafe_remove_auto_functionalized_pass(program) gm, new_signature = insert_write_back_for_buffers_pass(program) new_gm = program.graph_module @@ -1322,6 +1326,7 @@ def __init__( # Serialize emitter output, ready to be written to a file. self._pte_data: Cord = _serialize_pte_binary( program=self._emitter_output.program, + mutable_data=self._emitter_output.mutable_data, extract_delegate_segments=backend_config.extract_delegate_segments, extract_constant_segment=backend_config.extract_constant_segment, segment_alignment=backend_config.segment_alignment, diff --git a/exir/schema.py b/exir/schema.py index e9b589f839..706bc61140 100644 --- a/exir/schema.py +++ b/exir/schema.py @@ -265,3 +265,4 @@ class Program: backend_delegate_data: List[BackendDelegateInlineData] segments: List[DataSegment] constant_segment: SubsegmentOffsets + mutable_data_segments: Optional[List[SubsegmentOffsets]] = None diff --git a/exir/tensor.py b/exir/tensor.py index da35c2c491..7380a96ebc 100644 --- a/exir/tensor.py +++ b/exir/tensor.py @@ -326,11 +326,6 @@ def to_list( else: return x - internal_assert( - not spec.const or not allocation_info, - "We only create non-constant tensors as the constant tensors are directly written to buffer", - ) - tensor_size = to_list(spec.shape) tensor_dim_order = to_list(spec.dim_order) diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 7ef04d1283..8c20467ae6 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -101,6 +101,17 @@ python_unittest( ], ) +python_unittest( + name = "joint_graph", + srcs = [ + "test_joint_graph.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + ], +) + python_unittest( name = "error", srcs = [ diff --git a/exir/tests/test_joint_graph.py b/exir/tests/test_joint_graph.py new file mode 100644 index 0000000000..0aa724479b --- /dev/null +++ b/exir/tests/test_joint_graph.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import unittest + +import torch +import torch._dynamo + +from executorch.exir import to_edge +from torch.export._trace import _export +from torch.export.experimental import _export_forward_backward +from torch.export.exported_program import OutputKind + + +class TestJointGraph(unittest.TestCase): + def test_joint_graph(self) -> None: + class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + self.loss = torch.nn.CrossEntropyLoss() + + def forward(self, x, y): + return self.loss(self.linear(x).softmax(dim=0), y) + + m = Module() + example_inputs = (torch.ones(3), torch.tensor([1.0, 0.0, 0.0])) + m(*example_inputs) + ep = _export(m, example_inputs, pre_dispatch=True) + joint_ep = _export_forward_backward(ep) + edge = to_edge(joint_ep) + + output_node = None + for node in edge.exported_program().graph.nodes: + if node.op == "output": + output_node = node + break + + orig_outputs = len(output_node.args[0]) + + et = edge.to_executorch() + + weight_output_specs = [ + spec + for spec in et.exported_program().graph_signature.output_specs + if spec.kind == OutputKind.TOKEN + ] + + output_node = None + for node in et.exported_program().graph.nodes: + if node.op == "output": + output_node = node + break + + weight_outputs = len(output_node.args[0]) + + # make sure 2 new outputs are added to both the node and the spec + self.assertEqual(len(weight_output_specs), 2) # linear layer weight and bias + self.assertEqual( + weight_outputs - orig_outputs, 2 + ) # linear layer weight and bias + + # assert that the weight and bias have proper data_buffer_idx and allocation_info + self.assertEqual( + et.executorch_program.execution_plan[0] # pyre-ignore + .values[0] + .val.data_buffer_idx, + 1, + ) + self.assertEqual( + et.executorch_program.execution_plan[0] # pyre-ignore + .values[1] + .val.data_buffer_idx, + 2, + ) + self.assertEqual( + et.executorch_program.execution_plan[0] # pyre-ignore + .values[0] + .val.allocation_info.memory_offset_low, + 0, + ) + self.assertEqual( + et.executorch_program.execution_plan[0] # pyre-ignore + .values[1] + .val.allocation_info.memory_offset_low, + 48, + ) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 61d3af8afb..99ec648145 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -37,10 +37,11 @@ from executorch.exir.passes.insert_write_back_for_buffers_pass import ( insert_write_back_for_buffers_pass, ) + +from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass from executorch.exir.passes.normalize_view_copy_base_pass import ( NormalizeViewCopyBasePass, ) - from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators from executorch.exir.passes.replace_edge_with_backend_pass import EdgeToBackendOpsPass @@ -1676,3 +1677,100 @@ def forward(self, text_tokens): ) new_ep = constant_prop_pass(edge_manager._edge_programs["forward"]) _ = copy.deepcopy(new_ep.module_call_graph) + + def test_dim_order_revert_pass(self) -> None: + aten_op_str = "torch.ops.aten._to_copy.default" + edge_aten_op_str = "executorch_exir_dialects_edge__ops_aten__to_copy_default" + edge_dim_order_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" + + class Module(torch.nn.Module): + """ + A simple module that has a single to op that converts to channels last and then back to contiguous. + Assuming contiguous input. + """ + + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.to(memory_format=torch.channels_last).to( + memory_format=torch.contiguous_format + ) + x.to(memory_format=torch.channels_last).to( + memory_format=torch.contiguous_format + ) + + @staticmethod + def to_copy_count(): + return 4 + + def _do_checks( + test_str: str, allowed: str, allowed_count: int, not_allowed_list: List[str] + ) -> None: + for not_allowed in not_allowed_list: + FileCheck().check_count(allowed, allowed_count, exactly=True).check_not( + not_allowed + ).run(test_str) + + m = Module() + n = m.to_copy_count() + input = torch.randn([2, 3, 4, 5]).to(memory_format=torch.contiguous_format) + + # 1. vanilla export, no edge ops + ep = export( + m, + (input,), + ) + _do_checks( + ep.graph_module.code, + aten_op_str, + n, + [edge_aten_op_str, edge_dim_order_op_str], + ) + + # 2a. to edge without dim orders, we should see edge aten ops but not dim order ops + edge_prog = to_edge( + ep, compile_config=exir.EdgeCompileConfig(_skip_dim_order=True) + )._edge_programs["forward"] + _do_checks( + edge_prog.graph_module.code, + edge_aten_op_str, + n, + [aten_op_str, edge_dim_order_op_str], + ) + + # 3a. expect no change after the pass, we should see edge aten ops but not dim order ops + new_res = DimOrderOpsRevertPass()(edge_prog.graph_module) + self.assertIsNotNone(new_res) + _do_checks( + new_res.graph_module.code, + edge_aten_op_str, + n, + [aten_op_str, edge_dim_order_op_str], + ) + + # 2b. let's try with dim order enabled, we should see edge dim order ops but not edge aten ops + edge_prog_dim_order = to_edge( + ep, compile_config=exir.EdgeCompileConfig(_skip_dim_order=False) + )._edge_programs["forward"] + _do_checks( + edge_prog_dim_order.graph_module.code, + edge_dim_order_op_str, + n, + [aten_op_str, edge_aten_op_str], + ) + + # 3b. expect edge aten ops after the pass, we should see not see the edge dim order ops + new_res_dim_order = DimOrderOpsRevertPass()(edge_prog_dim_order.graph_module) + self.assertIsNotNone(new_res_dim_order) + _do_checks( + new_res_dim_order.graph_module.code, + edge_aten_op_str, + n, + [aten_op_str, edge_dim_order_op_str], + ) + + output_no_dim_order = new_res.graph_module(input) + output_no_dim_order_revert = new_res_dim_order.graph_module(input) + self.assertTrue( + torch.allclose(output_no_dim_order[0], output_no_dim_order_revert[0]) + ) diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index b4fe80f022..90be2c68c4 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -73,7 +73,7 @@ class ExecuTorchLlamaCallbackJni method(self(), s); } - void onStats(const Runner::Stats& result) const { + void onStats(const Stats& result) const { static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic(); static const auto method = cls->getMethod("onStats"); double eval_time = @@ -132,7 +132,7 @@ class ExecuTorchLlamaJni prompt->toStdString(), 128, [callback](std::string result) { callback->onResult(result); }, - [callback](const Runner::Stats& result) { callback->onStats(result); }); + [callback](const Stats& result) { callback->onStats(result); }); return 0; } diff --git a/extension/llm/README.md b/extension/llm/README.md new file mode 100644 index 0000000000..dfc193e41e --- /dev/null +++ b/extension/llm/README.md @@ -0,0 +1,47 @@ +This subtree contains libraries and utils of running generative AI, including Large Language Models (LLM) using ExecuTorch. +Below is a list of sub folders. +## export +Model preparation codes are in _export_ folder. The main entry point is the _LLMEdgeManager_ class. It hosts a _torch.nn.Module_, with a list of methods that can be used to prepare the LLM model for ExecuTorch runtime. +Note that ExecuTorch supports two [quantization APIs](https://pytorch.org/docs/stable/quantization.html#quantization-api-summary): eager mode quantization (aka source transform based quantization), and PyTorch 2 Export based quantization (aka pt2e quantization). +Typical methods include: +- _set_output_dir_: where users want to save the exported .pte file. +- _to_dtype_: override the data type of the module. +- _source_transform_: execute a series of source transform passes. Some transform passes include + - weight only quantization, which can be done at source (eager mode) level. + - replace some torch operators to a custom operator. For example, _replace_sdpa_with_custom_op_. +- _capture_pre_autograd_graph_: get a graph that is ready for pt2 graph-based quantization. +- _pt2e_quantize_ with passed in quantizers. + - util functions in _quantizer_lib.py_ can help to get different quantizers based on the needs. +- _export_to_edge_: export to edge dialect +- _to_backend_: lower the graph to an acceleration backend. +- _to_executorch_: get the executorch graph with optional optimization passes. +- _save_to_pte_: finally, the lowered and optimized graph can be saved into a .pte file for the runtime. + +Some usage of LLMEdgeManager can be found in executorch/examples/models/llama2, and executorch/examples/models/llava. + +When the .pte file is exported and saved, we can prepare a load and run it in a runner. + +## tokenizer +Currently, we support two types of tokenizers: sentencepiece and Tiktoken. +- In Python: + - _utils.py_: get the tokenizer from a model file path, based on the file format. + - _tokenizer.py_: rewrite a sentencepiece tokenizer model to a serialization format that the runtime can load. +- In C++: + - _tokenizer.h_: a simple tokenizer interface. Actual tokenizer classes can be implemented based on this. In this folder, we provide two tokenizer implementations: + - _bpe_tokenizer_. We need the rewritten version of tokenizer artifact (refer to _tokenizer.py_ above), for bpe tokenizer to work. + - _tiktokern_. It's for llama3 and llama3.1. + +## sampler +A sampler class in C++ to sample the logistics given some hyperparameters. + +## custom_ops +It hosts a custom sdpa operator. This sdpa operator implements CPU flash attention, it avoids copies by taking the kv cache as one of the arguments to this custom operator. +- _sdpa_with_kv_cache.py_, _op_sdpa_aot.cpp_: custom op definition in PyTorch with C++ registration. +- _op_sdpa.cpp_: the optimized operator implementation and registration of _sdpa_with_kv_cache.out_. + +## runner +It hosts the libary components used in a C++ llm runner. Currently, it hosts _stats.h_ on runtime status like token numbers and latency. + +With the components above, an actual runner can be built for a model or a series of models. An exmaple is in //executorch/examples/models/llama2/runner, where a C++ runner code is built to run Llama 2, 3, 3.1 and other models using the same architecture. + +Usages can also be found in the [torchchat repo](https://github.com/pytorch/torchchat/tree/main/runner). diff --git a/extension/llm/runner/TARGETS b/extension/llm/runner/TARGETS new file mode 100644 index 0000000000..2341af9282 --- /dev/null +++ b/extension/llm/runner/TARGETS @@ -0,0 +1,8 @@ +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/extension/llm/runner/stats.h b/extension/llm/runner/stats.h new file mode 100644 index 0000000000..31dd5e71cf --- /dev/null +++ b/extension/llm/runner/stats.h @@ -0,0 +1,123 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Runner stats for LLM +#pragma once +#include +#include +// patternlint-disable-next-line executorch-cpp-nostdinc +#include + +#include +namespace executorch::llm { + +struct Stats { + // Scaling factor for timestamps - in this case, we use ms. + const long SCALING_FACTOR_UNITS_PER_SECOND = 1000; + // Time stamps for the different stages of the execution + // model_load_start_ms: Start of model loading. + long model_load_start_ms; + // model_load_end_ms: End of model loading. + long model_load_end_ms; + // inference_start_ms: Immediately after the model is loaded (or we check + // for model load), measure the inference time. + long inference_start_ms; + // prompt_eval_end_ms: Prompt array allocation and tokenization. Ends right + // before the inference loop starts + long prompt_eval_end_ms; + // first_token: Timestamp when the first generated token is emitted + long first_token_ms; + // inference_end_ms: End of inference/generation. + long inference_end_ms; + // Keep a running total of the time spent in sampling. + long aggregate_sampling_time_ms; + // Token count from prompt + int64_t num_prompt_tokens; + // Token count from generated (total - prompt) + int64_t num_generated_tokens; +}; + +static constexpr auto kTopp = 0.9f; + +inline std::string stats_to_json_string(const Stats& stats) { + std::stringstream ss; + ss << "{\"prompt_tokens\":" << stats.num_prompt_tokens << "," + << "\"generated_tokens\":" << stats.num_generated_tokens << "," + << "\"model_load_start_ms\":" << stats.model_load_start_ms << "," + << "\"model_load_end_ms\":" << stats.model_load_end_ms << "," + << "\"inference_start_ms\":" << stats.inference_start_ms << "," + << "\"inference_end_ms\":" << stats.inference_end_ms << "," + << "\"prompt_eval_end_ms\":" << stats.prompt_eval_end_ms << "," + << "\"first_token_ms\":" << stats.first_token_ms << "," + << "\"aggregate_sampling_time_ms\":" << stats.aggregate_sampling_time_ms + << "," << "\"SCALING_FACTOR_UNITS_PER_SECOND\":" + << stats.SCALING_FACTOR_UNITS_PER_SECOND << "}"; + return ss.str(); +} + +inline void print_report(const Stats& stats) { + printf("PyTorchObserver %s\n", stats_to_json_string(stats).c_str()); + + ET_LOG( + Info, + "\tPrompt Tokens: %" PRIu64 " Generated Tokens: %" PRIu64, + stats.num_prompt_tokens, + stats.num_generated_tokens); + + ET_LOG( + Info, + "\tModel Load Time:\t\t%f (seconds)", + ((double)(stats.model_load_end_ms - stats.model_load_start_ms) / + stats.SCALING_FACTOR_UNITS_PER_SECOND)); + double inference_time_ms = + (double)(stats.inference_end_ms - stats.inference_start_ms); + ET_LOG( + Info, + "\tTotal inference time:\t\t%f (seconds)\t\t Rate: \t%f (tokens/second)", + inference_time_ms / stats.SCALING_FACTOR_UNITS_PER_SECOND, + + (stats.num_generated_tokens) / + (double)(stats.inference_end_ms - stats.inference_start_ms) * + stats.SCALING_FACTOR_UNITS_PER_SECOND); + double prompt_eval_time = + (double)(stats.prompt_eval_end_ms - stats.inference_start_ms); + ET_LOG( + Info, + "\t\tPrompt evaluation:\t%f (seconds)\t\t Rate: \t%f (tokens/second)", + prompt_eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND, + (stats.num_prompt_tokens) / prompt_eval_time * + stats.SCALING_FACTOR_UNITS_PER_SECOND); + + double eval_time = + (double)(stats.inference_end_ms - stats.prompt_eval_end_ms); + ET_LOG( + Info, + "\t\tGenerated %" PRIu64 + " tokens:\t%f (seconds)\t\t Rate: \t%f (tokens/second)", + stats.num_generated_tokens, + eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND, + stats.num_generated_tokens / eval_time * + stats.SCALING_FACTOR_UNITS_PER_SECOND); + + // Time to first token is measured from the start of inference, excluding + // model load time. + ET_LOG( + Info, + "\tTime to first generated token:\t%f (seconds)", + ((double)(stats.first_token_ms - stats.inference_start_ms) / + stats.SCALING_FACTOR_UNITS_PER_SECOND)); + + ET_LOG( + Info, + "\tSampling time over %" PRIu64 " tokens:\t%f (seconds)", + stats.num_prompt_tokens + stats.num_generated_tokens, + (double)stats.aggregate_sampling_time_ms / + stats.SCALING_FACTOR_UNITS_PER_SECOND); +} + +} // namespace executorch::llm diff --git a/extension/llm/runner/targets.bzl b/extension/llm/runner/targets.bzl new file mode 100644 index 0000000000..81a3d32ba8 --- /dev/null +++ b/extension/llm/runner/targets.bzl @@ -0,0 +1,10 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + runtime.cxx_library( + name = "stats", + exported_headers = ["stats.h"], + visibility = [ + "@EXECUTORCH_CLIENTS", + ], + ) diff --git a/extension/llm/tokenizer/bpe_tokenizer.cpp b/extension/llm/tokenizer/bpe_tokenizer.cpp index d1edaa0f08..07d138548d 100644 --- a/extension/llm/tokenizer/bpe_tokenizer.cpp +++ b/extension/llm/tokenizer/bpe_tokenizer.cpp @@ -190,7 +190,7 @@ BPETokenizer::encode(const std::string& text, int8_t bos, int8_t eos) const { std::vector tokens; // add optional BOS token, if desired - if (bos > 0) { + if (bos >= 0) { while (bos--) { tokens.push_back(bos_tok_); } diff --git a/extension/llm/tokenizer/bpe_tokenizer.h b/extension/llm/tokenizer/bpe_tokenizer.h index d398ec1235..7ea8402583 100644 --- a/extension/llm/tokenizer/bpe_tokenizer.h +++ b/extension/llm/tokenizer/bpe_tokenizer.h @@ -19,6 +19,8 @@ struct TokenIndex { int32_t id; }; +// A simple Byte Pair Encoding (BPE) Tokenizer. Note that the current C++ code +// won't work with this class, it needs to go through tokenizer.py first. class BPETokenizer : public Tokenizer { public: explicit BPETokenizer(); diff --git a/extension/llm/tokenizer/tokenizer.h b/extension/llm/tokenizer/tokenizer.h index ae193a257a..b49dc245eb 100644 --- a/extension/llm/tokenizer/tokenizer.h +++ b/extension/llm/tokenizer/tokenizer.h @@ -6,12 +6,12 @@ * LICENSE file in the root directory of this source tree. */ -// A simple Byte Pair Encoding (BPE) Tokenizer. Note that the vanila tokenizer -// model won't work with this class, it needs to go through tokenizer.py first. #pragma once #include +// patternlint-disable-next-line executorch-cpp-nostdinc #include +// patternlint-disable-next-line executorch-cpp-nostdinc #include #include @@ -20,6 +20,7 @@ namespace torch { namespace executor { +// A tokenizer interface. class Tokenizer { public: explicit Tokenizer() {} diff --git a/extension/training/optimizer/targets.bzl b/extension/training/optimizer/targets.bzl index 84043d27c9..69682feaee 100644 --- a/extension/training/optimizer/targets.bzl +++ b/extension/training/optimizer/targets.bzl @@ -26,7 +26,7 @@ def define_common_targets(): ] runtime.cxx_library( - name = "optimizer" + aten_suffix, + name = "sgd" + aten_suffix, srcs = [ "sgd.cpp", ], diff --git a/extension/training/optimizer/test/targets.bzl b/extension/training/optimizer/test/targets.bzl index 7ffa74d614..11269bfa18 100644 --- a/extension/training/optimizer/test/targets.bzl +++ b/extension/training/optimizer/test/targets.bzl @@ -15,7 +15,7 @@ def define_common_targets(): "sgd_test.cpp", ], deps = [ - "//executorch/extension/training/optimizer:optimizer" + aten_suffix, + "//executorch/extension/training/optimizer:sgd" + aten_suffix, "//executorch/runtime/core:core", "//executorch/runtime/core/exec_aten/testing_util:tensor_util", ], diff --git a/kernels/prim_ops/et_copy_index.cpp b/kernels/prim_ops/et_copy_index.cpp index cd34abfdab..40cf9a7e55 100644 --- a/kernels/prim_ops/et_copy_index.cpp +++ b/kernels/prim_ops/et_copy_index.cpp @@ -94,7 +94,8 @@ void et_copy_index(RuntimeContext& context, EValue** stack) { expected_output_size[i + 1] = copy_from.sizes()[i]; } - if (copy_to.sizes()[0] != expected_output_size[0]) { + if (copy_to.sizes()[0] < expected_output_size[0]) { + // Resize `copy_to` to the expected output size. const void* data_ptr = copy_to.const_data_ptr(); Error err = resize_tensor(copy_to, {expected_output_size, copy_to.sizes().size()}); diff --git a/runtime/executor/program.h b/runtime/executor/program.h index 5d2ec67a3e..802c112213 100644 --- a/runtime/executor/program.h +++ b/runtime/executor/program.h @@ -209,7 +209,7 @@ class Program final { /** * Loads a segment by index. * - * @param[in] SegmentInfo Struct containing an index to load from the + * @param[in] segment_info Struct containing an index to load from the * Program.segments list. The other fields of the struct, such as * `segment_type` and `descriptor`, need to also be correct. * diff --git a/test/end2end/exported_module.py b/test/end2end/exported_module.py index 1cf38cc54e..656b570512 100644 --- a/test/end2end/exported_module.py +++ b/test/end2end/exported_module.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + """Test helper for exporting an nn.Module to an ExecuTorch program.""" import functools @@ -22,6 +24,8 @@ ) from torch import nn from torch.export import export +from torch.export._trace import _export +from torch.export.experimental import _export_forward_backward class ExportedModule: @@ -65,6 +69,7 @@ def export( capture_config=None, extract_constant_segment: bool = True, skip_type_promotion: bool = False, + export_joint_graph: bool = False, ) -> "ExportedModule": """ Creates a new ExportedModule for the specified module class. @@ -157,15 +162,30 @@ def __init__(self, method): # variant, along with some other transformations. for method_name, method_input in method_name_to_args.items(): # if not isinstance(eager_module, torch.nn.Module): - exported_methods[method_name] = export( - eager_module, - method_input, - dynamic_shapes=( - method_name_to_dynamic_shapes[method_name] - if method_name_to_dynamic_shapes - else None - ), - ) + if export_joint_graph: + # _export was having issues with WrapperModule. + assert method_name == "forward" + ep = _export( + eager_module, + method_input, + dynamic_shapes=( + method_name_to_dynamic_shapes[method_name] + if method_name_to_dynamic_shapes + else None + ), + pre_dispatch=True, + ) + exported_methods[method_name] = _export_forward_backward(ep) + else: + exported_methods[method_name] = export( + eager_module, + method_input, + dynamic_shapes=( + method_name_to_dynamic_shapes[method_name] + if method_name_to_dynamic_shapes + else None + ), + ) exec_prog = to_edge( exported_methods, diff --git a/test/models/export_program.py b/test/models/export_program.py index c6d744d058..7941af376f 100644 --- a/test/models/export_program.py +++ b/test/models/export_program.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import argparse import inspect import os @@ -164,6 +166,23 @@ def get_method_names_to_export() -> List[str]: return ["forward", "forward2"] +class ModuleSimpleTrain(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + self.loss = torch.nn.CrossEntropyLoss() + + def forward(self, x, y): + return self.loss(self.linear(x).softmax(dim=0), y) + + def get_random_inputs(self): + return (torch.randn(3), torch.tensor([1.0, 0.0, 0.0])) + + @staticmethod + def export_joint(): + return True + + # # Main logic. # @@ -175,11 +194,15 @@ def export_module_to_program( skip_type_promotion: bool, ): """Exports the module and returns the serialized program data.""" + torch.manual_seed(0) # Look for an optional @staticmethod that defines custom trace params. export_kwargs: Dict[str, Any] = {} if hasattr(module_class, "get_export_kwargs"): # pyre-ignore[16]: pyre doesn't know about get_export_kwargs. export_kwargs = module_class.get_export_kwargs() + export_joint = False + if hasattr(module_class, "export_joint"): + export_joint = module_class.export_joint() # pyre-ignore if hasattr(module_class, "get_method_names_to_export"): # pyre-ignore[16]: pyre doesn't know about get_export_kwargs. methods = module_class.get_method_names_to_export() @@ -190,6 +213,7 @@ def export_module_to_program( methods, extract_constant_segment=extract_constant_segment, skip_type_promotion=skip_type_promotion, + export_joint_graph=export_joint, **export_kwargs, ) return module.executorch_program.buffer @@ -199,6 +223,7 @@ def main() -> None: # These args are optimized for genrule usage. There's a lot of startup # overhead for this tool, so it's faster to export multiple models at once # when possible. + torch.manual_seed(0) parser = argparse.ArgumentParser( prog="export_program", description="Exports nn.Module models to ExecuTorch .pte files", diff --git a/test/models/targets.bzl b/test/models/targets.bzl index e44c6eb0c7..ad907304ed 100644 --- a/test/models/targets.bzl +++ b/test/models/targets.bzl @@ -66,6 +66,7 @@ def define_common_targets(): "ModuleMultipleEntry", "ModuleIndex", "ModuleDynamicCatUnallocatedIO", + "ModuleSimpleTrain", ] # Generates Executorch .pte program files for various modules at build time.