Skip to content

Commit

Permalink
2024-09-04 nightly release (a4092c5)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Sep 4, 2024
1 parent c36acd4 commit 5793e3d
Show file tree
Hide file tree
Showing 94 changed files with 2,170 additions and 297 deletions.
3 changes: 3 additions & 0 deletions .ci/scripts/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ test_model_with_qnn() {
elif [[ "${MODEL_NAME}" == "ic3" ]]; then
EXPORT_SCRIPT=inception_v3
EXPORTED_MODEL_NAME=ic3_qnn.pte
elif [[ "${MODEL_NAME}" == "vit" ]]; then
EXPORT_SCRIPT=torchvision_vit
EXPORTED_MODEL_NAME=vit_qnn.pte
fi

"${PYTHON_EXECUTABLE}" -m examples.qualcomm.scripts.${EXPORT_SCRIPT} -b ${CMAKE_OUTPUT_DIR} -m SM8550 --compile_only
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/android-perf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ jobs:
# Separate default values from the workflow dispatch. To ensure defaults are accessible
# during scheduled runs and to provide flexibility for different defaults between
# on-demand and periodic benchmarking.
CRON_DEFAULT_MODELS: "stories110M,dl3,mv3,mv2,ic4,ic3"
CRON_DEFAULT_MODELS: "stories110M,dl3,mv3,mv2,ic4,ic3,vit"
CRON_DEFAULT_DEVICES: "samsung_galaxy_s2x"
CRON_DEFAULT_DELEGATES: "xnnpack,qnn"
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/trunk.yml
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ jobs:
strategy:
matrix:
dtype: [fp32]
model: [dl3, mv3, mv2, ic4, ic3]
model: [dl3, mv3, mv2, ic4, ic3, vit]
fail-fast: false
with:
runner: linux.2xlarge
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.log.default,
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.mul.Tensor,
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
op_conv2d,
op_dequant,
op_div,
op_exp,
op_full,
op_get_item,
op_hardtanh,
op_log,
op_mean_dim,
op_mm,
op_mul,
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/operators/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def adjust_pad_if_needed(self, input, weight, stride, pad, dilation):

if mod_remainder > pad:
raise RuntimeError(
f"ignoring input element is not currently supported, got a large stride {stride}"
"This case should be handled by the SizeAdjustConv2d pass, is it enabled?"
)
return pad - mod_remainder

Expand Down
81 changes: 81 additions & 0 deletions backends/arm/operators/op_exp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List

import numpy as np

import serializer.tosa_serializer as ts
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg

from executorch.backends.arm.tosa_quant_utils import (
dequantize_value,
get_quant_node_args,
QuantArgs,
quantize_value,
)
from serializer.tosa_serializer import TosaOp
from torch.fx import Node


@register_node_visitor
class ExpVisitor(NodeVisitor):
target = "aten.exp.default"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:

assert len(node.all_input_nodes) == 1
assert len(node.users) == 1

if is_quant_node:
# Assume quantized input is 8 bit.

# Create attribute for 8 bit table lookup.
input_node = node.all_input_nodes[0]
in_quantargs = get_quant_node_args(input_node)
output_node = list(node.users)[0]
out_quantargs = get_quant_node_args(output_node)

table = exp_table_8bit(in_quantargs, out_quantargs)
table_attr = ts.TosaSerializerAttribute()
table_attr.TableAttribute(table)

tosa_graph.addOperator(
TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr
)
else:
tosa_graph.addOperator(TosaOp.Op().EXP, [inputs[0].name], [output.name])


def exp_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs):
"""
Returns a table mapping 256 entries to exp([qmin,qmax])
"""

def exp(x):
# Convert quantized input to floating point exp input space.
v = dequantize_value(x, in_quantargs)
# Compute exp.
v = np.exp(v)
# Convert exp output back to quantized space.
return quantize_value(v, out_quantargs)

return [
exp(x)
for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8)
]
81 changes: 81 additions & 0 deletions backends/arm/operators/op_log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List

import numpy as np

import serializer.tosa_serializer as ts
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg

from executorch.backends.arm.tosa_quant_utils import (
dequantize_value,
get_quant_node_args,
QuantArgs,
quantize_value,
)
from serializer.tosa_serializer import TosaOp
from torch.fx import Node


@register_node_visitor
class LogVisitor(NodeVisitor):
target = "aten.log.default"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:

assert len(node.all_input_nodes) == 1
assert len(node.users) == 1

if is_quant_node:
# Assume quantized input is 8 bit.

# Create attribute for 8 bit table lookup.
input_node = node.all_input_nodes[0]
in_quantargs = get_quant_node_args(input_node)
output_node = list(node.users)[0]
out_quantargs = get_quant_node_args(output_node)

table = log_table_8bit(in_quantargs, out_quantargs)
table_attr = ts.TosaSerializerAttribute()
table_attr.TableAttribute(table)

