diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 3a0cd57ddb..c33cc533c0 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -1,5 +1,7 @@ mpmath==1.3.0 -numpy==1.25.2 +numpy==1.21.3; python_version == '3.10' +numpy==1.23.2; python_version == '3.11' +numpy; python_version >= '3.12' PyYAML==6.0.1 ruamel.yaml==0.17.32 sympy==1.12 @@ -8,6 +10,8 @@ tomli==2.0.1 torchsr==1.0.4 transformers==4.38.0 zstd==1.5.5.1 +pandas==2.0.3; python_version == '3.10' +pandas; python_version >= '3.11' pytest==7.2.0 pytest-cov==4.1.0 expecttest==0.1.6 diff --git a/.ci/scripts/gather_test_models.py b/.ci/scripts/gather_test_models.py index 55289140c4..36a64e4241 100755 --- a/.ci/scripts/gather_test_models.py +++ b/.ci/scripts/gather_test_models.py @@ -27,6 +27,7 @@ # This one causes timeout on smaller runner, the root cause is unclear (T161064121) "dl3": "linux.12xlarge", "emformer_join": "linux.12xlarge", + "emformer_predict": "linux.12xlarge", } } @@ -35,9 +36,11 @@ # Just some examples on how custom timeout can be set "linux": { "mobilebert": 90, + "emformer_predict": 360, }, "macos": { "mobilebert": 90, + "emformer_predict": 360, }, } @@ -84,7 +87,11 @@ def model_should_run_on_event(model: str, event: str) -> bool: """ if event == "pull_request": return model in ["mv3", "vit"] - return True + elif event == "push": + # 'emformer_predict' is running super slow. Only run it periodically + return model not in ["emformer_predict"] + else: + return True def model_should_run_on_target_os(model: str, target_os: str) -> bool: diff --git a/.ci/scripts/test_llama.sh b/.ci/scripts/test_llama.sh index 30b77ee38f..ae795b12ab 100644 --- a/.ci/scripts/test_llama.sh +++ b/.ci/scripts/test_llama.sh @@ -13,6 +13,7 @@ MODEL_NAME=$1 # stories110M.pt BUILD_TOOL=$2 # buck2 or cmake DTYPE=$3 # fp16 or fp32 MODE=${4:-"xnnpack+custom"} # portable or xnnpack+custom or xnnpack+custom+qe +UPLOAD_DIR=${5:-} if [[ $# -lt 4 ]]; then # Assuming 4 mandatory args echo "Expecting atleast 4 positional arguments" echo "Usage: [...]" @@ -126,6 +127,15 @@ cleanup_files() { rm params.json } +prepare_artifacts_upload() { + if [ -n "$UPLOAD_DIR" ]; then + echo "Preparing for uploading generated artifacs" + mkdir -p "${UPLOAD_DIR}" + zip -j "model.zip" "${MODEL_NAME}" tokenizer.bin + cp "model.zip" "${UPLOAD_DIR}" + fi +} + # Download and create artifacts. PARAMS="params.json" touch "${PARAMS}" @@ -205,6 +215,7 @@ if [[ "${RESULT}" == "${EXPECTED_PREFIX}"* ]]; then echo "Actual result: ${RESULT}" echo "Success" + prepare_artifacts_upload cleanup_files else echo "Expected result prefix: ${EXPECTED_PREFIX}" diff --git a/.github/workflows/android-perf.yml b/.github/workflows/android-perf.yml new file mode 100644 index 0000000000..a8223eef2c --- /dev/null +++ b/.github/workflows/android-perf.yml @@ -0,0 +1,224 @@ +name: android-perf + +on: + schedule: + - cron: 0 0 * * * + # Note: GitHub has an upper limit of 10 inputs + workflow_dispatch: + inputs: + models: + description: Models to be benchmarked + required: false + type: string + default: stories110M + devices: + description: Target devices to run benchmark + required: false + type: string + default: samsung_galaxy_s2x + delegates: + description: Backend delegates + required: false + type: string + default: xnnpack + threadpool: + description: Run with threadpool? + required: false + type: boolean + default: false + benchmark_configs: + description: The list of configs used the benchmark + required: false + type: string + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +permissions: read-all + +jobs: + set-parameters: + runs-on: linux.2xlarge + outputs: + models: ${{ steps.set-parameters.outputs.models }} + devices: ${{ steps.set-parameters.outputs.devices }} + delegates: ${{ steps.set-parameters.outputs.delegates }} + steps: + - name: Set parameters + id: set-parameters + shell: bash + run: | + set -ex + MODELS="${{ inputs.models }}" + DEVICES="${{ inputs.devices }}" + DELEGATES="${{ inputs.delegates }}" + + # Mapping devices to their corresponding device-pool-arn + declare -A DEVICE_POOL_ARNS + DEVICE_POOL_ARNS[samsung_galaxy_s2x]="arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/e59f866a-30aa-4aa1-87b7-4510e5820dfa" + + # Resolve device names with their corresponding ARNs + if [[ ! $(echo "$DEVICES" | jq empty 2>/dev/null) ]]; then + DEVICES=$(echo "$DEVICES" | jq -Rc 'split(",")') + fi + declare -a MAPPED_ARNS=() + for DEVICE in $(echo "$DEVICES" | jq -r '.[]'); do + if [[ -z "${DEVICE_POOL_ARNS[$DEVICE]}" ]]; then + echo "Error: No ARN found for device '$DEVICE'. Abort." >&2 + exit 1 + fi + MAPPED_ARNS+=("${DEVICE_POOL_ARNS[$DEVICE]}") + done + + echo "models=$(echo $MODELS | jq -Rc 'split(",")')" >> $GITHUB_OUTPUT + MAPPED_ARNS_JSON=$(printf '%s\n' "${MAPPED_ARNS[@]}" | jq -R . | jq -s .) + echo "devices=$(echo "$MAPPED_ARNS_JSON" | jq -c .)" >> $GITHUB_OUTPUT + echo "delegates=$(echo $DELEGATES | jq -Rc 'split(",")')" >> $GITHUB_OUTPUT + + export-models: + name: export-models + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + needs: set-parameters + strategy: + matrix: + model: ${{ fromJson(needs.set-parameters.outputs.models) }} + delegate: ${{ fromJson(needs.set-parameters.outputs.delegates) }} + fail-fast: false + with: + runner: linux.2xlarge + docker-image: executorch-ubuntu-22.04-clang12 + submodules: 'true' + timeout: 60 + upload-artifact: android-models + script: | + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + + PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "cmake" + echo "Exporting model: ${{ matrix.model }}" + export ARTIFACTS_DIR_NAME=artifacts-to-be-uploaded/${{ matrix.model }}_${{ matrix.delegate }} + + # TODO(T197546696): Note that the following scripts/steps only work for llama. It's expected to fail for other models+delegates. + # Install requirements for export_llama + PYTHON_EXECUTABLE=python bash examples/models/llama2/install_requirements.sh + # Test llama2 + PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh "${{ matrix.model }}.pt" "cmake" "fp32" "xnnpack+custom+qe" "${ARTIFACTS_DIR_NAME}"\ + + # Upload models to S3. The artifacts are needed not only by the device farm but also TorchChat + upload-models: + needs: export-models + runs-on: linux.2xlarge + steps: + - name: Download the models from GitHub + uses: actions/download-artifact@v3 + with: + # The name here needs to match the name of the upload-artifact parameter + name: android-models + path: ${{ runner.temp }}/artifacts/ + + - name: Verify the models + shell: bash + working-directory: ${{ runner.temp }}/artifacts/ + run: | + ls -lah ./ + + - name: Upload the models to S3 + uses: seemethere/upload-artifact-s3@v5 + with: + s3-bucket: gha-artifacts + s3-prefix: | + ${{ github.repository }}/${{ github.run_id }}/artifact + retention-days: 1 + if-no-files-found: ignore + path: ${{ runner.temp }}/artifacts/ + + build-llm-demo: + name: build-llm-demo + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + needs: set-parameters + strategy: + matrix: + tokenizer: [bpe] + with: + runner: linux.2xlarge + docker-image: executorch-ubuntu-22.04-clang12-android + submodules: 'true' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + upload-artifact: android-apps + script: | + set -eux + + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh cmake + export ARTIFACTS_DIR_NAME=artifacts-to-be-uploaded + + # TODO: This needs to be replaced with a generic loader .apk + # Build LLM Demo for Android + bash build/build_android_llm_demo.sh ${{ matrix.tokenizer }} ${ARTIFACTS_DIR_NAME} + + # Upload artifacts to S3. The artifacts are needed not only by the device farm but also TorchChat + upload-android-apps: + needs: build-llm-demo + runs-on: linux.2xlarge + steps: + - name: Download the apps from GitHub + uses: actions/download-artifact@v3 + with: + # The name here needs to match the name of the upload-artifact parameter + name: android-apps + path: ${{ runner.temp }}/artifacts/ + + - name: Verify the apps + shell: bash + working-directory: ${{ runner.temp }}/artifacts/ + run: | + ls -lah ./ + + - name: Upload the apps to S3 + uses: seemethere/upload-artifact-s3@v5 + with: + s3-bucket: gha-artifacts + s3-prefix: | + ${{ github.repository }}/${{ github.run_id }}/artifact + retention-days: 14 + if-no-files-found: ignore + path: ${{ runner.temp }}/artifacts/ + + # Let's see how expensive this job is, we might want to tone it down by running it periodically + benchmark-on-device: + permissions: + id-token: write + contents: read + uses: pytorch/test-infra/.github/workflows/mobile_job.yml@main + needs: + - set-parameters + - upload-models + - upload-android-apps + strategy: + matrix: + model: ${{ fromJson(needs.set-parameters.outputs.models) }} + delegate: ${{ fromJson(needs.set-parameters.outputs.delegates) }} + device: ${{ fromJson(needs.set-parameters.outputs.devices) }} + with: + device-type: android + runner: linux.2xlarge + test-infra-ref: '' + # This is the ARN of ExecuTorch project on AWS + project-arn: arn:aws:devicefarm:us-west-2:308535385114:project:02a2cf0f-6d9b-45ee-ba1a-a086587469e6 + device-pool-arn: ${{ matrix.device }} + # Uploaded to S3 from the previous job, the name of the app comes from the project itself. + # Unlike models there are limited numbers of build flavor for apps, and the model controls whether it should build with bpe/tiktoken tokenizer. + # It's okay to build all possible apps with all possible flavors in job "build-llm-demo". However, in this job, once a model is given, there is only + # one app+flavor that could load and run the model. + # TODO: Hard code llm_demo_bpe for now in this job. + android-app-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/llm_demo_bpe/app-debug.apk + android-test-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/llm_demo_bpe/app-debug-androidTest.apk + # The test spec can be downloaded from https://ossci-assets.s3.amazonaws.com/android-llama2-device-farm-test-spec.yml + test-spec: arn:aws:devicefarm:us-west-2:308535385114:upload:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/abd86868-fa63-467e-a5c7-218194665a77 + # Uploaded to S3 from the previous job + extra-data: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/${{ matrix.model }}_${{ matrix.delegate }}/model.zip diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 36099ca651..591a0328b7 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -193,7 +193,7 @@ jobs: strategy: fail-fast: false with: - runner: linux.12xlarge + runner: linux.24xlarge docker-image: executorch-ubuntu-22.04-clang12 submodules: 'true' ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} @@ -210,11 +210,19 @@ jobs: bash examples/models/llava/install_requirements.sh # run export_llava.sh - python examples/models/llava/export_llava.py + python examples/models/llava/export_llava.py --use-sdpa-with-kv-cache --pte-name llava_custom_sdpa.pte # verify file exists - if [ ! -f "llava_combined_xnnpack.pte" ]; then - echo "llava_combined_xnnpack.pte not found!" + if [ ! -f "llava_custom_sdpa.pte" ]; then + echo "llava_custom_sdpa.pte not found!" + exit 1 + fi + + python examples/models/llava/export_llava.py --no-use-sdpa-with-kv-cache --pte-name llava.pte + + # verify file exists + if [ ! -f "llava.pte" ]; then + echo "llava.pte not found!" exit 1 fi diff --git a/CMakeLists.txt b/CMakeLists.txt index d00ac243e2..01c1ea847c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -774,7 +774,7 @@ endif() if(EXECUTORCH_BUILD_KERNELS_CUSTOM) # TODO: move all custom kernels to ${CMAKE_CURRENT_SOURCE_DIR}/kernels/custom add_subdirectory( - ${CMAKE_CURRENT_SOURCE_DIR}/examples/models/llama2/custom_ops + ${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/custom_ops ) endif() diff --git a/backends/apple/coreml/runtime/delegate/ETCoreMLModelManager.mm b/backends/apple/coreml/runtime/delegate/ETCoreMLModelManager.mm index e7846256e6..927df0483f 100644 --- a/backends/apple/coreml/runtime/delegate/ETCoreMLModelManager.mm +++ b/backends/apple/coreml/runtime/delegate/ETCoreMLModelManager.mm @@ -655,7 +655,7 @@ - (void)prewarmRecentlyUsedAssetsWithMaxCount:(NSUInteger)maxCount { NSError *prewarmError = nil; if (![asset prewarmAndReturnError:&prewarmError]) { - ETCoreMLLogError(localError, + ETCoreMLLogError(prewarmError, "%@: Failed to prewarm asset with identifier = %@", NSStringFromClass(strongSelf.assetManager.class), asset.identifier); diff --git a/backends/apple/coreml/runtime/delegate/backend_delegate.mm b/backends/apple/coreml/runtime/delegate/backend_delegate.mm index f6eb7a83fd..efa3dd2472 100644 --- a/backends/apple/coreml/runtime/delegate/backend_delegate.mm +++ b/backends/apple/coreml/runtime/delegate/backend_delegate.mm @@ -157,7 +157,7 @@ - (BOOL)_loadAndReturnError:(NSError * _Nullable __autoreleasing *)error { if (self.config.should_prewarm_asset) { [modelManager prewarmRecentlyUsedAssetsWithMaxCount:1]; } - + return YES; } @@ -188,9 +188,14 @@ - (ModelHandle*)loadModelFromAOTData:(NSData*)data return nil; } - return [self.impl loadModelFromAOTData:data - configuration:configuration - error:error]; + auto handle = [self.impl loadModelFromAOTData:data + configuration:configuration + error:error]; + if ((handle != NULL) && self.config.should_prewarm_model) { + [self.impl prewarmModelWithHandle:handle error:nil]; + } + + return handle; } - (BOOL)executeModelWithHandle:(ModelHandle*)handle diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index 8ef5a79d3f..f187191fee 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -166,43 +166,6 @@ def get_intermediate_path(compile_spec: List[CompileSpec]) -> str: return None -def generate_ethosu_compile_spec( - config: str, - permute_memory_to_nhwc: Optional[bool] = None, - quantize_io: Optional[bool] = None, - system_config: Optional[str] = None, - memory_mode: Optional[str] = None, - extra_flags: Optional[str] = None, - config_ini: Optional[str] = "Arm/vela.ini", -) -> List[CompileSpec]: - return ( - ArmCompileSpecBuilder() - .ethosu_compile_spec( - config, - system_config=system_config, - memory_mode=memory_mode, - extra_flags=extra_flags, - config_ini=config_ini, - ) - .set_permute_memory_format(permute_memory_to_nhwc) - .set_quantize_io(quantize_io) - .build() - ) - - -def generate_tosa_compile_spec( - permute_memory_to_nhwc: Optional[bool] = None, - output_path: Optional[str] = None, -) -> List[CompileSpec]: - return ( - ArmCompileSpecBuilder() - .tosa_compile_spec() - .set_permute_memory_format(permute_memory_to_nhwc) - .dump_intermediate_artifacts_to(output_path) - .build() - ) - - @final class ArmBackend(BackendDetails): @staticmethod diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index 54cfafcc9b..56dac5d248 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -47,6 +47,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.avg_pool2d.default, exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten._softmax.default, + exir_ops.edge.aten.slice_copy.Tensor, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.view_copy.default, exir_ops.edge.aten.clone.default, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 79c507816d..e868b584cf 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -19,6 +19,7 @@ op_permute, op_quant, op_sigmoid, + op_slice, op_softmax, op_sub, op_view, diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py new file mode 100644 index 0000000000..8d59835ff0 --- /dev/null +++ b/backends/arm/operators/op_slice.py @@ -0,0 +1,55 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +import serializer.tosa_serializer as ts +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from serializer.tosa_serializer import TosaOp +from torch.fx import Node + + +@register_node_visitor +class SliceVisitor(NodeVisitor): + target = "aten.slice_copy.Tensor" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + + # aten.slice_copy supports slicing in 1d at a time. + # The arguments are dimension of slicing, start index and end index. + assert len(inputs) == 4 + input_node, dim, start, end = inputs + + # Translate and check parameters in Pytorch dim order. + shape = input_node.shape + dim = dim.number + end = (shape[dim] + end.number) % shape[dim] + size = end - start.number + assert size > 0 + assert size <= shape[dim] + + # Convert aten args to Tosa's start and size attributes and in TOSA dim order. + attr = ts.TosaSerializerAttribute() + start_attr = [start.number if i == dim else 0 for i in input_node.dim_order] + size_attr = [size if i == dim else shape[i] for i in input_node.dim_order] + attr.SliceAttribute(start_attr, size_attr) + + tosa_graph.addOperator( + TosaOp.Op().SLICE, [input_node.name], [output.name], attr + ) diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 3e1aceefe1..397ba68565 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -17,8 +17,11 @@ import torch import torch.nn.functional as F + +from executorch.backends.arm.quantizer import arm_quantizer_utils from executorch.backends.arm.quantizer.arm_quantizer_utils import ( convert_scalars_to_attrs, + mark_nodes_as_annotated, propagate_annotation, ) from executorch.backends.arm.quantizer.quantization_annotation import ( @@ -41,6 +44,10 @@ ) from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer +from torch.ao.quantization.quantizer.utils import ( + _annotate_input_qspec_map, + _annotate_output_qspec, +) from torch.fx import GraphModule, Node __all__ = [ @@ -263,6 +270,7 @@ class ArmQuantizer(Quantizer): def __init__(self) -> None: super().__init__() self.global_config: Optional[QuantizationConfig] = None + self.io_config: Optional[QuantizationConfig] = None self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {} self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {} @@ -294,6 +302,11 @@ def set_module_name( self.module_name_config[module_name] = quantization_config return self + def set_io(self, quantization_config): + """Set quantization_config for input and output nodes.""" + self.io_config = quantization_config + return self + def transform_for_annotation(self, model: GraphModule) -> GraphModule: """An initial pass for transforming the graph to prepare it for annotation. Currently transforms scalar values to tensor attributes. @@ -358,8 +371,33 @@ def _annotate_for_static_quantization_config( self.global_config, _get_not_module_type_or_name_filter(tp_list, module_name_list), ) + + if self.io_config: + self._annotate_io(model, self.io_config) + return model + def _annotate_io( + self, + model: GraphModule, + quantization_config: QuantizationConfig, + ): + for node in model.graph.nodes: + if arm_quantizer_utils.is_annotated(node): + continue + if node.op == "placeholder": + _annotate_output_qspec( + node, + quantization_config.get_output_act_qspec(), + ) + mark_nodes_as_annotated([node]) + if node.op == "output": + parent = node.all_input_nodes[0] + _annotate_input_qspec_map( + node, parent, quantization_config.get_input_act_qspec() + ) + mark_nodes_as_annotated([node]) + def validate(self, model: GraphModule) -> None: pass diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index ee2844e668..89703f89b0 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -140,7 +140,7 @@ def is_share_obs_or_fq_op(op: Callable) -> bool: torch.ops.aten.adaptive_avg_pool2d.default, torch.ops.aten.view_copy.default, torch.ops.aten.view.default, - torch.ops.aten.slice_copy.Tensor, + torch.ops.aten.slice.Tensor, torch.ops.aten.flatten.using_ints, torch.ops.aten.dropout.default, ] diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index 906164aac3..f85fd1f2da 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -89,17 +89,26 @@ def get_tosa_compile_spec(permute_memory_to_nhwc=True, custom_path=None): """ Default compile spec for TOSA tests. """ + return get_tosa_compile_spec_unbuilt(permute_memory_to_nhwc, custom_path).build() + + +def get_tosa_compile_spec_unbuilt( + permute_memory_to_nhwc=False, custom_path=None +) -> ArmCompileSpecBuilder: + """Get the ArmCompileSpecBuilder for the default TOSA tests, to modify + the compile spec before calling .build() to finalize it. + """ intermediate_path = custom_path or tempfile.mkdtemp(prefix="arm_tosa_") if not os.path.exists(intermediate_path): os.makedirs(intermediate_path, exist_ok=True) - compile_spec = ( + compile_spec_builder = ( ArmCompileSpecBuilder() .tosa_compile_spec() .set_permute_memory_format(permute_memory_to_nhwc) .dump_intermediate_artifacts_to(intermediate_path) - .build() ) - return compile_spec + + return compile_spec_builder def get_u55_compile_spec( @@ -108,7 +117,20 @@ def get_u55_compile_spec( """ Default compile spec for Ethos-U55 tests. """ + return get_u55_compile_spec_unbuilt( + permute_memory_to_nhwc, quantize_io=quantize_io, custom_path=custom_path + ).build() + + +def get_u55_compile_spec_unbuilt( + permute_memory_to_nhwc=False, quantize_io=False, custom_path=None +) -> ArmCompileSpecBuilder: + """Get the ArmCompileSpecBuilder for the default TOSA tests, to modify + the compile spec before calling .build() to finalize it. + """ artifact_path = custom_path or tempfile.mkdtemp(prefix="arm_u55_") + if not os.path.exists(artifact_path): + os.makedirs(artifact_path, exist_ok=True) compile_spec = ( ArmCompileSpecBuilder() .ethosu_compile_spec( @@ -120,6 +142,5 @@ def get_u55_compile_spec( .set_quantize_io(is_option_enabled("quantize_io") or quantize_io) .set_permute_memory_format(permute_memory_to_nhwc) .dump_intermediate_artifacts_to(artifact_path) - .build() ) return compile_spec diff --git a/backends/arm/test/misc/test_debug_feats.py b/backends/arm/test/misc/test_debug_feats.py index 9a0702c900..aa9703f9eb 100644 --- a/backends/arm/test/misc/test_debug_feats.py +++ b/backends/arm/test/misc/test_debug_feats.py @@ -41,6 +41,8 @@ def forward(self, x): class TestDumpPartitionedArtifact(unittest.TestCase): + """Tests dumping the partition artifact in ArmTester. Both to file and to stdout.""" + def _tosa_MI_pipeline(self, module: torch.nn.Module, dump_file=None): ( ArmTester( @@ -96,6 +98,8 @@ def test_BI_artifact(self): class TestNumericalDiffPrints(unittest.TestCase): + """Tests trigging the exception printout from the ArmTester's run and compare function.""" + def test_numerical_diff_prints(self): model = Linear(20, 30) tester = ( @@ -120,3 +124,28 @@ def test_numerical_diff_prints(self): pass # Implicit pass test else: self.fail() + + +class TestDumpOperatorsAndDtypes(unittest.TestCase): + def test_dump_ops_and_dtypes(self): + model = Linear(20, 30) + ( + ArmTester( + model, + example_inputs=model.get_inputs(), + compile_spec=common.get_tosa_compile_spec(), + ) + .quantize() + .dump_dtype_distribution() + .dump_operator_distribution() + .export() + .dump_dtype_distribution() + .dump_operator_distribution() + .to_edge() + .dump_dtype_distribution() + .dump_operator_distribution() + .partition() + .dump_dtype_distribution() + .dump_operator_distribution() + ) + # Just test that there are no execeptions. diff --git a/backends/arm/test/models/test_mobilenet_v2_arm.py b/backends/arm/test/models/test_mobilenet_v2_arm.py index eae5d4358a..248153a518 100644 --- a/backends/arm/test/models/test_mobilenet_v2_arm.py +++ b/backends/arm/test/models/test_mobilenet_v2_arm.py @@ -22,8 +22,9 @@ class TestMobileNetV2(unittest.TestCase): + """Tests MobileNetV2.""" - mv2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights) + mv2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT) mv2 = mv2.eval() normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 622d811822..3bd2b2605c 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -17,6 +17,8 @@ class TestSimpleAdd(unittest.TestCase): + """Tests a single add op, x+x and x+y.""" + class Add(torch.nn.Module): test_parameters = [ (torch.FloatTensor([1, 2, 3, 5, 7]),), diff --git a/backends/arm/test/ops/test_avg_pool.py b/backends/arm/test/ops/test_avg_pool.py index fb2609939f..32a0e5555a 100644 --- a/backends/arm/test/ops/test_avg_pool.py +++ b/backends/arm/test/ops/test_avg_pool.py @@ -28,6 +28,8 @@ class TestAvgPool2d(unittest.TestCase): + """Tests AvgPool2d.""" + class AvgPool2d(torch.nn.Module): def __init__( self, diff --git a/backends/arm/test/ops/test_batch_norm.py b/backends/arm/test/ops/test_batch_norm.py index 0d6f9dea2c..4935e910d6 100644 --- a/backends/arm/test/ops/test_batch_norm.py +++ b/backends/arm/test/ops/test_batch_norm.py @@ -497,6 +497,8 @@ class TestBatchNorm2d(unittest.TestCase): + """Tests BatchNorm2d.""" + class BatchNorm2d(torch.nn.Module): def __init__( self, diff --git a/backends/arm/test/ops/test_clone.py b/backends/arm/test/ops/test_clone.py index edfaafbcc2..8386283f24 100644 --- a/backends/arm/test/ops/test_clone.py +++ b/backends/arm/test/ops/test_clone.py @@ -8,19 +8,25 @@ # Tests the clone op which copies the data of the input tensor (possibly with new data format) # -import logging import unittest from typing import Tuple import torch + +from executorch.backends.arm.quantizer.arm_quantizer import ( + ArmQuantizer, + get_symmetric_quantization_config, +) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester -from parameterized import parameterized -logger = logging.getLogger(__name__) +from executorch.backends.xnnpack.test.tester.tester import Quantize +from parameterized import parameterized class TestSimpleClone(unittest.TestCase): + """Tests clone.""" + class Clone(torch.nn.Module): sizes = [10, 15, 50, 100] test_parameters = [(torch.ones(n),) for n in sizes] @@ -53,13 +59,14 @@ def _test_clone_tosa_MI_pipeline( def _test_clone_tosa_BI_pipeline( self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] ): + quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec(), ) - .quantize() + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() .check_count({"torch.ops.aten.clone.default": 1}) .to_edge() @@ -72,13 +79,14 @@ def _test_clone_tosa_BI_pipeline( def _test_clone_tosa_u55_pipeline( self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] ): + quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester( module, example_inputs=test_data, compile_spec=common.get_u55_compile_spec(), ) - .quantize() + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() .check_count({"torch.ops.aten.clone.default": 1}) .to_edge() @@ -91,16 +99,10 @@ def _test_clone_tosa_u55_pipeline( def test_clone_tosa_MI(self, test_tensor: torch.Tensor): self._test_clone_tosa_MI_pipeline(self.Clone(), (test_tensor,)) - # Expected to fail since ArmQuantizer cannot quantize a Clone layer - # TODO MLETROCH-125 @parameterized.expand(Clone.test_parameters) - @unittest.expectedFailure def test_clone_tosa_BI(self, test_tensor: torch.Tensor): self._test_clone_tosa_BI_pipeline(self.Clone(), (test_tensor,)) - # Expected to fail since ArmQuantizer cannot quantize a Clone layer - # TODO MLETROCH-125 @parameterized.expand(Clone.test_parameters) - @unittest.expectedFailure def test_clone_u55_BI(self, test_tensor: torch.Tensor): self._test_clone_tosa_u55_pipeline(self.Clone(), (test_tensor,)) diff --git a/backends/arm/test/ops/test_conv.py b/backends/arm/test/ops/test_conv.py index 614d056072..9ebfe77da2 100644 --- a/backends/arm/test/ops/test_conv.py +++ b/backends/arm/test/ops/test_conv.py @@ -244,6 +244,8 @@ def forward(self, x): class TestConv2D(unittest.TestCase): + """Tests Conv2D, both single ops and multiple Convolutions in series.""" + def _test_conv2d_tosa_MI_pipeline( self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] ): diff --git a/backends/arm/test/ops/test_conv_combos.py b/backends/arm/test/ops/test_conv_combos.py index 41f76ccbb7..88006df1a0 100644 --- a/backends/arm/test/ops/test_conv_combos.py +++ b/backends/arm/test/ops/test_conv_combos.py @@ -154,6 +154,8 @@ def forward(self, x): class TestConvCombos(unittest.TestCase): + """Tests conv combined with other ops.""" + def _test_conv_combo_tosa_MI_pipeline( self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] ): diff --git a/backends/arm/test/ops/test_depthwise_conv.py b/backends/arm/test/ops/test_depthwise_conv.py index 7eacbac432..9b3f79e6a1 100644 --- a/backends/arm/test/ops/test_depthwise_conv.py +++ b/backends/arm/test/ops/test_depthwise_conv.py @@ -9,6 +9,8 @@ from typing import Tuple +import pytest + import torch from executorch.backends.arm.test import common from executorch.backends.arm.test.ops.test_conv import Conv2d @@ -130,6 +132,9 @@ class TestDepthwiseConv2D(unittest.TestCase): + """Tests Conv2D where groups == in_channels and out_channels = K * in_channels. This + is a special case enables depthwise convolution.""" + def _test_dw_conv2d_tosa_MI_pipeline( self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] ): @@ -189,7 +194,9 @@ def _test_dw_conv2d_u55_BI_pipeline( def test_dw_conv2d_tosa_MI(self, test_name, model): self._test_dw_conv2d_tosa_MI_pipeline(model, model.get_inputs()) + # TODO: Investigate flakyness (MLTORCH-307) @parameterized.expand(testsuite) + @pytest.mark.flaky(reruns=3) def test_dw_conv2d_tosa_BI(self, test_name, model): self._test_dw_conv2d_tosa_BI_pipeline(model, model.get_inputs()) diff --git a/backends/arm/test/ops/test_div.py b/backends/arm/test/ops/test_div.py index b13581dca1..60a0b8a4cc 100644 --- a/backends/arm/test/ops/test_div.py +++ b/backends/arm/test/ops/test_div.py @@ -78,6 +78,8 @@ class TestDiv(unittest.TestCase): + """Tests division""" + class Div(torch.nn.Module): def __init__( self, diff --git a/backends/arm/test/ops/test_full.py b/backends/arm/test/ops/test_full.py index 4f01b1c8f9..1be7f59ab8 100644 --- a/backends/arm/test/ops/test_full.py +++ b/backends/arm/test/ops/test_full.py @@ -19,6 +19,8 @@ class TestFull(unittest.TestCase): + """Tests the full op which creates a tensor of a given shape filled with a given value.""" + class Full(torch.nn.Module): # A single full op def forward(self): diff --git a/backends/arm/test/ops/test_linear.py b/backends/arm/test/ops/test_linear.py index 0e6747fe27..33f62955ec 100644 --- a/backends/arm/test/ops/test_linear.py +++ b/backends/arm/test/ops/test_linear.py @@ -91,6 +91,7 @@ class TestLinear(unittest.TestCase): + """tests the linear operation y = Ax + b""" _edge_compile_config: EdgeCompileConfig = EdgeCompileConfig( _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. @@ -155,7 +156,7 @@ def _test_linear_tosa_BI_pipeline( def _test_linear_tosa_u55_BI_pipeline( self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] ): - ( + tester = ( ArmTester( module, example_inputs=test_data, @@ -169,8 +170,12 @@ def _test_linear_tosa_u55_BI_pipeline( .partition() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() + .serialize() ) + if common.is_option_enabled("corstone300"): + tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) + @parameterized.expand(test_data_suite_rank1 + test_data_suite_rank4) def test_linear_tosa_MI( self, diff --git a/backends/arm/test/ops/test_mean_dim.py b/backends/arm/test/ops/test_mean_dim.py index 433661e99e..e0db958f74 100644 --- a/backends/arm/test/ops/test_mean_dim.py +++ b/backends/arm/test/ops/test_mean_dim.py @@ -40,6 +40,8 @@ class TestMeanDim(unittest.TestCase): + """Tests MeanDim, called AdaptiveAvgPool2d in Pytorch.""" + class MeanDim(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/arm/test/ops/test_slice.py b/backends/arm/test/ops/test_slice.py new file mode 100644 index 0000000000..a1c1e29cbc --- /dev/null +++ b/backends/arm/test/ops/test_slice.py @@ -0,0 +1,116 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import Tuple + +import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + ArmQuantizer, + get_symmetric_quantization_config, +) +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.backends.xnnpack.test.tester.tester import Quantize +from parameterized import parameterized + + +class TestSimpleSlice(unittest.TestCase): + + class Slice(torch.nn.Module): + + sizes = [(10), (10, 10), (10, 10, 10), ((1, 12, 10, 10))] + test_tensors = [(torch.ones(n),) for n in sizes] + + def forward(self, x: torch.Tensor): + if x.dim() == 1: + return x[3:-3] + elif x.dim() == 2: + return x[1:3, 3:5] + elif x.dim() == 3: + return x[0:7, 0:1, 0:8] + elif x.dim() == 4: + return x[:, 2:5, 3:5, 4:5] + + def _test_slice_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: torch.Tensor + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(), + ) + .export() + .check(["torch.ops.aten.slice.Tensor"]) + .to_edge() + .check(["executorch_exir_dialects_edge__ops_aten_slice_copy"]) + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_slice_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor], permute: bool + ): + + quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec( + permute_memory_to_nhwc=permute + ), + ) + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .export() + .check(["torch.ops.aten.slice.Tensor"]) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data, qtol=1) + ) + + def _test_slice_u55_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_u55_compile_spec(), + ) + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .export() + .check(["torch.ops.aten.slice.Tensor"]) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + ) + + @parameterized.expand(Slice.test_tensors) + def test_slice_tosa_MI(self, tensor): + self._test_slice_tosa_MI_pipeline(self.Slice(), (tensor,)) + + @parameterized.expand(Slice.test_tensors[:2]) + def test_slice_nchw_tosa_BI(self, test_tensor: torch.Tensor): + self._test_slice_tosa_BI_pipeline(self.Slice(), (test_tensor,), False) + + @parameterized.expand(Slice.test_tensors[2:]) + def test_slice_nhwc_tosa_BI(self, test_tensor: torch.Tensor): + self._test_slice_tosa_BI_pipeline(self.Slice(), (test_tensor,), True) + + # Fails during Vela compilation when trying to use a Tuple as a Named tuple, + # Could be Vela Issue, wait until Regor. + @parameterized.expand(Slice.test_tensors) + @unittest.expectedFailure + def test_slice_u55_BI(self, test_tensor: torch.Tensor): + self._test_slice_u55_BI_pipeline(self.Slice(), (test_tensor,)) diff --git a/backends/arm/test/ops/test_softmax.py b/backends/arm/test/ops/test_softmax.py index b2ef115dad..b3b6230daa 100644 --- a/backends/arm/test/ops/test_softmax.py +++ b/backends/arm/test/ops/test_softmax.py @@ -28,6 +28,8 @@ class TestSoftmax(unittest.TestCase): + """Tests softmax.""" + class Softmax(torch.nn.Module): def __init__(self, dim: int = -1): super().__init__() diff --git a/backends/arm/test/ops/test_view.py b/backends/arm/test/ops/test_view.py index 5dcd1fe73f..1f51261bf7 100644 --- a/backends/arm/test/ops/test_view.py +++ b/backends/arm/test/ops/test_view.py @@ -8,19 +8,25 @@ # Tests the view op which changes the size of a Tensor without changing the underlying data. # -import logging import unittest from typing import Tuple import torch + +from executorch.backends.arm.quantizer.arm_quantizer import ( + ArmQuantizer, + get_symmetric_quantization_config, +) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester -from parameterized import parameterized -logger = logging.getLogger(__name__) +from executorch.backends.xnnpack.test.tester.tester import Quantize +from parameterized import parameterized class TestSimpleView(unittest.TestCase): + """Tests the view operation.""" + class View(torch.nn.Module): sizes = [10, 15, 50, 100] @@ -50,13 +56,14 @@ def _test_view_tosa_MI_pipeline( def _test_view_tosa_BI_pipeline( self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] ): + quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester( module, example_inputs=test_data, compile_spec=common.get_tosa_compile_spec(), ) - .quantize() + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() .check_count({"torch.ops.aten.view.default": 1}) .to_edge() @@ -69,13 +76,14 @@ def _test_view_tosa_BI_pipeline( def _test_view_u55_BI_pipeline( self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] ): + quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) ( ArmTester( module, example_inputs=test_data, compile_spec=common.get_u55_compile_spec(), ) - .quantize() + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) .export() .check_count({"torch.ops.aten.view.default": 1}) .to_edge() @@ -88,16 +96,10 @@ def _test_view_u55_BI_pipeline( def test_view_tosa_MI(self, test_tensor: torch.Tensor): self._test_view_tosa_MI_pipeline(self.View(), (test_tensor,)) - # Expected to fail since ArmQuantizer cannot quantize a View layer. - # TODO MLETROCH-125 @parameterized.expand(View.test_parameters) - @unittest.expectedFailure def test_view_tosa_BI(self, test_tensor: torch.Tensor): self._test_view_tosa_BI_pipeline(self.View(), (test_tensor,)) - # Expected to fail since ArmQuantizer cannot quantize a View layer. - # TODO MLETROCH-125 @parameterized.expand(View.test_parameters) - @unittest.expectedFailure def test_view_u55_BI(self, test_tensor: torch.Tensor): self._test_view_u55_BI_pipeline(self.View(), (test_tensor,)) diff --git a/backends/arm/test/passes/test_tag_io_quant_pass.py b/backends/arm/test/passes/test_tag_io_quant_pass.py index 8757cf99d8..9f292bb7ca 100644 --- a/backends/arm/test/passes/test_tag_io_quant_pass.py +++ b/backends/arm/test/passes/test_tag_io_quant_pass.py @@ -22,6 +22,7 @@ def forward(self, x): class TestTagIOQuantPass(unittest.TestCase): + """Tests the TagIOQuantPass which tags q/dq nodes on model inputs and outputs to not include them in our partitions.""" def _tosa_BI_u55_pipeline(self, module: torch.nn.Module): ( diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 19d76e13b4..58c99a9201 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -510,4 +510,20 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict: with open(os.path.join(tmp, "output.json"), "r") as f: json_out = json.load(f) + # Cast float tensors to proper dtype. + try: + for region in json_out["regions"]: + for block in region["blocks"]: + for tensor in block["tensors"]: + if "data" in tensor: + if tensor["type"] == "FP32": + data = np.array(tensor["data"]) + data = data.astype(np.int8) + data = np.frombuffer(data, dtype=np.float32) + data = data.reshape(tensor["shape"]) + tensor["data"] = data + except Exception: + # This is just nice-to-have if it works, don't care if it fails. + pass + return json_out diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 97ab67b3d1..be5ea7dd71 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -4,7 +4,10 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Any, List, Literal, Optional, Tuple + +from collections import Counter +from pprint import pformat +from typing import Any, List, Literal, Optional, Tuple, Union import executorch.backends.xnnpack.test.tester.tester as tester @@ -31,6 +34,7 @@ from executorch.backends.xnnpack.test.tester import Tester from executorch.exir import EdgeCompileConfig from executorch.exir.backend.compile_spec_schema import CompileSpec +from torch.fx import Graph logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -39,7 +43,6 @@ class Partition(tester.Partition): def dump_artifact(self, path_to_dump: Optional[str]): super().dump_artifact(path_to_dump) - from pprint import pformat to_print = None for spec in self.graph_module.lowered_module_0.compile_specs: @@ -55,12 +58,7 @@ def dump_artifact(self, path_to_dump: Optional[str]): to_print = f"\n Vela command stream: \n{to_print}" break assert to_print is not None, "No TOSA nor Vela compile spec found" - - if path_to_dump: - with open(path_to_dump, "a") as fp: - fp.write(to_print) - else: - print(to_print) + _dump_str(to_print, path_to_dump) class Serialize(tester.Serialize): @@ -272,6 +270,66 @@ def run_method_and_compare_outputs( return self + def get_graph(self, stage: str | None = None) -> Graph: + if stage is None: + stage = self.cur + artifact = self.get_artifact(stage) + if self.cur == self.stage_name(tester.ToEdge) or self.cur == self.stage_name( + Partition + ): + graph = artifact.exported_program().graph + elif self.cur == self.stage_name(tester.Export) or self.cur == self.stage_name( + tester.Quantize + ): + graph = artifact.graph + else: + raise RuntimeError( + "Can only get a graph from Quantize, ToEdge, Export, and Partition stages." + ) + + return graph + + def dump_operator_distribution( + self, path_to_dump: Optional[str] = None + ) -> ArmQuantizer: + """Dump a dictionary with {operator: operator count} for the operators in the + graph of the current stage. + + Returns self for daisy-chaining. + """ + graph = self.get_graph(self.cur) + op_dist = _get_operator_distribution(graph) + to_print = self.cur + " operators: " + _format_dict(op_dist) + "\n" + _dump_str(to_print, path_to_dump) + return self + + def dump_dtype_distribution( + self, path_to_dump: Optional[str] = None + ) -> ArmQuantizer: + """Dump a dictionary with {dtype: dtype count} for the dtypes of the nodes in the + graph of the current stage. + + Returns self for daisy-chaining. + """ + graph = self.get_graph(self.cur) + op_dist = _get_dtype_distribution(graph) + to_print = self.cur + " placeholder data types: " + _format_dict(op_dist) + "\n" + _dump_str(to_print, path_to_dump) + return self + + @staticmethod + def _calculate_reference_output( + module: Union[torch.fx.GraphModule, torch.nn.Module], inputs + ) -> torch.Tensor: + """ + Note: I'd prefer to use the base class method here, but since it use the + exported program, I can't. The partitioner stage clears the state_dict + of the exported program, which causes an issue when evaluating the + module. + """ + + return module.forward(*inputs) + def transpose_data_format( self, data: Tuple[torch.Tensor], to: Literal["NHWC", "NCHW"] ): @@ -331,3 +389,37 @@ def _compare_outputs( ) logger.error(f"{atol=}, {rtol=}, {qtol=}") raise e + + +def _get_dtype_distribution(graph: Graph) -> dict: + """Counts the occurences of placeholder data types in a graph. + The result is a dict {'data type':'number of placeholders'} + """ + return Counter( + [ + node.meta["val"].dtype + for node in list(graph.nodes) + if node.op == "placeholder" + ] + ) + + +def _get_operator_distribution(graph: Graph) -> dict[str, int]: + """Counts the occurences of operator names in a graph. + The result is a dict {'operator name':'number of nodes'} + """ + return Counter( + [str(node.target) for node in list(graph.nodes) if node.op == "call_function"] + ) + + +def _dump_str(to_print: str, path_to_dump: Optional[str] = None): + if path_to_dump: + with open(path_to_dump, "a") as fp: + fp.write(to_print) + else: + print(to_print) + + +def _format_dict(to_print: dict) -> str: + return pformat(to_print, compact=True, indent=1) diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index bd4ec660a6..79646c1293 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -28,13 +28,13 @@ python_library( "compiler.py", ], deps = [ - "fbsource//third-party/pypi/pyre-extensions:pyre-extensions", ":passes", ":utils", "//caffe2:torch", "//executorch/backends/cadence/aot/quantizer:fusion_pass", "//executorch/backends/cadence/aot/quantizer:quantizer", "//executorch/backends/transforms:decompose_sdpa", + "//executorch/backends/transforms:remove_clone_ops", "//executorch/exir:lib", ], ) @@ -49,5 +49,7 @@ python_library( "//caffe2:torch", "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", + "//executorch/exir/passes:lib", + "//executorch/exir/passes:spec_prop_pass", ], ) diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 302252c42a..509e254b55 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -7,27 +7,28 @@ # pyre-strict import logging +from typing import Optional import torch from executorch.backends.cadence.aot.passes import ( + InitializePipeline, + RemoveNopExpandOpPass, RemoveZeroSizedCatArgsPass, + ReplaceLogicalNotBooleanWhereWithWherePass, ReplacePT2DequantWithCadenceDequantPass, ReplacePT2QuantWithCadenceQuantPass, ReplaceScalarTensorWithFullPass, ReplaceSqueezeAndUnsqueezeWithViewPass, ) from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion -from executorch.backends.cadence.aot.quantizer.quantizer import ( - CadenceAtenQuantizer, - CadenceQuantizer, -) +from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer from executorch.backends.cadence.aot.utils import model_is_quantized from executorch.backends.transforms.decompose_sdpa import ( DecomposeScaledDotProductAttention, ) +from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge -from pyre_extensions import assert_is_instance from torch._export import capture_pre_autograd_graph from torch.ao.quantization.pt2e.export_utils import model_is_exported from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -36,16 +37,24 @@ from torch.export.exported_program import ExportedProgram -def quantize_pt2( +# Note: this is not meant as a primary API since it can create inconsistencies +# if the quantizer here is different from the quantizer used to convert. It is +# however useful for unit tests to separate the converted model from the fused +# model, to be able to get reference numerics. +# If this does not apply, please use quantize_and_fuse_pt2 instead. +def convert_pt2( model: torch.nn.Module, inputs: tuple[object, ...], + quantizer: CadenceQuantizer, ) -> torch.fx.GraphModule: """ - Instantiate the CadenceQuantizer (PTQ), prepare, convert and fuse the model. - Returns a GraphModule with the quantized model. + Prepare and convert a model using the given quantizer. + The quantizer must be supplied and be the same as the one used to + fuse the model later, if applicable. If you do not expect that behavior, + please use quantize_and_fuse_pt2 instead, which will instantiate a + default quantizer for you if needed. + Returns a GraphModule with the converted model. """ - # Quantizer - quantizer = CadenceQuantizer() # Export with dynamo model_exp = capture_pre_autograd_graph(model, inputs) @@ -62,14 +71,54 @@ def quantize_pt2( # Convert converted_model = convert_pt2e(prepared_model) + return converted_model + + +# Note: this is not meant as a primary API since it can create inconsistencies +# if the quantizer here is different from the quantizer used to convert. It is +# however useful for unit tests to separate the converted model from the fused +# model, to be able to get reference numerics. +# If this does not apply, please use quantize_and_fuse_pt2 instead. +def fuse_pt2( + converted_graph_module: torch.fx.GraphModule, + quantizer: CadenceQuantizer, +) -> torch.fx.GraphModule: + """ + Fuse a converted graph module using the given quantizer. + The quantizer must be the same as the one used to convert the model. + If you do not expect that behavior, please use quantize_and_fuse_pt2 instead, + which will instantiate a default quantizer for you if needed. + Returns a GraphModule with the fused model. + """ # Get patterns and apply fusion of dq -> op -> q to qop - patterns = [ - assert_is_instance(q, CadenceAtenQuantizer).pattern - for q in quantizer.quantizers - ] - QuantFusion(patterns)(converted_model) + # pyre-ignore[16]: no attribute + patterns = [q.pattern for q in quantizer.quantizers] + QuantFusion(patterns)(converted_graph_module) - return converted_model + return converted_graph_module + + +# Note: this is the one-liner API to quantize and fuse a model. +def quantize_pt2( + model: torch.nn.Module, + inputs: tuple[object, ...], + quantizer: Optional[CadenceQuantizer] = None, +) -> torch.fx.GraphModule: + """ + Prepare, convert and fuse the model using the given quantizer. + Returns a GraphModule with the quantized model. + """ + # Quantizer + if not quantizer: + quantizer = CadenceQuantizer() + + # Get converted graph module + converted_gm = convert_pt2(model, inputs, quantizer) + + # Get fused model + fused_gm = fuse_pt2(converted_gm, quantizer) + + return fused_gm # Export the model and lower it to an ExportedProgram (in aten IR) @@ -148,8 +197,12 @@ def export_to_cadence( # Run a couple required passes for quant/dequant ops cadence_program_manager = edge_program_manager.transform( [ + InitializePipeline(), RemoveZeroSizedCatArgsPass(), + ReplaceLogicalNotBooleanWhereWithWherePass(), ReplaceScalarTensorWithFullPass(), + RemoveCloneOpsTransform(), + RemoveNopExpandOpPass(), ReplaceSqueezeAndUnsqueezeWithViewPass(), ReplacePT2QuantWithCadenceQuantPass(), ReplacePT2DequantWithCadenceDequantPass(), diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index f79d5f870d..dbfe1e3639 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -62,16 +62,31 @@ - arg_meta: null kernel_name: torch::executor::full_out +- op: mean.out + kernels: + - arg_meta: null + kernel_name: torch::executor::mean_dim_out + - op: mul.out kernels: - arg_meta: null kernel_name: torch::executor::mul_out +- op: mul.Scalar_out + kernels: + - arg_meta: null + kernel_name: torch::executor::mul_scalar_out + - op: permute_copy.out kernels: - arg_meta: null kernel_name: torch::executor::permute_copy_out +- op: rsqrt.out + kernels: + - arg_meta: null + kernel_name: torch::executor::rsqrt_out + - op: sigmoid.out kernels: - arg_meta: null @@ -134,3 +149,8 @@ kernels: - arg_meta: null kernel_name: impl::reference::quantized_relu_out + +func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_matmul_out diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index c877a7149d..adcf086873 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from math import prod from typing import Optional, Tuple import torch @@ -186,28 +187,29 @@ def quantized_matmul_meta( X_size = list(X.size()) Y_size = list(Y.size()) - assert len(X_size) == len( - Y_size - ), "quantized matmul not supported for tensors of different dimensions" - - if len(X_size) == 3: - assert ( - X_size[0] == Y_size[0] - ), "quantized matmul only supported for batch dimension of same size" - if transposed: - assert X_size[2] == Y_size[2], "matrices cannot be multiplied" - out_size = X_size[:2] + [Y_size[1]] - else: - assert X_size[2] == Y_size[1], "matrices cannot be multiplied" - out_size = X_size[:2] + [Y_size[2]] - elif len(X_size) == 2: - if transposed: - assert X_size[1] == Y_size[1], "matrices cannot be multiplied" - out_size = [X_size[0], Y_size[0]] - else: - assert X_size[1] == Y_size[0], "matrices cannot be multiplied" - out_size = [X_size[0], Y_size[1]] + # Get the batch dimensions for both tensors + X_batch_dims = X_size[:-2] + Y_batch_dims = Y_size[:-2] + + # If they don't match, check that they're compatible + if X_batch_dims != Y_batch_dims: + assert prod(X_batch_dims) == prod( + Y_batch_dims + ), f"Batch dimensions of X and Y do not match: {X_batch_dims} vs {Y_batch_dims}" + + # Get the matmul output size + if transposed: + assert X_size[-1] == Y_size[-1], "matrices cannot be multiplied" + mat_size = [X_size[-2], Y_size[-2]] else: - raise AssertionError("quantized matmul only supported for 2D or 3D tensors") + assert X_size[-1] == Y_size[-2], "matrices cannot be multiplied" + mat_size = [X_size[-2], Y_size[-1]] + + # Combine the larger batch dimensions with the matmul output size + out_size = ( + X_batch_dims + mat_size + if len(X_batch_dims) > len(Y_batch_dims) + else Y_batch_dims + mat_size + ) return X.new_empty(out_size, dtype=X.dtype) diff --git a/backends/cadence/aot/passes.py b/backends/cadence/aot/passes.py index ca8a44f00c..db419bfb5e 100644 --- a/backends/cadence/aot/passes.py +++ b/backends/cadence/aot/passes.py @@ -4,18 +4,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, Tuple +# pyre-strict + +from typing import Any, cast, Dict, Sequence, Tuple import torch from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue +from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue +from executorch.exir.passes import dead_code_elimination_pass +from executorch.exir.passes.spec_prop_pass import SpecPropPass from torch._subclasses import FakeTensor from torch.utils._pytree import tree_map_only - -# pyre-strict - # Similar to what's done in executorch/exir/pass_base.py Argument = Any # pyre-ignore @@ -173,3 +174,95 @@ def call_operator( init_args[0] = new_args args = tuple(args) return super().call_operator(op, args, kwargs, meta) + + +class RemoveNopExpandOpPass(ExportPass): + """ + For an expand op, if the operator shape matches the expand shape, then the + expand is a nop. + """ + + def call_operator( + self, + op, # pyre-ignore + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if get_edge_overload_packet(op) not in { + exir_ops.edge.aten.expand_copy, + exir_ops.edge.aten.expand, + }: + return super().call_operator(op, args, kwargs, meta) + + # Parse the args, and check for nop condition + arg0 = cast(ProxyValue, args[0]) + arg1 = cast(Sequence[int], args[1]) + in_tensor = arg0.to_tensor() + if list(in_tensor.shape) == list(arg1): + return arg0 + + return super().call_operator(op, args, kwargs, meta) + + +class ReplaceLogicalNotBooleanWhereWithWherePass(ExportPass): + """ + A where op with a logical_not and a boolean tensor can be replaced + by a where op with flipped inputs and the initial boolean tensor. + """ + + def replace_logical_nop_where_with_where( + self, graph_module: torch.fx.GraphModule + ) -> None: + graph = graph_module.graph + for node in graph.nodes: + # We are only interested in where nodes + if node.target != exir_ops.edge.aten.where.self: + continue + + # If the third arg is not a logical_not, bail. + if node.args[0].target != exir_ops.edge.aten.logical_not.default: + continue + + # Get the third arg node and its input + logical_not_node = node.args[0] + logical_not_input_tensor = ( + logical_not_node.args[0].to_tensor() + if isinstance(logical_not_node.args[0], ProxyValue) + else logical_not_node.args[0] + ) + + # If the logical_not input is not a boolean tensor, bail. + if logical_not_input_tensor.meta["spec"].dtype != torch.bool: + continue + + # Replace the where op with another one, flipping the inputs and using the boolean + # tensor from logical_not. + with graph.inserting_before(node): + linear_node = graph.call_function( + exir_ops.edge.aten.where.self, + args=(logical_not_node.args[0], node.args[2], node.args[1]), + ) + # Replace all the uses + node.replace_all_uses_with(linear_node) + + graph_module.recompile() + graph_module.graph.eliminate_dead_code() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + self.replace_logical_nop_where_with_where(graph_module) + result = super().call(graph_module) + return result + + +class InitializePipeline(ExportPass): + """ + Initialize the Jarvis pipeline. This should invariably be the first pass to + run. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + dead_code_elimination_pass(graph_module) + result = SpecPropPass()(graph_module) + assert result is not None + return result diff --git a/backends/cadence/aot/quantizer/TARGETS b/backends/cadence/aot/quantizer/TARGETS index 8b3449cd85..6290626216 100644 --- a/backends/cadence/aot/quantizer/TARGETS +++ b/backends/cadence/aot/quantizer/TARGETS @@ -31,7 +31,6 @@ python_library( ], typing = True, deps = [ - "fbsource//third-party/pypi/pyre-extensions:pyre-extensions", ":patterns", ":utils", "//caffe2:torch", diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 4cd3c6bfb4..51bace9168 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -26,7 +26,6 @@ is_annotated, no_outside_users, ) -from pyre_extensions import assert_is_instance from torch import fx @@ -100,14 +99,11 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: continue for output, *custom_spec in anchors.output: - assert_is_instance(output, fx.Node).meta["quantization_annotation"] = ( - QuantizationAnnotation( - # pyre-ignore[6]: incompatible parameter type - output_qspec=( - custom_spec[0] if custom_spec else output_act_qspec - ), - _annotated=True, - ) + # pyre-ignore[16]: no attribute + output.meta["quantization_annotation"] = QuantizationAnnotation( + # pyre-ignore[6]: incompatible parameter type + output_qspec=(custom_spec[0] if custom_spec else output_act_qspec), + _annotated=True, ) def annotate_inputs( @@ -118,16 +114,17 @@ def annotate_inputs( spec: Optional[QuantizationSpec], ) -> None: for node, idx, *custom_spec in inputs: - _node = assert_is_instance(node, fx.Node) - annotation = _node.meta.get( + # pyre-ignore[16]: no attribute + annotation = node.meta.get( "quantization_annotation", QuantizationAnnotation(_annotated=True), ) - # pyre-ignore[6]: incompatible parameter type - annotation.input_qspec_map[_node.args[idx]] = ( + # pyre-ignore[16]: no attribute + annotation.input_qspec_map[node.args[idx]] = ( custom_spec[0] if custom_spec else spec ) - _node.meta["quantization_annotation"] = annotation + # pyre-ignore[16]: no attribute + node.meta["quantization_annotation"] = annotation annotate_inputs(anchors.inputs, input_act_qspec) annotate_inputs(anchors.weights, weight_qspec) diff --git a/backends/cadence/reference/operators/CMakeLists.txt b/backends/cadence/reference/operators/CMakeLists.txt index c22dc0c997..c81e934850 100644 --- a/backends/cadence/reference/operators/CMakeLists.txt +++ b/backends/cadence/reference/operators/CMakeLists.txt @@ -32,12 +32,15 @@ set(_aten_ops__srcs "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/matmul_ops_util.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/reduce_util.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/pattern/unary_ufunc_realhb_to_floath.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_cat.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mean.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mul.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_permute_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_rsqrt.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sigmoid.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_slice_copy.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_softmax.cpp" @@ -60,7 +63,8 @@ target_include_directories(aten_ops_cadence PUBLIC ${ROOT_DIR}/.. add_library( custom_ops "quantized_linear_out.cpp" "quantized_conv_out.cpp" "quantized_relu_out.cpp" "quantized_layer_norm.cpp" - "quantize_per_tensor.cpp" "dequantize_per_tensor.cpp") + "quantize_per_tensor.cpp" "dequantize_per_tensor.cpp" + "quantized_matmul_out.cpp") target_include_directories(custom_ops PUBLIC ${ROOT_DIR}/.. ${CMAKE_BINARY_DIR} ${_common_include_directories}) diff --git a/backends/cadence/reference/operators/quantized_matmul_out.cpp b/backends/cadence/reference/operators/quantized_matmul_out.cpp index 95df35caba..49dd222a96 100644 --- a/backends/cadence/reference/operators/quantized_matmul_out.cpp +++ b/backends/cadence/reference/operators/quantized_matmul_out.cpp @@ -13,6 +13,9 @@ namespace impl { namespace reference { namespace native { +using Tensor = exec_aten::Tensor; +using RuntimeContext = torch::executor::RuntimeContext; + // The quantized matmul. The quantized matmul accumulates in a wider register, // whose type is TA. template < @@ -50,27 +53,32 @@ __attribute__((noinline)) void qmatmul( } } -template +template void inline _typed_quantized_matmul( const Tensor& X, int64_t X_zero_point, const Tensor& Y, int64_t Y_zero_point, - const c10::optional& bias, + const exec_aten::optional& bias, int64_t out_multiplier, int64_t out_shift, int64_t out_zero_point, bool transposed, Tensor& out) { - ctype* __restrict__ out_data = out.mutable_data_ptr(); - const ctype* __restrict__ X_data = X.const_data_ptr(); - const ctype* __restrict__ Y_data = Y.const_data_ptr(); + size_t batch_size = getLeadingDims(X, X.dim() - 2); + size_t leading_dim = X.size(X.dim() - 2); + size_t out_dim = Y.size(Y.dim() - 1 - transposed); + size_t in_dim = X.size(X.dim() - 1); + + T* __restrict__ out_data = out.mutable_data_ptr(); + const T* __restrict__ X_data = X.const_data_ptr(); + const T* __restrict__ Y_data = Y.const_data_ptr(); for (size_t i = 0; i < batch_size; ++i) { - const ctype* x = X_data + i * leading_dim * in_dim; - const ctype* y = Y_data + i * in_dim * out_dim; - ctype* z = out_data + i * leading_dim * out_dim; + const T* x = X_data + i * leading_dim * in_dim; + const T* y = Y_data + i * in_dim * out_dim; + T* z = out_data + i * leading_dim * out_dim; if (transposed) { - qmatmul( + qmatmul( z, static_cast(out_multiplier), static_cast(out_shift), @@ -83,7 +91,7 @@ void inline _typed_quantized_matmul( in_dim, out_dim); } else { - qmatmul( + qmatmul( z, static_cast(out_multiplier), static_cast(out_shift), @@ -101,24 +109,18 @@ void inline _typed_quantized_matmul( } void quantized_matmul_out( + RuntimeContext& ctx, const Tensor& X, int64_t X_zero_point, const Tensor& Y, int64_t Y_zero_point, - const c10::optional& bias, + const exec_aten::optional& bias, int64_t out_multiplier, int64_t out_shift, int64_t out_zero_point, bool transposed, Tensor& out) { - (void)bias; - - size_t batch_size = getLeadingDims(X, X.dim() - 2); - size_t leading_dim = X.size(X.dim() - 2); - size_t out_dim = Y.size(Y.dim() - 1 - transposed); - size_t in_dim = X.size(X.dim() - 1); - - if (out.ScalarType() == at::ScalarType::Byte) { + if (out.scalar_type() == at::ScalarType::Byte) { _typed_quantized_matmul( X, X_zero_point, @@ -130,7 +132,7 @@ void quantized_matmul_out( out_zero_point, transposed, out); - } else if (out.ScalarType() == at::ScalarType::Char) { + } else if (out.scalar_type() == at::ScalarType::Char) { _typed_quantized_matmul( X, X_zero_point, diff --git a/backends/qualcomm/aot/ir/qcir_utils.cpp b/backends/qualcomm/aot/ir/qcir_utils.cpp index e025b8667a..75446bb733 100755 --- a/backends/qualcomm/aot/ir/qcir_utils.cpp +++ b/backends/qualcomm/aot/ir/qcir_utils.cpp @@ -100,7 +100,7 @@ Qnn_DataType_t ToDataType(qcir::DataType type) { } flatbuffers::Offset ToQuantizeParam( - const Qnn_QuantizeParams_t& param, + const Qnn_Tensor_t& tensor, flatbuffers::FlatBufferBuilder* builder) { static const std::unordered_map def_map{ {QNN_DEFINITION_IMPL_GENERATED, qcir::QuantizeDef::IMPL_GENERATED}, @@ -124,6 +124,7 @@ flatbuffers::Offset ToQuantizeParam( int32_t axis = 0; uint32_t bitwidth = 0; + auto param = QNN_VER_PTR(tensor)->quantizeParams; auto quant_type = type_map.at(param.quantizationEncoding); std::vector data; std::vector scales; @@ -160,7 +161,9 @@ flatbuffers::Offset ToQuantizeParam( } } break; default: - QNN_EXECUTORCH_LOG_ERROR("QNN_QUANTIZATION_ENCODING_UNDEFINED detected"); + QNN_EXECUTORCH_LOG_WARN( + "QNN_QUANTIZATION_ENCODING_UNDEFINED detected: %s", + QNN_VER_PTR(tensor)->name); break; } return CreateQuantizeParamDirect( @@ -174,7 +177,7 @@ flatbuffers::Offset ToQuantizeParam( &data); } -Qnn_QuantizeParams_t ToQuantizeParam(const qparam_type& param) { +Qnn_QuantizeParams_t ToQuantizeParam(const tensor_type& tensor) { static const std::unordered_map def_map{ {qcir::QuantizeDef::IMPL_GENERATED, QNN_DEFINITION_IMPL_GENERATED}, {qcir::QuantizeDef::DEFINED, QNN_DEFINITION_DEFINED}, @@ -196,6 +199,7 @@ Qnn_QuantizeParams_t ToQuantizeParam(const qparam_type& param) { }; Qnn_QuantizeParams_t p = QNN_QUANTIZE_PARAMS_INIT; + auto param = tensor->qparam(); p.encodingDefinition = def_map.at(param->def()); p.quantizationEncoding = type_map.at(param->type()); switch (p.quantizationEncoding) { @@ -225,7 +229,9 @@ Qnn_QuantizeParams_t ToQuantizeParam(const qparam_type& param) { const_cast(param->offsets()->data()); } break; default: - QNN_EXECUTORCH_LOG_ERROR("qcir::QuantizeType::UNDEFINED detected"); + QNN_EXECUTORCH_LOG_WARN( + "qcir::QuantizeType::UNDEFINED detected: %s", + tensor->name()->c_str()); break; } return p; @@ -248,7 +254,7 @@ flatbuffers::Offset ToTensor( &shape, ToTensorType(QNN_VER_PTR(tensor)->type), ToDataType(QNN_VER_PTR(tensor)->dataType), - ToQuantizeParam(QNN_VER_PTR(tensor)->quantizeParams, builder), + ToQuantizeParam(tensor, builder), &buffer); } @@ -261,7 +267,7 @@ Qnn_Tensor_t ToTensor(const tensor_type& tensor) { QNN_VER_PTR(t)->name = tensor->name()->c_str(); QNN_VER_PTR(t)->type = ToTensorType(tensor->type()); QNN_VER_PTR(t)->dataType = ToDataType(tensor->dtype()); - QNN_VER_PTR(t)->quantizeParams = ToQuantizeParam(tensor->qparam()); + QNN_VER_PTR(t)->quantizeParams = ToQuantizeParam(tensor); QNN_VER_PTR(t)->rank = tensor->shape()->size(); QNN_VER_PTR(t)->dimensions = const_cast(tensor->shape()->data()); QNN_VER_PTR(t)->clientBuf.dataSize = tensor->data()->size(); diff --git a/backends/qualcomm/aot/ir/qcir_utils.h b/backends/qualcomm/aot/ir/qcir_utils.h index 30a5481f9f..890dfa33ca 100755 --- a/backends/qualcomm/aot/ir/qcir_utils.h +++ b/backends/qualcomm/aot/ir/qcir_utils.h @@ -26,9 +26,9 @@ qcir::DataType ToDataType(Qnn_DataType_t type); Qnn_DataType_t ToDataType(qcir::DataType type); flatbuffers::Offset ToQuantizeParam( - const Qnn_QuantizeParams_t& param, + const Qnn_Tensor_t& tensor, flatbuffers::FlatBufferBuilder* builder); -Qnn_QuantizeParams_t ToQuantizeParam(const qparam_type& type); +Qnn_QuantizeParams_t ToQuantizeParam(const tensor_type& tensor); flatbuffers::Offset ToTensor( const Qnn_Tensor_t& tensor, diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index c4fbdeae14..d3bf98bae7 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -23,6 +23,8 @@ op_hardsigmoid, op_hardswish, op_hardtanh, + op_index, + op_index_put, op_layer_norm, op_linear, op_log_softmax, @@ -75,6 +77,8 @@ op_hardswish, op_hardtanh, op_hardsigmoid, + op_index, + op_index_put, op_layer_norm, op_linear, op_log_softmax, diff --git a/backends/qualcomm/builders/op_index.py b/backends/qualcomm/builders/op_index.py new file mode 100644 index 0000000000..6f8dc558fe --- /dev/null +++ b/backends/qualcomm/builders/op_index.py @@ -0,0 +1,83 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import numpy as np +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpGather, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Index(NodeVisitor): + # schema = aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor + target = ["aten.index.Tensor"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + + if len(node.args[1]) > 1: + # TODO consider to implement it in a recursive way. + raise NotImplementedError("Not support tuple of tensor.") + + indices_node = node.args[1][0] + indices_tensor = self.get_tensor(indices_node, node).to(torch.int32) + assert indices_tensor.size(0) != 0, "Not support empty indices list" + + indices_tensor_wrapper = self.define_tensor( + indices_node, + indices_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + + gather_input_tensors = [input_tensor_wrapper, indices_tensor_wrapper] + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=False, + ) + gather_output_tensors = [output_tensor_wrapper] + + gather_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpGather.op_name, + ) + gather_op.AddInputTensors(gather_input_tensors) + gather_op.AddOutputTensors(gather_output_tensors) + + # If support tuple of tensor, need to refine it based on len + gather_op.AddScalarParam( + OpGather.param_axis, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, + {"data": np.int32(0)}, + ) + + return gather_op diff --git a/backends/qualcomm/builders/op_index_put.py b/backends/qualcomm/builders/op_index_put.py new file mode 100644 index 0000000000..af5311dfb2 --- /dev/null +++ b/backends/qualcomm/builders/op_index_put.py @@ -0,0 +1,83 @@ +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpScatterNd, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class IndexPutVisitor(NodeVisitor): + target = ["aten.index_put.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + indicies_node = node.args[1] + indices_list = [ + self.get_tensor(idx, idx) for idx in indicies_node if idx is not None + ] + + # Unpack the tuple + indices_unpacked = [torch.flatten(idx) for idx in indices_list] + + # Convert to 2-D tensor + indices_qnn = torch.cat(indices_unpacked).unsqueeze(0) + indice_node = [n for n in indicies_node if isinstance(n, torch.fx.Node)] + # TODO consider to write a pass to combine to one input tensor for indices + assert len(indice_node) == 1, "Not support mutilple indices tensor" + + indices_tensor_wrapper = self.define_tensor( + indice_node[0], + indices_qnn, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + value_node = node.args[2] + + value_tensor = self.get_tensor(value_node, node) + + value_tensor_wrapper = self.define_tensor( + value_node, + value_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=False, + ) + + index_put_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpScatterNd.op_name, + ) + index_put_op.AddInputTensors( + [input_tensor_wrapper, indices_tensor_wrapper, value_tensor_wrapper] + ) + index_put_op.AddOutputTensors([output_tensor_wrapper]) + + return index_put_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index dca47ebeec..4a87e5dbbb 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -124,13 +124,6 @@ class OpExpandDims: param_axis: str = "axis" -@dataclass(init=False, frozen=True) -class OpReduceSum: - op_name: str = "ReduceSum" - param_axes: str = "axes" - param_keep_dims: str = "keep_dims" - - @dataclass(init=False, frozen=True) class OpFullyConnected: op_name: str = "FullyConnected" @@ -144,13 +137,14 @@ class OpGather: @dataclass(init=False, frozen=True) -class OpGelu: - op_name: str = "Gelu" +class OpGatherND: + op_name: str = "GatherNd" + param_batch_dims: str = "batch_dims" @dataclass(init=False, frozen=True) -class OpSqrt: - op_name: str = "ElementWiseSquareRoot" +class OpGelu: + op_name: str = "Gelu" @dataclass(init=False, frozen=True) @@ -246,6 +240,13 @@ class OpReduceMean: param_keep_dims: str = "keep_dims" +@dataclass(init=False, frozen=True) +class OpReduceSum: + op_name: str = "ReduceSum" + param_axes: str = "axes" + param_keep_dims: str = "keep_dims" + + @dataclass(init=False, frozen=True) class OpRelu: op_name: str = "Relu" @@ -277,6 +278,12 @@ class OpResizeNearestNeighbor: param_half_pixel_centers: str = "half_pixel_centers" +@dataclass(init=False, frozen=True) +class OpScatterNd: + op_name: str = "ScatterNd" + param_reduction: str = "reduction" + + @dataclass(init=False, frozen=True) class OpSigmoid: op_name: str = "Sigmoid" @@ -307,6 +314,11 @@ class OpSplit: param_split_index: str = "split_index" +@dataclass(init=False, frozen=True) +class OpSqrt: + op_name: str = "ElementWiseSquareRoot" + + @dataclass(init=False, frozen=True) class OpSqueeze: op_name: str = "Squeeze" diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index 61935cf353..c60afc2dd3 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -13,8 +13,7 @@ exir_ops.edge.aten.clone.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.slice_scatter.default, - exir_ops.edge.aten.index.Tensor, - exir_ops.edge.aten.index_put.default, + exir_ops.edge.aten.copy.default, ] allow_list_operator = [ diff --git a/backends/qualcomm/passes/annotate_quant_attrs.py b/backends/qualcomm/passes/annotate_quant_attrs.py index 199d26b026..0dc39d2a4d 100644 --- a/backends/qualcomm/passes/annotate_quant_attrs.py +++ b/backends/qualcomm/passes/annotate_quant_attrs.py @@ -94,9 +94,11 @@ def _dequant_fold_params(self, n, quant_attrs, param): def _annotate_quant_attrs( self, graph_module: torch.fx.GraphModule ) -> torch.fx.GraphModule: + # Keep track of const params that has been dequant, so it does not get + # dequant multiple times if the const param has more than 1 user + visited_const_param = set() for n in graph_module.graph.nodes: self._annotate_requant(n) - # With fold_quant enabled, check if the input of dq op is quantized param. param = None if n.target in dq_ops: @@ -106,7 +108,8 @@ def _annotate_quant_attrs( quant_attrs = get_quant_attrs(self.edge_program, n) self._annotate_source_nodes(n, quant_attrs) - if param is not None: + if param is not None and n.args[0] not in visited_const_param: + visited_const_param.add(n.args[0]) self._dequant_fold_params(n, quant_attrs, param) return graph_module diff --git a/backends/qualcomm/passes/recompose_pixel_unshuffle.py b/backends/qualcomm/passes/recompose_pixel_unshuffle.py index cadc310bbb..a47f3d119a 100644 --- a/backends/qualcomm/passes/recompose_pixel_unshuffle.py +++ b/backends/qualcomm/passes/recompose_pixel_unshuffle.py @@ -35,7 +35,13 @@ def call(self, graph_module: torch.fx.GraphModule): for node in graph.nodes: if node.op == "call_function" and node.target == self.reshape_target: with graph.inserting_after(node): - premute_node = node.args[0] + + # Clone op still exists between permute and reshape_target during quantization, + # so we need to check for args[0].args[0] to get permute node + if self.quantization_capture: + premute_node = node.args[0].args[0] + else: + premute_node = node.args[0] if any( [ len(node.args[1]) != 4, diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 91e31b62e4..d51e016473 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -12,7 +12,6 @@ RecomposePixelUnshuffle, ) from executorch.backends.qualcomm.passes.reduce_dynamic_range import ReduceDynamicRange -from executorch.backends.qualcomm.passes.remove_redundancy import RemoveRedundancy from executorch.backends.qualcomm.passes.replace_inf_buffer import ReplaceInfBuffer from executorch.backends.transforms.decompose_sdpa import ( DecomposeScaledDotProductAttention, @@ -182,7 +181,6 @@ def set_per_channel_linear_quant(self, enable: bool) -> None: self._update_per_channel_weight_quant_ops(linear_ops, enable) def transform_for_annotation(self, model: GraphModule) -> GraphModule: - model = RemoveRedundancy()(model).graph_module model = ReduceDynamicRange()(model).graph_module model = RecomposePixelUnshuffle(quantization_capture=True)(model).graph_module model = DecomposeScaledDotProductAttention()(model).graph_module diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index f2265daf32..d31b4753a3 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -784,6 +784,38 @@ def annotate_embedding(node: Node, quantization_config: QuantizationConfig) -> N ) +@register_annotator([torch.ops.aten.index.Tensor]) +def annotate_index(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_in_out_obs_sharing_op(node, quantization_config) + if not _is_annotated([node]): + input_qspec_map = {} + input = node.args[0] + input_qspec_map[input] = quantization_config.input_activation + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=SharedQuantizationSpec((input, node)), + _annotated=True, + ) + + +@register_annotator( + [torch.ops.aten.index_put.default, torch.ops.aten.index_put_.default] +) +def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None: + input = node.args[0] + value = node.args[2] + + input_qspec_map = {} + input_qspec_map[input] = quantization_config.input_activation + input_qspec_map[value] = SharedQuantizationSpec((input, node)) + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=SharedQuantizationSpec((input, node)), + _annotated=True, + ) + + @register_annotator([torch.ops.aten.expand.default]) def annotate_expand(node: Node, quantization_config: QuantizationConfig) -> None: annotate_in_out_obs_sharing_op(node, quantization_config) diff --git a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp index f08f688cf9..36512c4ff2 100644 --- a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp +++ b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp @@ -23,12 +23,12 @@ Result QnnExecuTorchBackend::init( ArrayRef compile_specs) const { // covert SizedBuffer to qnn ExecuTorch option QnnExecuTorchContextBinary qnn_context_blob; - const qnn_delegate::QnnExecuTorchOptions* qnn_executorch_options; + const qnn_delegate::QnnExecuTorchOptions* qnn_executorch_options = nullptr; qnn_context_blob.buffer = const_cast(processed->data()); qnn_context_blob.nbytes = processed->size(); - // covert CompileSpec to qnn ExecuTorch option + // convert CompileSpec to qnn ExecuTorch option for (auto& compile_spec : compile_specs) { if (std::strcmp(compile_spec.key, QNN_COMPILE_SPEC) == 0) qnn_executorch_options = diff --git a/backends/qualcomm/runtime/SharedBuffer.cpp b/backends/qualcomm/runtime/SharedBuffer.cpp index 430c8f757a..3fa62d09cd 100644 --- a/backends/qualcomm/runtime/SharedBuffer.cpp +++ b/backends/qualcomm/runtime/SharedBuffer.cpp @@ -87,7 +87,12 @@ SharedBuffer& SharedBuffer::GetSharedBufferManager() { std::lock_guard lk(init_mutex_); static SharedBuffer shared_buffer_manager; if (!shared_buffer_manager.GetInitialize()) { +#if defined(__aarch64__) Error status = shared_buffer_manager.Load(); +#else + // For x86_64 platform + Error status = Error::Ok; +#endif if (status == Error::Ok) { shared_buffer_manager.SetInitialize(true); } @@ -96,9 +101,11 @@ SharedBuffer& SharedBuffer::GetSharedBufferManager() { } SharedBuffer::~SharedBuffer() { +#if defined(__aarch64__) if (initialize_) { SharedBuffer::GetSharedBufferManager().UnLoad(); } +#endif }; void* SharedBuffer::AllocMem(size_t bytes, size_t alignment) { diff --git a/backends/qualcomm/scripts/build.sh b/backends/qualcomm/scripts/build.sh index 3712a83fde..d6b1da62fc 100755 --- a/backends/qualcomm/scripts/build.sh +++ b/backends/qualcomm/scripts/build.sh @@ -107,19 +107,33 @@ if [ "$BUILD_X86_64" = true ]; then rm -rf $BUILD_ROOT && mkdir $BUILD_ROOT fi cd $BUILD_ROOT + # TODO: Use CMAKE_BUILD_TYPE=RelWithDebInfo, and handle flatcc issues cmake \ - -DCMAKE_BUILD_TYPE=RelWithDebInfo \ + -DCMAKE_BUILD_TYPE=Debug \ + -DCMAKE_INSTALL_PREFIX=$BUILD_ROOT \ -DQNN_SDK_ROOT=${QNN_SDK_ROOT} \ -DEXECUTORCH_BUILD_QNN=ON \ + -DEXECUTORCH_BUILD_SDK=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_ENABLE_EVENT_TRACER=ON \ -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ - -DBUCK2=$BUCK2 \ -S $PRJ_ROOT \ -B $BUILD_ROOT \ - cmake \ - --build $BUILD_ROOT \ - -t "PyQnnManagerAdaptor" "PyQnnWrapperAdaptor" -j16 + cmake --build $BUILD_ROOT -j16 --target install rm -f $PRJ_ROOT/backends/qualcomm/python/* cp -fv $BUILD_ROOT/backends/qualcomm/Py* "$PRJ_ROOT/backends/qualcomm/python" + + EXAMPLE_ROOT=examples/qualcomm + CMAKE_PREFIX_PATH="${BUILD_ROOT}/lib/cmake/ExecuTorch;${BUILD_ROOT}/third-party/gflags;" + + cmake $PRJ_ROOT/$EXAMPLE_ROOT \ + -DCMAKE_BUILD_TYPE=Debug \ + -DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \ + -DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \ + -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ + -B$EXAMPLE_ROOT + + cmake --build $EXAMPLE_ROOT -j16 fi diff --git a/backends/qualcomm/targets.bzl b/backends/qualcomm/targets.bzl new file mode 100644 index 0000000000..cb89bb24ef --- /dev/null +++ b/backends/qualcomm/targets.bzl @@ -0,0 +1,28 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + + +def generate_schema_header(rule_name, srcs, headers, default_header): + """Generate header file given flatbuffer schema + """ + runtime.genrule( + name = rule_name, + srcs = srcs, + # We're only generating a single file, so it seems like we could use + # `out`, but `flatc` takes a directory as a parameter, not a single + # file. Use `outs` so that `${OUT}` is expanded as the containing + # directory instead of the file itself. + outs = {header: [header] for header in headers}, + default_outs = [default_header], + cmd = " ".join([ + "$(exe {})".format(runtime.external_dep_location("flatc")), + "--cpp", + "--cpp-std c++11", + "--gen-mutable", + "--scoped-enums", + "-o ${OUT}", + "${SRCS}", + # Let our infra know that the file was generated. + " ".join(["&& echo // @" + "generated >> ${OUT}/" + header for header in headers]), + ]), + visibility = [], # Private + ) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index fe72b1e893..ff52fc61b5 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -443,6 +443,29 @@ def forward(self, x): return self.hardtanh(x) +class Index(torch.nn.Module): + def __init__(self): + super().__init__() + self.idx0 = torch.tensor([[0, 1], [2, 3], [4, 5]]) + self.idx1 = torch.tensor([[1, 2], [3, 4], [5, 6]]) + + def forward(self, x): + return x[self.idx0] + x[self.idx1] + + +class IndexPut(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer( + "k_cache", + torch.zeros((1, 1024, 12, 64), dtype=torch.float32), + ) + + def forward(self, input_pos, k_val): + k_out = torch.ops.aten.index_put_(self.k_cache, [None, input_pos], k_val) + return k_out + + class LayerNorm(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 508a027da6..c1c070ca3c 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -147,6 +147,7 @@ def test_qnn_backend_element_wise_ceil(self): def test_qnn_backend_element_wise_div(self): eps = 1e-03 + torch.manual_seed(8) test_comb = [ { QCOM_MODULE: [Div()], # noqa: F405 @@ -256,6 +257,19 @@ def test_qnn_backend_hardtanh(self): sample_input = (torch.randn([2, 5, 1, 3]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_index(self): + module = Index() # noqa: F405 + sample_input = (torch.randn([8, 172, 64]),) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_index_put(self): + module = IndexPut() # noqa: F405 + sample_input = ( + torch.tensor([2], dtype=torch.int32), + torch.randn([1, 1, 12, 64]), + ) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_interpolate_bilinear_2d(self): module = ResizeBilinear2D() # noqa: F405 sample_input = (torch.randn(2, 3, 4, 5),) @@ -708,6 +722,7 @@ def test_qnn_backend_element_wise_ceil(self): def test_qnn_backend_element_wise_div(self): eps = 1e-03 + torch.manual_seed(8) test_comb = [ { QCOM_MODULE: [Div()], # noqa: F405 @@ -827,6 +842,21 @@ def test_qnn_backend_hardtanh(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_index(self): + module = Index() # noqa: F405 + sample_input = (torch.randn([8, 172, 64]),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_index_put(self): + module = IndexPut() # noqa: F405 + sample_input = ( + torch.tensor([2], dtype=torch.int32), + torch.randn([1, 1, 12, 64]), + ) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_interpolate_bilinear_2d(self): module = ResizeBilinear2D() # noqa: F405 sample_input = (torch.randn(2, 3, 4, 5),) @@ -1295,7 +1325,6 @@ def test_qnn_backend_multi_contexts_composite(self): exec_prog = edge_prog.to_executorch() self.verify_output(module.get_reference_module(), sample_input, exec_prog) - @unittest.expectedFailure def test_qnn_backend_profile_op(self): TestQNN.enable_profile = True backend_options = generate_htp_compiler_spec(use_fp16=True) @@ -1310,7 +1339,7 @@ def test_qnn_backend_profile_op(self): module, sample_input, expected_partitions=1, - expected_profile_events=25, + expected_profile_events=24, ) def test_qnn_backend_shared_buffer(self): @@ -1460,7 +1489,6 @@ def test_qnn_backend_multi_contexts_composite(self): exec_prog = edge_prog.to_executorch() self.verify_output(module.get_reference_module(), sample_input, exec_prog) - @unittest.expectedFailure def test_qnn_backend_profile_op(self): TestQNN.enable_profile = True backend_options = generate_htp_compiler_spec(use_fp16=False) @@ -1476,7 +1504,7 @@ def test_qnn_backend_profile_op(self): module, sample_input, expected_partitions=1, - expected_profile_events=26, + expected_profile_events=25, ) def test_qnn_backend_shared_buffer(self): @@ -1581,8 +1609,11 @@ def test_fbnet(self): conn = listener.accept() p.communicate() msg = json.loads(conn.recv()) - self.assertGreaterEqual(msg["top_1"], 60) - self.assertGreaterEqual(msg["top_5"], 90) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["top_1"], 60) + self.assertGreaterEqual(msg["top_5"], 90) def test_gMLP(self): if not self.required_envs([self.image_dataset]): @@ -1614,8 +1645,11 @@ def test_gMLP(self): conn = listener.accept() p.communicate() msg = json.loads(conn.recv()) - self.assertGreaterEqual(msg["top_1"], 60) - self.assertGreaterEqual(msg["top_5"], 90) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["top_1"], 60) + self.assertGreaterEqual(msg["top_5"], 90) def test_ssd300_vgg16(self): if not self.required_envs([self.pretrained_weight, self.oss_repo]): @@ -1649,7 +1683,10 @@ def test_ssd300_vgg16(self): conn = listener.accept() p.communicate() msg = json.loads(conn.recv()) - self.assertGreaterEqual(msg["mAP"], 0.70) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["mAP"], 0.70) def test_dino_v2(self): if not self.required_envs([self.image_dataset]): @@ -1680,8 +1717,11 @@ def test_dino_v2(self): conn = listener.accept() p.communicate() msg = json.loads(conn.recv()) - self.assertGreaterEqual(msg["top_1"], 70) - self.assertGreaterEqual(msg["top_5"], 85) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["top_1"], 70) + self.assertGreaterEqual(msg["top_5"], 85) def test_esrgan(self): if not self.required_envs(): @@ -1714,8 +1754,11 @@ def test_esrgan(self): conn = listener.accept() p.communicate() msg = json.loads(conn.recv()) - self.assertGreaterEqual(msg["PSNR"], 24) - self.assertGreaterEqual(msg["SSIM"], 0.8) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["PSNR"], 24) + self.assertGreaterEqual(msg["SSIM"], 0.8) def test_squeezenet(self): if not self.required_envs([self.image_dataset]): @@ -1747,8 +1790,11 @@ def test_squeezenet(self): conn = listener.accept() p.communicate() msg = json.loads(conn.recv()) - self.assertGreaterEqual(msg["top_1"], 40) - self.assertGreaterEqual(msg["top_5"], 70) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["top_1"], 40) + self.assertGreaterEqual(msg["top_5"], 70) class TestExampleScript(TestQNN): @@ -1794,8 +1840,11 @@ def test_mobilenet_v2(self): conn = listener.accept() p.communicate() msg = json.loads(conn.recv()) - self.assertGreaterEqual(msg["top_1"], 60) - self.assertGreaterEqual(msg["top_5"], 80) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["top_1"], 60) + self.assertGreaterEqual(msg["top_5"], 80) def test_mobilenet_v3(self): if not self.required_envs([self.image_dataset]): @@ -1829,8 +1878,11 @@ def test_mobilenet_v3(self): conn = listener.accept() p.communicate() msg = json.loads(conn.recv()) - self.assertGreaterEqual(msg["top_1"], 60) - self.assertGreaterEqual(msg["top_5"], 80) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["top_1"], 60) + self.assertGreaterEqual(msg["top_5"], 80) def test_inception_v3(self): if not self.required_envs([self.image_dataset]): @@ -1864,8 +1916,11 @@ def test_inception_v3(self): conn = listener.accept() p.communicate() msg = json.loads(conn.recv()) - self.assertGreaterEqual(msg["top_1"], 60) - self.assertGreaterEqual(msg["top_5"], 80) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["top_1"], 60) + self.assertGreaterEqual(msg["top_5"], 80) def test_inception_v4(self): if not self.required_envs([self.image_dataset]): @@ -1899,8 +1954,11 @@ def test_inception_v4(self): conn = listener.accept() p.communicate() msg = json.loads(conn.recv()) - self.assertGreaterEqual(msg["top_1"], 60) - self.assertGreaterEqual(msg["top_5"], 80) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["top_1"], 60) + self.assertGreaterEqual(msg["top_5"], 80) def test_vit(self): if not self.required_envs([self.image_dataset]): @@ -1934,8 +1992,11 @@ def test_vit(self): conn = listener.accept() p.communicate() msg = json.loads(conn.recv()) - self.assertGreaterEqual(msg["top_1"], 70) - self.assertGreaterEqual(msg["top_5"], 90) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["top_1"], 70) + self.assertGreaterEqual(msg["top_5"], 90) def test_edsr(self): if not self.required_envs(): @@ -1968,8 +2029,11 @@ def test_edsr(self): conn = listener.accept() p.communicate() msg = json.loads(conn.recv()) - self.assertGreaterEqual(msg["PSNR"], 25) - self.assertGreaterEqual(msg["SSIM"], 0.8) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["PSNR"], 25) + self.assertGreaterEqual(msg["SSIM"], 0.8) def test_deeplab_v3(self): if not self.required_envs(): @@ -2002,9 +2066,12 @@ def test_deeplab_v3(self): conn = listener.accept() p.communicate() msg = json.loads(conn.recv()) - self.assertGreaterEqual(msg["PA"], 0.85) - self.assertGreaterEqual(msg["MPA"], 0.70) - self.assertGreaterEqual(msg["MIoU"], 0.55) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["PA"], 0.85) + self.assertGreaterEqual(msg["MPA"], 0.70) + self.assertGreaterEqual(msg["MIoU"], 0.55) def test_stories_single_llama(self): if not self.required_envs(): @@ -2049,8 +2116,11 @@ def test_stories_single_llama(self): conn = listener.accept() p.communicate() msg = json.loads(conn.recv()) - model_out = msg["result"][0] - self.assertTrue(model_out.startswith(golden_start_with)) + if "Error" in msg: + self.fail(msg["Error"]) + else: + model_out = msg["result"][0] + self.assertTrue(model_out.startswith(golden_start_with)) def test_mobilebert(self): if not self.required_envs([self.pretrained_weight]): @@ -2085,9 +2155,12 @@ def test_mobilebert(self): conn = listener.accept() p.communicate() msg = json.loads(conn.recv()) - cpu, htp = msg["CPU"], msg["HTP"] - for k, v in cpu.items(): - self.assertLessEqual(abs(v[0] - htp[k][0]), 2) + if "Error" in msg: + self.fail(msg["Error"]) + else: + cpu, htp = msg["CPU"], msg["HTP"] + for k, v in cpu.items(): + self.assertLessEqual(abs(v[0] - htp[k][0]), 2) @unittest.skip("will be enabled after TODOs got resolved") def test_ptq_mobilebert(self): @@ -2127,9 +2200,12 @@ def test_ptq_mobilebert(self): conn = listener.accept() p.communicate() msg = json.loads(conn.recv()) - cpu, htp = msg["CPU"], msg["HTP"] - for k, v in cpu.items(): - self.assertLessEqual(abs(v[0] - htp[k][0]), 5) + if "Error" in msg: + self.fail(msg["Error"]) + else: + cpu, htp = msg["CPU"], msg["HTP"] + for k, v in cpu.items(): + self.assertLessEqual(abs(v[0] - htp[k][0]), 5) def test_export_example(self): if not self.required_envs([self.model_name]): @@ -2212,6 +2288,12 @@ def setup_environment(): help="Path to open source software model repository", type=str, ) + parser.add_argument( + "-x", + "--enable_x86_64", + help="Enable unittest to be executed on x86_64 platform", + action="store_true", + ) args, ns_args = parser.parse_known_args(namespace=unittest) TestQNN.host = args.host @@ -2228,6 +2310,7 @@ def setup_environment(): TestQNN.error_only = args.error_only TestQNN.oss_repo = args.oss_repo TestQNN.shared_buffer = args.shared_buffer + TestQNN.enable_x86_64 = args.enable_x86_64 return sys.argv[:1] + ns_args diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index f31f07562b..ef0ac0f202 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -27,7 +27,11 @@ QcomChipset, ) from executorch.backends.qualcomm.utils.utils import capture_program -from executorch.examples.qualcomm.scripts.utils import SimpleADB +from executorch.examples.qualcomm.scripts.utils import ( + generate_inputs, + make_output_dir, + SimpleADB, +) from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.compile_spec_schema import CompileSpec @@ -133,6 +137,7 @@ class TestQNN(unittest.TestCase): use_16a16w: str = "16a16w" use_16a4w: str = "16a4w" shared_buffer: bool = False + enable_x86_64: bool = False def _assert_outputs_equal(self, model_output, ref_output): self.assertTrue(len(ref_output) == len(model_output)) @@ -201,16 +206,16 @@ def verify_output( tmp_dir, ) - device_output_dir = f"{tmp_dir}/outputs" - device_outputs = [] + output_dir = f"{tmp_dir}/outputs" + outputs = [] etdump_path = f"{tmp_dir}/etdump.etdp" def post_process(): - for i, f in enumerate(sorted(os.listdir(device_output_dir))): - filename = os.path.join(device_output_dir, f) + for i, f in enumerate(sorted(os.listdir(output_dir))): + filename = os.path.join(output_dir, f) output = np.fromfile(filename, dtype=ref_outputs[i].numpy().dtype) output = torch.from_numpy(output).reshape(ref_outputs[i].shape) - device_outputs.append(output) + outputs.append(output) def validate_profile(): inspector = Inspector(etdump_path=etdump_path, etrecord=etrecord_path) @@ -218,23 +223,58 @@ def validate_profile(): len(inspector.to_dataframe().index) == expected_profile_events ) - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=self.build_folder, - pte_path=pte_fname, - workspace="/data/local/tmp/qnn_executorch_test", - device_id=self.device, - host_id=self.host, - soc_model=self.model, - error_only=self.error_only, - ) - adb.push(inputs=[sample_inputs], input_list=input_list) - adb.execute() - adb.pull(output_path=tmp_dir, callback=post_process) - self._assert_outputs_equal(device_outputs, ref_outputs) + if self.enable_x86_64: + generate_inputs(tmp_dir, "input_list.txt", [sample_inputs], input_list) + make_output_dir(output_dir) + + target = "x86_64-linux-clang" + qnn_sdk = os.environ.get("QNN_SDK_ROOT", None) + assert qnn_sdk, "QNN_SDK_ROOT was not found in environment variable" + + build_path = "build_x86_64" + cmds = [ + # export LD_LIBRARY_PATH to QNN_SDK_ROOT + f"export LD_LIBRARY_PATH={qnn_sdk}/lib/{target}/:{self.executorch_root}/{build_path}/lib && " + # qnn_executor_runner + f"{self.executorch_root}/{build_path}/examples/qualcomm/qnn_executor_runner", + f"--model_path {pte_fname}", + f"--input_list_path {tmp_dir}/input_list.txt", + f"--output_folder_path {output_dir}", + ] + + subprocess.run( + " ".join(cmds), + shell=True, + executable="/bin/bash", + capture_output=True, + cwd=tmp_dir, + ) + + # Verify the outputs + post_process() + self._assert_outputs_equal(outputs, ref_outputs) + + # Verify the etdump + if expected_profile_events != -1: + validate_profile() + else: + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=self.build_folder, + pte_path=pte_fname, + workspace="/data/local/tmp/qnn_executorch_test", + device_id=self.device, + host_id=self.host, + soc_model=self.model, + error_only=self.error_only, + ) + adb.push(inputs=[sample_inputs], input_list=input_list) + adb.execute() + adb.pull(output_path=tmp_dir, callback=post_process) + self._assert_outputs_equal(outputs, ref_outputs) - if expected_profile_events != -1: - adb.pull_etdump(etdump_path, callback=validate_profile) + if expected_profile_events != -1: + adb.pull_etdump(etdump_path, callback=validate_profile) def lower_module_and_test_output( self, @@ -362,6 +402,8 @@ def _insert_clone( (node,), ) inserted_node.meta["val"] = node.meta["val"] + if "quant_attrs" in node.meta: + inserted_node.meta["quant_attrs"] = node.meta["quant_attrs"] for user in users: user.replace_input_with(node, inserted_node) diff --git a/backends/vulkan/partitioner/supported_ops.py b/backends/vulkan/partitioner/supported_ops.py index 26436a0eb9..08d7f96a6b 100644 --- a/backends/vulkan/partitioner/supported_ops.py +++ b/backends/vulkan/partitioner/supported_ops.py @@ -8,6 +8,8 @@ import operator +from executorch.backends.vulkan.passes.custom_ops_defs import grid_priors_op # noqa + from executorch.exir.dialects._ops import ops as exir_ops @@ -129,6 +131,7 @@ def __contains__(self, op): exir_ops.edge.aten.upsample_nearest2d.vec, exir_ops.edge.aten.zeros.default, exir_ops.edge.aten.zeros_like.default, + exir_ops.edge.et_vk.grid_priors.default, ] diff --git a/backends/vulkan/passes/custom_ops_defs.py b/backends/vulkan/passes/custom_ops_defs.py index 67e7db828a..62f21bfee6 100644 --- a/backends/vulkan/passes/custom_ops_defs.py +++ b/backends/vulkan/passes/custom_ops_defs.py @@ -48,15 +48,18 @@ def conv_with_clamp_impl( conv_with_clamp_op = getattr(getattr(torch.ops, namespace), name) +# The dimension of x should be larger than 1 def grid_priors_impl( - height, - width, + x, stride, offset, ): - shift_x = (torch.arange(0, width) + offset) * stride - shift_y = (torch.arange(0, height) + offset) * stride - shift_xx, shift_yy = torch.meshgrid(shift_y, shift_x) + height, width = x.shape[-2:] + # Need to specify device of torch.arange to avoid executorch exporting error + shift_x = (torch.arange(0, width, device=x.device) + offset) * stride + shift_y = (torch.arange(0, height, device=x.device) + offset) * stride + # Need to specify indexing parameter ('ij' is the default value) to avoid executorch exporting error + shift_xx, shift_yy = torch.meshgrid([shift_y, shift_x], indexing="ij") shift_xx = shift_xx.reshape(-1) shift_yy = shift_yy.reshape(-1) shifts = torch.stack((shift_yy, shift_xx), dim=-1) @@ -64,6 +67,24 @@ def grid_priors_impl( name = "grid_priors" -lib.define(f"{name}(int height, int width, int stride, float offset) -> Tensor") -lib.impl(name, grid_priors_impl) +lib.define(f"{name}(Tensor self, int stride, float offset) -> Tensor") +lib.impl(name, grid_priors_impl, "CompositeExplicitAutograd") grid_priors_op = getattr(getattr(torch.ops, namespace), name) + + +# When lowering to executorch, ops are converted from default to out variant. Hence, custom ops define both variants. +def grid_priors_out_impl( + x, + stride, + offset, + out, +): + out = grid_priors_impl(x, stride, offset) + return out + + +name = "grid_priors_out" +lib.define( + f"{name}(Tensor self, int stride, float offset, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, grid_priors_out_impl, "CompositeExplicitAutograd") diff --git a/backends/vulkan/passes/test_custom_ops.py b/backends/vulkan/passes/test_custom_ops.py index a1a3a40f67..c68dd6d679 100644 --- a/backends/vulkan/passes/test_custom_ops.py +++ b/backends/vulkan/passes/test_custom_ops.py @@ -97,14 +97,15 @@ class GridPriors(torch.nn.Module): def __init__(self): super().__init__() - def forward(self, height, width, stride, offset): - return torch.ops.et_vk.grid_priors(height, width, stride, offset) + def forward(self, x, stride, offset): + return torch.ops.et_vk.grid_priors(x, stride, offset) model = GridPriors() - sample_input = (2, 3, 4, 0.5) + sample_input = (torch.rand(2, 5, 2, 3), 4, 0.5) custom_out = model(*sample_input) - def calculate_expected_output(height, width, stride, offset): + def calculate_expected_output(x, stride, offset): + height, width = x.shape[-2:] shift_x = (torch.arange(0, width) + offset) * stride shift_y = (torch.arange(0, height) + offset) * stride shift_xx, shift_yy = torch.meshgrid(shift_y, shift_x) diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index c9e3aaa31e..c734ed395e 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -482,6 +482,7 @@ def __init__( src_dir_paths: Union[str, List[str]], env: Dict[Any, Any], glslc_path: Optional[str], + glslc_flags: str = "", ) -> None: if isinstance(src_dir_paths, str): self.src_dir_paths = [src_dir_paths] @@ -490,6 +491,7 @@ def __init__( self.env = env self.glslc_path = glslc_path + self.glslc_flags = glslc_flags self.glsl_src_files: Dict[str, str] = {} self.template_yaml_files: List[str] = [] @@ -668,19 +670,23 @@ def process_shader(shader_paths_pair): if self.glslc_path is not None: spv_out_path = os.path.join(output_dir, f"{shader_name}.spv") - cmd = [ - self.glslc_path, - "-fshader-stage=compute", - glsl_out_path, - "-o", - spv_out_path, - "--target-env=vulkan1.1", - "-Werror", - ] + [ - arg - for src_dir_path in self.src_dir_paths - for arg in ["-I", src_dir_path] - ] + cmd = ( + [ + self.glslc_path, + "-fshader-stage=compute", + glsl_out_path, + "-o", + spv_out_path, + "--target-env=vulkan1.1", + "-Werror", + ] + + [ + arg + for src_dir_path in self.src_dir_paths + for arg in ["-I", src_dir_path] + ] + + self.glslc_flags.split() + ) subprocess.check_call(cmd) @@ -966,6 +972,8 @@ def main(argv: List[str]) -> int: parser.add_argument("-c", "--glslc-path", required=True, help="") parser.add_argument("-t", "--tmp-dir-path", required=True, help="/tmp") parser.add_argument("-o", "--output-path", required=True, help="") + parser.add_argument("--optimize_size", action="store_true", help="") + parser.add_argument("--optimize", action="store_true", help="") parser.add_argument( "--env", metavar="KEY=VALUE", nargs="*", help="Set a number of key-value pairs" ) @@ -984,7 +992,15 @@ def main(argv: List[str]) -> int: if not os.path.exists(options.tmp_dir_path): os.makedirs(options.tmp_dir_path) - shader_generator = SPVGenerator(options.glsl_paths, env, options.glslc_path) + glslc_flags = "" + if options.optimize_size: + glslc_flags += "-Os" + elif options.optimize: + glslc_flags += "-O" + + shader_generator = SPVGenerator( + options.glsl_paths, env, options.glslc_path, glslc_flags + ) output_spv_files = shader_generator.generateSPV(options.tmp_dir_path) genCppFiles( diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 2046e78e88..fb2c379c1b 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -319,24 +319,20 @@ utils::uvec3 ComputeGraph::create_global_wg_size(const ValueRef idx) { return image_extents_of(idx); } -utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) { +utils::uvec3 ComputeGraph::create_local_wg_size( + const utils::uvec3 global_wg_size) { if (config_.enable_local_wg_size_override) { return config_.local_wg_size_override; } - if (is_buffer_storage(idx)) { - return {64u, 1u, 1u}; - } - - const utils::uvec3 image_extents = image_extents_of(idx); utils::uvec3 local_group_size = {4, 4, 4}; - if (image_extents.data[2u] == 1) { - if (image_extents.data[1u] == 1) { + if (global_wg_size.data[2u] == 1) { + if (global_wg_size.data[1u] == 1) { local_group_size.data[0u] = 64; local_group_size.data[1u] = 1; local_group_size.data[2u] = 1; - } else if (image_extents.data[1u] < 8) { + } else if (global_wg_size.data[1u] < 8) { local_group_size.data[0u] = 16; local_group_size.data[1u] = 4; local_group_size.data[2u] = 1; @@ -349,6 +345,10 @@ utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) { return local_group_size; } +utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) { + return create_local_wg_size(image_extents_of(idx)); +} + void ComputeGraph::copy_into_staging( const ValueRef idx, const void* data, diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 5237a7746d..898a856291 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -180,7 +180,9 @@ class ComputeGraph final { return values_.at(idx).type(); } - // Get Tensor Property + // + // Tensor Properties Accessors + // std::vector sizes_of(const ValueRef idx) const; @@ -226,7 +228,9 @@ class ComputeGraph final { return values_.at(idx).toTensor().ntexels_ubo(); } + // // Scalar Value Extraction + // template T extract_scalar(const ValueRef idx) { @@ -459,9 +463,7 @@ class ComputeGraph final { utils::uvec3 create_global_wg_size(const ValueRef idx); /* - * Suggest a local workgroup size for a given `api::vTensor` value, assuming - * that every shader invocation calculates one texel element of the output - * tensor. + * Suggest a local workgroup size for a given global workgroup size. * * The local workgroup size will be formed to try and minimize the number of * inactive invocations. @@ -469,6 +471,13 @@ class ComputeGraph final { * Currently, the local workgroup size is hard-coded to contain a total of 64 * shader invocations. In the future, this value can be configured. */ + utils::uvec3 create_local_wg_size(const utils::uvec3 global_wg_size); + + /* + * Convenience function to suggest a local workgroup size for a given + * `api::vTensor` value, assuming that every shader invocation calculates one + * texel element of the output tensor. + */ utils::uvec3 create_local_wg_size(const ValueRef idx); // @@ -500,6 +509,17 @@ class ComputeGraph final { void resize_input(const int64_t idx, const std::vector& new_sizes); void propagate_resize(); + // + // Miscellaneous Utilities + // + + /* + * Check whether the GPU supports 8 bit buffers. + */ + inline bool int8_buffers_enabled() const { + return context_->adapter_ptr()->has_full_int8_buffers_support(); + } + // // Debug support (implemented in Logging.cpp) // diff --git a/backends/vulkan/runtime/graph/ops/glsl/grid_priors.glsl b/backends/vulkan/runtime/graph/ops/glsl/grid_priors.glsl new file mode 100644 index 0000000000..93a2c53e01 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/grid_priors.glsl @@ -0,0 +1,38 @@ +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_ubo(1, "ivec4", "in_sizes")} +${layout_declare_ubo(2, "ivec4", "out_sizes")} +${layout_declare_ubo(3, "int", "stride", "float", "offset")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int packed_dim = C_DIM; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 idx = to_tensor_idx(pos, out_sizes, packed_dim); + + if (pos_out_of_bounds(pos, out_sizes, packed_dim)) { + return; + } + int width = in_sizes.x; + VEC4_T outtex; + if (pos.x == 0) { + float value = (pos.y % width + offset) * stride; + outtex = VEC4_T(value, 0, 0, 0); + } else if (pos.x == 1) { + float value = (pos.y / width + offset) * stride; + outtex = VEC4_T(value, 0, 0, 0); + } + + imageStore(t_out, pos, outtex); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/grid_priors.yaml b/backends/vulkan/runtime/graph/ops/glsl/grid_priors.yaml new file mode 100644 index 0000000000..654edca610 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/grid_priors.yaml @@ -0,0 +1,12 @@ +grid_priors: + parameter_names_with_default_values: + NDIM: 3 + DTYPE: float + PACKING: C_packed + STORAGE: texture3d + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: grid_priors diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h index 0ecfb83eac..d3264e43a2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -80,7 +80,7 @@ ivec4 from_nchw_buffer_i(int buf_i, ivec4 sizes) { * Returns: The (x, y, z, n) texel position corresponding to the first element * of the texel at the specified buffer index */ -ivec4 to_texel_pos(int buf_i, ivec4 strides, int packed_dim) { +ivec4 to_tensor_idx(int buf_i, ivec4 strides, int packed_dim) { ivec4 idx; for (int i = 3; i >= 0; i--) { if (i != packed_dim) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/int8_tensor_to_nchw_noint8.glsl b/backends/vulkan/runtime/graph/ops/glsl/int8_tensor_to_nchw_noint8.glsl new file mode 100644 index 0000000000..21290d0ce8 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/int8_tensor_to_nchw_noint8.glsl @@ -0,0 +1,54 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#include "indexing_utils.h" + +layout(std430) buffer; + +#extension GL_EXT_control_flow_attributes : require + +${layout_declare_tensor(0, "r", "t_in", "int8", "texture3d")} +${layout_declare_buffer(1, "w", "nchw_out", "int")} +${layout_declare_ubo(2, "ivec4", "tensor_sizes")} +${layout_declare_ubo(3, "int", "out_ntexels")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int packed_dim = C_DIM; + +void main() { + const int out_buf_idx = int(gl_GlobalInvocationID.x); + if (out_buf_idx >= out_ntexels) { + return; + } + + ivec4 values; + int in_buf_idx = 4 * out_buf_idx; + + [[unroll]] for (int i = 0; i < 4; ++i) { + const ivec4 tensor_idx = from_nchw_buffer_i(in_buf_idx, tensor_sizes); + const ivec4 texture_pos = to_texture_elem_pos( + tensor_idx, tensor_sizes, packed_dim); + values[i] = load_texel(t_in, texture_pos.xyz)[texture_pos.w]; + in_buf_idx++; + } + + // Manually pack 4x 8-bit integers into a 32 bit integer. Note that little + // endian is assumed, since most processors use little endian. Thus the + // "later" values are placed in most significant bytes. + int packed = ((values[3] & 0xFF) << 24) + | ((values[2] & 0xFF) << 16) + | ((values[1] & 0xFF) << 8) + | ((values[0] & 0xFF)); + + nchw_out[out_buf_idx] = packed; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_tensor_noint8.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_tensor_noint8.glsl new file mode 100644 index 0000000000..378cf09d12 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_tensor_noint8.glsl @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#include "indexing_utils.h" + +layout(std430) buffer; + +#extension GL_EXT_control_flow_attributes : require + +${layout_declare_tensor(0, "w", "t_out", "int8", "texture3d")} +${layout_declare_buffer(1, "r", "nchw_in", "int")} +${layout_declare_ubo(2, "ivec4", "tensor_sizes")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int packed_dim = C_DIM; + +/* + * Extends sign of int8 + */ +int extend_sign(int x) { + if (x >> 7 == 1) { + return x | 0xFFFFFF00; + } + return x; +} + +ivec4 read_texel(ivec4 tensor_idx) { + const ivec4 buf_indices = get_texel_nchw_buffer_ixs( + tensor_idx, tensor_sizes, packed_dim); + + int shift = (1 << 8) - 1; + ivec4 masks; + // Masks used to unpack 4x 8-bit values from a 32 bit integer. Note that + // little endian is assumed, as most processors use little endian. Thus the + // most significant bytes correspond to the "latter" packed values. + masks.x = shift << (8 * (buf_indices.x % 4)); + masks.y = shift << (8 * (buf_indices.y % 4)); + masks.z = shift << (8 * (buf_indices.z % 4)); + masks.w = shift << (8 * (buf_indices.w % 4)); + + ivec4 out_tex = ivec4(0); + + [[unroll]] for (int i = 0; i < 4; ++i) { + if (tensor_idx[packed_dim] + i < tensor_sizes[packed_dim]) { + int in_texel = nchw_in[buf_indices[i] / 4]; + int extracted_val = (in_texel & masks[i]) >> (8 * (buf_indices[i] % 4)); + extracted_val = extend_sign(extracted_val); + out_tex[i] = extracted_val; + } + } + + return out_tex; +} + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 tensor_idx = to_tensor_idx(pos, tensor_sizes, packed_dim); + + if (any(greaterThanEqual(tensor_idx, tensor_sizes))) { + return; + } + + write_texel(t_out, pos, read_texel(tensor_idx)); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_tensor.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_tensor.glsl index c0bbc5183a..c218482b09 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_tensor.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_tensor.glsl @@ -62,7 +62,7 @@ void main() { return; } - ivec4 tensor_idx = to_texel_pos(t_id, gpu_strides, packed_dim); + ivec4 tensor_idx = to_tensor_idx(t_id, gpu_strides, packed_dim); tensor_idx[packed_dim] *= 4; t_out[t_id] = read_texel(tensor_idx); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl index 139c82866f..37988f21ec 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl @@ -53,7 +53,7 @@ void main() { return; } - const ivec4 out_pos = to_texel_pos(t_id, out_strides, 0); + const ivec4 out_pos = to_tensor_idx(t_id, out_strides, 0); VEC4_T outtex = q_8w_linear(out_pos, mat1_sizes.x); write_texel(t_out, t_id, outtex); diff --git a/backends/vulkan/runtime/graph/ops/glsl/tensor_to_nchw.glsl b/backends/vulkan/runtime/graph/ops/glsl/tensor_to_nchw.glsl index 78d8346428..d545e5d86e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/tensor_to_nchw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/tensor_to_nchw.glsl @@ -61,7 +61,7 @@ void main() { } const VEC4_T intex = t_in[t_id]; - ivec4 tensor_idx = to_texel_pos(t_id, gpu_strides, packed_dim); + ivec4 tensor_idx = to_tensor_idx(t_id, gpu_strides, packed_dim); tensor_idx[packed_dim] *= 4; write_out_texel(intex, tensor_idx); } diff --git a/backends/vulkan/runtime/graph/ops/impl/GridPriors.cpp b/backends/vulkan/runtime/graph/ops/impl/GridPriors.cpp new file mode 100644 index 0000000000..17b6b351db --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/GridPriors.cpp @@ -0,0 +1,79 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +namespace vkcompute { + +struct GridPriorsParam final { + int32_t stride; + float offset; +}; + +void resize_grid_priors_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr in = graph->get_tensor(extra_args[0]); + std::vector in_sizes = in->sizes(); + int64_t height = in_sizes.at(in_sizes.size() - 2); + int64_t width = in_sizes.at(in_sizes.size() - 1); + std::vector sizes = {height * width, 2}; + out->virtual_resize(sizes); +} + +void add_grid_priors_node( + ComputeGraph& graph, + const ValueRef& in, + const ValueRef& stride_ref, + const ValueRef& offset_ref, + const ValueRef& out) { + vTensorPtr t_out = graph.get_tensor(out); + vTensorPtr t_in = graph.get_tensor(in); + int32_t stride = graph.extract_scalar(stride_ref); + float offset = graph.extract_scalar(offset_ref); + + std::string kernel_name = "grid_priors"; + kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(kernel_name, *t_out); + + GridPriorsParam param = {stride, offset}; + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + graph.create_global_wg_size(out), + graph.create_local_wg_size(out), + // Inputs and Outputs + { + {out, vkapi::MemoryAccessType::WRITE}, + }, + // Shader params buffers + { + t_in->sizes_ubo(), + t_out->sizes_ubo(), + graph.create_params_buffer(param), + }, + // Specialization Constants + {}, + resize_grid_priors_node, + {in})); +} + +void grid_priors(ComputeGraph& graph, const std::vector& args) { + return add_grid_priors_node(graph, args[0], args[1], args[2], args[3]); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.grid_priors.default, grid_priors); +} +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 2e5e9addfb..79b463d7ef 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -21,8 +21,8 @@ void add_staging_to_tensor_node( const ValueRef out_tensor) { VK_CHECK_COND(graph.val_is_staging(in_staging)); - vkapi::ShaderInfo shader = - get_nchw_to_tensor_shader(*graph.get_tensor(out_tensor)); + vkapi::ShaderInfo shader = get_nchw_to_tensor_shader( + *graph.get_tensor(out_tensor), graph.int8_buffers_enabled()); vkapi::ParamsBindList ubos({graph.sizes_ubo(out_tensor)}); if (graph.is_buffer_storage(out_tensor)) { @@ -55,10 +55,26 @@ void add_tensor_to_staging_node( const ValueRef out_staging) { VK_CHECK_COND(graph.val_is_staging(out_staging)); - vkapi::ShaderInfo shader = - get_tensor_to_nchw_shader(*graph.get_tensor(in_tensor)); + vkapi::ShaderInfo shader = get_tensor_to_nchw_shader( + *graph.get_tensor(in_tensor), graph.int8_buffers_enabled()); + utils::uvec3 global_wg_size = graph.create_global_wg_size(in_tensor); vkapi::ParamsBindList ubos({graph.sizes_ubo(in_tensor)}); + + // Normally, the tensor_to_nchw shader is structured so that each thread reads + // one texel from the input texture and writes each component of the texel + // into the corresponding location in the output buffer. However, this shader + // is structured slightly differently in that each thread writes out a + // complete 32 bit integer (containing 4 packed 8-bit integers) into the + // output buffer. Therefore, the global work group size for this shader will + // be the number of elements in the output buffer divided by 4, as opposed to + // the extents of the input texture. + if (shader.kernel_name == "int8_tensor_to_nchw_noint8") { + uint32_t buffer_len = graph.get_staging(out_staging)->numel() / 4; + global_wg_size = {buffer_len, 1, 1}; + ubos.append({graph.ntexels_ubo(in_tensor)}); + } + if (graph.is_buffer_storage(in_tensor)) { ubos.append({ graph.texel_strides_ubo(in_tensor), @@ -69,8 +85,8 @@ void add_tensor_to_staging_node( graph.execute_nodes().emplace_back(new ExecuteNode( graph, shader, - graph.create_global_wg_size(in_tensor), - graph.create_local_wg_size(in_tensor), + global_wg_size, + graph.create_local_wg_size(global_wg_size), // Input and Outputs {{in_tensor, vkapi::MemoryAccessType::READ}, {out_staging, vkapi::MemoryAccessType::WRITE}}, @@ -86,7 +102,8 @@ ValueRef prepack( const utils::GPUMemoryLayout layout) { ValueRef v = graph.add_tensor_like(vref, layout); - vkapi::ShaderInfo shader = get_nchw_to_tensor_shader(*graph.get_tensor(v)); + vkapi::ShaderInfo shader = get_nchw_to_tensor_shader( + *graph.get_tensor(v), graph.int8_buffers_enabled()); vkapi::ParamsBindList ubos({graph.sizes_ubo(v)}); if (graph.is_buffer_storage(v)) { diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp index d681618d9d..2ade34e425 100644 --- a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp @@ -95,10 +95,17 @@ void set_staging_zeros(api::StorageBuffer& staging, const size_t nbytes) { memset(data_ptr, 0, staging.nbytes()); } -vkapi::ShaderInfo get_nchw_to_tensor_shader(const api::vTensor& v_dst) { +vkapi::ShaderInfo get_nchw_to_tensor_shader( + const api::vTensor& v_dst, + const bool int8_buffer_enabled) { std::string kernel_name; kernel_name.reserve(kShaderNameReserve); + if (v_dst.dtype() == vkapi::kChar && + v_dst.storage_type() == utils::kTexture3D && !int8_buffer_enabled) { + return VK_KERNEL(nchw_to_int8_tensor_noint8); + } + kernel_name = "nchw_to_tensor"; add_dtype_suffix(kernel_name, v_dst); add_storage_type_suffix(kernel_name, v_dst); @@ -106,10 +113,17 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader(const api::vTensor& v_dst) { return VK_KERNEL_FROM_STR(kernel_name); } -vkapi::ShaderInfo get_tensor_to_nchw_shader(const api::vTensor& v_src) { +vkapi::ShaderInfo get_tensor_to_nchw_shader( + const api::vTensor& v_src, + bool int8_buffer_enabled) { std::string kernel_name; kernel_name.reserve(kShaderNameReserve); + if (v_src.dtype() == vkapi::kChar && + v_src.storage_type() == utils::kTexture3D && !int8_buffer_enabled) { + return VK_KERNEL(int8_tensor_to_nchw_noint8); + } + kernel_name = "tensor_to_nchw"; add_dtype_suffix(kernel_name, v_src); add_storage_type_suffix(kernel_name, v_src); diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h index dfe86a9e26..cabc17f30e 100644 --- a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h @@ -31,7 +31,11 @@ void set_staging_zeros(api::StorageBuffer& staging, const size_t nbytes); // Functions to get shaders // -vkapi::ShaderInfo get_nchw_to_tensor_shader(const api::vTensor& v_dst); -vkapi::ShaderInfo get_tensor_to_nchw_shader(const api::vTensor& v_src); +vkapi::ShaderInfo get_nchw_to_tensor_shader( + const api::vTensor& v_dst, + bool int8_buffer_enabled = true); +vkapi::ShaderInfo get_tensor_to_nchw_shader( + const api::vTensor& v_src, + bool int8_buffer_enabled = true); } // namespace vkcompute diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 477e54a2d7..da40f0a720 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -262,6 +262,9 @@ def get_or_create_value_for(self, arg: _Argument): raise RuntimeError(f"Cannot create value for arg of type {type(arg)}") def process_placeholder_node(self, node: Node) -> None: + # ignores any tensors that don't get used in any ops + if len(node.users) == 0: + return None ids = self.create_node_value(node) if not self.is_param_node(node): if isinstance(ids, int): diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index 981552f17a..e8b232098b 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -1,12 +1,15 @@ +load("@fbcode_macros//build_defs:native_rules.bzl", "buck_genrule") load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") def get_vulkan_compiler_flags(): return ["-Wno-missing-prototypes", "-Wno-global-constructors"] def vulkan_spv_shader_lib(name, spv_filegroups, is_fbcode = False): - gen_vulkan_spv_target = "//executorch/backends/vulkan:gen_vulkan_spv_bin" - glslc_path = "//caffe2/fb/vulkan/dotslash:glslc" + gen_vulkan_spv_target = "//xplat/executorch/backends/vulkan:gen_vulkan_spv_bin" + glslc_path = "//xplat/caffe2/fb/vulkan/dotslash:glslc" + if is_fbcode: + gen_vulkan_spv_target = "//executorch/backends/vulkan:gen_vulkan_spv_bin" glslc_path = "//caffe2/fb/vulkan/tools:glslc" glsl_paths = [] @@ -15,21 +18,25 @@ def vulkan_spv_shader_lib(name, spv_filegroups, is_fbcode = False): for target, subpath in spv_filegroups.items(): glsl_paths.append("$(location {})/{}".format(target, subpath)) - genrule_cmd = [ - "$(exe {})".format(gen_vulkan_spv_target), - "--glsl-paths {}".format(" ".join(glsl_paths)), - "--output-path $OUT", - "--glslc-path=$(exe {})".format(glslc_path), - "--tmp-dir-path=$OUT", - ] + genrule_cmd = ( + "$(exe {}) ".format(gen_vulkan_spv_target) + + "--glsl-paths {} ".format(" ".join(glsl_paths)) + + "--output-path $OUT " + + "--glslc-path=$(exe {}) ".format(glslc_path) + + "--tmp-dir-path=$OUT " + + select({ + "DEFAULT": "", + "ovr_config//os:android": "--optimize", + }) + ) genrule_name = "gen_{}_cpp".format(name) - runtime.genrule( + buck_genrule( name = genrule_name, outs = { "{}.cpp".format(name): ["spv.cpp"], }, - cmd = " ".join(genrule_cmd), + cmd = genrule_cmd, default_outs = ["."], labels = ["uses_dotslash"], ) diff --git a/backends/vulkan/test/glsl/all_shaders.yaml b/backends/vulkan/test/glsl/all_shaders.yaml index edba41b7ea..37403c97ac 100644 --- a/backends/vulkan/test/glsl/all_shaders.yaml +++ b/backends/vulkan/test/glsl/all_shaders.yaml @@ -47,21 +47,12 @@ idx_fill_buffer: idx_fill_texture: parameter_names_with_default_values: DTYPE: float - NDIM: 3 - PACKING: CHANNELS_PACKED generate_variant_forall: - PACKING: - - VALUE: "CHANNELS_PACKED" - SUFFIX: "C_packed" - - VALUE: "WIDTH_PACKED" - SUFFIX: "W_packed" - - VALUE: "HEIGHT_PACKED" - SUFFIX: "H_packed" DTYPE: - - VALUE: "half" - SUFFIX: "half" - - VALUE: "float" - SUFFIX: "float" + - VALUE: half + - VALUE: float + - VALUE: int + - VALUE: int8 shader_variants: - NAME: idx_fill_texture diff --git a/backends/vulkan/test/glsl/idx_fill_texture.glsl b/backends/vulkan/test/glsl/idx_fill_texture.glsl index 1f75cadf49..8914d2b892 100644 --- a/backends/vulkan/test/glsl/idx_fill_texture.glsl +++ b/backends/vulkan/test/glsl/idx_fill_texture.glsl @@ -12,21 +12,17 @@ #define VEC4_T ${texel_type(DTYPE)} -#define POS ${get_pos[NDIM]("pos")} - #include "indexing_utils.h" layout(std430) buffer; -layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; - -layout(set = 0, binding = 1) uniform PRECISION restrict Sizes { - ivec4 sizes; -}; +${layout_declare_tensor(0, "w", "image_out", DTYPE, "texture3d")} +${layout_declare_ubo(1, "ivec4", "sizes")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; layout(constant_id = 3) const int packed_dim = C_DIM; +layout(constant_id = 4) const int offset = 10; void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); @@ -37,6 +33,6 @@ void main() { } const ivec4 buf_indices = get_texel_nchw_buffer_ixs(idx, sizes, packed_dim); - VEC4_T texel = VEC4_T(buf_indices); - imageStore(image_out, POS, texel); + VEC4_T texel = VEC4_T(buf_indices) + offset; + imageStore(image_out, pos, texel); } diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 778ce67787..b3fde403f7 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -1632,3 +1632,21 @@ def forward(self, x): (torch.tensor([[[0, 1], [0, 1]], [[4, 2], [3, 3]]]),), memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], ) + + def test_vulkan_backend_grid_priors(self): + class GridPriorsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.ops.et_vk.grid_priors( + x, + stride=8, + offset=0.5, + ) + + self.lower_module_and_test_output( + GridPriorsModule(), + (torch.rand(size=[1, 5, 2, 3]),), + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) diff --git a/backends/vulkan/test/utils/test_utils.cpp b/backends/vulkan/test/utils/test_utils.cpp index c55c286acc..649c0c82d6 100644 --- a/backends/vulkan/test/utils/test_utils.cpp +++ b/backends/vulkan/test/utils/test_utils.cpp @@ -76,7 +76,8 @@ void record_nchw_to_image_op( SV(v_dst.packed_dim_whcn_idx())}; context->submit_compute_job( - get_nchw_to_tensor_shader(v_dst), + get_nchw_to_tensor_shader( + v_dst, context->adapter_ptr()->has_full_int8_buffers_support()), pipeline_barrier, v_dst.image_extents(), adaptive_work_group_size(v_dst.image_extents()), @@ -112,6 +113,27 @@ void record_image_to_nchw_op( v_src.sizes_ubo()); } +void record_int8_image_to_nchw_noint8_op( + api::Context* const context, + api::vTensor& v_src, + api::StorageBuffer& dst_buffer) { + vkapi::PipelineBarrier pipeline_barrier{}; + uint32_t buffer_len = utils::safe_downcast(dst_buffer.numel() / 4); + utils::uvec3 global_wg_size = {buffer_len, 1, 1}; + context->submit_compute_job( + VK_KERNEL(int8_tensor_to_nchw_noint8), + pipeline_barrier, + global_wg_size, + adaptive_work_group_size(global_wg_size), + {v_src.packed_dim_whcn_idx()}, + VK_NULL_HANDLE, + 0, + v_src.image(pipeline_barrier, vkapi::PipelineStage::COMPUTE), + dst_buffer.buffer(), + v_src.sizes_ubo(), + v_src.ntexels_ubo()); +} + void record_conv2d_prepack_weights_op( api::Context* const context, vkapi::VulkanBuffer& src_buffer, diff --git a/backends/vulkan/test/utils/test_utils.h b/backends/vulkan/test/utils/test_utils.h index 89e16131c9..3dd9497e69 100644 --- a/backends/vulkan/test/utils/test_utils.h +++ b/backends/vulkan/test/utils/test_utils.h @@ -82,6 +82,11 @@ void record_image_to_nchw_op( api::vTensor& v_src, vkapi::VulkanBuffer& dst_buffer); +void record_int8_image_to_nchw_noint8_op( + api::Context* const context, + api::vTensor& v_src, + api::StorageBuffer& dst_buffer); + void record_conv2d_prepack_weights_op( api::Context* const context, vkapi::VulkanBuffer& src_buffer, diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 9260475ab6..6f0879c422 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -1692,25 +1693,21 @@ void run_from_gpu_test( if (dtype == vkapi::kHalf && !context()->adapter_ptr()->has_16bit_storage()) { return; } - if ((dtype == vkapi::kChar || dtype == vkapi::kQInt8) && - !context()->adapter_ptr()->has_full_int8_buffers_support()) { - return; - } vTensor vten = vTensor(context(), sizes, dtype, storage_type, memory_layout); std::string kernel_name("idx_fill_texture"); - add_memory_layout_suffix(kernel_name, vten); add_dtype_suffix(kernel_name, vten); + int32_t offset = -50; + { vkapi::PipelineBarrier pipeline_barrier{}; - vkapi::SpecVarList specialization_constants = {vten.packed_dim_whcn_idx()}; context()->submit_compute_job( VK_KERNEL_FROM_STR(kernel_name), pipeline_barrier, vten.image_extents(), {4, 4, 4}, - specialization_constants, + {vten.packed_dim_whcn_idx(), offset}, VK_NULL_HANDLE, 0, vten.image( @@ -1722,7 +1719,12 @@ void run_from_gpu_test( StorageBuffer staging_buffer(context(), dtype, vten.gpu_numel()); - record_image_to_nchw_op(context(), vten, staging_buffer.buffer()); + if (dtype == vkapi::kChar && + !context()->adapter_ptr()->has_full_int8_buffers_support()) { + record_int8_image_to_nchw_noint8_op(context(), vten, staging_buffer); + } else { + record_image_to_nchw_op(context(), vten, staging_buffer.buffer()); + } submit_to_gpu(); @@ -1730,12 +1732,12 @@ void run_from_gpu_test( copy_staging_to_ptr(staging_buffer, data_out.data(), staging_buffer.nbytes()); for (int i = 0; i < vten.numel(); i++) { - CHECK_VALUE(data_out, i, i); + CHECK_VALUE(data_out, i, i + offset); } } template -void run_to_gpu_test( +void round_trip_test( std::vector& sizes, utils::GPUMemoryLayout memory_layout = utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, @@ -1744,10 +1746,6 @@ void run_to_gpu_test( if (dtype == vkapi::kHalf && !context()->adapter_ptr()->has_16bit_storage()) { return; } - if ((dtype == vkapi::kChar || dtype == vkapi::kQInt8) && - !context()->adapter_ptr()->has_full_int8_buffers_support()) { - return; - } vTensor vten = vTensor(context(), sizes, dtype, storage_type, memory_layout); @@ -1756,16 +1754,22 @@ void run_to_gpu_test( std::vector data_in(staging_buffer_in.numel()); for (int i = 0; i < staging_buffer_in.numel(); i++) { - data_in[i] = i; + data_in[i] = T(i * -1); } copy_ptr_to_staging(data_in.data(), staging_buffer_in, vten.gpu_nbytes()); // Output staging buffer StorageBuffer staging_buffer_out(context(), dtype, vten.gpu_numel()); - // Copy data in and out of the tensor record_nchw_to_image_op(context(), staging_buffer_in.buffer(), vten); - record_image_to_nchw_op(context(), vten, staging_buffer_out.buffer()); + + // Copy data in and out of the tensor + if (dtype == vkapi::kChar && + !context()->adapter_ptr()->has_full_int8_buffers_support()) { + record_int8_image_to_nchw_noint8_op(context(), vten, staging_buffer_out); + } else { + record_image_to_nchw_op(context(), vten, staging_buffer_out.buffer()); + } // Execute command buffer submit_to_gpu(); @@ -1777,11 +1781,51 @@ void run_to_gpu_test( // All indices should be equal to the input data for (int i = 0; i < vten.numel(); i++) { - CHECK_VALUE(data_out, i, i); + CHECK_VALUE(data_out, i, data_in[i]); + } +} + +template +void compute_graph_round_trip_test( + std::vector& sizes, + utils::GPUMemoryLayout memory_layout = + utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, + vkapi::ScalarType dtype = vkapi::kFloat, + utils::StorageType storage_type = utils::StorageType::TEXTURE_3D) { + if (dtype == vkapi::kHalf && !context()->adapter_ptr()->has_16bit_storage()) { + return; + } + + GraphConfig config; + ComputeGraph graph(config); + + ValueRef r_tensor = + graph.add_tensor(sizes, dtype, storage_type, memory_layout); + ValueRef r_staging_in = graph.set_input_tensor(r_tensor); + ValueRef r_staging_out = graph.set_output_tensor(r_tensor); + + graph.prepare(); + graph.encode_execute(); + + vTensorPtr tensor = graph.get_tensor(r_tensor); + + std::vector data_in(tensor->numel()); + for (int i = 0; i < data_in.size(); i++) { + data_in[i] = T(i * -1); + } + graph.copy_into_staging(r_staging_in, data_in.data(), data_in.size()); + + graph.execute(); + + std::vector data_out(tensor->gpu_numel()); + graph.copy_from_staging(r_staging_out, data_out.data(), data_out.size()); + + for (int i = 0; i < data_in.size(); i++) { + CHECK_VALUE(data_out, i, data_in[i]); } } -TEST(VulkanToFromGPUShaderTest, to_gpu_and_from_gpu_test_texture) { +TEST(VulkanToFromGPUShaderTest, round_trip_tests) { // The below tests will fill each texel element with the value of the linear // buffer index that corresponds to it. The texel at position (0, 0, 0) will // be filled with the values [0, 1, 2, 3], the texel at position (1, 0, 0) @@ -1824,11 +1868,17 @@ TEST(VulkanToFromGPUShaderTest, to_gpu_and_from_gpu_test_texture) { }; #define RUN_TESTS(ctype, dtype) \ - run_to_gpu_test( \ + round_trip_test( \ sizes, utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, dtype); \ - run_to_gpu_test( \ + round_trip_test( \ sizes, utils::GPUMemoryLayout::TENSOR_WIDTH_PACKED, dtype); \ - run_to_gpu_test( \ + round_trip_test( \ + sizes, utils::GPUMemoryLayout::TENSOR_HEIGHT_PACKED, dtype); \ + compute_graph_round_trip_test( \ + sizes, utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, dtype); \ + compute_graph_round_trip_test( \ + sizes, utils::GPUMemoryLayout::TENSOR_WIDTH_PACKED, dtype); \ + compute_graph_round_trip_test( \ sizes, utils::GPUMemoryLayout::TENSOR_HEIGHT_PACKED, dtype); for (auto& sizes : to_test) { @@ -2203,3 +2253,75 @@ TEST(VulkanComputeGraphOpsTest, conv2d_prepack_test) { 0, 3, 9, 0, 0, 6, 12, 0, 0, 5, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); } + +void test_grid_priors( + std::vector input_sizes, + std::vector output_sizes, + int stride, + double offset, + const std::vector& data_out_expected) { + GraphConfig config; + ComputeGraph graph(config); + + // Build graph + IOValueRef in = graph.add_input_tensor( + input_sizes, + vkapi::kFloat, + utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED); + IOValueRef out; + out.value = graph.add_tensor( + output_sizes, + vkapi::kFloat, + utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED); + + VK_GET_OP_FN("et_vk.grid_priors.default") + (graph, + {in.value, + graph.add_scalar(stride), + graph.add_scalar(offset), + out.value}); + + out.staging = graph.set_output_tensor(out.value); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + vTensorPtr t_in = graph.get_tensor(in.value); + vTensorPtr t_out = graph.get_tensor(out.value); + // Resize input + graph.propagate_resize(); + + // run graph + graph.execute(); + + std::vector output_data(t_out->gpu_numel()); + graph.copy_from_staging(out.staging, output_data.data(), output_data.size()); + + // check results + int h_out = utils::val_at(-2, t_out->sizes()); + int w_out = utils::val_at(-1, t_out->sizes()); + for (size_t i = 0; i < h_out; ++i) { + for (size_t j = 0; j < w_out; ++j) { + size_t idx_out = i * w_out + j; + CHECK_VALUE(output_data, idx_out, data_out_expected[idx_out]); + } + } +} + +TEST(VulkanComputeGraphOpsTest, grid_priors_test) { + test_grid_priors( + /*input size = */ {1, 5, 2, 3}, + /*output size = */ {6, 2}, + /*stride = */ 1, + /*offset = */ 0.0, + /*data_out_expected = */ {0, 0, 1, 0, 2, 0, 0, 1, 1, 1, 2, 1}); + + test_grid_priors( + /*input size = */ {1, 5, 2, 3}, + /*output size = */ {6, 2}, + /*stride = */ 8, + /*offset = */ 0.5, + /*data_out_expected = */ {4, 4, 12, 4, 20, 4, 4, 12, 12, 12, 20, 12}); +} diff --git a/backends/vulkan/tools/gpuinfo/TARGETS b/backends/vulkan/tools/gpuinfo/TARGETS index e9dd22e92d..10e3acb4b8 100644 --- a/backends/vulkan/tools/gpuinfo/TARGETS +++ b/backends/vulkan/tools/gpuinfo/TARGETS @@ -23,6 +23,7 @@ buck_filegroup( vulkan_spv_shader_lib( name = "gpuinfo_shader_lib", + is_fbcode = True, spv_filegroups = { ":gpuinfo_shaders": "glsl", }, diff --git a/backends/vulkan/tools/gpuinfo/config.json b/backends/vulkan/tools/gpuinfo/config.json new file mode 100644 index 0000000000..afb5cbc6c5 --- /dev/null +++ b/backends/vulkan/tools/gpuinfo/config.json @@ -0,0 +1,48 @@ +{ + "reg_count": { + "enabled": true, + "threshold": 3, + "compensate": 0.1 + }, + "buf_cacheline_size": { + "enabled": true, + "threshold": 10, + "compensate": 0.1 + }, + "buffer_bandwidth": { + "enabled": true, + "range": 134217728, + "nflush": 4, + "nunroll": 16, + "niter": 10 + }, + "ubo_bandwidth": { + "enabled": true, + "range": 134217728, + "nflush": 4, + "nunroll": 16, + "niter": 10 + }, + "shared_bandwidth": { + "enabled": true, + "nflush": 4, + "nunroll": 16, + "niter": 10 + }, + "warp_size": { + "enabled": true, + "threshold": 3, + "compensate": 0.1 + }, + "tex_bandwidth": { + "enabled": true, + "nflush": 4, + "nunroll": 16, + "niter": 10 + }, + "tex_cacheline_concurr": { + "enabled": true, + "threshold": 3, + "compensate": 0.1 + } +} diff --git a/backends/vulkan/tools/gpuinfo/glsl/buf_bandwidth.glsl b/backends/vulkan/tools/gpuinfo/glsl/buf_bandwidth.glsl index c16ad5d14b..38c9befec6 100644 --- a/backends/vulkan/tools/gpuinfo/glsl/buf_bandwidth.glsl +++ b/backends/vulkan/tools/gpuinfo/glsl/buf_bandwidth.glsl @@ -26,6 +26,11 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; layout(constant_id = 3) const int niter = 1; layout(constant_id = 4) const int nvec = 1; layout(constant_id = 5) const int local_group_size = 1; +// The address mask works as a modulo because x % 2^n == x & (2^n - 1). +// This will help us limit address accessing to a specific set of unique +// addresses depending on the access size we want to measure. +layout(constant_id = 6) const int addr_mask = 1; +layout(constant_id = 7) const int workgroup_width = 1; $if MEMTYPE == "shared": shared vec4 A[nvec]; @@ -36,15 +41,7 @@ void main() { A[gl_LocalInvocationID[0]][0] = gl_LocalInvocationID[0]; memoryBarrierShared(); - // The address mask works as a modulo because x % 2^n == x & (2^n - 1). - // This will help us limit address accessing to a specific set of unique - // addresses depending on the access size we want to measure. - const int addr_mask = nvec - 1; vec4 sum = vec4(0); - - // This is to distribute the accesses to unique addresses across the workgroups, once the - // size of the access excedes the workgroup width. - const uint workgroup_width = local_group_size * niter * ${NUNROLL}; uint offset = (gl_WorkGroupID[0] * workgroup_width + gl_LocalInvocationID[0]) & addr_mask; int i = 0; diff --git a/backends/vulkan/tools/gpuinfo/glsl/tex_bandwidth.glsl b/backends/vulkan/tools/gpuinfo/glsl/tex_bandwidth.glsl new file mode 100644 index 0000000000..7ab67bd2d0 --- /dev/null +++ b/backends/vulkan/tools/gpuinfo/glsl/tex_bandwidth.glsl @@ -0,0 +1,56 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_type(DTYPE)} + +layout(std430) buffer; + +${layout_declare_sampler(0, "r", "A", DTYPE)} +${layout_declare_buffer(1, "w", "B", DTYPE, "PRECISION", False)} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int niter = 1; +layout(constant_id = 4) const int nvec = 1; +layout(constant_id = 5) const int local_group_size = 1; +// The address mask works as a modulo because x % 2^n == x & (2^n - 1). +// This will help us limit address accessing to a specific set of unique +// addresses depending on the access size we want to measure. +layout(constant_id = 6) const int addr_mask = 1; +layout(constant_id = 7) const int workgroup_width = 1; + +void main() { + vec4 sum = vec4(0); + uint offset = (gl_WorkGroupID[0] * workgroup_width + gl_LocalInvocationID[0]) & addr_mask; + + int i = 0; + for (; i < niter; ++i){ + VEC4_T in_texel; + $for j in range(int(NUNROLL)): + $if DIM == 0: + in_texel = texelFetch(A, ivec3(offset, 0, 0), 0); + $elif DIM == 1: + in_texel = texelFetch(A, ivec3(0, offset, 0), 0); + $elif DIM == 2: + in_texel = texelFetch(A, ivec3(0, 0, offset), 0); + + sum *= in_texel; + + // On each unroll, a new unique address will be accessed through the offset, + // limited by the address mask to a specific set of unique addresses + offset = (offset + local_group_size) & addr_mask; + } + + // This is to ensure no compiler optimizations occur + vec4 zero = vec4(i>>31); + + B[gl_LocalInvocationID[0]] = sum + zero; +} diff --git a/backends/vulkan/tools/gpuinfo/glsl/tex_bandwidth.yaml b/backends/vulkan/tools/gpuinfo/glsl/tex_bandwidth.yaml new file mode 100644 index 0000000000..84da6938fd --- /dev/null +++ b/backends/vulkan/tools/gpuinfo/glsl/tex_bandwidth.yaml @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +tex_bandwidth: + parameter_names_with_default_values: + DTYPE: float + NUNROLL: "16" + generate_variant_forall: + DIM: + - RANGE: [0, 2] + shader_variants: + - NAME: tex_bandwidth diff --git a/backends/vulkan/tools/gpuinfo/glsl/tex_cacheline_concurr.glsl b/backends/vulkan/tools/gpuinfo/glsl/tex_cacheline_concurr.glsl new file mode 100644 index 0000000000..62659c7bb8 --- /dev/null +++ b/backends/vulkan/tools/gpuinfo/glsl/tex_cacheline_concurr.glsl @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_type(DTYPE)} + +layout(std430) buffer; + +${layout_declare_sampler(0, "r", "in_tex", DTYPE)} +${layout_declare_buffer(1, "w", "out_buf", DTYPE, "PRECISION", False)} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int niter = 1; + +void main() { + vec4 sum = vec4(0); + int i = 0; + for (; i < niter; ++i){ + $if DIM == 0: + sum += texelFetch(in_tex, ivec3(gl_GlobalInvocationID[0], 0, 0), 0); + $elif DIM == 1: + sum += texelFetch(in_tex, ivec3(0, gl_GlobalInvocationID[0], 0), 0); + $elif DIM == 2: + sum += texelFetch(in_tex, ivec3(0, 0, gl_GlobalInvocationID[0]), 0); + } + + // This is to ensure no compiler optimizations occur + vec4 zero = vec4(i>>31); + + out_buf[0] = sum + zero; +} diff --git a/backends/vulkan/tools/gpuinfo/glsl/tex_cacheline_concurr.yaml b/backends/vulkan/tools/gpuinfo/glsl/tex_cacheline_concurr.yaml new file mode 100644 index 0000000000..6b557c9f66 --- /dev/null +++ b/backends/vulkan/tools/gpuinfo/glsl/tex_cacheline_concurr.yaml @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +tex_cacheline_concurr: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DIM: + - RANGE: [0, 2] + shader_variants: + - NAME: tex_cacheline_concurr diff --git a/backends/vulkan/tools/gpuinfo/include/app.h b/backends/vulkan/tools/gpuinfo/include/app.h new file mode 100644 index 0000000000..a46e9e6b9a --- /dev/null +++ b/backends/vulkan/tools/gpuinfo/include/app.h @@ -0,0 +1,114 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +#include "utils.h" + +namespace gpuinfo { + +class App { + private: + folly::dynamic config_; + + public: + size_t buf_cache_size; + uint32_t max_shared_mem_size; + uint32_t sm_count; + uint32_t nthread_logic; + uint32_t subgroup_size; + uint32_t max_tex_width; + uint32_t max_tex_height; + uint32_t max_tex_depth; + + App() { + context()->initialize_querypool(); + + std::cout << context()->adapter_ptr()->stringize() << std::endl + << std::endl; + + auto cl_device = get_cl_device(); + + sm_count = cl_device.getInfo(); + nthread_logic = cl_device.getInfo(); + buf_cache_size = cl_device.getInfo(); + max_shared_mem_size = cl_device.getInfo(); + max_tex_width = cl_device.getInfo(); + max_tex_height = cl_device.getInfo(); + max_tex_depth = cl_device.getInfo(); + + VkPhysicalDeviceSubgroupProperties subgroup_props{}; + VkPhysicalDeviceProperties2 props2{}; + + props2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2; + props2.pNext = &subgroup_props; + subgroup_props.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES; + vkGetPhysicalDeviceProperties2( + context()->adapter_ptr()->physical_handle(), &props2); + subgroup_size = subgroup_props.subgroupSize; + + std::cout << std::endl; + std::cout << "SM count," << sm_count << std::endl; + std::cout << "Logic Thread Count," << nthread_logic << std::endl; + std::cout << "Cache Size," << buf_cache_size << std::endl; + std::cout << "Shared Memory Size," << max_shared_mem_size << std::endl; + std::cout << "SubGroup Size," << subgroup_size << std::endl; + std::cout << "MaxTexWidth," << max_tex_width << std::endl; + std::cout << "MaxTexHeight," << max_tex_height << std::endl; + std::cout << "MaxTexDepth," << max_tex_depth << std::endl; + } + + float get_config(const std::string& test, const std::string& key) const { + if (config_[test].empty()) { + throw std::runtime_error("Missing config for " + test); + } + + if (!config_[test][key].isNumber()) { + throw std::runtime_error( + "Config for " + test + "." + key + " is not a number"); + } + + float value; + if (config_[test][key].isDouble()) { + value = config_[test][key].getDouble(); + } else { + value = config_[test][key].getInt(); + } + + std::cout << "Read value for " << test << "." << key << " = " << value + << std::endl; + return value; + } + + bool enabled(const std::string& test) const { + if (config_.empty() || config_[test].empty() || + !config_[test]["enabled"].isBool()) { + return true; + } + return config_[test]["enabled"].getBool(); + } + + void load_config(std::string file_path) { + std::ifstream file(file_path); + std::stringstream buffer; + buffer << file.rdbuf(); + const std::string json_str = buffer.str(); + if (json_str.empty()) { + throw std::runtime_error( + "Failed to read config file from " + file_path + "."); + } + config_ = folly::parseJson(json_str); + } +}; +} // namespace gpuinfo diff --git a/backends/vulkan/tools/gpuinfo/include/architecture.h b/backends/vulkan/tools/gpuinfo/include/architecture.h new file mode 100644 index 0000000000..0d312ee87c --- /dev/null +++ b/backends/vulkan/tools/gpuinfo/include/architecture.h @@ -0,0 +1,285 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include "app.h" +#include "stats.h" +#include "utils.h" + +using namespace vkapi; + +namespace gpuinfo { + +void reg_count(const App& app) { + if (!app.enabled("reg_count")) { + std::cout << "Skipped Register Count" << std::endl; + return; + } + + std::cout << std::endl; + std::cout << "------ Register Count ------" << std::endl; + const uint32_t NREG_MIN = 1; + const uint32_t NREG_MAX = 512; + const uint32_t NREG_STEP = 1; + + const double COMPENSATE = app.get_config("reg_count", "compensate"); + const double THRESHOLD = app.get_config("reg_count", "threshold"); + + const uint32_t NGRP_MIN = 1; + const uint32_t NGRP_MAX = 64; + const uint32_t NGRP_STEP = 1; + + uint32_t NITER; + + auto bench = [&](uint32_t ngrp, uint32_t nreg) { + StorageBuffer buffer(context(), vkapi::kFloat, 1); + vkapi::PipelineBarrier pipeline_barrier{}; + + auto shader_name = "reg_count_" + std::to_string(nreg); + + auto time = benchmark_on_gpu(shader_name, 30, [&]() { + context()->submit_compute_job( + VK_KERNEL_FROM_STR(shader_name), + pipeline_barrier, + {1, ngrp, 1}, + {1, 1, 1}, + {SV(NITER)}, + VK_NULL_HANDLE, + 0, + buffer.buffer()); + }); + return time; + }; + + ensure_min_niter(1000, NITER, [&]() { return bench(1, NREG_MIN); }); + + uint32_t nreg_max; + + DtJumpFinder<5> dj(COMPENSATE, THRESHOLD); + uint32_t nreg = NREG_MIN; + for (; nreg <= NREG_MAX; nreg += NREG_STEP) { + double time = bench(1, nreg); + std::cout << "Testing nreg=\t" << nreg << "\tTime=\t" << time << "\tus" + << std::endl; + if (dj.push(time)) { + nreg -= NREG_STEP; + nreg_max = nreg; + break; + } + } + if (nreg >= NREG_MAX) { + std::cout << "Unable to conclude a maximal register count" << std::endl; + nreg_max = NREG_STEP; + } else { + std::cout << nreg_max << " registers are available at most" << std::endl; + } + + auto find_ngrp_by_nreg = [&](const uint32_t nreg) { + DtJumpFinder<3> dj(COMPENSATE, THRESHOLD); + for (auto ngrp = NGRP_MIN; ngrp <= NGRP_MAX; ngrp += NGRP_STEP) { + auto time = bench(ngrp, nreg); + std::cout << "Testing occupation (nreg=\t" << nreg << "\t); ngrp=\t" + << ngrp << "\t, time=\t" << time << "\tus" << std::endl; + + if (dj.push(time)) { + ngrp -= NGRP_STEP; + std::cout << "Using " << nreg << " registers can have " << ngrp + << " concurrent single-thread workgroups" << std::endl; + return ngrp; + } + } + std::cout + << "Unable to conclude a maximum number of concurrent single-thread workgroups when " + << nreg << " registers are occupied" << std::endl; + return (uint32_t)1; + }; + + uint32_t ngrp_full, ngrp_half; + ngrp_full = find_ngrp_by_nreg(nreg_max); + ngrp_half = find_ngrp_by_nreg(nreg_max / 2); + + std::string reg_ty; + + if (ngrp_full * 1.5 < ngrp_half) { + std::cout << "All physical threads in an sm share " << nreg_max + << " registers" << std::endl; + reg_ty = "Pooled"; + + } else { + std::cout << "Each physical thread has " << nreg_max << " registers" + << std::endl; + reg_ty = "Dedicated"; + } + + std::cout << std::endl << std::endl; + std::cout << "MaxRegisters," << nreg_max << std::endl; + std::cout << "ConcurrentWorkgroupsFullReg," << ngrp_full << std::endl; + std::cout << "ConcurrentWorkgroupsHalfReg," << ngrp_half << std::endl; + std::cout << "RegisterType," << reg_ty << std::endl; +} + +// Warp size is a difficult metric to obtain because the hardware limitations +// do not always coincide with the way the SM divides the workload. For +// instance, the hardware can have a warp size of 64 threads, but an SM might +// be able to simulate concurrency of 128 threads with a single scheduler. + +// Because of this, it is important to measure the warp size different ways, +// that can evidence both the physical limitations of the hardware, and the +// actual behavior of the driver. + +// Additionally,the SM can behave in two different ways when the assigned +// workload is smaller than the warp size. + +// In Case 1, like ARM Mali, the SM can assign dummy workloads to fill empty +// threads and maintain a uniform workload. + +// In Case 2, like in Adreno, the driver might decide to pack multiple works +// together and dispatch them at once. +void warp_size(const App& app, const bool verbose = false) { + if (!app.enabled("warp_size")) { + std::cout << "Skipped Warp Size" << std::endl; + return; + } + + std::cout << "\n------ Warp Size ------" << std::endl; + + // Method A: Stress test with a kernel that uses complex ALU operations like + // integer division to avoid latency hiding. Increase the number of threads + // until a jump in latency is detected. + + // This timing-based method helps us identify physical warp sizes. It also + // helps with Case 2, when threads of multiple warps are managed by the same + // scheduler at the same time. + const double COMPENSATE = app.get_config("warp_size", "compensate"); + const double THRESHOLD = app.get_config("warp_size", "threshold"); + + uint32_t NITER; + + auto bench = [&](uint32_t nthread) { + StorageBuffer out_buf(context(), vkapi::kInt, app.nthread_logic); + vkapi::PipelineBarrier pipeline_barrier{}; + + auto shader_name = "warp_size_physical"; + + auto time = benchmark_on_gpu(shader_name, 10, [&]() { + context()->submit_compute_job( + VK_KERNEL_FROM_STR(shader_name), + pipeline_barrier, + // Large number of work groups selected to potentially saturate all + // ALUs and thus have a better baseline for comparison. + {nthread, 1024, 1}, + {nthread, 1, 1}, + {SV(NITER)}, + VK_NULL_HANDLE, + 0, + out_buf.buffer()); + }); + + return time; + }; + + ensure_min_niter(1000, NITER, [&]() { return bench(1); }); + + uint32_t warp_size = app.subgroup_size; + DtJumpFinder<5> dj(COMPENSATE, THRESHOLD); + + // We increase the number of threads until we hit a jump in the data. + uint32_t nthread = 1; + for (; nthread <= app.nthread_logic; ++nthread) { + double time = bench(nthread); + std::cout << "nthread=\t" << nthread << "\t(\t" << time << "\tus)" + << std::endl; + if (dj.push(time)) { + warp_size = nthread - 1; + break; + } + } + if (nthread >= app.nthread_logic) { + std::cout + << "Unable to conclude a physical warp size. Assuming warp_size == subgroup_size" + << std::endl; + } + + // Method B: Let all the threads in a warp race and atomically fetch-add + // a counter, then store the counter values to the output buffer in the + // scheduling order of these threads. If all the order numbers follow an + // ascending order, then the threads are likely executing within a warp. + // Threads in different warps are not managed by the same scheduler, so they + // would race for a same ID out of order, unaware of each other. + + // This method evidences the actual driver behavior when running + // concurrency, regardless of the physical limitations of the hardware. + + // Likewise, this method helps us identify warp sizes when the SM + // sub-divides its ALUs into independent groups, like the three execution + // engines in a Mali G76 core. It helps warp-probing in Case 1 because it + // doesn't depend on kernel timing, so the extra wait time doesn't lead to + // inaccuracy. + auto bench_sm = [&](uint32_t nthread) { + StorageBuffer out_buf(context(), vkapi::kInt, app.nthread_logic); + vkapi::PipelineBarrier pipeline_barrier{}; + + auto shader_name = "warp_size_scheduler"; + + benchmark_on_gpu(shader_name, 1, [&]() { + context()->submit_compute_job( + VK_KERNEL_FROM_STR(shader_name), + pipeline_barrier, + {nthread, 1, 1}, + {nthread, 1, 1}, + {}, + VK_NULL_HANDLE, + 0, + out_buf.buffer()); + }); + + std::vector data(app.nthread_logic); + copy_staging_to_ptr(out_buf, data.data(), out_buf.nbytes()); + + if (verbose) { + std::stringstream ss; + for (auto j = 0; j < nthread; ++j) { + ss << data[j] << " "; + } + std::cout << ss.str() << std::endl; + } + + // Check until which point is the data in ascending order. + int32_t last = -1; + int32_t j = 0; + for (; j < nthread; ++j) { + if (last >= data[j]) { + break; + } + last = data[j]; + } + + return j; + }; + + // Test increasing sizes until the data is no longer in ascending order. + uint32_t warp_size_scheduler = warp_size; + int i = 1; + for (; i <= app.nthread_logic; ++i) { + uint32_t nascend = bench_sm(i); + if (nascend != i) { + warp_size_scheduler = nascend; + break; + } + } + if (i > app.nthread_logic) { + std::cout << "Unable to conclude an SM Warp Size." << std::endl; + } + + std::cout << "PhysicalWarpSize," << warp_size << std::endl; + std::cout << "SMWarpSize," << warp_size_scheduler << std::endl; +} +}; // namespace gpuinfo diff --git a/backends/vulkan/tools/gpuinfo/include/buffers.h b/backends/vulkan/tools/gpuinfo/include/buffers.h new file mode 100644 index 0000000000..c8cf93c4a1 --- /dev/null +++ b/backends/vulkan/tools/gpuinfo/include/buffers.h @@ -0,0 +1,216 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include "app.h" +#include "stats.h" +#include "utils.h" + +using namespace vkapi; + +namespace gpuinfo { + +void buf_cacheline_size(const App& app) { + if (!app.enabled("buf_cacheline_size")) { + std::cout << "Skipped Buffer Cacheline Size" << std::endl; + return; + } + + std::cout << std::endl; + std::cout << "------ Buffer Cacheline Size ------" << std::endl; + + const double COMPENSATE = app.get_config("buf_cacheline_size", "compensate"); + const double THRESHOLD = app.get_config("buf_cacheline_size", "threshold"); + + const uint32_t PITCH = app.buf_cache_size / app.nthread_logic; + const uint32_t BUF_SIZE = app.buf_cache_size; + const uint32_t MAX_STRIDE = PITCH; + + uint32_t NITER; + + auto bench = [&](int stride) { + StorageBuffer in_buf(context(), vkapi::kFloat, BUF_SIZE); + StorageBuffer out_buf(context(), vkapi::kFloat, 1); + vkapi::PipelineBarrier pipeline_barrier{}; + + auto shader_name = "buf_cacheline_size"; + + auto time = benchmark_on_gpu(shader_name, 100, [&]() { + context()->submit_compute_job( + VK_KERNEL_FROM_STR(shader_name), + pipeline_barrier, + {app.nthread_logic, 1, 1}, + {app.nthread_logic, 1, 1}, + {SV(NITER), SV(stride), SV(PITCH)}, + VK_NULL_HANDLE, + 0, + in_buf.buffer(), + out_buf.buffer()); + }); + return time; + }; + + ensure_min_niter(1000, NITER, [&]() { return bench(1); }); + + uint32_t cacheline_size; + + DtJumpFinder<5> dj(COMPENSATE, THRESHOLD); + uint32_t stride = 1; + for (; stride <= MAX_STRIDE; ++stride) { + double time = bench(stride); + std::cout << "Testing stride=\t" << stride << "\t, time=\t" << time + << std::endl; + + if (dj.push(time)) { + cacheline_size = stride * sizeof(float); + break; + } + } + if (stride >= MAX_STRIDE) { + std::cout << "Unable to conclude a top level buffer cacheline size." + << std::endl; + cacheline_size = MAX_STRIDE * sizeof(float); + } + + std::cout << "BufTopLevelCachelineSize," << cacheline_size << std::endl; +} + +void _bandwidth( + const App& app, + const std::string memtype, + const uint32_t range) { + auto memtype_lower = memtype; + std::transform( + memtype_lower.begin(), + memtype_lower.end(), + memtype_lower.begin(), + [](unsigned char c) { return std::tolower(c); }); + + auto test_name = memtype_lower + "_bandwidth"; + + // Cache lines flushed + const uint32_t NFLUSH = app.get_config(test_name, "nflush"); + // Number of loop unrolls. Changing this value requires an equal change in + // buf_bandwidth.yaml + const uint32_t NUNROLL = app.get_config(test_name, "nunroll"); + // Number of iterations. Increasing this value reduces noise in exchange for + // higher latency. + const uint32_t NITER = app.get_config(test_name, "niter"); + // Vector dimensions (vec4) + const uint32_t VEC_WIDTH = 4; + const uint32_t VEC_SIZE = VEC_WIDTH * sizeof(float); + // Number of vectors that fit in the selected memory space + const uint32_t NVEC = range / VEC_SIZE; + // Number of memory reads per thread + const uint32_t NREAD_PER_THREAD = NUNROLL * NITER; + // Number of threads needed to read al l vectors + // The thread count doesn't divide by thread workload in shared memory + // because of the limited memory size. + const uint32_t NTHREAD = memtype == "Shared" ? NVEC : NVEC / NREAD_PER_THREAD; + // Occupy all threads + const uint32_t local_x = app.nthread_logic; + // Ensure that global is a multiple of local, and distribute across all SMs + const uint32_t global_x = + (NTHREAD / local_x * local_x) * app.sm_count * NFLUSH; + + auto bench = [&](uint32_t access_size) { + // Number of vectors that fit in this iteration + const uint32_t nvec_access = access_size / VEC_SIZE; + + // The address mask works as a modulo because x % 2^n == x & (2^n - 1). + // This will help us limit address accessing to a specific set of unique + // addresses depending on the access size we want to measure. + const uint32_t addr_mask = nvec_access - 1; + + // This is to distribute the accesses to unique addresses across the + // workgroups, once the size of the access excedes the workgroup width. + const uint32_t workgroup_width = local_x * NITER * NUNROLL; + + StorageBuffer in_buf(context(), vkapi::kFloat, range / sizeof(float)); + StorageBuffer out_buf( + context(), vkapi::kFloat, VEC_WIDTH * app.nthread_logic); + vkapi::PipelineBarrier pipeline_barrier{}; + + auto shader_name = "buf_bandwidth_" + memtype_lower; + + auto time = benchmark_on_gpu(shader_name, 10, [&]() { + context()->submit_compute_job( + VK_KERNEL_FROM_STR(shader_name), + pipeline_barrier, + {global_x, 1, 1}, + {local_x, 1, 1}, + {SV(NITER), + SV(nvec_access), + SV(local_x), + SV(addr_mask), + SV(workgroup_width)}, + VK_NULL_HANDLE, + 0, + in_buf.buffer(), + out_buf.buffer()); + }); + + const uint32_t SIZE_TRANS = global_x * NREAD_PER_THREAD * VEC_SIZE; + auto gbps = SIZE_TRANS * 1e-3 / time; + std::cout << memtype << " bandwidth accessing \t" << access_size + << "\tB unique data is \t" << gbps << " \tgbps (\t" << time + << "\tus)" << std::endl; + return gbps; + }; + + double max_bandwidth = 0; + double min_bandwidth = DBL_MAX; + for (uint32_t access_size = VEC_SIZE; access_size < range; access_size *= 2) { + double gbps = bench(access_size); + max_bandwidth = std::max(gbps, max_bandwidth); + min_bandwidth = std::min(gbps, min_bandwidth); + } + + std::cout << "Max" << memtype << "Bandwidth (GB/s)," << max_bandwidth + << std::endl; + std::cout << "Min" << memtype << "Bandwidth (GB/s)," << min_bandwidth + << std::endl; +} + +void buf_bandwidth(const App& app) { + if (!app.enabled("buffer_bandwidth")) { + std::cout << "Skipped Memory Bandwidth" << std::endl; + return; + } + + std::cout << "\n------ Memory Bandwidth ------" << std::endl; + // Maximum memory space read - 128MB + // For regular devices, bandwidth plateaus at less memory than this, so more + // is not needed. + const uint32_t RANGE = app.get_config("buffer_bandwidth", "range"); + _bandwidth(app, "Buffer", RANGE); +} + +void ubo_bandwidth(const App& app) { + if (!app.enabled("ubo_bandwidth")) { + std::cout << "Skipped UBO Bandwidth" << std::endl; + return; + } + + std::cout << "\n------ UBO Bandwidth ------" << std::endl; + const uint32_t RANGE = app.get_config("ubo_bandwidth", "range"); + _bandwidth(app, "UBO", RANGE); +} + +void shared_mem_bandwidth(const App& app) { + if (!app.enabled("shared_bandwidth")) { + std::cout << "Skipped Shared Memory Bandwidth" << std::endl; + return; + } + + std::cout << "\n------ Shared Bandwidth ------" << std::endl; + const uint32_t RANGE = app.max_shared_mem_size; + _bandwidth(app, "Shared", RANGE); +} +} // namespace gpuinfo diff --git a/backends/vulkan/tools/gpuinfo/include/textures.h b/backends/vulkan/tools/gpuinfo/include/textures.h new file mode 100644 index 0000000000..7679f11b0c --- /dev/null +++ b/backends/vulkan/tools/gpuinfo/include/textures.h @@ -0,0 +1,220 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include "app.h" +#include "stats.h" +#include "utils.h" + +namespace gpuinfo { + +// Textures are drastically different from buffers in terms of data layout. +// While buffers are a contiguous range of memory, textures are opaque objects +// defined by the vendor and it is possible that nearby points of data are not +// neighboring in memory. Likewise, data points are accessed in +// multi-dimensional patches instead of simple lines. This makes the stride +// method for figuring out the cache line size not applicable. To go around +// this, this experiment runs an increasing amount of threads accessing +// different datapoints in the texture and measures latency. If the cache line +// is big enough to contain all requested data for the amount of threads, +// latency will be low. When there are more threads and hence more data than +// what a single cache line can handle, a second line must be fetched, +// increasing latency in a measurable way. +void tex_cacheline_concurr(const App& app) { + if (!app.enabled("tex_cacheline_concurr")) { + std::cout << "Skipped Texture Cacheline Optimal Concurrency" << std::endl; + return; + } + + const uint32_t TEXEL_WIDTH = 4; + const uint32_t TEXEL_SIZE = sizeof(float) * TEXEL_WIDTH; + + const double COMPENSATE = + app.get_config("tex_cacheline_concurr", "compensate"); + const double THRESHOLD = app.get_config("tex_cacheline_concurr", "threshold"); + + for (int dim = 0; dim < 3; ++dim) { + std::cout << std::endl; + std::cout << "------ Texture Cacheline Optimal Concurrency (dim = " << dim + << ") ------" << std::endl; + + uint32_t NITER; + + const uint32_t IMG_OTHER_EDGE = dim == 0 ? app.max_tex_width + : dim == 1 ? app.max_tex_height + : app.max_tex_depth; + + const uint32_t MAX_NTHREAD = std::min(app.nthread_logic, IMG_OTHER_EDGE); + + auto bench = [&](uint32_t nthread) { + std::vector sizes_whd = { + app.max_tex_width, app.max_tex_height, app.max_tex_depth}; + + auto sizes_nchw = whd_to_nchw(sizes_whd); + + vTensor in_tensor = + api::vTensor(api::context(), sizes_nchw, vkapi::kFloat); + + StorageBuffer out_buf(context(), vkapi::kFloat, TEXEL_WIDTH); + + vkapi::PipelineBarrier pipeline_barrier{}; + + auto shader_name = "tex_cacheline_concurr_" + std::to_string(dim); + + auto time = benchmark_on_gpu(shader_name, 100, [&]() { + context()->submit_compute_job( + VK_KERNEL_FROM_STR(shader_name), + pipeline_barrier, + {nthread, 1, 1}, + {nthread, 1, 1}, + {SV(NITER)}, + VK_NULL_HANDLE, + 0, + in_tensor.image(), + out_buf.buffer()); + }); + return time; + }; + + ensure_min_niter(1000, NITER, [&]() { return bench(1); }); + + DtJumpFinder<5> dj(COMPENSATE, THRESHOLD); + uint32_t nthread = 1; + for (; nthread <= MAX_NTHREAD; ++nthread) { + double time = bench(nthread); + std::cout << "Testing nthread=\t" << nthread << "\t, time=\t" << time + << std::endl; + + if (dj.push(time)) { + auto max_concurrency = nthread - 1; + std::cout << "TextureCachelineConcurrencyDim" << dim << " (B)," + << max_concurrency * TEXEL_SIZE << std::endl; + break; + } + } + if (nthread >= MAX_NTHREAD) { + std::cout + << "Unable to conclude an optimal texture cacheline concurrency for dim " + << dim << std::endl; + }; + } + + // TODO: Use concurrency information to obtain the cache line size for + // textures as done in https://fburl.com/98xiou3g +} + +void tex_bandwidth(const App& app) { + if (!app.enabled("tex_bandwidth")) { + std::cout << "Skipped Texture Bandwidth" << std::endl; + return; + } + + for (int dim = 0; dim < 3; dim++) { + std::cout << "\n------ Texture Bandwidth (Dim = " << dim << ") ------" + << std::endl; + const uint32_t MAX_SIZE = dim == 0 ? app.max_tex_width + : dim == 1 ? app.max_tex_height + : app.max_tex_depth; + + // rgba, float + const uint32_t VEC_WIDTH = 4; + const uint32_t VEC_SIZE = VEC_WIDTH * sizeof(float); + const uint32_t NVEC = MAX_SIZE; + + const uint32_t RANGE = NVEC * VEC_SIZE; + + // Cache lines flushed + const uint32_t NFLUSH = app.get_config("tex_bandwidth", "nflush"); + // Number of loop unrolls. Changing this value requires an equal change in + // tex_bandwidth.yaml + const uint32_t NUNROLL = app.get_config("tex_bandwidth", "nunroll"); + // Number of iterations. Increasing this value reduces noise in exchange + // for higher latency. + const uint32_t NITER = app.get_config("tex_bandwidth", "niter"); + // Number of memory reads per thread + const uint32_t NREAD_PER_THREAD = NUNROLL * NITER; + // Number of threads needed to read all texells + const uint32_t NTHREAD = NVEC; + // Occupy all threads + const uint32_t local_x = app.nthread_logic; + // Ensure that global is a multiple of local, and distribute across all + // SMs + const uint32_t global_x = + (NTHREAD / local_x * local_x) * app.sm_count * NFLUSH; + + auto shader_name = "tex_bandwidth_" + std::to_string(dim); + + std::vector sizes_whd = {MAX_SIZE, 1, 1}; + if (dim == 1) { + sizes_whd = {1, MAX_SIZE, 1}; + } else if (dim == 2) { + sizes_whd = {1, 1, MAX_SIZE}; + } + auto sizes_nchw = whd_to_nchw(sizes_whd); + + vTensor in_tensor = api::vTensor(api::context(), sizes_nchw, vkapi::kFloat); + + auto bench = [&](uint32_t access_size, uint32_t dim) { + // Number of texels that fit in this iteration + const uint32_t ntexel_access = access_size / VEC_SIZE; + + // The address mask works as a modulo because x % 2^n == x & (2^n - 1). + // This will help us limit address accessing to a specific set of unique + // addresses depending on the access size we want to measure. + const uint32_t addr_mask = ntexel_access - 1; + + // This is to distribute the accesses to unique addresses across the + // workgroups, once the size of the access excedes the workgroup width. + const uint32_t workgroup_width = local_x * NITER * NUNROLL; + + StorageBuffer out_buf( + context(), vkapi::kFloat, VEC_WIDTH * app.nthread_logic); + vkapi::PipelineBarrier pipeline_barrier{}; + + auto time = benchmark_on_gpu(shader_name, 10, [&]() { + context()->submit_compute_job( + VK_KERNEL_FROM_STR(shader_name), + pipeline_barrier, + {global_x, 1, 1}, + {local_x, 1, 1}, + {SV(NITER), + SV(ntexel_access), + SV(local_x), + SV(addr_mask), + SV(workgroup_width)}, + VK_NULL_HANDLE, + 0, + in_tensor.image(), + out_buf.buffer()); + }); + + const uint32_t SIZE_TRANS = global_x * NREAD_PER_THREAD * VEC_SIZE; + double gbps = SIZE_TRANS * 1e-3 / time; + std::cout << "Texture bandwidth accessing \t" << access_size + << "\tB unique data is \t" << gbps << " \tgbps (\t" << time + << "\tus)" << std::endl; + return gbps; + }; + + double max_bandwidth = 0; + double min_bandwidth = DBL_MAX; + for (uint32_t access_size = VEC_SIZE; access_size < RANGE; + access_size *= 2) { + double gbps = bench(access_size, dim); + max_bandwidth = std::max(gbps, max_bandwidth); + min_bandwidth = std::min(gbps, min_bandwidth); + } + + std::cout << "MaxTextureBandwidthDim" << dim << "(GB/s)," << max_bandwidth + << std::endl; + std::cout << "MinTextureBandwidthDim" << dim << "(GB/s)," << min_bandwidth + << std::endl; + } +} +} // namespace gpuinfo diff --git a/backends/vulkan/tools/gpuinfo/include/utils.h b/backends/vulkan/tools/gpuinfo/include/utils.h index 231fb32c5a..887cb443ef 100644 --- a/backends/vulkan/tools/gpuinfo/include/utils.h +++ b/backends/vulkan/tools/gpuinfo/include/utils.h @@ -54,6 +54,15 @@ void ensure_min_niter( } } +std::vector whd_to_nchw(std::vector sizes) { + const int64_t W = sizes[0]; + const int64_t H = sizes[1]; + const int64_t D = sizes[2]; + + // Channels-packed: {W, H, D} = {W, H, (C / 4) * N} + return {1, D * 4, H, W}; +} + cl_platform_id get_cl_platform_id() { cl_uint nplatform_id; clGetPlatformIDs(0, nullptr, &nplatform_id); diff --git a/backends/vulkan/tools/gpuinfo/src/app.cpp b/backends/vulkan/tools/gpuinfo/src/app.cpp deleted file mode 100644 index 8facdb5160..0000000000 --- a/backends/vulkan/tools/gpuinfo/src/app.cpp +++ /dev/null @@ -1,497 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -#include "stats.h" -#include "utils.h" - -using namespace vkapi; - -class App { - private: - size_t buf_cache_size_; - uint32_t max_shared_mem_size_; - uint32_t sm_count_; - uint32_t nthread_logic_; - uint32_t subgroup_size_; - - public: - App() { - context()->initialize_querypool(); - - std::cout << context()->adapter_ptr()->stringize() << std::endl - << std::endl; - - auto cl_device = get_cl_device(); - - sm_count_ = cl_device.getInfo(); - nthread_logic_ = cl_device.getInfo(); - buf_cache_size_ = cl_device.getInfo(); - max_shared_mem_size_ = cl_device.getInfo(); - - VkPhysicalDeviceSubgroupProperties subgroup_props{}; - VkPhysicalDeviceProperties2 props2{}; - - props2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2; - props2.pNext = &subgroup_props; - subgroup_props.sType = - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES; - vkGetPhysicalDeviceProperties2( - context()->adapter_ptr()->physical_handle(), &props2); - subgroup_size_ = subgroup_props.subgroupSize; - - std::cout << std::endl; - std::cout << "SM count," << sm_count_ << std::endl; - std::cout << "Logic Thread Count," << nthread_logic_ << std::endl; - std::cout << "Cache Size," << buf_cache_size_ << std::endl; - std::cout << "Shared Memory Size," << max_shared_mem_size_ << std::endl; - std::cout << "SubGroup Size," << subgroup_size_ << std::endl; - } - - void reg_count() { - std::cout << std::endl; - std::cout << "------ Register Count ------" << std::endl; - const uint32_t NREG_MIN = 1; - const uint32_t NREG_MAX = 512; - const uint32_t NREG_STEP = 1; - - // TODO: Make these values configurable - const double COMPENSATE = 0.01; - const double THRESHOLD = 3; - - const uint32_t NGRP_MIN = 1; - const uint32_t NGRP_MAX = 64; - const uint32_t NGRP_STEP = 1; - - uint32_t NITER; - - auto bench = [&](uint32_t ngrp, uint32_t nreg) { - StorageBuffer buffer(context(), vkapi::kFloat, 1); - vkapi::PipelineBarrier pipeline_barrier{}; - - auto shader_name = "reg_count_" + std::to_string(nreg); - - auto time = benchmark_on_gpu(shader_name, 100, [&]() { - context()->submit_compute_job( - VK_KERNEL_FROM_STR(shader_name), - pipeline_barrier, - {1, ngrp, 1}, - {1, 1, 1}, - {SV(NITER)}, - VK_NULL_HANDLE, - 0, - buffer.buffer()); - }); - return time; - }; - - std::cout << "Calculating NITER..." << std::endl; - ensure_min_niter(1000, NITER, [&]() { return bench(1, NREG_MIN); }); - std::cout << "NITER," << NITER << std::endl; - - uint32_t nreg_max; - - DtJumpFinder<5> dj(COMPENSATE, THRESHOLD); - uint32_t nreg = NREG_MIN; - for (; nreg <= NREG_MAX; nreg += NREG_STEP) { - double time = bench(1, nreg); - std::cout << "Testing nreg=\t" << nreg << "\tTime=\t" << time - << std::endl; - if (dj.push(time)) { - nreg -= NREG_STEP; - nreg_max = nreg; - break; - } - } - if (nreg >= NREG_MAX) { - std::cout << "Unable to conclude a maximal register count" << std::endl; - nreg_max = NREG_STEP; - } else { - std::cout << nreg_max << " registers are available at most" << std::endl; - } - - auto find_ngrp_by_nreg = [&](const uint32_t nreg) { - DtJumpFinder<5> dj(COMPENSATE, THRESHOLD); - for (auto ngrp = NGRP_MIN; ngrp <= NGRP_MAX; ngrp += NGRP_STEP) { - auto time = bench(ngrp, nreg); - std::cout << "Testing occupation (nreg=" << nreg << "); ngrp=" << ngrp - << ", time=" << time << " us" << std::endl; - - if (dj.push(time)) { - ngrp -= NGRP_STEP; - std::cout << "Using " << nreg << " registers can have " << ngrp - << " concurrent single-thread workgroups" << std::endl; - return ngrp; - } - } - std::cout - << "Unable to conclude a maximum number of concurrent single-thread workgroups when " - << nreg << " registers are occupied" << std::endl; - return (uint32_t)1; - }; - - uint32_t ngrp_full, ngrp_half; - ngrp_full = find_ngrp_by_nreg(nreg_max); - ngrp_half = find_ngrp_by_nreg(nreg_max / 2); - - std::string reg_ty; - - if (ngrp_full * 1.5 < ngrp_half) { - std::cout << "All physical threads in an sm share " << nreg_max - << " registers" << std::endl; - reg_ty = "Pooled"; - - } else { - std::cout << "Each physical thread has " << nreg_max << " registers" - << std::endl; - reg_ty = "Dedicated"; - } - - std::cout << std::endl << std::endl; - std::cout << "NITER," << NITER << std::endl; - std::cout << "Max registers," << nreg_max << std::endl; - std::cout << "Concurrent full single thread workgroups," << ngrp_full - << std::endl; - std::cout << "Concurrent half single thread workgroups," << ngrp_half - << std::endl; - std::cout << "Register type," << reg_ty << std::endl; - } - - void buf_cacheline_size() { - std::cout << std::endl; - std::cout << "------ Buffer Cacheline Size ------" << std::endl; - - // TODO: Make these values configurable - const double COMPENSATE = 0.01; - const double THRESHOLD = 10; - - const uint32_t PITCH = buf_cache_size_ / nthread_logic_; - const uint32_t BUF_SIZE = buf_cache_size_; - const uint32_t MAX_STRIDE = PITCH; - - uint32_t NITER; - - auto bench = [&](int stride) { - StorageBuffer in_buf(context(), vkapi::kFloat, BUF_SIZE); - StorageBuffer out_buf(context(), vkapi::kFloat, 1); - vkapi::PipelineBarrier pipeline_barrier{}; - - auto shader_name = "buf_cacheline_size"; - - auto time = benchmark_on_gpu(shader_name, 100, [&]() { - context()->submit_compute_job( - VK_KERNEL_FROM_STR(shader_name), - pipeline_barrier, - {nthread_logic_, 1, 1}, - {nthread_logic_, 1, 1}, - {SV(NITER), SV(stride), SV(PITCH)}, - VK_NULL_HANDLE, - 0, - in_buf.buffer(), - out_buf.buffer()); - }); - return time; - }; - - ensure_min_niter(1000, NITER, [&]() { return bench(1); }); - - uint32_t cacheline_size; - - DtJumpFinder<5> dj(COMPENSATE, THRESHOLD); - uint32_t stride = 1; - for (; stride <= MAX_STRIDE; ++stride) { - double time = bench(stride); - std::cout << "Testing stride=\t" << stride << "\t, time=\t" << time - << std::endl; - - if (dj.push(time)) { - cacheline_size = stride * sizeof(float); - break; - } - } - if (stride >= MAX_STRIDE) { - std::cout << "Unable to conclude a top level buffer cacheline size." - << std::endl; - cacheline_size = MAX_STRIDE; - } - - std::cout << "BufTopLevelCachelineSize," << cacheline_size << std::endl; - } - - private: - void _bandwidth(std::string memtype, uint32_t range) { - // TODO: Make these values configurable - // Cache lines flushed - const uint32_t NFLUSH = 4; - // Number of loop unrolls. Changing this value requires an equal change in - // buf_bandwidth.yaml - const uint32_t NUNROLL = 16; - // Number of iterations. Increasing this value reduces noise in exchange for - // higher latency. - const uint32_t NITER = 10; - // Vector dimensions (vec4) - const uint32_t VEC_WIDTH = 4; - const uint32_t VEC_SIZE = VEC_WIDTH * sizeof(float); - // Number of vectors that fit in the selected memory space - const uint32_t NVEC = range / VEC_SIZE; - // Number of memory reads per thread - const uint32_t NREAD_PER_THREAD = NUNROLL * NITER; - // Number of threads needed to read al l vectors - // The thread count doesn't divide by thread workload in shared memory - // because of the limited memory size. - const uint32_t NTHREAD = - memtype == "Shared" ? NVEC : NVEC / NREAD_PER_THREAD; - // Occupy all threads - const uint32_t local_x = nthread_logic_; - // Ensure that global is a multiple of local, and distribute across all SMs - const uint32_t global_x = - (NTHREAD / local_x * local_x) * sm_count_ * NFLUSH; - - auto bench = [&](uint32_t access_size) { - // Number of vectors that fit in this iteration - const uint32_t nvec_access = access_size / VEC_SIZE; - - StorageBuffer in_buf(context(), vkapi::kFloat, range / sizeof(float)); - StorageBuffer out_buf( - context(), vkapi::kFloat, VEC_WIDTH * nthread_logic_); - vkapi::PipelineBarrier pipeline_barrier{}; - - auto memtype_lower = memtype; - std::transform( - memtype_lower.begin(), - memtype_lower.end(), - memtype_lower.begin(), - [](unsigned char c) { return std::tolower(c); }); - auto shader_name = "buf_bandwidth_" + memtype_lower; - - auto time = benchmark_on_gpu(shader_name, 10, [&]() { - context()->submit_compute_job( - VK_KERNEL_FROM_STR(shader_name), - pipeline_barrier, - {global_x, 1, 1}, - {local_x, 1, 1}, - {SV(NITER), SV(nvec_access), SV(local_x)}, - VK_NULL_HANDLE, - 0, - in_buf.buffer(), - out_buf.buffer()); - }); - - const uint32_t SIZE_TRANS = global_x * NREAD_PER_THREAD * VEC_SIZE; - auto gbps = SIZE_TRANS * 1e-3 / time; - std::cout << memtype << " bandwidth accessing \t" << access_size - << "\tB unique data is \t" << gbps << " \tgbps (\t" << time - << "\tus)" << std::endl; - return gbps; - }; - - double max_bandwidth = 0; - double min_bandwidth = DBL_MAX; - for (uint32_t access_size = VEC_SIZE; access_size < range; - access_size *= 2) { - double gbps = bench(access_size); - max_bandwidth = std::max(gbps, max_bandwidth); - min_bandwidth = std::min(gbps, min_bandwidth); - } - - std::cout << "Max" << memtype << "Bandwidth (GB/s)," << max_bandwidth - << std::endl; - std::cout << "Min" << memtype << "Bandwidth (GB/s)," << min_bandwidth - << std::endl; - } - - public: - void buf_bandwidth() { - std::cout << "\n------ Memory Bandwidth ------" << std::endl; - // Maximum memory space read - 128MB - // For regular devices, bandwidth plateaus at less memory than this, so more - // is not needed. - const uint32_t RANGE = 128 * 1024 * 1024; - _bandwidth("Buffer", RANGE); - } - - void ubo_bandwidth() { - std::cout << "\n------ UBO Bandwidth ------" << std::endl; - const uint32_t RANGE = 128 * 1024 * 1024; - _bandwidth("UBO", RANGE); - } - void shared_mem_bandwidth() { - std::cout << "\n------ Shared Bandwidth ------" << std::endl; - const uint32_t RANGE = max_shared_mem_size_; - _bandwidth("Shared", RANGE); - } - - // Warp size is a difficult metric to obtain because the hardware limitations - // do not always coincide with the way the SM divides the workload. For - // instance, the hardware can have a warp size of 64 threads, but an SM might - // be able to simulate concurrency of 128 threads with a single scheduler. - - // Because of this, it is important to measure the warp size different ways, - // that can evidence both the physical limitations of the hardware, and the - // actual behavior of the driver. - - // Additionally,the SM can behave in two different ways when the assigned - // workload is smaller than the warp size. - - // In Case 1, like ARM Mali, the SM can assign dummy workloads to fill empty - // threads and maintain a uniform workload. - - // In Case 2, like in Adreno, the driver might decide to pack multiple works - // together and dispatch them at once. - void warp_size(bool verbose = false) { - std::cout << "\n------ Warp Size ------" << std::endl; - - // Method A: Stress test with a kernel that uses complex ALU operations like - // integer division to avoid latency hiding. Increase the number of threads - // until a jump in latency is detected. - - // This timing-based method helps us identify physical warp sizes. It also - // helps with Case 2, when threads of multiple warps are managed by the same - // scheduler at the same time. - const double COMPENSATE = 0.01; - const double THRESHOLD = 3; - - uint32_t NITER; - - auto bench = [&](uint32_t nthread) { - StorageBuffer out_buf(context(), vkapi::kInt, nthread_logic_); - vkapi::PipelineBarrier pipeline_barrier{}; - - auto shader_name = "warp_size_physical"; - - auto time = benchmark_on_gpu(shader_name, 10, [&]() { - context()->submit_compute_job( - VK_KERNEL_FROM_STR(shader_name), - pipeline_barrier, - // Large number of work groups selected to potentially saturate all - // ALUs and thus have a better baseline for comparison. - {nthread, 1024, 1}, - {nthread, 1, 1}, - {SV(NITER)}, - VK_NULL_HANDLE, - 0, - out_buf.buffer()); - }); - - return time; - }; - - ensure_min_niter(1000, NITER, [&]() { return bench(1); }); - - uint32_t warp_size = subgroup_size_; - DtJumpFinder<5> dj(COMPENSATE, THRESHOLD); - - // We increase the number of threads until we hit a jump in the data. - uint32_t nthread = 1; - for (; nthread <= nthread_logic_; ++nthread) { - double time = bench(nthread); - std::cout << "nthread=\t" << nthread << "\t(\t" << time << "\tus)" - << std::endl; - if (dj.push(time)) { - warp_size = nthread - 1; - break; - } - } - if (nthread >= nthread_logic_) { - std::cout - << "Unable to conclude a physical warp size. Assuming warp_size == subgroup_size" - << std::endl; - } - - // Method B: Let all the threads in a warp race and atomically fetch-add - // a counter, then store the counter values to the output buffer in the - // scheduling order of these threads. If all the order numbers follow an - // ascending order, then the threads are likely executing within a warp. - // Threads in different warps are not managed by the same scheduler, so they - // would race for a same ID out of order, unaware of each other. - - // This method evidences the actual driver behavior when running - // concurrency, regardless of the physical limitations of the hardware. - - // Likewise, this method helps us identify warp sizes when the SM - // sub-divides its ALUs into independent groups, like the three execution - // engines in a Mali G76 core. It helps warp-probing in Case 1 because it - // doesn't depend on kernel timing, so the extra wait time doesn't lead to - // inaccuracy. - auto bench_sm = [&](uint32_t nthread) { - StorageBuffer out_buf(context(), vkapi::kInt, nthread_logic_); - vkapi::PipelineBarrier pipeline_barrier{}; - - auto shader_name = "warp_size_scheduler"; - - benchmark_on_gpu(shader_name, 1, [&]() { - context()->submit_compute_job( - VK_KERNEL_FROM_STR(shader_name), - pipeline_barrier, - {nthread, 1, 1}, - {nthread, 1, 1}, - {}, - VK_NULL_HANDLE, - 0, - out_buf.buffer()); - }); - - std::vector data(nthread_logic_); - copy_staging_to_ptr(out_buf, data.data(), out_buf.nbytes()); - - if (verbose) { - std::stringstream ss; - for (auto j = 0; j < nthread; ++j) { - ss << data[j] << " "; - } - std::cout << ss.str() << std::endl; - } - - // Check until which point is the data in ascending order. - int32_t last = -1; - int32_t j = 0; - for (; j < nthread; ++j) { - if (last >= data[j]) { - break; - } - last = data[j]; - } - - return j; - }; - - // Test increasing sizes until the data is no longer in ascending order. - uint32_t warp_size_scheduler = warp_size; - int i = 1; - for (; i <= nthread_logic_; ++i) { - uint32_t nascend = bench_sm(i); - if (nascend != i) { - warp_size_scheduler = nascend; - break; - } - } - if (i > nthread_logic_) { - std::cout << "Unable to conclude an SM Warp Size." << std::endl; - } - - std::cout << "PhysicalWarpSize," << warp_size << std::endl; - std::cout << "SMWarpSize," << warp_size_scheduler << std::endl; - } -}; - -int main(int argc, const char** argv) { - App app; - - // TODO: Allow user to skip tests - app.reg_count(); - app.buf_cacheline_size(); - app.buf_bandwidth(); - app.ubo_bandwidth(); - app.shared_mem_bandwidth(); - app.warp_size(); - - return 0; -} diff --git a/backends/vulkan/tools/gpuinfo/src/main.cpp b/backends/vulkan/tools/gpuinfo/src/main.cpp new file mode 100644 index 0000000000..f0e29aaf1a --- /dev/null +++ b/backends/vulkan/tools/gpuinfo/src/main.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "app.h" +#include "architecture.h" +#include "buffers.h" +#include "textures.h" + +using namespace vkapi; + +int main(int argc, const char** argv) { + gpuinfo::App app; + + std::string file_path = "config.json"; + if (argc > 1) { + file_path = argv[1]; + }; + app.load_config(file_path); + + // Architecture + gpuinfo::reg_count(app); + gpuinfo::warp_size(app); + + // Buffers + gpuinfo::buf_cacheline_size(app); + gpuinfo::buf_bandwidth(app); + gpuinfo::ubo_bandwidth(app); + gpuinfo::shared_mem_bandwidth(app); + + // Textures + gpuinfo::tex_bandwidth(app); + gpuinfo::tex_cacheline_concurr(app); + + return 0; +} diff --git a/backends/xnnpack/CMakeLists.txt b/backends/xnnpack/CMakeLists.txt index 1ac7867f3c..b0b80d8633 100644 --- a/backends/xnnpack/CMakeLists.txt +++ b/backends/xnnpack/CMakeLists.txt @@ -37,21 +37,10 @@ set(_common_compile_options -Wno-deprecated-declarations -fPIC) set(_xnnpack_schema__include_dir "${CMAKE_BINARY_DIR}/schema/include") # Paths to headers generated from the .fbs files. -set(_xnnpack_flatbuffer__outputs) -foreach(fbs_file ${_xnnpack_schema__srcs}) - string(REGEX REPLACE "([^/]+)[.]fbs$" "\\1_generated.h" generated - "${fbs_file}" - ) - list(APPEND _xnnpack_flatbuffer__outputs - "${_xnnpack_schema__include_dir}/executorch/${generated}" - ) -endforeach() - set(_xnnpack_schema__outputs) foreach(fbs_file ${_xnnpack_schema__srcs}) - string(REGEX REPLACE "runtime_([^/]+)[.]fbs$" "\\1_generated.h" generated - "${fbs_file}" - ) + string(REGEX REPLACE "([^/]+)[.]fbs$" "\\1_generated.h" + generated "${fbs_file}") list(APPEND _xnnpack_schema__outputs "${_xnnpack_schema__include_dir}/executorch/${generated}" ) @@ -64,7 +53,6 @@ add_custom_command( ${FLATC_EXECUTABLE} --cpp --cpp-std c++11 --scoped-enums -o "${_xnnpack_schema__include_dir}/executorch/backends/xnnpack/serialization" ${_xnnpack_schema__srcs} - COMMAND mv ${_xnnpack_flatbuffer__outputs} ${_xnnpack_schema__outputs} WORKING_DIRECTORY ${EXECUTORCH_ROOT} COMMENT "Generating xnnpack_schema headers" VERBATIM diff --git a/backends/xnnpack/partition/graphs/bilinear_2d.py b/backends/xnnpack/partition/graphs/bilinear_2d.py index a971cb9244..0040439f84 100644 --- a/backends/xnnpack/partition/graphs/bilinear_2d.py +++ b/backends/xnnpack/partition/graphs/bilinear_2d.py @@ -37,12 +37,15 @@ def forward(self, x): ] for align_corners in [True, False]: for config in capture_configs: - edge = exir.capture( - bilinear2d(align_corners), sample_inputs, config - ).to_edge( - config=get_xnnpack_edge_compile_config(), - ) - _bilinear2d_graphs[edge.exported_program.graph_module] = align_corners + for skip_dim_order_flag in [True, False]: + edge = exir.capture( + bilinear2d(align_corners), sample_inputs, config + ).to_edge( + config=get_xnnpack_edge_compile_config( + skip_dim_order=skip_dim_order_flag + ) + ) + _bilinear2d_graphs[edge.exported_program.graph_module] = align_corners return _bilinear2d_graphs diff --git a/backends/xnnpack/passes/__init__.py b/backends/xnnpack/passes/__init__.py index 1ca4fe307f..c3a85e4aa8 100644 --- a/backends/xnnpack/passes/__init__.py +++ b/backends/xnnpack/passes/__init__.py @@ -27,6 +27,7 @@ from executorch.exir.pass_base import ExportPass from executorch.exir.passes.const_prop_pass import ConstPropPass +from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass from executorch.exir.program._program import _transform from torch._export.pass_base import PassType @@ -50,6 +51,8 @@ def __init__( if not passes: # All the XNNPACK passes self.passes = [ + # TODO - remove this pass once we have a better support for dim_order ops lowering + DimOrderOpsRevertPass, ConvertToUpsampleBilinear2d, ConvertToLinearPass, ConvertToSDPAPass, diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 8c8db60065..722dabdfbe 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -124,17 +124,9 @@ const uint8_t* getConstantDataPtr( const uint8_t* constant_data_ptr) { auto buffer_idx = tensor_value->constant_buffer_idx(); if (buffer_idx) { - if (!constant_data_ptr) { - // TODO(T172265611): Remove constant_buffer in flatbuffer path after BC - // window - const auto& constant_buffer = *flatbuffer_graph->constant_buffer(); - return constant_buffer[buffer_idx]->storage()->data(); - } else { - const auto& constant_data_offsets = *flatbuffer_graph->constant_data(); - uint64_t constant_data_offset = - constant_data_offsets[buffer_idx]->offset(); - return constant_data_ptr + constant_data_offset; - } + const auto& constant_data_offsets = *flatbuffer_graph->constant_data(); + uint64_t constant_data_offset = constant_data_offsets[buffer_idx]->offset(); + return constant_data_ptr + constant_data_offset; } return nullptr; @@ -194,105 +186,29 @@ Error defineTensor( xnn_status status; // The type we might have to convert to - auto dq_datatype = getDataType(tensor_value->dq_datatype()); - - if (dq_datatype != xnn_datatype::xnn_datatype_invalid) { - if (dq_datatype != xnn_datatype::xnn_datatype_qint8) { - ET_CHECK_OR_RETURN_ERROR( - false, - Internal, - "Only int8_t is supported for dq_datatype for now, got: %d", - dq_datatype); - } else { - ET_CHECK_OR_RETURN_ERROR( - (tensor_value->flags() & XNN_VALUE_FLAG_EXTERNAL_INPUT), - Internal, - "Dynamic quantization of tensor is only allowed for the external input tensor value for now! got flags: %u", - tensor_value->flags()); - } - } + auto datatype = getDataType(tensor_value->datatype()); if (qtensor_value == nullptr) { // FP32 tensor - if (!isQuantizedDataType(dq_datatype)) { - // Define non-quantied tensor - status = xnn_define_tensor_value( - /*subgraph=*/subgraph_ptr, - /*datatype=*/getDataType(tensor_value->datatype()), - /*num_dims=*/tensor_value->num_dims(), - /*dims=*/dims_data.data(), - /*data=*/buffer_ptr, - /*external_id=*/tensor_value->external_id(), - /*flags=*/tensor_value->flags(), - /*id_out=*/&id); - } else if (dq_datatype != xnn_datatype::xnn_datatype_invalid) { - ET_CHECK_OR_RETURN_ERROR( - isQuantizedDataType(dq_datatype), - Internal, - "Dynamic quantization can only produce supported quantized dtypes"); - ET_CHECK_OR_RETURN_ERROR( - tensor_value->external_id() != XNN_INVALID_VALUE_ID, - Internal, - "Dynamic quantization can only work with external inputs for now, got an internal ID"); - ET_CHECK_OR_RETURN_ERROR( - buffer_ptr == nullptr, - Internal, - "Dynamic quantization can only work with external inputs for now, got const data"); - - switch (dq_datatype) { - case xnn_datatype::xnn_datatype_qint8: { - // HACK TO Maintain FC/BC for ASR this will be removed after 01/2024 - - // When encountering a dynamically quantized tensor via dq_datatype, - // which is the old flow for serializing dynamically quantized linear. - // We replace the definition of a single tensor with a new dynamic - // Quantization pattern. We change the pattern from: - // serialized_qd_input - // to - // (fp32_input --> convert --> qdint8_input) - - status = xnn_define_dynamically_quantized_tensor_value( - /*subgraph=*/subgraph_ptr, - /*datatype=*/xnn_datatype_qdint8, - /*num_dims=*/tensor_value->num_dims(), - /*num_nonbatch_dims=*/1, // always do per token quantization - /*dims=*/dims_data.data(), - /*external_id=*/XNN_INVALID_VALUE_ID, // always internal value id - /*flags=*/0, // this is netiher external input or output - /*id_out=*/&id); - - // this is the FP16 or FP32 external value that is being dynamically - // quantized - uint32_t float_id; - enum xnn_datatype fp_datatype = getDataType(tensor_value->datatype()); - status = xnn_define_tensor_value( - /*subgraph=*/subgraph_ptr, - /*datatype=*/fp_datatype, - /*num_dims=*/tensor_value->num_dims(), - /*dims=*/dims_data.data(), - /*data=*/buffer_ptr, - /*external_id=*/tensor_value->external_id(), - /*flags=*/tensor_value->flags(), - /*id_out=*/&float_id); - - // Define dynamic conversion from float to qdint8 - status = xnn_define_convert( - /*subgraph=*/subgraph_ptr, - /*input_id=*/float_id, - /*output_id=*/id, - /*flags=*/0); - break; - } - default: - ET_CHECK_OR_RETURN_ERROR( - false, - NotImplemented, - "Unhandled Dyanmic Quantization dtype: %d", - dq_datatype); - } - } else { - ET_CHECK_OR_RETURN_ERROR(false, NotImplemented, "Unhandled fp32 tensor"); - } + ET_CHECK_OR_RETURN_ERROR( + !isQuantizedDataType(datatype), + Internal, + "xnn_datatype is quantized, but is not quantized tensor value"); + + status = xnn_define_tensor_value( + /*subgraph=*/subgraph_ptr, + /*datatype=*/datatype, + /*num_dims=*/tensor_value->num_dims(), + /*dims=*/dims_data.data(), + /*data=*/buffer_ptr, + /*external_id=*/tensor_value->external_id(), + /*flags=*/tensor_value->flags(), + /*id_out=*/&id); + ET_CHECK_OR_RETURN_ERROR( + xnn_status_success == status, + Internal, + "Failed to define tensor with id %i", + id); } else { // define tensor for quantized switch (qtensor_value->quant_params_type()) { @@ -306,7 +222,7 @@ Error defineTensor( qparams->zero_point()); status = xnn_define_quantized_tensor_value( /*subgraph=*/subgraph_ptr, - /*datatype=*/getDataType(tensor_value->datatype()), + /*datatype=*/datatype, /*zero_point=*/qparams->zero_point(), /*scale=*/qparams->scale(), /*num_dims=*/tensor_value->num_dims(), @@ -319,9 +235,8 @@ Error defineTensor( } case fb_xnnpack::XNNQuantParams::PerChannelQuant: { auto qparams = qtensor_value->quant_params_as_PerChannelQuant(); - enum xnn_datatype dtype = getDataType(tensor_value->datatype()); int32_t zero_point = - (dtype == xnn_datatype::xnn_datatype_qcint4 ? 8 : 0); + (datatype == xnn_datatype::xnn_datatype_qcint4 ? 8 : 0); ET_LOG( Debug, @@ -329,11 +244,11 @@ Error defineTensor( buffer_ptr, qparams->scale()->size(), qparams->channel_dim(), - dtype, + datatype, zero_point); status = xnn_define_channelwise_quantized_tensor_value_v2( /*subgraph=*/subgraph_ptr, - /*datatype=*/dtype, + /*datatype=*/datatype, /*zero_point=*/zero_point, /*scale=*/qparams->scale()->data(), /*num_dims=*/tensor_value->num_dims(), @@ -346,7 +261,6 @@ Error defineTensor( break; } case fb_xnnpack::XNNQuantParams::PerChannelGroupQuant: { - xnn_datatype datatype = getDataType(tensor_value->datatype()); ET_CHECK_OR_RETURN_ERROR( datatype == xnn_datatype::xnn_datatype_qbint4, Internal, @@ -410,7 +324,7 @@ Error defineTensor( "Dynamically Quantized Tensors currently only support per token quantization"); status = xnn_define_dynamically_quantized_tensor_value( /*subgraph=*/subgraph_ptr, - /*datatype=*/getDataType(tensor_value->datatype()), + /*datatype=*/datatype, /*num_dims=*/tensor_value->num_dims(), /*num_nonbatch_dims*/ qparams->num_nonbatch_dims(), /*dims=*/dims_data.data(), @@ -1594,23 +1508,24 @@ __ET_NODISCARD Error XNNCompiler::compileModel( constant_data = reinterpret_cast(buffer_pointer) + header->constant_data_offset; } else if (header.error() == Error::NotFound) { - flatbuffer_data = reinterpret_cast(buffer_pointer); + ET_LOG( + Error, + "XNNHeader version mismatch: '%.4s' != expected '%.4s'", + // Header Magic and FlatbufferIdentifier are same offset and size + flatbuffers::GetBufferIdentifier(buffer_pointer), + XNNHeader::kMagic); + return header.error(); } else { ET_LOG(Error, "XNNHeader may be corrupt"); return header.error(); } - // Temporarily support identifier XN00 and XN01 - bool is_supported_version = - strncmp(flatbuffers::GetBufferIdentifier(flatbuffer_data), "XN00", 4) == - 0 || - strncmp(flatbuffers::GetBufferIdentifier(flatbuffer_data), "XN01", 4) == - 0; ET_CHECK_OR_RETURN_ERROR( - is_supported_version, + fb_xnnpack::XNNGraphBufferHasIdentifier(flatbuffer_data), DelegateInvalidCompatibility, - "XNNPACK Delegate Serialization Format version identifier '%.4s' != expected XN00 or XN01'", - flatbuffers::GetBufferIdentifier(flatbuffer_data)); + "XNNPACK Delegate flatbuffer version mismatch: '%.4s' != expected '%.4s'", + flatbuffers::GetBufferIdentifier(flatbuffer_data), + fb_xnnpack::XNNGraphIdentifier()); auto flatbuffer_graph = fb_xnnpack::GetXNNGraph(flatbuffer_data); // initialize xnnpack diff --git a/backends/xnnpack/serialization/runtime_schema.fbs b/backends/xnnpack/serialization/runtime_schema.fbs deleted file mode 100644 index 5ace211149..0000000000 --- a/backends/xnnpack/serialization/runtime_schema.fbs +++ /dev/null @@ -1,354 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. - -namespace fb_xnnpack; - -// Update after any BC breaking changes -file_identifier "XN00"; - -// datatype for xnn-values -enum XNNDatatype : short { - /// Invalid data type. Valid Values never have this datatype. - xnn_datatype_invalid = 0, - /// IEEE754 single-precision floating-point. - xnn_datatype_fp32 = 1, - /// IEEE754 half-precision floating-point. - xnn_datatype_fp16 = 2, - /// Quantized 8-bit signed integer with shared per-Value quantization parameters. - xnn_datatype_qint8 = 3, - /// Quantized 8-bit unsigned integer with shared per-Value quantization parameters. - xnn_datatype_quint8 = 4, - /// Quantized 32-bit signed integer with shared per-Value quantization parameters. - xnn_datatype_qint32 = 5, - /// Quantized 8-bit signed integer with shared per-channel quantization parameters. - xnn_datatype_qcint8 = 6, - /// Quantized 32-bit signed integer with shared per-channel quantization parameters. - xnn_datatype_qcint32 = 7, - /// Quantized 4-bit signed integer with shared per-channel quantization parameters. - xnn_datatype_qcint4 = 8, - /// Dynamically quantized 8-bit signed integer with per-batch quantization parameters. - xnn_datatype_qdint8 = 9, - /// Quantized 4-bit signed integer with shared blockwise quantization parameters. - xnn_datatype_qbint4 = 10, -} - -// type of quantization -union XNNQuantParams { - PerChannelQuant, - PerTensorQuant, - PerTokenDynamicQuant, - PerChannelGroupQuant, -} - -// taken from executorch -// Data buffer abstraction. -table Buffer { - storage:[ubyte] (force_align: 16); -} - -table PerChannelQuant { - scale:[float]; - channel_dim:int; -} - -table PerTokenDynamicQuant { - num_nonbatch_dims:int; -} - -table PerTensorQuant { - scale:float; - zero_point:int; -} - -table PerChannelGroupQuant { - scale:[float]; - channel_dim:int; - group_size:int; -} - -table XNNTensorValue { - // type of the tensor elements. - datatype:XNNDatatype; - // number of dimensions in the shape. - num_dims:uint; - // pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL. - // XNNPACK does not keep any pointers to this array after the function returns. - dims:[uint]; - // Index to the program's constant buffer table, value 0 is reserved to indicate non constant - constant_buffer_idx:uint; - // external ID for the Value. The ID must be within the range of reserved Value IDs specified on - // the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be - // created for the Value. - external_id:uint; - // binary features of the Value. Supported values are any combination of XNN_VALUE_FLAG_EXTERNAL_INPUT - // and XNN_VALUE_FLAG_EXTERNAL_OUTPUT. - flags:uint; - // pointer to the variable that will be initialized with the Value ID upon successful return. If a - // valid @a external_id was provided, the variable will be initialized with the @a external_id value. - id_out:uint; - // does this value need to be quantized dynamically at runtime? - // if we are quantizing at runtime, this field points to a target dtype - dq_datatype:XNNDatatype = xnn_datatype_invalid; -} - -table XNNQuantizedTensorValue { - // Base Tensor Value - tensor_value:XNNTensorValue; - // Quantization parameters - quant_params:XNNQuantParams; -} - -union XNodeUnion { - XNNAdd: _XNNNode2x1, - XNNFullyConnected, - XNNSoftmax: _XNNNode1x1, - XNNSigmoid: _XNNNode1x1, - XNNStaticTranspose, - XNNClamp: _XNNNode1x1, - XNNConv2d: _XNNNodeConv, - XNNDiv: _XNNNode2x1, - XNNStaticResizeBilinear2D, - XNNStaticConstantPad, - XNNAvgPooling2d: _XNNPooling2D, - XNNMinimum: _XNNNode2x1, - XNNDepthwiseConv2d: _XNNNodeConv, - XNNMaxPooling2d: _XNNPooling2D, - XNNMultiply: _XNNNode2x1, - XNNSubtract: _XNNNode2x1, - XNNFloor: _XNNNode1x1, - XNNConvert: _XNNNode1x1, - XNNGlobalAvgPooling2d: _XNNNode1x1, - XNNStaticReshape, - XNNArgMaxPooling2d, - XNNSquareRoot: _XNNNode1x1, - XNNCeiling: _XNNNode1x1, - XNNHardswish: _XNNNode1x1, - XNNLeakyReLU, - XNNMaximum: _XNNNode2x1, - XNNNegate: _XNNNode1x1, - XNNSquare: _XNNNode1x1, - XNNELU, - XNNAbs: _XNNNode1x1, - XNNPReLU: _XNNNode2x1, - XNNConcatenate2: _XNNCat, - XNNConcatenate3: _XNNCat, - XNNConcatenate4: _XNNCat, - XNNStaticSlice, - XNNScaledDotProductAttention, -} - -union XValueUnion { - XNNTensorValue, - XNNQuantizedTensorValue, -} - -table OutputMinMax { - output_min:float; - output_max:float; -} - -table XNode { - xnode_union:XNodeUnion; - // An int which can be linked back to the node in the origin graph - debug_handle:uint; - output_min_max:OutputMinMax; -} - -table XValue { - xvalue_union:XValueUnion; -} - -table XNNStaticTranspose { - num_dims:uint; - perm:[uint]; - input_id:uint; - output_id:uint; - flags:uint; -} - -table XNNStaticResizeBilinear2D { - new_height:uint; - new_width:uint; - input_id:uint; - output_id:uint; - flags:uint; -} - -table XNNStaticConstantPad { - pre_paddings:[uint]; - post_paddings:[uint]; - padding_value:float; - input_id:uint; - output_id:uint; - flags:uint; -} - -// A node with two input and one output -// Not meant to be used directly -table _XNNNode2x1 { - input1_id:uint; - input2_id:uint; - output_id:uint; - flags:uint; -} - -// A node with one input and one output -// Not meant to be used directly -table _XNNNode1x1 { - input_id:uint; - output_id:uint; - flags:uint; -} - -table _XNNCat { - axis: uint; - input1_id: uint; - input2_id: uint; - input3_id: uint; - input4_id: uint; - output_id: uint; - flags: uint; -} - -table XNNELU { - alpha:float; - input_id:uint; - output_id:uint; - flags:uint; -} - -table XNNFullyConnected { - input1_id:uint; - filter_id:uint; - bias_id:uint; - output_id:uint; - flags:uint; -} - -table _XNNNodeConv { - padding_top:uint; - padding_right:uint; - padding_bottom:uint; - padding_left:uint; - kernel_height:uint; - kernel_width:uint; - subsampling_height:uint; - subsampling_width:uint; - dilation_height:uint; - dilation_width:uint; - group_input_channels:uint; - group_output_channels:uint; - groups:uint; - adjustment_height:uint; - adjustment_width:uint; - input1_id:uint; - filter_id:uint; - bias_id:uint; - output_id:uint; - flags:uint; -} - -table _XNNPooling2D { - padding_top: uint; - padding_right: uint; - padding_bottom: uint; - padding_left: uint; - pooling_height: uint; - pooling_width: uint; - stride_height: uint; - stride_width: uint; - dilation_height: uint; - dilation_width: uint; - input_id: uint; - output_id: uint; - flags: uint; -} - -table XNNStaticReshape { - num_dims:uint; - new_shape:[uint]; - input_id: uint; - output_id: uint; - flags: uint; -} - -table XNNStaticSlice { - num_dims:uint; - offsets:[uint]; - sizes:[uint]; - input_id:uint; - output_id:uint; - flags:uint; -} - -table XNNScaledDotProductAttention { - query_id:uint; - key_id:uint; - value_id:uint; - scale_id:uint; - mask_id:uint; - output_id:uint; - flags:uint; -} - -table XNNArgMaxPooling2d { - padding_top: uint; - padding_right: uint; - padding_bottom: uint; - padding_left: uint; - pooling_height: uint; - pooling_width: uint; - input_id: uint; - output_value_id: uint; - output_index_id: uint; - flags: uint; -} - -table XNNLeakyReLU { - negative_slope: float; - input_id: uint; - output_id: uint; - flags: uint; -} - -// Describes data offsets for constant data -table ConstantDataOffset { - // Constant data offsets are relative to the constant data base offset provided - // in the XNNPACKHeader. - offset: uint64; - - // The size in bytes of valid data starting at the offset. The constant data - // may be followed by padding before the next piece of constant data - size: uint64; -} - -table XNNGraph { - // Schema version. - version:string; - xnodes:[XNode]; - xvalues:[XValue]; - - // Number of external inputs/outputs - num_externs:uint; - - // Ids of external inputs - input_ids:[uint]; - - // Ids of external outputs - output_ids:[uint]; - - // Tables of constant data, used for constant Values (e.g. - // data field of weight tensors). Each constant is assigned an index into the table - // which are each individually aligned. 0 index is reserved to be pointed to by non-constant - // Tensors. Exactly one of constant_buffer and constant_data must be non-empty - constant_buffer:[Buffer]; - - // the list index is memory buffer id, the value is the memory buffer size. - mem_buffer_sizes: [uint]; - - // List of the constant data that follows the XNNGraph in this file. Each constant data is assigned an index into - // the table. 0 index is reserved to be pointed to by non-constant Tensor. Exactly one of constant_buffer and - // constant_data must be non-empty - constant_data:[ConstantDataOffset]; -} - -root_type XNNGraph; diff --git a/examples/models/llama2/custom_ops/__init__.py b/backends/xnnpack/serialization/schema_version_history.txt similarity index 100% rename from examples/models/llama2/custom_ops/__init__.py rename to backends/xnnpack/serialization/schema_version_history.txt diff --git a/backends/xnnpack/serialization/targets.bzl b/backends/xnnpack/serialization/targets.bzl index 5eeab3de2b..05e6b89f12 100644 --- a/backends/xnnpack/serialization/targets.bzl +++ b/backends/xnnpack/serialization/targets.bzl @@ -4,14 +4,14 @@ def define_common_targets(): runtime.genrule( name = "gen_xnnpack_schema", srcs = [ - "runtime_schema.fbs", + "schema.fbs", ], # We're only generating a single file, so it seems like we could use # `out`, but `flatc` takes a directory as a parameter, not a single # file. Use `outs` so that `${OUT}` is expanded as the containing # directory instead of the file itself. outs = { - "schema_generated.h": ["runtime_schema_generated.h"], + "schema_generated.h": ["schema_generated.h"], }, cmd = " ".join([ "$(exe {})".format(runtime.external_dep_location("flatc")), diff --git a/backends/xnnpack/test/ops/bilinear2d.py b/backends/xnnpack/test/ops/bilinear2d.py index ab9d3d3c11..d3c8535069 100644 --- a/backends/xnnpack/test/ops/bilinear2d.py +++ b/backends/xnnpack/test/ops/bilinear2d.py @@ -65,12 +65,15 @@ def forward(self, x): ) return a + # Since we may or may not enable dim order, use these ops only for + # check_not since we have `to_copy` and `to_dim_order_copy` in the list. ops = { "executorch_exir_dialects_edge__ops_aten_sub_Tensor", "executorch_exir_dialects_edge__ops_aten_mul_Tensor", "executorch_exir_dialects_edge__ops_aten_index_Tensor", "executorch_exir_dialects_edge__ops_aten_arange_start_step", "executorch_exir_dialects_edge__ops_aten__to_copy_default", + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default", "executorch_exir_dialects_edge__ops_aten_add_Tensor", "executorch_exir_dialects_edge__ops_aten_clamp_default", } @@ -81,7 +84,6 @@ def test_fp32_static_resize_bilinear2d(self): Tester(self.StaticResizeBilinear2dModule(), example_inputs) .export() .to_edge() - .check(self.ops) .partition() .check_not(self.ops) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) @@ -90,13 +92,12 @@ def test_fp32_static_resize_bilinear2d(self): .run_method_and_compare_outputs() ) - def test_fp32_static_resize_bilinear2d_with_align_cornesr(self): + def test_fp32_static_resize_bilinear2d_with_align_corners(self): example_inputs = (torch.randn(2, 3, 4, 5),) ( Tester(self.StaticResizeBilinear2dModuleWithAlignCorners(), example_inputs) .export() .to_edge() - .check(self.ops) .partition() .check_not(self.ops) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) diff --git a/backends/xnnpack/utils/configs.py b/backends/xnnpack/utils/configs.py index 3fe290606c..9dda84c5e5 100644 --- a/backends/xnnpack/utils/configs.py +++ b/backends/xnnpack/utils/configs.py @@ -12,8 +12,12 @@ ### XNNPACK Configs ### -def get_xnnpack_edge_compile_config() -> exir.EdgeCompileConfig: - return exir.EdgeCompileConfig(_check_ir_validity=False, _skip_dim_order=True) +def get_xnnpack_edge_compile_config( + skip_dim_order: bool = True, +) -> exir.EdgeCompileConfig: + return exir.EdgeCompileConfig( + _check_ir_validity=False, _skip_dim_order=skip_dim_order + ) def get_transform_passes(additional_passes=None) -> List[PassType]: diff --git a/build/cmake_deps.toml b/build/cmake_deps.toml index 80abd46409..2a94a32fa4 100644 --- a/build/cmake_deps.toml +++ b/build/cmake_deps.toml @@ -282,7 +282,7 @@ filters = [ # ---------------------------------- LLama start ---------------------------------- [targets.custom_ops] buck_targets = [ - "//examples/models/llama2/custom_ops:custom_ops", + "//extension/llm/custom_ops:custom_ops", ] filters = [ ".cpp$", diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 916a766f7c..f854a081fa 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -11,7 +11,8 @@ import logging import torch -from executorch.backends.arm.arm_backend import generate_ethosu_compile_spec + +from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder from executorch.backends.arm.arm_partitioner import ArmPartitioner from executorch.backends.arm.quantizer.arm_quantizer import ( ArmQuantizer, @@ -212,12 +213,13 @@ def forward(self, x): if args.delegate is True: edge = edge.to_backend( ArmPartitioner( - generate_ethosu_compile_spec( - "ethos-u55-128", - permute_memory_to_nhwc=args.model_name - in MODEL_NAME_TO_MODEL.keys(), - quantize_io=True, + ArmCompileSpecBuilder() + .ethosu_compile_spec("ethos-u55-128") + .set_permute_memory_format( + args.model_name in MODEL_NAME_TO_MODEL.keys() ) + .set_quantize_io(True) + .build() ) ) logging.debug(f"Lowered graph:\n{edge.exported_program().graph}") diff --git a/examples/cadence/models/babyllama.py b/examples/cadence/models/babyllama.py new file mode 100644 index 0000000000..603eb5f3d9 --- /dev/null +++ b/examples/cadence/models/babyllama.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Example script for exporting simple models to flatbuffer + +import logging + +from executorch.backends.cadence.aot.ops_registrations import * # noqa + +import torch + +from executorch.backends.cadence.aot.export_example import export_model + +from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer + + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +def main() -> None: + args = ModelArgs( + dim=512, + vocab_size=512, + hidden_dim=1024, + n_heads=8, + # use_kv_cache=True, + n_layers=1, + ) + seq = 64 + b = 1 + model = Transformer(args) + example_inputs = (torch.randint(0, 10, [b, seq], dtype=torch.int64),) + + export_model(model, example_inputs) + + +if __name__ == "__main__": + main() diff --git a/examples/cadence/models/wav2vec2.py b/examples/cadence/models/wav2vec2.py new file mode 100644 index 0000000000..5db9ea2a6d --- /dev/null +++ b/examples/cadence/models/wav2vec2.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Example script for exporting simple models to flatbuffer + +import logging + +from executorch.backends.cadence.aot.ops_registrations import * # noqa + +import torch + +from executorch.backends.cadence.aot.export_example import export_model +from torchaudio.models.wav2vec2.model import wav2vec2_model, Wav2Vec2Model + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +def main() -> None: + # The wrapper is needed to avoid issues with the optional second arguments + # of Wav2Vec2Models. + class Wav2Vec2ModelWrapper(torch.nn.Module): + def __init__(self, model: Wav2Vec2Model): + super().__init__() + self.model = model + + def forward(self, x): + out, _ = self.model(x) + return out + + _model = wav2vec2_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=768, + encoder_projection_dropout=0.1, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=12, + encoder_num_heads=12, + encoder_attention_dropout=0.1, + encoder_ff_interm_features=3072, + encoder_ff_interm_dropout=0.0, + encoder_dropout=0.1, + encoder_layer_norm_first=False, + encoder_layer_drop=0.1, + aux_num_out=None, + ) + _model.eval() + + model = Wav2Vec2ModelWrapper(_model) + model.eval() + + # test input + audio_len = 1680 + example_inputs = (torch.rand(1, audio_len),) + + export_model(model, example_inputs) + + +if __name__ == "__main__": + main() diff --git a/examples/demo-apps/android/LlamaDemo/app/build.gradle.kts b/examples/demo-apps/android/LlamaDemo/app/build.gradle.kts index 3c168689f7..37c8cbf0ba 100644 --- a/examples/demo-apps/android/LlamaDemo/app/build.gradle.kts +++ b/examples/demo-apps/android/LlamaDemo/app/build.gradle.kts @@ -17,7 +17,7 @@ android { defaultConfig { applicationId = "com.example.executorchllamademo" - minSdk = 24 + minSdk = 28 targetSdk = 33 versionCode = 1 versionName = "1.0" @@ -56,7 +56,10 @@ dependencies { implementation("androidx.camera:camera-core:1.3.0-rc02") implementation("androidx.constraintlayout:constraintlayout:2.2.0-alpha12") implementation("com.facebook.fbjni:fbjni:0.5.1") + implementation("com.google.code.gson:gson:2.8.6") implementation(files("libs/executorch-llama.aar")) + implementation("com.google.android.material:material:1.12.0") + implementation("androidx.activity:activity:1.9.0") testImplementation("junit:junit:4.13.2") androidTestImplementation("androidx.test.ext:junit:1.1.5") androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1") diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/AndroidManifest.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/AndroidManifest.xml index 3eaf301b5a..bb231420df 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/AndroidManifest.xml +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/AndroidManifest.xml @@ -3,32 +3,44 @@ xmlns:tools="http://schemas.android.com/tools" package="com.example.executorchllamademo"> - + + + + + + + - + + android:theme="@style/Theme.AppCompat.Light.NoActionBar"> diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/AppLog.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/AppLog.java new file mode 100644 index 0000000000..36d0741938 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/AppLog.java @@ -0,0 +1,49 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import java.text.SimpleDateFormat; +import java.util.Date; +import java.util.Locale; + +public class AppLog { + private final Long timestamp; + private final String message; + + public AppLog(String message) { + this.timestamp = getCurrentTimeStamp(); + this.message = message; + } + + public Long getTimestamp() { + return timestamp; + } + + public String getMessage() { + return message; + } + + public String getFormattedLog() { + return "[" + getFormattedTimeStamp() + "] " + message; + } + + private Long getCurrentTimeStamp() { + return System.currentTimeMillis(); + } + + private String getFormattedTimeStamp() { + return formatDate(timestamp); + } + + private String formatDate(long milliseconds) { + SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.getDefault()); + Date date = new Date(milliseconds); + return formatter.format(date); + } +} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/DemoSharedPreferences.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/DemoSharedPreferences.java new file mode 100644 index 0000000000..99a94c00eb --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/DemoSharedPreferences.java @@ -0,0 +1,90 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import android.content.Context; +import android.content.SharedPreferences; +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; +import java.lang.reflect.Type; +import java.util.ArrayList; + +public class DemoSharedPreferences { + Context context; + SharedPreferences sharedPreferences; + + public DemoSharedPreferences(Context context) { + this.context = context; + this.sharedPreferences = getSharedPrefs(); + } + + private SharedPreferences getSharedPrefs() { + return context.getSharedPreferences( + context.getString(R.string.demo_pref_file_key), Context.MODE_PRIVATE); + } + + public String getSavedMessages() { + return sharedPreferences.getString(context.getString(R.string.saved_messages_json_key), ""); + } + + public void addMessages(MessageAdapter messageAdapter) { + SharedPreferences.Editor editor = sharedPreferences.edit(); + Gson gson = new Gson(); + String msgJSON = gson.toJson(messageAdapter.getSavedMessages()); + editor.putString(context.getString(R.string.saved_messages_json_key), msgJSON); + editor.apply(); + } + + public void removeExistingMessages() { + SharedPreferences.Editor editor = sharedPreferences.edit(); + editor.remove(context.getString(R.string.saved_messages_json_key)); + editor.apply(); + } + + public void addSettings(SettingsFields settingsFields) { + SharedPreferences.Editor editor = sharedPreferences.edit(); + Gson gson = new Gson(); + String settingsJSON = gson.toJson(settingsFields); + editor.putString(context.getString(R.string.settings_json_key), settingsJSON); + editor.apply(); + } + + public String getSettings() { + return sharedPreferences.getString(context.getString(R.string.settings_json_key), ""); + } + + public void saveLogs() { + SharedPreferences.Editor editor = sharedPreferences.edit(); + Gson gson = new Gson(); + String msgJSON = gson.toJson(ETLogging.getInstance().getLogs()); + editor.putString(context.getString(R.string.logs_json_key), msgJSON); + editor.apply(); + } + + public void removeExistingLogs() { + SharedPreferences.Editor editor = sharedPreferences.edit(); + editor.remove(context.getString(R.string.logs_json_key)); + editor.apply(); + } + + public ArrayList getSavedLogs() { + String logsJSONString = + sharedPreferences.getString(context.getString(R.string.logs_json_key), null); + if (logsJSONString == null || logsJSONString.isEmpty()) { + return new ArrayList<>(); + } + Gson gson = new Gson(); + Type type = new TypeToken>() {}.getType(); + ArrayList appLogs = gson.fromJson(logsJSONString, type); + if (appLogs == null) { + return new ArrayList<>(); + } + return appLogs; + } +} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETImage.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETImage.java new file mode 100644 index 0000000000..cf3c3e5f0a --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETImage.java @@ -0,0 +1,116 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import android.content.ContentResolver; +import android.graphics.Bitmap; +import android.graphics.BitmapFactory; +import android.graphics.Color; +import android.net.Uri; +import androidx.annotation.Nullable; +import java.io.FileNotFoundException; +import java.io.InputStream; + +public class ETImage { + private int width; + private int height; + private final byte[] bytes; + private final Uri uri; + private final ContentResolver contentResolver; + + ETImage(ContentResolver contentResolver, Uri uri) { + this.contentResolver = contentResolver; + this.uri = uri; + bytes = getBytesFromImageURI(uri); + } + + public int getWidth() { + return width; + } + + public int getHeight() { + return height; + } + + public Uri getUri() { + return uri; + } + + public byte[] getBytes() { + return bytes; + } + + private byte[] getBytesFromImageURI(Uri uri) { + try { + int RESIZED_IMAGE_WIDTH = 336; + Bitmap bitmap = resizeImage(uri, RESIZED_IMAGE_WIDTH); + + if (bitmap == null) { + ETLogging.getInstance().log("Unable to get bytes from Image URI. Bitmap is null"); + return new byte[0]; + } + + width = bitmap.getWidth(); + height = bitmap.getHeight(); + + byte[] rgbValues = new byte[width * height * 3]; + + for (int y = 0; y < height; y++) { + for (int x = 0; x < width; x++) { + // Get the color of the current pixel + int color = bitmap.getPixel(x, y); + + // Extract the RGB values from the color + int red = Color.red(color); + int green = Color.green(color); + int blue = Color.blue(color); + + // Store the RGB values in the byte array + rgbValues[(y * width + x) * 3] = (byte) red; + rgbValues[(y * width + x) * 3 + 1] = (byte) green; + rgbValues[(y * width + x) * 3 + 2] = (byte) blue; + } + } + return rgbValues; + } catch (FileNotFoundException e) { + throw new RuntimeException(e); + } + } + + @Nullable + private Bitmap resizeImage(Uri uri, int maxLength) throws FileNotFoundException { + InputStream inputStream = contentResolver.openInputStream(uri); + if (inputStream == null) { + ETLogging.getInstance().log("Unable to resize image, input streams is null"); + return null; + } + Bitmap bitmap = BitmapFactory.decodeStream(inputStream); + if (bitmap == null) { + ETLogging.getInstance().log("Unable to resize image, bitmap during decode stream is null"); + return null; + } + + float aspectRatio; + int finalWidth, finalHeight; + + if (bitmap.getWidth() > bitmap.getHeight()) { + // width > height --> width = maxLength, height scale with aspect ratio + aspectRatio = bitmap.getWidth() / (float) bitmap.getHeight(); + finalWidth = maxLength; + finalHeight = Math.round(maxLength / aspectRatio); + } else { + // height >= width --> height = maxLength, width scale with aspect ratio + aspectRatio = bitmap.getHeight() / (float) bitmap.getWidth(); + finalHeight = maxLength; + finalWidth = Math.round(maxLength / aspectRatio); + } + + return Bitmap.createScaledBitmap(bitmap, finalWidth, finalHeight, false); + } +} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETLogging.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETLogging.java new file mode 100644 index 0000000000..e595348945 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETLogging.java @@ -0,0 +1,54 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import android.app.Application; +import android.util.Log; +import java.util.ArrayList; + +public class ETLogging extends Application { + private static ETLogging singleton; + + private ArrayList logs; + private DemoSharedPreferences mDemoSharedPreferences; + + @Override + public void onCreate() { + super.onCreate(); + singleton = this; + mDemoSharedPreferences = new DemoSharedPreferences(this.getApplicationContext()); + logs = mDemoSharedPreferences.getSavedLogs(); + if (logs == null) { // We don't have existing sharedPreference stored + logs = new ArrayList<>(); + } + } + + public static ETLogging getInstance() { + return singleton; + } + + public void log(String message) { + AppLog appLog = new AppLog(message); + logs.add(appLog); + Log.d("ETLogging", appLog.getMessage()); + } + + public ArrayList getLogs() { + return logs; + } + + public void clearLogs() { + logs.clear(); + mDemoSharedPreferences.removeExistingLogs(); + } + + public void saveLogs() { + mDemoSharedPreferences.saveLogs(); + } +} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LogsActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LogsActivity.java new file mode 100644 index 0000000000..8700528d44 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LogsActivity.java @@ -0,0 +1,86 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import android.app.AlertDialog; +import android.content.DialogInterface; +import android.os.Bundle; +import android.widget.ImageButton; +import android.widget.ListView; +import androidx.appcompat.app.AppCompatActivity; +import androidx.core.graphics.Insets; +import androidx.core.view.ViewCompat; +import androidx.core.view.WindowInsetsCompat; + +public class LogsActivity extends AppCompatActivity { + + private LogsAdapter mLogsAdapter; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_logs); + ViewCompat.setOnApplyWindowInsetsListener( + requireViewById(R.id.main), + (v, insets) -> { + Insets systemBars = insets.getInsets(WindowInsetsCompat.Type.systemBars()); + v.setPadding(systemBars.left, systemBars.top, systemBars.right, systemBars.bottom); + return insets; + }); + + setupLogs(); + setupClearLogsButton(); + } + + @Override + public void onResume() { + super.onResume(); + mLogsAdapter.clear(); + mLogsAdapter.addAll(ETLogging.getInstance().getLogs()); + mLogsAdapter.notifyDataSetChanged(); + } + + private void setupLogs() { + ListView mLogsListView = requireViewById(R.id.logsListView); + mLogsAdapter = new LogsAdapter(this, R.layout.logs_message); + + mLogsListView.setAdapter(mLogsAdapter); + mLogsAdapter.addAll(ETLogging.getInstance().getLogs()); + mLogsAdapter.notifyDataSetChanged(); + } + + private void setupClearLogsButton() { + ImageButton clearLogsButton = requireViewById(R.id.clearLogsButton); + clearLogsButton.setOnClickListener( + view -> { + new AlertDialog.Builder(this) + .setTitle("Delete Logs History") + .setMessage("Do you really want to delete logs history?") + .setIcon(android.R.drawable.ic_dialog_alert) + .setPositiveButton( + android.R.string.yes, + new DialogInterface.OnClickListener() { + public void onClick(DialogInterface dialog, int whichButton) { + // Clear the messageAdapter and sharedPreference + ETLogging.getInstance().clearLogs(); + mLogsAdapter.clear(); + mLogsAdapter.notifyDataSetChanged(); + } + }) + .setNegativeButton(android.R.string.no, null) + .show(); + }); + } + + @Override + protected void onDestroy() { + super.onDestroy(); + ETLogging.getInstance().saveLogs(); + } +} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LogsAdapter.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LogsAdapter.java new file mode 100644 index 0000000000..76c6a1aa1b --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LogsAdapter.java @@ -0,0 +1,45 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import android.view.LayoutInflater; +import android.view.View; +import android.view.ViewGroup; +import android.widget.ArrayAdapter; +import android.widget.TextView; +import androidx.annotation.NonNull; +import java.util.Objects; + +public class LogsAdapter extends ArrayAdapter { + public LogsAdapter(android.content.Context context, int resource) { + super(context, resource); + } + + static class ViewHolder { + private TextView logTextView; + } + + @NonNull + @Override + public View getView(int position, View convertView, @NonNull ViewGroup parent) { + ViewHolder mViewHolder = null; + + String logMessage = Objects.requireNonNull(getItem(position)).getFormattedLog(); + + if (convertView == null || convertView.getTag() == null) { + mViewHolder = new ViewHolder(); + convertView = LayoutInflater.from(getContext()).inflate(R.layout.logs_message, parent, false); + mViewHolder.logTextView = convertView.requireViewById(R.id.logsTextView); + } else { + mViewHolder = (ViewHolder) convertView.getTag(); + } + mViewHolder.logTextView.setText(logMessage); + return convertView; + } +} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java index 2c94c242ed..44d310231a 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -8,32 +8,72 @@ package com.example.executorchllamademo; -import android.app.Activity; +import android.Manifest; import android.app.ActivityManager; import android.app.AlertDialog; -import android.content.Context; +import android.content.ContentResolver; +import android.content.ContentValues; +import android.content.Intent; +import android.content.pm.PackageManager; +import android.net.Uri; import android.os.Bundle; +import android.os.Handler; +import android.os.Looper; +import android.provider.MediaStore; import android.system.ErrnoException; import android.system.Os; -import android.widget.Button; +import android.text.InputType; +import android.util.Log; +import android.view.View; import android.widget.EditText; import android.widget.ImageButton; +import android.widget.ImageView; +import android.widget.LinearLayout; import android.widget.ListView; -import java.io.File; +import android.widget.TextView; +import android.widget.Toast; +import androidx.activity.result.ActivityResultLauncher; +import androidx.activity.result.PickVisualMediaRequest; +import androidx.activity.result.contract.ActivityResultContracts; +import androidx.annotation.NonNull; +import androidx.appcompat.app.AppCompatActivity; +import androidx.constraintlayout.widget.ConstraintLayout; +import androidx.core.app.ActivityCompat; +import androidx.core.content.ContextCompat; +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.List; import org.pytorch.executorch.LlamaCallback; import org.pytorch.executorch.LlamaModule; -public class MainActivity extends Activity implements Runnable, LlamaCallback { +public class MainActivity extends AppCompatActivity implements Runnable, LlamaCallback { private EditText mEditTextMessage; - private Button mSendButton; - private ImageButton mModelButton; + private ImageButton mSendButton; + private ImageButton mGalleryButton; + private ImageButton mCameraButton; private ListView mMessagesView; private MessageAdapter mMessageAdapter; private LlamaModule mModule = null; private Message mResultMessage = null; - - private String mModelFilePath = ""; - private String mTokenizerFilePath = ""; + private ImageButton mSettingsButton; + private TextView mMemoryView; + private ActivityResultLauncher mPickGallery; + private ActivityResultLauncher mCameraRoll; + private List mSelectedImageUri; + private ConstraintLayout mMediaPreviewConstraintLayout; + private LinearLayout mAddMediaLayout; + private static final int MAX_NUM_OF_IMAGES = 5; + private static final int REQUEST_IMAGE_CAPTURE = 1; + private Uri cameraImageUri; + private DemoSharedPreferences mDemoSharedPreferences; + private SettingsFields mCurrentSettingsFields; + private Handler mMemoryUpdateHandler; + private Runnable memoryUpdater; + // UI Specific to user using INSTRUCT_MODE + private boolean INSTRUCT_MODE = false; + private String INSTRUCT_INSTRUCTION = "In Instruct Mode. Press SEND"; @Override public void onResult(String result) { @@ -52,23 +92,13 @@ public void onStats(float tps) { }); } - private static String[] listLocalFile(String path, String suffix) { - File directory = new File(path); - if (directory.exists() && directory.isDirectory()) { - File[] files = directory.listFiles((dir, name) -> name.toLowerCase().endsWith(suffix)); - String[] result = new String[files.length]; - for (int i = 0; i < files.length; i++) { - if (files[i].isFile() && files[i].getName().endsWith(suffix)) { - result[i] = files[i].getAbsolutePath(); - } - } - return result; + private void setLocalModel(String modelPath, String tokenizerPath, float temperature) { + if (mModule != null) { + mModule.resetNative(); + mModule = null; } - return new String[0]; - } - - private void setLocalModel(String modelPath, String tokenizerPath) { - Message modelLoadingMessage = new Message("Loading model...", false); + Message modelLoadingMessage = new Message("Loading model...", false, MessageType.SYSTEM, 0); + ETLogging.getInstance().log("Loading model " + modelPath + " with tokenizer " + tokenizerPath); runOnUiThread( () -> { mSendButton.setEnabled(false); @@ -76,9 +106,15 @@ private void setLocalModel(String modelPath, String tokenizerPath) { mMessageAdapter.notifyDataSetChanged(); }); long runStartTime = System.currentTimeMillis(); - mModule = new LlamaModule(modelPath, tokenizerPath, 0.8f); + mModule = new LlamaModule(modelPath, tokenizerPath, temperature); int loadResult = mModule.load(); + long loadDuration = System.currentTimeMillis() - runStartTime; + String modelLoadError = ""; + String modelInfo = ""; if (loadResult != 0) { + // TODO: Map the error code to a reason to let the user know why model loading failed + modelInfo = "*Model could not load (Error Code: " + loadResult + ")*" + "\n"; + loadDuration = 0; AlertDialog.Builder builder = new AlertDialog.Builder(this); builder.setTitle("Load failed: " + loadResult); runOnUiThread( @@ -86,18 +122,37 @@ private void setLocalModel(String modelPath, String tokenizerPath) { AlertDialog alert = builder.create(); alert.show(); }); + } else { + String[] segments = modelPath.split("/"); + String pteName = segments[segments.length - 1]; + segments = tokenizerPath.split("/"); + String tokenizerName = segments[segments.length - 1]; + modelInfo = + "Successfully loaded model. " + + pteName + + " and tokenizer " + + tokenizerName + + " in " + + (float) loadDuration / 1000 + + " sec." + + " You can send text or image for inference"; } - long loadDuration = System.currentTimeMillis() - runStartTime; - String modelInfo = - "Model path: " + Message modelLoadedMessage = new Message(modelInfo, false, MessageType.SYSTEM, 0); + + String modelLoggingInfo = + modelLoadError + + "Model path: " + modelPath + "\nTokenizer path: " + tokenizerPath + + "\nTemperature: " + + temperature + "\nModel loaded time: " + loadDuration + " ms"; - Message modelLoadedMessage = new Message(modelInfo, false); + ETLogging.getInstance().log("Load complete. " + modelLoggingInfo); + runOnUiThread( () -> { mSendButton.setEnabled(true); @@ -107,55 +162,26 @@ private void setLocalModel(String modelPath, String tokenizerPath) { }); } - private String memoryInfo() { - final ActivityManager am = (ActivityManager) getSystemService(Context.ACTIVITY_SERVICE); - ActivityManager.MemoryInfo memInfo = new ActivityManager.MemoryInfo(); - am.getMemoryInfo(memInfo); - return "Total RAM: " - + Math.floorDiv(memInfo.totalMem, 1000000) - + " MB. Available RAM: " - + Math.floorDiv(memInfo.availMem, 1000000) - + " MB."; - } - - private void modelDialog() { - String[] pteFiles = listLocalFile("/data/local/tmp/llama/", ".pte"); - String[] binFiles = listLocalFile("/data/local/tmp/llama/", ".bin"); - String[] modelFiles = listLocalFile("/data/local/tmp/llama/", ".model"); - String[] tokenizerFiles = new String[binFiles.length + modelFiles.length]; - System.arraycopy(binFiles, 0, tokenizerFiles, 0, binFiles.length); - System.arraycopy(modelFiles, 0, tokenizerFiles, binFiles.length, modelFiles.length); - AlertDialog.Builder modelPathBuilder = new AlertDialog.Builder(this); - modelPathBuilder.setTitle("Select model path"); - AlertDialog.Builder tokenizerPathBuilder = new AlertDialog.Builder(this); - tokenizerPathBuilder.setTitle("Select tokenizer path"); - modelPathBuilder.setSingleChoiceItems( - pteFiles, - -1, - (dialog, item) -> { - mModelFilePath = pteFiles[item]; - mEditTextMessage.setText(""); - dialog.dismiss(); - tokenizerPathBuilder.create().show(); - }); - - tokenizerPathBuilder.setSingleChoiceItems( - tokenizerFiles, - -1, - (dialog, item) -> { - mTokenizerFilePath = tokenizerFiles[item]; - Runnable runnable = - new Runnable() { - @Override - public void run() { - setLocalModel(mModelFilePath, mTokenizerFilePath); - } - }; - new Thread(runnable).start(); - dialog.dismiss(); - }); + private void loadLocalModelAndParameters( + String modelFilePath, String tokenizerFilePath, float temperature) { + Runnable runnable = + new Runnable() { + @Override + public void run() { + setLocalModel(modelFilePath, tokenizerFilePath, temperature); + } + }; + new Thread(runnable).start(); + } - modelPathBuilder.create().show(); + private void populateExistingMessages(String existingMsgJSON) { + Gson gson = new Gson(); + Type type = new TypeToken>() {}.getType(); + ArrayList savedMessages = gson.fromJson(existingMsgJSON, type); + for (Message msg : savedMessages) { + mMessageAdapter.add(msg); + } + mMessageAdapter.notifyDataSetChanged(); } @Override @@ -169,27 +195,379 @@ protected void onCreate(Bundle savedInstanceState) { finish(); } - mEditTextMessage = findViewById(R.id.editTextMessage); - mSendButton = findViewById(R.id.sendButton); + mEditTextMessage = requireViewById(R.id.editTextMessage); + mSendButton = requireViewById(R.id.sendButton); mSendButton.setEnabled(false); - mModelButton = findViewById(R.id.modelButton); - mMessagesView = findViewById(R.id.messages_view); - mMessageAdapter = new MessageAdapter(this, R.layout.sent_message); + mMessagesView = requireViewById(R.id.messages_view); + mMessageAdapter = new MessageAdapter(this, R.layout.sent_message, new ArrayList()); mMessagesView.setAdapter(mMessageAdapter); - mModelButton.setOnClickListener( + mDemoSharedPreferences = new DemoSharedPreferences(this.getApplicationContext()); + String existingMsgJSON = mDemoSharedPreferences.getSavedMessages(); + if (!existingMsgJSON.isEmpty()) { + populateExistingMessages(existingMsgJSON); + } + mSettingsButton = requireViewById(R.id.settings); + mSettingsButton.setOnClickListener( view -> { - mModule.stop(); - mMessageAdapter.clear(); - mMessageAdapter.notifyDataSetChanged(); - modelDialog(); + Intent myIntent = new Intent(MainActivity.this, SettingsActivity.class); + MainActivity.this.startActivity(myIntent); }); + mCurrentSettingsFields = new SettingsFields(); + mMemoryUpdateHandler = new Handler(Looper.getMainLooper()); onModelRunStopped(); - modelDialog(); + setupMediaButton(); + setupGalleryPicker(); + setupCameraRoll(); + startMemoryUpdate(); + setupShowLogsButton(); + } + + @Override + protected void onPause() { + super.onPause(); + mDemoSharedPreferences.addMessages(mMessageAdapter); + } + + @Override + protected void onResume() { + super.onResume(); + // Check for if settings parameters have changed + Gson gson = new Gson(); + String settingsFieldsJSON = mDemoSharedPreferences.getSettings(); + if (!settingsFieldsJSON.isEmpty()) { + SettingsFields updatedSettingsFields = + gson.fromJson(settingsFieldsJSON, SettingsFields.class); + if (updatedSettingsFields == null) { + // Added this check, because gson.fromJson can return null + askUserToSelectModel(); + return; + } + boolean isUpdated = !mCurrentSettingsFields.equals(updatedSettingsFields); + boolean isLoadModel = updatedSettingsFields.getIsLoadModel(); + if (isUpdated) { + if (isLoadModel) { + // If users change the model file, but not pressing loadModelButton, we won't load the new + // model + checkForUpdateAndReloadModel(updatedSettingsFields); + } else { + askUserToSelectModel(); + } + checkForPromptChange(updatedSettingsFields); + checkForClearChatHistory(updatedSettingsFields); + // Update current to point to the latest + mCurrentSettingsFields = new SettingsFields(updatedSettingsFields); + } + } else { + askUserToSelectModel(); + } + } + + private void checkForClearChatHistory(SettingsFields updatedSettingsFields) { + if (updatedSettingsFields.getIsClearChatHistory()) { + mMessageAdapter.clear(); + mMessageAdapter.notifyDataSetChanged(); + mDemoSharedPreferences.removeExistingMessages(); + // changing to false since chat history has been cleared. + updatedSettingsFields.saveIsClearChatHistory(false); + mDemoSharedPreferences.addSettings(updatedSettingsFields); + } + } + + private void checkForUpdateAndReloadModel(SettingsFields updatedSettingsFields) { + // TODO need to add 'load model' in settings and queue loading based on that + String modelPath = updatedSettingsFields.getModelFilePath(); + String tokenizerPath = updatedSettingsFields.getTokenizerFilePath(); + double temperature = updatedSettingsFields.getTemperature(); + if (!modelPath.isEmpty() && !tokenizerPath.isEmpty()) { + if (updatedSettingsFields.getIsLoadModel() + || !modelPath.equals(mCurrentSettingsFields.getModelFilePath()) + || !tokenizerPath.equals(mCurrentSettingsFields.getTokenizerFilePath()) + || temperature != mCurrentSettingsFields.getTemperature()) { + loadLocalModelAndParameters( + updatedSettingsFields.getModelFilePath(), + updatedSettingsFields.getTokenizerFilePath(), + (float) updatedSettingsFields.getTemperature()); + updatedSettingsFields.saveLoadModelAction(false); + mDemoSharedPreferences.addSettings(updatedSettingsFields); + } + } else { + askUserToSelectModel(); + } + } + + private void checkForPromptChange(SettingsFields updatedSettingsFields) { + if (updatedSettingsFields.isSystemPromptChanged() + || updatedSettingsFields.isUserPromptChanged()) { + enableInstructMode(); + } else { + disableInstructMode(); + } + } + + private void enableInstructMode() { + INSTRUCT_MODE = true; + mEditTextMessage.setText(INSTRUCT_INSTRUCTION); + mEditTextMessage.setInputType(InputType.TYPE_NULL); + mEditTextMessage.clearFocus(); + } + + private void disableInstructMode() { + INSTRUCT_MODE = false; + mEditTextMessage.setText(""); + mEditTextMessage.setInputType(InputType.TYPE_CLASS_TEXT); + mEditTextMessage.clearFocus(); + } + + private void askUserToSelectModel() { + String askLoadModel = + "To get started, select your desired model and tokenizer " + "from the top right corner"; + Message askLoadModelMessage = new Message(askLoadModel, false, MessageType.SYSTEM, 0); + ETLogging.getInstance().log(askLoadModel); + runOnUiThread( + () -> { + mMessageAdapter.add(askLoadModelMessage); + mMessageAdapter.notifyDataSetChanged(); + }); + } + + private void setupShowLogsButton() { + ImageButton showLogsButton = requireViewById(R.id.showLogsButton); + showLogsButton.setOnClickListener( + view -> { + Intent myIntent = new Intent(MainActivity.this, LogsActivity.class); + MainActivity.this.startActivity(myIntent); + }); + } + + private void setupMediaButton() { + mAddMediaLayout = requireViewById(R.id.addMediaLayout); + mAddMediaLayout.setVisibility(View.GONE); // We hide this initially + + ImageButton addMediaButton = requireViewById(R.id.addMediaButton); + addMediaButton.setOnClickListener( + view -> { + mAddMediaLayout.setVisibility(View.VISIBLE); + }); + + mGalleryButton = requireViewById(R.id.galleryButton); + mGalleryButton.setOnClickListener( + view -> { + // Launch the photo picker and let the user choose only images. + mPickGallery.launch( + new PickVisualMediaRequest.Builder() + .setMediaType(ActivityResultContracts.PickVisualMedia.ImageOnly.INSTANCE) + .build()); + }); + mCameraButton = requireViewById(R.id.cameraButton); + mCameraButton.setOnClickListener( + view -> { + Log.d("CameraRoll", "Check permission"); + if (ContextCompat.checkSelfPermission(MainActivity.this, Manifest.permission.CAMERA) + != PackageManager.PERMISSION_GRANTED) { + ActivityCompat.requestPermissions( + MainActivity.this, + new String[] {Manifest.permission.CAMERA}, + REQUEST_IMAGE_CAPTURE); + } else { + launchCamera(); + } + }); + } + + private void setupCameraRoll() { + // Registers a camera roll activity launcher. + mCameraRoll = + registerForActivityResult( + new ActivityResultContracts.TakePicture(), + result -> { + if (result && cameraImageUri != null) { + Log.d("CameraRoll", "Photo saved to uri: " + cameraImageUri); + mAddMediaLayout.setVisibility(View.GONE); + List uris = new ArrayList<>(); + uris.add(cameraImageUri); + showMediaPreview(uris); + } else { + // Delete the temp image file based on the url since the photo is not successfully + // taken + if (cameraImageUri != null) { + ContentResolver contentResolver = MainActivity.this.getContentResolver(); + contentResolver.delete(cameraImageUri, null, null); + Log.d("CameraRoll", "No photo taken. Delete temp uri"); + } + } + }); + mMediaPreviewConstraintLayout = requireViewById(R.id.mediaPreviewConstraintLayout); + ImageButton mediaPreviewCloseButton = requireViewById(R.id.mediaPreviewCloseButton); + mediaPreviewCloseButton.setOnClickListener( + view -> { + mMediaPreviewConstraintLayout.setVisibility(View.GONE); + mSelectedImageUri = null; + }); + + ImageButton addMoreImageButton = requireViewById(R.id.addMoreImageButton); + addMoreImageButton.setOnClickListener( + view -> { + Log.d("addMore", "clicked"); + mMediaPreviewConstraintLayout.setVisibility(View.GONE); + // Direct user to select type of input + mCameraButton.callOnClick(); + }); + } + + private String updateMemoryUsage() { + ActivityManager.MemoryInfo memoryInfo = new ActivityManager.MemoryInfo(); + ActivityManager activityManager = (ActivityManager) getSystemService(ACTIVITY_SERVICE); + if (activityManager == null) { + return "---"; + } + activityManager.getMemoryInfo(memoryInfo); + long totalMem = memoryInfo.totalMem / (1024 * 1024); + long availableMem = memoryInfo.availMem / (1024 * 1024); + long usedMem = totalMem - availableMem; + return usedMem + "MB"; + } + + private void startMemoryUpdate() { + mMemoryView = requireViewById(R.id.ram_usage_live); + memoryUpdater = + new Runnable() { + @Override + public void run() { + mMemoryView.setText(updateMemoryUsage()); + mMemoryUpdateHandler.postDelayed(this, 1000); + } + }; + mMemoryUpdateHandler.post(memoryUpdater); + } + + @Override + public void onRequestPermissionsResult( + int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) { + super.onRequestPermissionsResult(requestCode, permissions, grantResults); + if (requestCode == REQUEST_IMAGE_CAPTURE && grantResults.length != 0) { + if (grantResults[0] == PackageManager.PERMISSION_GRANTED) { + launchCamera(); + } else if (grantResults[0] == PackageManager.PERMISSION_DENIED) { + Log.d("CameraRoll", "Permission denied"); + } + } + } + + private void launchCamera() { + ContentValues values = new ContentValues(); + values.put(MediaStore.Images.Media.TITLE, "New Picture"); + values.put(MediaStore.Images.Media.DESCRIPTION, "From Camera"); + values.put(MediaStore.Images.Media.RELATIVE_PATH, "DCIM/Camera/"); + cameraImageUri = + MainActivity.this + .getContentResolver() + .insert(MediaStore.Images.Media.EXTERNAL_CONTENT_URI, values); + mCameraRoll.launch(cameraImageUri); + } + + private void setupGalleryPicker() { + // Registers a photo picker activity launcher in single-select mode. + mPickGallery = + registerForActivityResult( + new ActivityResultContracts.PickMultipleVisualMedia(MAX_NUM_OF_IMAGES), + uris -> { + if (!uris.isEmpty()) { + Log.d("PhotoPicker", "Selected URIs: " + uris); + mAddMediaLayout.setVisibility(View.GONE); + for (Uri uri : uris) { + MainActivity.this + .getContentResolver() + .takePersistableUriPermission(uri, Intent.FLAG_GRANT_READ_URI_PERMISSION); + } + showMediaPreview(uris); + } else { + Log.d("PhotoPicker", "No media selected"); + } + }); + + mMediaPreviewConstraintLayout = requireViewById(R.id.mediaPreviewConstraintLayout); + ImageButton mediaPreviewCloseButton = requireViewById(R.id.mediaPreviewCloseButton); + mediaPreviewCloseButton.setOnClickListener( + view -> { + mMediaPreviewConstraintLayout.setVisibility(View.GONE); + mSelectedImageUri = null; + }); + + ImageButton addMoreImageButton = requireViewById(R.id.addMoreImageButton); + addMoreImageButton.setOnClickListener( + view -> { + Log.d("addMore", "clicked"); + mMediaPreviewConstraintLayout.setVisibility(View.GONE); + mGalleryButton.callOnClick(); + }); + } + + private List getProcessedImagesForModel(List uris) { + List imageList = new ArrayList<>(); + if (uris != null) { + uris.forEach( + (uri) -> { + imageList.add(new ETImage(this.getContentResolver(), uri)); + }); + } + return imageList; + } + + private void showMediaPreview(List uris) { + if (mSelectedImageUri == null) { + mSelectedImageUri = uris; + } else { + mSelectedImageUri.addAll(uris); + } + + if (mSelectedImageUri.size() > MAX_NUM_OF_IMAGES) { + mSelectedImageUri = mSelectedImageUri.subList(0, MAX_NUM_OF_IMAGES); + Toast.makeText( + this, "Only max " + MAX_NUM_OF_IMAGES + " images are allowed", Toast.LENGTH_SHORT) + .show(); + } + Log.d("mSelectedImageUri", mSelectedImageUri.size() + " " + mSelectedImageUri); + + mMediaPreviewConstraintLayout.setVisibility(View.VISIBLE); + + List imageViews = new ArrayList(); + + // Pre-populate all the image views that are available from the layout (currently max 5) + imageViews.add(requireViewById(R.id.mediaPreviewImageView1)); + imageViews.add(requireViewById(R.id.mediaPreviewImageView2)); + imageViews.add(requireViewById(R.id.mediaPreviewImageView3)); + imageViews.add(requireViewById(R.id.mediaPreviewImageView4)); + imageViews.add(requireViewById(R.id.mediaPreviewImageView5)); + + // Hide all the image views (reset state) + for (int i = 0; i < imageViews.size(); i++) { + imageViews.get(i).setVisibility(View.GONE); + } + + // Only show/render those that have proper Image URIs + for (int i = 0; i < mSelectedImageUri.size(); i++) { + imageViews.get(i).setVisibility(View.VISIBLE); + imageViews.get(i).setImageURI(mSelectedImageUri.get(i)); + } + } + + private void addSelectedImagesToChatThread(List selectedImageUri) { + if (selectedImageUri == null) { + return; + } + mMediaPreviewConstraintLayout.setVisibility(View.GONE); + for (int i = 0; i < selectedImageUri.size(); i++) { + Uri imageURI = selectedImageUri.get(i); + Log.d("image uri ", "test " + imageURI.getPath()); + mMessageAdapter.add(new Message(imageURI.toString(), true, MessageType.IMAGE, 0)); + } + mMessageAdapter.notifyDataSetChanged(); } private void onModelRunStarted() { - mSendButton.setText("Stop"); + mSendButton.setClickable(false); + mSendButton.setImageResource(R.drawable.baseline_stop_24); mSendButton.setOnClickListener( view -> { mModule.stop(); @@ -197,16 +575,49 @@ private void onModelRunStarted() { } private void onModelRunStopped() { - setTitle(memoryInfo()); - mSendButton.setText("Generate"); + mSendButton.setClickable(true); + mSendButton.setImageResource(R.drawable.baseline_send_24); mSendButton.setOnClickListener( view -> { - String prompt = mEditTextMessage.getText().toString(); - mMessageAdapter.add(new Message(prompt, true)); + addSelectedImagesToChatThread(mSelectedImageUri); + // TODO: When ET supports multimodal, this is where we will add the images as part of the + // prompt. + List processedImageList = getProcessedImagesForModel(mSelectedImageUri); + processedImageList.forEach( + image -> { + ETLogging.getInstance() + .log( + "Image preprocessed:" + + " uri = " + + image.getUri().getLastPathSegment() + + "," + + " width = " + + image.getWidth() + + "," + + " height = " + + image.getHeight() + + "," + + " bytes size = " + + image.getBytes().length); + }); + String prompt; + if (INSTRUCT_MODE) { + prompt = mCurrentSettingsFields.getEntirePrompt(); + mEditTextMessage.setText(INSTRUCT_INSTRUCTION); + } else { + prompt = mEditTextMessage.getText().toString(); + mEditTextMessage.setText(""); + } + mMessageAdapter.add(new Message(prompt, true, MessageType.TEXT, 0)); mMessageAdapter.notifyDataSetChanged(); mEditTextMessage.setText(""); - mResultMessage = new Message("", false); + mResultMessage = new Message("", false, MessageType.TEXT, 0); mMessageAdapter.add(mResultMessage); + // Scroll to bottom of the list + mMessagesView.smoothScrollToPosition(mMessageAdapter.getCount() - 1); + // After images are added to prompt and chat thread, we clear the imageURI list + // Note: This has to be done after imageURIs are no longer needed by LlamaModule + mSelectedImageUri = null; Runnable runnable = new Runnable() { @Override @@ -218,9 +629,11 @@ public void run() { onModelRunStarted(); } }); - + ETLogging.getInstance().log("Running inference.. prompt=" + prompt); + long generateStartTime = System.currentTimeMillis(); mModule.generate(prompt, MainActivity.this); - + long generateDuration = System.currentTimeMillis() - generateStartTime; + mResultMessage.setTotalGenerationTime(generateDuration); runOnUiThread( new Runnable() { @Override @@ -228,6 +641,7 @@ public void run() { onModelRunStopped(); } }); + ETLogging.getInstance().log("Inference completed"); } }; new Thread(runnable).start(); @@ -242,8 +656,27 @@ public void run() { @Override public void run() { mMessageAdapter.notifyDataSetChanged(); - setTitle(memoryInfo()); } }); } + + @Override + public void onBackPressed() { + super.onBackPressed(); + if (mAddMediaLayout != null && mAddMediaLayout.getVisibility() == View.VISIBLE) { + mAddMediaLayout.setVisibility(View.GONE); + } else { + // Default behavior of back button + finish(); + } + } + + @Override + protected void onDestroy() { + super.onDestroy(); + mMemoryUpdateHandler.removeCallbacks(memoryUpdater); + // This is to cover the case where the app is shutdown when user is on MainActivity but + // never clicked on the logsActivity + ETLogging.getInstance().saveLogs(); + } } diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/Message.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/Message.java index 81b77b1aba..b2e5380e2a 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/Message.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/Message.java @@ -8,14 +8,50 @@ package com.example.executorchllamademo; +import java.text.SimpleDateFormat; +import java.util.Date; +import java.util.Locale; + public class Message { private String text; - private boolean isSent; + private final boolean isSent; private float tokensPerSecond; + private long totalGenerationTime; + private final long timestamp; + private final MessageType messageType; + private String imagePath; + private final int promptID; + + private static final String TIMESTAMP_FORMAT = "hh:mm a"; // example: 2:23 PM - public Message(String text, boolean isSent) { - this.text = text; + public Message(String text, boolean isSent, MessageType messageType, int promptID) { this.isSent = isSent; + this.messageType = messageType; + this.promptID = promptID; + + if (messageType == MessageType.IMAGE) { + this.imagePath = text; + } else { + this.text = text; + } + + if (messageType != MessageType.SYSTEM) { + this.timestamp = System.currentTimeMillis(); + } else { + this.timestamp = (long) 0; + } + } + + public int getPromptID() { + return promptID; + } + + public MessageType getMessageType() { + return messageType; + } + + public String getImagePath() { + return imagePath; } public String getText() { @@ -34,7 +70,25 @@ public void setTokensPerSecond(float tokensPerSecond) { this.tokensPerSecond = tokensPerSecond; } + public void setTotalGenerationTime(long totalGenerationTime) { + this.totalGenerationTime = totalGenerationTime; + } + public float getTokensPerSecond() { return tokensPerSecond; } + + public long getTotalGenerationTime() { + return totalGenerationTime; + } + + public long getTimestamp() { + return timestamp; + } + + public String getFormattedTimestamp() { + SimpleDateFormat formatter = new SimpleDateFormat(TIMESTAMP_FORMAT, Locale.getDefault()); + Date date = new Date(timestamp); + return formatter.format(date); + } } diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java index 656da1967d..d9cbd95a1a 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java @@ -8,33 +8,86 @@ package com.example.executorchllamademo; +import android.net.Uri; import android.view.LayoutInflater; import android.view.View; import android.view.ViewGroup; import android.widget.ArrayAdapter; +import android.widget.ImageView; import android.widget.TextView; +import java.util.ArrayList; public class MessageAdapter extends ArrayAdapter { - public MessageAdapter(android.content.Context context, int resource) { + + private final ArrayList savedMessages; + + public MessageAdapter( + android.content.Context context, int resource, ArrayList savedMessages) { super(context, resource); + this.savedMessages = savedMessages; } @Override public View getView(int position, View convertView, ViewGroup parent) { Message currentMessage = getItem(position); + int layoutIdForListItem; - int layoutIdForListItem = - currentMessage.getIsSent() ? R.layout.sent_message : R.layout.received_message; + if (currentMessage.getMessageType() == MessageType.SYSTEM) { + layoutIdForListItem = R.layout.system_message; + } else { + layoutIdForListItem = + currentMessage.getIsSent() ? R.layout.sent_message : R.layout.received_message; + } View listItemView = LayoutInflater.from(getContext()).inflate(layoutIdForListItem, parent, false); - TextView messageTextView = listItemView.findViewById(R.id.message_text); - messageTextView.setText(currentMessage.getText()); + if (currentMessage.getMessageType() == MessageType.IMAGE) { + ImageView messageImageView = listItemView.requireViewById(R.id.message_image); + messageImageView.setImageURI(Uri.parse(currentMessage.getImagePath())); + TextView messageTextView = listItemView.requireViewById(R.id.message_text); + messageTextView.setVisibility(View.GONE); + } else { + TextView messageTextView = listItemView.requireViewById(R.id.message_text); + messageTextView.setText(currentMessage.getText()); + } + String metrics = ""; + TextView tokensView; if (currentMessage.getTokensPerSecond() > 0) { - TextView tokensView = listItemView.findViewById(R.id.tokens_per_second); - tokensView.setText("" + currentMessage.getTokensPerSecond() + " t/s"); + metrics = String.format("%.2f", currentMessage.getTokensPerSecond()) + "t/s "; + } + + if (currentMessage.getTotalGenerationTime() > 0) { + metrics = metrics + (float) currentMessage.getTotalGenerationTime() / 1000 + "s "; + } + + if (currentMessage.getTokensPerSecond() > 0 || currentMessage.getTotalGenerationTime() > 0) { + tokensView = listItemView.requireViewById(R.id.generation_metrics); + tokensView.setText(metrics); + TextView separatorView = listItemView.requireViewById(R.id.bar); + separatorView.setVisibility(View.VISIBLE); + } + + if (currentMessage.getTimestamp() > 0) { + TextView timestampView = listItemView.requireViewById(R.id.timestamp); + timestampView.setText(currentMessage.getFormattedTimestamp()); } return listItemView; } + + @Override + public void add(Message msg) { + super.add(msg); + savedMessages.add(msg); + } + + @Override + public void clear() { + super.clear(); + savedMessages.clear(); + } + + public ArrayList getSavedMessages() { + return savedMessages; + } } diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageType.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageType.java new file mode 100644 index 0000000000..6042acb572 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageType.java @@ -0,0 +1,15 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +public enum MessageType { + TEXT, + IMAGE, + SYSTEM +} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java new file mode 100644 index 0000000000..1d109e0195 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java @@ -0,0 +1,325 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import android.app.AlertDialog; +import android.content.DialogInterface; +import android.os.Bundle; +import android.text.Editable; +import android.text.TextWatcher; +import android.widget.Button; +import android.widget.EditText; +import android.widget.ImageButton; +import android.widget.TextView; +import androidx.appcompat.app.AppCompatActivity; +import androidx.core.graphics.Insets; +import androidx.core.view.ViewCompat; +import androidx.core.view.WindowInsetsCompat; +import com.google.gson.Gson; +import java.io.File; + +public class SettingsActivity extends AppCompatActivity { + + private String mModelFilePath = ""; + private String mTokenizerFilePath = ""; + private TextView mModelTextView; + private TextView mTokenizerTextView; + private ImageButton mModelImageButton; + private ImageButton mTokenizerImageButton; + private EditText mSystemPromptEditText; + private EditText mUserPromptEditText; + private Button mLoadModelButton; + private double mSetTemperature; + private String mSystemPrompt; + private String mUserPrompt; + + public SettingsFields mSettingsFields; + + private DemoSharedPreferences mDemoSharedPreferences; + public static double TEMPERATURE_MIN_VALUE = 0.1; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_settings); + ViewCompat.setOnApplyWindowInsetsListener( + requireViewById(R.id.main), + (v, insets) -> { + Insets systemBars = insets.getInsets(WindowInsetsCompat.Type.systemBars()); + v.setPadding(systemBars.left, systemBars.top, systemBars.right, systemBars.bottom); + return insets; + }); + mDemoSharedPreferences = new DemoSharedPreferences(getBaseContext()); + mSettingsFields = new SettingsFields(); + setupSettings(); + } + + private void setupSettings() { + mModelTextView = requireViewById(R.id.modelTextView); + mTokenizerTextView = requireViewById(R.id.tokenizerTextView); + mModelImageButton = requireViewById(R.id.modelImageButton); + mTokenizerImageButton = requireViewById(R.id.tokenizerImageButton); + mSystemPromptEditText = requireViewById(R.id.systemPromptText); + mUserPromptEditText = requireViewById(R.id.userPromptText); + loadSettings(); + + // TODO: The two setOnClickListeners will be removed after file path issue is resolved + mModelImageButton.setOnClickListener( + view -> { + setupModelSelectorDialog(); + }); + mTokenizerImageButton.setOnClickListener( + view -> { + setupTokenizerSelectorDialog(); + }); + mModelFilePath = mSettingsFields.getModelFilePath(); + if (!mModelFilePath.isEmpty()) { + mModelTextView.setText(getFilenameFromPath(mModelFilePath)); + } + mTokenizerFilePath = mSettingsFields.getTokenizerFilePath(); + if (!mTokenizerFilePath.isEmpty()) { + mTokenizerTextView.setText(getFilenameFromPath(mTokenizerFilePath)); + } + + setupParameterSettings(); + setupPromptSettings(); + setupClearChatHistoryButton(); + setupLoadModelButton(); + } + + private void setupLoadModelButton() { + mLoadModelButton = requireViewById(R.id.loadModelButton); + mLoadModelButton.setEnabled(true); + mLoadModelButton.setOnClickListener( + view -> { + new AlertDialog.Builder(this) + .setTitle("Load Model") + .setMessage("Do you really want to load the new model?") + .setIcon(android.R.drawable.ic_dialog_alert) + .setPositiveButton( + android.R.string.yes, + new DialogInterface.OnClickListener() { + public void onClick(DialogInterface dialog, int whichButton) { + mSettingsFields.saveLoadModelAction(true); + mLoadModelButton.setEnabled(false); + } + }) + .setNegativeButton(android.R.string.no, null) + .show(); + }); + } + + private void setupClearChatHistoryButton() { + Button clearChatButton = requireViewById(R.id.clearChatButton); + clearChatButton.setOnClickListener( + view -> { + new AlertDialog.Builder(this) + .setTitle("Delete Chat History") + .setMessage("Do you really want to delete chat history?") + .setIcon(android.R.drawable.ic_dialog_alert) + .setPositiveButton( + android.R.string.yes, + new DialogInterface.OnClickListener() { + public void onClick(DialogInterface dialog, int whichButton) { + mSettingsFields.saveIsClearChatHistory(true); + } + }) + .setNegativeButton(android.R.string.no, null) + .show(); + }); + } + + private void setupParameterSettings() { + setupTemperatureSettings(); + } + + private void setupTemperatureSettings() { + mSetTemperature = mSettingsFields.getTemperature(); + EditText temperatureEditText = requireViewById(R.id.temperatureEditText); + temperatureEditText.setText(String.valueOf(mSetTemperature)); + temperatureEditText.addTextChangedListener( + new TextWatcher() { + @Override + public void beforeTextChanged(CharSequence s, int start, int count, int after) {} + + @Override + public void onTextChanged(CharSequence s, int start, int before, int count) {} + + @Override + public void afterTextChanged(Editable s) { + mSetTemperature = Double.parseDouble(s.toString()); + // This is needed because temperature is changed together with model loading + // Once temperature is no longer in LlamaModule constructor, we can remove this + mSettingsFields.saveLoadModelAction(true); + saveSettings(); + } + }); + } + + private void setupPromptSettings() { + setupSystemPromptSettings(); + setupUserPromptSettings(); + } + + private void setupSystemPromptSettings() { + mSystemPrompt = mSettingsFields.getSystemPrompt(); + mSystemPromptEditText.setText(mSystemPrompt); + mSystemPromptEditText.addTextChangedListener( + new TextWatcher() { + @Override + public void beforeTextChanged(CharSequence s, int start, int count, int after) {} + + @Override + public void onTextChanged(CharSequence s, int start, int before, int count) {} + + @Override + public void afterTextChanged(Editable s) { + mSystemPrompt = s.toString(); + } + }); + + ImageButton resetSystemPrompt = requireViewById(R.id.resetSystemPrompt); + resetSystemPrompt.setOnClickListener( + view -> { + new AlertDialog.Builder(this) + .setTitle("Reset System Prompt") + .setMessage("Do you really want to reset system prompt?") + .setIcon(android.R.drawable.ic_dialog_alert) + .setPositiveButton( + android.R.string.yes, + new DialogInterface.OnClickListener() { + public void onClick(DialogInterface dialog, int whichButton) { + // Clear the messageAdapter and sharedPreference + mSystemPromptEditText.setText(mSettingsFields.getSystemPromptTemplate()); + } + }) + .setNegativeButton(android.R.string.no, null) + .show(); + }); + } + + private void setupUserPromptSettings() { + mUserPrompt = mSettingsFields.getUserPrompt(); + mUserPromptEditText.setText(mUserPrompt); + mUserPromptEditText.addTextChangedListener( + new TextWatcher() { + @Override + public void beforeTextChanged(CharSequence s, int start, int count, int after) {} + + @Override + public void onTextChanged(CharSequence s, int start, int before, int count) {} + + @Override + public void afterTextChanged(Editable s) { + mUserPrompt = s.toString(); + } + }); + + ImageButton resetUserPrompt = requireViewById(R.id.resetUserPrompt); + resetUserPrompt.setOnClickListener( + view -> { + new AlertDialog.Builder(this) + .setTitle("Reset Prompt Template") + .setMessage("Do you really want to reset the prompt template?") + .setIcon(android.R.drawable.ic_dialog_alert) + .setPositiveButton( + android.R.string.yes, + new DialogInterface.OnClickListener() { + public void onClick(DialogInterface dialog, int whichButton) { + // Clear the messageAdapter and sharedPreference + mUserPromptEditText.setText(mSettingsFields.getUserPromptTemplate()); + } + }) + .setNegativeButton(android.R.string.no, null) + .show(); + }); + } + + private void setupModelSelectorDialog() { + String[] pteFiles = listLocalFile("/data/local/tmp/llama/", ".pte"); + AlertDialog.Builder modelPathBuilder = new AlertDialog.Builder(this); + modelPathBuilder.setTitle("Select model path"); + + modelPathBuilder.setSingleChoiceItems( + pteFiles, + -1, + (dialog, item) -> { + mModelFilePath = pteFiles[item]; + mModelTextView.setText(getFilenameFromPath(mModelFilePath)); + mLoadModelButton.setEnabled(true); + dialog.dismiss(); + }); + + modelPathBuilder.create().show(); + } + + private static String[] listLocalFile(String path, String suffix) { + File directory = new File(path); + if (directory.exists() && directory.isDirectory()) { + File[] files = directory.listFiles((dir, name) -> name.toLowerCase().endsWith(suffix)); + String[] result = new String[files.length]; + for (int i = 0; i < files.length; i++) { + if (files[i].isFile() && files[i].getName().endsWith(suffix)) { + result[i] = files[i].getAbsolutePath(); + } + } + return result; + } + return null; + } + + private void setupTokenizerSelectorDialog() { + String[] binFiles = listLocalFile("/data/local/tmp/llama/", ".bin"); + String[] tokenizerFiles = new String[binFiles.length]; + System.arraycopy(binFiles, 0, tokenizerFiles, 0, binFiles.length); + AlertDialog.Builder tokenizerPathBuilder = new AlertDialog.Builder(this); + tokenizerPathBuilder.setTitle("Select tokenizer path"); + tokenizerPathBuilder.setSingleChoiceItems( + tokenizerFiles, + -1, + (dialog, item) -> { + mTokenizerFilePath = tokenizerFiles[item]; + mTokenizerTextView.setText(getFilenameFromPath(mTokenizerFilePath)); + mLoadModelButton.setEnabled(true); + dialog.dismiss(); + }); + + tokenizerPathBuilder.create().show(); + } + + private String getFilenameFromPath(String uriFilePath) { + String[] segments = uriFilePath.split("/"); + if (segments.length > 0) { + return segments[segments.length - 1]; // get last element (aka filename) + } + return ""; + } + + private void loadSettings() { + Gson gson = new Gson(); + String settingsFieldsJSON = mDemoSharedPreferences.getSettings(); + if (!settingsFieldsJSON.isEmpty()) { + mSettingsFields = gson.fromJson(settingsFieldsJSON, SettingsFields.class); + } + } + + private void saveSettings() { + mSettingsFields.saveModelPath(mModelFilePath); + mSettingsFields.saveTokenizerPath(mTokenizerFilePath); + mSettingsFields.saveParameters(mSetTemperature); + mSettingsFields.savePrompts(mSystemPrompt, mUserPrompt); + mDemoSharedPreferences.addSettings(mSettingsFields); + } + + @Override + public void onBackPressed() { + super.onBackPressed(); + saveSettings(); + } +} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java new file mode 100644 index 0000000000..d42a241293 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java @@ -0,0 +1,135 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +public class SettingsFields { + private static final String SYSTEM_PLACEHOLDER = "{{ system_prompt }}"; + private static final String USER_PLACEHOLDER = "{{ user_prompt }}"; + private static String SYSTEM_PROMPT_TEMPLATE = + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n" + + SYSTEM_PLACEHOLDER + + "<|eot_id|>"; + private static String USER_PROMPT_TEMPLATE = + "<|start_header_id|>user<|end_header_id|>\n" + + USER_PLACEHOLDER + + "<|eot_id|>\n" + + "<|start_header_id|>assistant<|end_header_id|>"; + + public String getModelFilePath() { + return modelFilePath; + } + + public String getTokenizerFilePath() { + return tokenizerFilePath; + } + + public double getTemperature() { + return temperature; + } + + public String getSystemPrompt() { + return systemPrompt; + } + + public String getUserPrompt() { + return userPrompt; + } + + public String getEntirePrompt() { + return systemPrompt + userPrompt; + } + + public String getSystemPromptTemplate() { + return SYSTEM_PROMPT_TEMPLATE; + } + + public String getUserPromptTemplate() { + return USER_PROMPT_TEMPLATE; + } + + public boolean getIsClearChatHistory() { + return isClearChatHistory; + } + + public boolean getIsLoadModel() { + return isLoadModel; + } + + private String modelFilePath; + private String tokenizerFilePath; + private double temperature; + private String systemPrompt; + private String userPrompt; + private boolean isClearChatHistory; + private boolean isLoadModel; + + public SettingsFields() { + modelFilePath = ""; + tokenizerFilePath = ""; + temperature = SettingsActivity.TEMPERATURE_MIN_VALUE; + systemPrompt = SYSTEM_PROMPT_TEMPLATE; + userPrompt = USER_PROMPT_TEMPLATE; + isClearChatHistory = false; + isLoadModel = false; + } + + public SettingsFields(SettingsFields settingsFields) { + this.modelFilePath = settingsFields.modelFilePath; + this.tokenizerFilePath = settingsFields.tokenizerFilePath; + this.temperature = settingsFields.temperature; + this.systemPrompt = settingsFields.getSystemPrompt(); + this.userPrompt = settingsFields.getUserPrompt(); + this.isClearChatHistory = settingsFields.getIsClearChatHistory(); + this.isLoadModel = settingsFields.getIsLoadModel(); + } + + public void saveModelPath(String modelFilePath) { + this.modelFilePath = modelFilePath; + } + + public void saveTokenizerPath(String tokenizerFilePath) { + this.tokenizerFilePath = tokenizerFilePath; + } + + public void saveParameters(Double temperature) { + this.temperature = temperature; + } + + public void savePrompts(String systemPrompt, String userPrompt) { + this.systemPrompt = systemPrompt; + this.userPrompt = userPrompt; + } + + public void saveIsClearChatHistory(boolean needToClear) { + this.isClearChatHistory = needToClear; + } + + public void saveLoadModelAction(boolean shouldLoadModel) { + this.isLoadModel = shouldLoadModel; + } + + public boolean equals(SettingsFields anotherSettingsFields) { + if (this == anotherSettingsFields) return true; + return modelFilePath.equals(anotherSettingsFields.modelFilePath) + && tokenizerFilePath.equals(anotherSettingsFields.tokenizerFilePath) + && temperature == anotherSettingsFields.temperature + && systemPrompt.equals(anotherSettingsFields.systemPrompt) + && userPrompt.equals(anotherSettingsFields.userPrompt) + && isClearChatHistory == anotherSettingsFields.isClearChatHistory + && isLoadModel == anotherSettingsFields.isLoadModel; + } + + public boolean isSystemPromptChanged() { + return !systemPrompt.contains(SYSTEM_PLACEHOLDER); + } + + public boolean isUserPromptChanged() { + return !userPrompt.contains(USER_PLACEHOLDER); + } +} diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/banner_shape.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/banner_shape.xml new file mode 100644 index 0000000000..70f251ee64 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/banner_shape.xml @@ -0,0 +1,7 @@ + + + + + \ No newline at end of file diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_add_24.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_add_24.xml new file mode 100644 index 0000000000..9f83b8fbe7 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_add_24.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_add_photo_alternate_24.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_add_photo_alternate_24.xml new file mode 100644 index 0000000000..d710d27110 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_add_photo_alternate_24.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_article_24.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_article_24.xml new file mode 100644 index 0000000000..30d5d26b98 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_article_24.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_close_24.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_close_24.xml new file mode 100644 index 0000000000..f8ca0c64b9 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_close_24.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_delete_forever_24.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_delete_forever_24.xml new file mode 100644 index 0000000000..2c71fc6e56 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_delete_forever_24.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_restart_alt_24.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_restart_alt_24.xml new file mode 100644 index 0000000000..9285db079a --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_restart_alt_24.xml @@ -0,0 +1,6 @@ + + + + diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_send_24.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_send_24.xml new file mode 100644 index 0000000000..3abc6cb33b --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_send_24.xml @@ -0,0 +1,5 @@ + + + diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_settings_24.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_settings_24.xml new file mode 100644 index 0000000000..42593b298e --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_settings_24.xml @@ -0,0 +1,10 @@ + + + diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_stop_24.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_stop_24.xml new file mode 100644 index 0000000000..817d57b76a --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/baseline_stop_24.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/btn.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/btn.xml new file mode 100644 index 0000000000..ceb3ac56c9 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/btn.xml @@ -0,0 +1,8 @@ + + + + + + + \ No newline at end of file diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/custom_button_round.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/custom_button_round.xml new file mode 100644 index 0000000000..87c82d2a38 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/custom_button_round.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/input_text_shape.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/input_text_shape.xml new file mode 100644 index 0000000000..15c404c60d --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/input_text_shape.xml @@ -0,0 +1,10 @@ + + + + + + + \ No newline at end of file diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/logo.png b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/logo.png new file mode 100644 index 0000000000..60e3e5174e Binary files /dev/null and b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/logo.png differ diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/outline_add_box_48.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/outline_add_box_48.xml new file mode 100644 index 0000000000..c8b2c96d58 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/outline_add_box_48.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/outline_arrow_drop_down_circle_24.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/outline_arrow_drop_down_circle_24.xml new file mode 100644 index 0000000000..a8c859d8b3 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/outline_arrow_drop_down_circle_24.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/outline_camera_alt_48.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/outline_camera_alt_48.xml new file mode 100644 index 0000000000..c7b4b2e4a1 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/outline_camera_alt_48.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/outline_image_48.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/outline_image_48.xml new file mode 100644 index 0000000000..a8bb4b2f64 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/outline_image_48.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/prompt_shape.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/prompt_shape.xml new file mode 100644 index 0000000000..1627ed98c0 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/drawable/prompt_shape.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_logs.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_logs.xml new file mode 100644 index 0000000000..b327a544f2 --- /dev/null +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_logs.xml @@ -0,0 +1,55 @@ + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml index 089acb572b..ec215e63ba 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml @@ -1,44 +1,237 @@ - - + + + + + + + + + - + + + + + + -