Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qualcomm AI Engine Direct - Optimization and fix mutable buffer issue #5072

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 0 additions & 35 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -372,38 +372,3 @@ jobs:

# Run pytest with coverage
pytest -c /dev/null -v -n auto --cov=./ --cov-report=xml backends/arm/test


test-llama-runner-qnn-linux:
name: test-llama-runner-qnn-linux
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
strategy:
matrix:
dtype: [fp32]
build-tool: [cmake]
mode: [qnn]
fail-fast: false
with:
runner: linux.2xlarge
docker-image: executorch-ubuntu-22.04-clang12-android
submodules: 'true'
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
timeout: 900
script: |
# The generic Linux job chooses to use base env, not the one setup by the image
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
conda activate "${CONDA_ENV}"

DTYPE=${{ matrix.dtype }}
BUILD_TOOL=${{ matrix.build-tool }}
MODE=${{ matrix.mode }}

PYTHON_EXECUTABLE=python bash .ci/scripts/setup-qnn-deps.sh
PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh

# Setup executorch
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh buck2
# Install requirements for export_llama
PYTHON_EXECUTABLE=python bash examples/models/llama2/install_requirements.sh
# Test llama2
PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh stories110M "${BUILD_TOOL}" "${DTYPE}" "${MODE}"
2 changes: 2 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
op_quantize,
op_relu,
op_reshape,
op_rms_norm,
op_rsqrt,
op_select_copy,
op_sigmoid,
Expand Down Expand Up @@ -92,6 +93,7 @@
op_quantize,
op_relu,
op_reshape,
op_rms_norm,
op_rsqrt,
op_select_copy,
op_sigmoid,
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def get_quant_tensor_value(

dtype = quant_configs[QCOM_DTYPE]

tensor = tensor.div(scale + 1e-6).add(zero_point).round().to(dtype)
tensor = tensor.div(scale).add(zero_point).round().to(dtype)
# Make the backends access data correctly
if quant_configs.get(QCOM_BITWIDTH) == 4:
mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8)
Expand Down
86 changes: 24 additions & 62 deletions backends/qualcomm/builders/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,7 @@

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import (
QCOM_DATA,
QCOM_DTYPE,
QCOM_QUANT_ATTRS,
QCOM_QUANT_MAX,
QCOM_QUANT_MIN,
QCOM_SCALE,
QCOM_ZERO_POINT,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.backends.qualcomm.utils.constants import QCOM_DATA

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import (
Expand Down Expand Up @@ -94,52 +85,6 @@ def _add_conv_op_parameter(

return conv_op

def _get_bias_tensor(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper],
num_output_channel: int,
) -> PyQnnWrapper.PyQnnOpWrapper:
# build dummy node if bias is not given
bias_node = (
node.args[2]
if node.args[2] is not None
else torch.fx.Node(
node.graph,
node.name + "_runtime_bias",
"call_function",
exir_ops.edge.aten.full.default,
(), # args
{}, # kwargs
)
)
# zeros tensor to meet HTP constraint if bias is not given
bias_tensor = (
get_parameter(bias_node, self.edge_program)
if node.args[2] is not None
else torch.zeros(num_output_channel)
)
# insert quant attribute to meet HTP constraint if bias is not given
if (
node.args[2] is None
and (bias_quant_attrs := node.meta.get(QCOM_QUANT_ATTRS)) is not None
):
quant_attrs = bias_quant_attrs.copy()
quant_attrs[QCOM_ZERO_POINT] = 0
quant_attrs[QCOM_SCALE] = 0
quant_attrs[QCOM_DTYPE] = torch.int32
quant_attrs[QCOM_QUANT_MAX] = torch.iinfo(torch.int32).max
quant_attrs[QCOM_QUANT_MIN] = torch.iinfo(torch.int32).min + 1
bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs

return self.define_tensor(
bias_node,
bias_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
is_input_tensor=False,
)

def _define_conv1d(
self,
node: torch.fx.Node,
Expand Down Expand Up @@ -204,9 +149,17 @@ def _define_conv1d(
is_input_tensor=False,
)
conv_input_tensors = [unsqueeze_output_tensor_wrapper, filter_tensor_wrapper]
conv_input_tensors.append(
self._get_bias_tensor(node, nodes_to_wrappers, filter_tensor.shape[-1])
)
if node.args[2] is not None:
bias_node = node.args[2]
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
bias_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
is_input_tensor=False,
)
conv_input_tensors.append(bias_tensor_wrapper)

stride = [1] + cast(List[int], node.args[3])
padding = [0] + cast(List[int], node.args[4])
Expand Down Expand Up @@ -312,9 +265,18 @@ def define_node(
is_input_tensor=False,
)
conv_input_tensors = [input_tensor_wrapper, filter_tensor_wrapper]
conv_input_tensors.append(
self._get_bias_tensor(node, nodes_to_wrappers, filter_tensor.shape[-1])
)

if node.args[2] is not None:
bias_node = node.args[2]
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
bias_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
is_input_tensor=False,
)
conv_input_tensors.append(bias_tensor_wrapper)

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
Expand Down
127 changes: 127 additions & 0 deletions backends/qualcomm/builders/op_rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
import numpy as np

import torch
from executorch.backends.qualcomm.builders.utils import get_parameter
from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_QUANT_ATTRS
from executorch.exir.dialects._ops import ops as exir_ops

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpRmsNorm, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class RmsNormVisitor(NodeVisitor):
target = ["aten.rms_norm.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
# args of node : ['input', 'normalized_shape', 'weight', 'eps']
input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=True,
)

# should be a immutable list
normalized_shapes = node.args[1]
if (
len(normalized_shapes) != 1
and normalized_shapes[0] != input_tensor.shape[-1]
):
print("Only supports normalization with last input dimension")
return
axes = [node.args[0].meta["val"].dim() - 1]
axes_shape = [len(axes)]

weight_node = node.args[2]
weight_tensor = get_parameter(weight_node, self.edge_program)
weight_tensor_wrapper = self.define_tensor(
weight_node,
weight_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
is_input_tensor=False,
)

# Fake node, nn moudle seems to be inconsistant with document
bias_tensor = torch.zeros(weight_tensor.shape)
bias_node = torch.fx.Node(
node.graph,
node.name + "_runtime_bias",
"call_function",
exir_ops.edge.aten.tensor.default,
(), # args
{}, # kwargs
)
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
bias_tensor_wrapper = self.define_tensor(
bias_node,
bias_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
is_input_tensor=False,
)

epsilon = node.args[3]
if isinstance(epsilon, torch.fx.Node):
epsilon = get_parameter(epsilon, self.edge_program)
epsilon = (
epsilon
if isinstance(epsilon, float)
else torch.finfo(epsilon.dtype).eps
)

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=False,
)