tosa_graph.addOperator(
TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr
)
else:
tosa_graph.addOperator(TosaOp.Op().LOG, [inputs[0].name], [output.name])


def log_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs):
"""
Returns a table mapping 256 entries to log([qmin,qmax])
"""

def log(x):
# Convert quantized input to floating point log input space.
v = dequantize_value(x, in_quantargs)
# Compute log.
v = np.log(v)
# Convert log output back to quantized space.
return quantize_value(v, out_quantargs)

return [
log(x)
for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8)
]
4 changes: 3 additions & 1 deletion backends/arm/passes/annotate_channels_last_dim_order_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def call(self, graph_module: torch.fx.GraphModule):
NHWC_Order = (0, 2, 3, 1)
HWCM_Order = (2, 3, 0, 1)
for node in graph_module.graph.nodes:
if isinstance(node.meta["val"], tuple):
if isinstance(
node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list)
):
node_data = node.meta["val"][0].data
else:
node_data = node.meta["val"].data
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ConvertSplitToSlicePass,
)
from executorch.backends.arm.passes.remove_clone_pass import RemoveClonePass
from executorch.backends.arm.passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.pass_manager import PassManager

Expand All @@ -29,6 +30,7 @@ def transform_to_backend_pipeline(
self, graph_module: torch.fx.Graph, compile_spec: CompileSpec
):
"""Apply passes before transforming program to backend"""
self.add_pass(SizeAdjustConv2DPass())
self.add_pass(RemoveClonePass())
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(ConvertSplitToSlicePass())
Expand Down
129 changes: 129 additions & 0 deletions backends/arm/passes/size_adjust_conv2d_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast, Optional

import torch.fx
from executorch.backends.arm.tosa_quant_utils import is_quant_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch._ops import OpOverload


def conv_remainder(input_length, pad, dilation, weight, stride):
"""
Returns the size
"""
return (input_length + 2 * pad - dilation * (weight - 1) - 1) % stride


def insert_q_dq_pair(
graph: torch.fx.Graph,
anchor: torch.fx.Node,
q_params: tuple,
):
with graph.inserting_after(anchor):
q = create_node(
graph=graph,
op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(), # We add the argument last
)
q.meta = anchor.meta

with graph.inserting_after(q):
dq = create_node(
graph=graph,
op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(q,) + q_params,
)
dq.meta = q.meta

anchor.replace_all_uses_with(dq)
# We add this last so the replace all uses above does not replace the quantized
# node's first use
q.args = (anchor,) + q_params
return dq


def create_node(
graph: torch.fx.Graph,
op_target: OpOverload,
args: tuple = (),
kwargs: Optional[dict] = None,
):
return graph.create_node(
"call_function",
op_target,
args=args,
kwargs=kwargs or {},
)


class SizeAdjustConv2DPass(ExportPass):
"""
Adjust the convolution input size to match perfectly with the
weight size, padding, stride and dilation parameters.
This is done by inserting a slice op to remove the uneven end of the input.
"""

conv2d_op = exir_ops.edge.aten.convolution.default
slice_op = exir_ops.edge.aten.slice_copy.Tensor

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
modified_graph = False
for node in graph.nodes:
if node.op != "call_function":
continue
if node.target != self.conv2d_op:
continue

conv_node = cast(torch.fx.Node, node)
input_node, weight, _, stride_hw, pad_hw, dilation_hw, _, _, _ = (
conv_node.args
)
weight_shape = weight.meta["val"].shape
input_shape = input_node.meta["val"].shape

slice_args = []
for stride, pad, dilation, dim in zip(
cast(list, stride_hw),
cast(list, pad_hw),
cast(list, dilation_hw),
(2, 3),
):
remainder = conv_remainder(
input_shape[dim], pad, dilation, weight_shape[dim], stride
)
if remainder > pad:
adjustment = remainder - pad
args = (dim, 0, input_shape[dim] - adjustment)
slice_args.append(args)
if len(slice_args) == 0:
continue

with graph_module.graph.inserting_before(node):
last_node = cast(torch.fx.Node, input_node)
for args in slice_args:
slice_node = graph.create_node(
"call_function", self.slice_op, (last_node,) + args
)
if is_quant_node(last_node):
q_params = last_node.args[1:]
dq_node = insert_q_dq_pair(
graph_module.graph, slice_node, q_params
)
last_node = dq_node
else:
last_node = slice_node
conv_node.replace_input_with(input_node, last_node)
modified_graph = True

if modified_graph:
graph_module = super().call(graph_module).graph_module
graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
1 change: 1 addition & 0 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ class ArmQuantizer(Quantizer):
"sigmoid",
"mm",
"cat",
"one_to_one",
]

def __init__(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantization_annotation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def decorator(annotator: AnnotatorType):
max_pool2d_annotator,
mm_annotator,
mul_annotator,
one_to_one_annotator,
sigmoid_annotator,
sub_annotator,
)
Loading

0 comments on commit 5793e3d

Please sign in to comment.