rms_nrom_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpRmsNorm.op_name,
)

rms_nrom_op.AddInputTensors(
[input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper]
)
rms_nrom_op.AddOutputTensors([output_tensor_wrapper])
rms_nrom_op.AddScalarParam(
OpRmsNorm.param_epsilon,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
{QCOM_DATA: np.float32(epsilon)},
)
rms_nrom_op.AddTensorParam(
OpRmsNorm.param_axes,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
len(axes_shape),
axes_shape,
np.array(axes, dtype=np.uint32),
True,
)

return rms_nrom_op
7 changes: 7 additions & 0 deletions backends/qualcomm/builders/qnn_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,13 @@ class OpResizeNearestNeighbor:
param_half_pixel_centers: str = "half_pixel_centers"


@dataclass(init=False, frozen=True)
class OpRmsNorm:
op_name: str = "RmsNorm"
param_epsilon: str = "epsilon"
param_axes: str = "axes"


@dataclass(init=False, frozen=True)
class OpScatterNd:
op_name: str = "ScatterNd"
Expand Down
1 change: 1 addition & 0 deletions backends/qualcomm/passes/annotate_and_quant_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def _annotate_scalar_node(
float,
torch.float32,
torch.int32,
torch.int64,
]:
return

Expand Down
24 changes: 24 additions & 0 deletions backends/qualcomm/passes/i64_to_i32.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
# LICENSE file in the root directory of this source tree.
import torch
from executorch.backends.qualcomm.builders.utils import get_parameter, is_constant
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch._subclasses.fake_tensor import FakeTensor


class I64toI32(ExportPass):
Expand All @@ -16,6 +18,8 @@ class I64toI32(ExportPass):
def __init__(self, edge_program: torch.export.ExportedProgram):
super(I64toI32, self).__init__()
self.edge_program = edge_program
# pyre-ignore[4]
self.copy_op = exir_ops.edge.aten._to_copy.default

def _update_meta(self, node: torch.fx.node) -> None:
meta_val = node.meta["val"]
Expand All @@ -32,13 +36,33 @@ def _update_meta(self, node: torch.fx.node) -> None:
if meta_val.dtype == torch.int64:
node.meta["val"] = meta_val.to(torch.float)

# pyre-ignore[2]
def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool:
return isinstance(node_val, FakeTensor) and node_val.dtype == dtype

def _cast_to_int32(self, graph_module: torch.fx.GraphModule):
for n in graph_module.graph.nodes:
if is_constant(n, self.edge_program):
param = get_parameter(n, self.edge_program)
if param.dtype == torch.int64:
# QNN does not support int64
self._update_meta(n)
elif n.op == "placeholder":
node_val = n.meta["val"]
if self._is_tensor_of_dtype(node_val, torch.int64):
with graph_module.graph.inserting_after(n):
args = (n,)
to_dst_node = graph_module.graph.create_node(
"call_function",
self.copy_op,
args,
{"dtype": torch.int32},
)
to_dst_node.meta["val"] = node_val.to(torch.int32)

# Replace usage of the src dtype result with the dst dtype result.
n.replace_all_uses_with(to_dst_node)
to_dst_node.args = (n,)

def call(self, graph_module: torch.fx.GraphModule):
self._cast_to_int32(graph_module)
Expand Down
Loading
Loading