Skip to content

Commit

Permalink
Buckify backends/arm for meta internal use.
Browse files Browse the repository at this point in the history
Differential Revision: D62062674

Pull Request resolved: #5023
  • Loading branch information
hsharma35 authored Sep 5, 2024
1 parent cd1c833 commit a8c592e
Show file tree
Hide file tree
Showing 21 changed files with 232 additions and 42 deletions.
83 changes: 83 additions & 0 deletions backends/arm/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

python_library(
name = "arm_partitioner",
srcs = [
"arm_partitioner.py",
],
typing = True,
deps = [
":arm_backend",
"//executorch/backends/arm/passes:passes",
"//executorch/exir:lib",
],
)

python_library(
name = "arm_backend",
srcs = [
"arm_backend.py",
],
typing = True,
deps = [
"fbsource//third-party/pypi/flatbuffers:flatbuffers",
"fbsource//third-party/pypi/ml-dtypes:ml-dtypes",
"fbsource//third-party/serialization_lib/python/serializer:serializer",
"fbsource//third-party/serialization_lib/python/tosa:tosa",
":arm_vela",
"//executorch/backends/arm/operators:lib",
"//executorch/backends/arm/operators:node_visitor",
"//executorch/backends/arm/passes:passes",
],
)

python_library(
name = "arm_vela",
srcs = [
"arm_vela.py",
],
typing = True,
deps = [
"fbsource//third-party/pypi/ethos-u-vela:ethos-u-vela",
],
)

python_library(
name = "tosa_mapping",
srcs = [
"tosa_mapping.py",
],
typing = True,
deps = [
"fbsource//third-party/serialization_lib/python/serializer:serializer",
"//caffe2:torch",
],
)

python_library(
name = "tosa_quant_utils",
srcs = [
"tosa_quant_utils.py",
],
typing = True,
deps = [
"fbsource//third-party/pypi/numpy:numpy",
"fbsource//third-party/serialization_lib/python/serializer:serializer",
"fbsource//third-party/serialization_lib/python/tosa:tosa",
":tosa_mapping",
"//executorch/exir/dialects:lib",
],
)

python_library(
name = "tosa_utils",
srcs = [
"tosa_utils.py",
],
typing = True,
deps = [
"fbsource//third-party/serialization_lib/python/serializer:serializer",
":tosa_quant_utils",
"//executorch/backends/arm/operators:node_visitor",
],
)
2 changes: 1 addition & 1 deletion backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def is_tosa(compile_spec: List[CompileSpec]) -> bool:
return False


def get_intermediate_path(compile_spec: List[CompileSpec]) -> str:
def get_intermediate_path(compile_spec: List[CompileSpec]) -> Optional[str]:
for spec in compile_spec:
if spec.key == "debug_artifact_path":
return spec.value.decode()
Expand Down
22 changes: 9 additions & 13 deletions backends/arm/arm_vela.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

import os
import struct
import subprocess
import tempfile

from typing import List

import numpy as np
from ethosu.vela import vela


# Pack either input or output tensor block, compose the related arrays into
Expand Down Expand Up @@ -38,21 +38,17 @@ def vela_compile(tosa_graph, args: List[str]):
with tempfile.TemporaryDirectory() as tmpdir:
tosaname = "out.tosa"
flatbuffer = tosa_graph.serialize()
with open(os.path.join(tmpdir, tosaname), "wb") as f:
tosa_path = os.path.join(tmpdir, tosaname)
with open(tosa_path, "wb") as f:
f.write(flatbuffer)

# invoke vela
vela_command = f"cd {tmpdir}; vela {' '.join(args)} {tosaname}"
try:
subprocess.run([vela_command], shell=True, check=True, capture_output=True)
except subprocess.CalledProcessError as process_error:
raise RuntimeError(
f"Vela compiler ('{vela_command}') failed with error:\n \
{process_error.stderr.decode()}\n \
Stdout:\n{process_error.stdout.decode()}"
)

np_path = os.path.join(tmpdir, "output", "out_sg0_vela.npz")
output_dir = os.path.join(tmpdir, "output")
args.append(f"--output-dir={output_dir}")
args.append(tosa_path)
vela.main(" ".join(args).split(" "))

np_path = os.path.join(output_dir, "out_sg0_vela.npz")
blocks = b""

with np.load(np_path, allow_pickle=False) as data:
Expand Down
34 changes: 34 additions & 0 deletions backends/arm/operators/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

python_library(
name = "node_visitor",
srcs = ["node_visitor.py"],
typing = True,
deps = [
"//executorch/backends/arm:tosa_mapping",
],
)

python_library(
name = "ops",
srcs = glob(["op_*.py"]),
typing = True,
deps = [
"fbsource//third-party/serialization_lib/python/tosa:tosa",
":node_visitor",
"//executorch/backends/arm:tosa_mapping",
"//executorch/backends/arm:tosa_quant_utils",
"//executorch/backends/arm:tosa_utils",
"//executorch/exir:lib",
],
)

python_library(
name = "lib",
srcs = ["__init__.py"],
typing = True,
deps = [
":node_visitor",
":ops",
],
)
1 change: 1 addition & 0 deletions backends/arm/operators/op_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def define_node(
build_rescale(
tosa_fb=tosa_graph,
scale=final_output_scale,
# pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined.
input_node=bmm_result,
output_name=output.name,
output_type=ts.DType.INT8,
Expand Down
7 changes: 4 additions & 3 deletions backends/arm/operators/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
from typing import cast, List

import serializer.tosa_serializer as ts
import torch
Expand Down Expand Up @@ -156,11 +156,12 @@ def define_node(
# integer value domain of the next op. Otherwise return float32 output.
if is_quant_node:
# Get scale_factor from input, weight, and output.
_, input_scale, _, _, _, _ = getNodeArgs(node.args[0])
_, weight_scale, _, _, _, _ = getNodeArgs(node.args[1])
_, input_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[0]))
_, weight_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[1]))
_, output_scale, output_zp, _, _, _ = getNodeArgs(list(node.users)[0])
build_rescale_conv_output(
tosa_graph,
# pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined.
conv2d_res,
output.name,
actual_out_type,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/op_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def define_node(
build_rescale(
tosa_fb=tosa_graph,
scale=final_output_scale,
# pyre-ignore[61]: Uninitialized local [61]: Local variable `reshape_intermediate` is undefined, or not always defined.
input_node=reshape_intermediate,
output_name=output.name,
output_type=ts.DType.INT8,
Expand Down
10 changes: 7 additions & 3 deletions backends/arm/operators/op_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List
from typing import cast, List

import executorch.backends.arm.tosa_quant_utils as tqutils
import executorch.backends.arm.tosa_utils as tutils
Expand Down Expand Up @@ -35,8 +35,12 @@ def define_node(
if is_quant_node:
input_A = inputs[0]
input_B = inputs[1]
input_A_qargs = tqutils.get_quant_node_args(node.args[0])
input_B_qargs = tqutils.get_quant_node_args(node.args[1])
input_A_qargs = tqutils.get_quant_node_args(
cast(torch.fx.Node, node.args[0])
)
input_B_qargs = tqutils.get_quant_node_args(
cast(torch.fx.Node, node.args[1])
)

input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order)
input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order)
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/operators/op_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast

import serializer.tosa_serializer as ts
import torch

Expand All @@ -11,7 +13,7 @@ def process_output(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
):
for output in node.args[0]:
for output in cast(tuple[torch.fx.Node, ...], node.args[0]):
tosa_graph.addOutputTensor(
tosa_graph.currRegion.currBasicBlock.tensors[output.name]
)
2 changes: 1 addition & 1 deletion backends/arm/operators/op_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

import serializer.tosa_serializer as ts
import torch
import tosa.Op as TosaOp

from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_utils import tosa_shape
from serializer.tosa_serializer import TosaOp


@register_node_visitor
Expand Down
12 changes: 12 additions & 0 deletions backends/arm/passes/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

python_library(
name = "passes",
srcs = glob(["*.py"]),
typing = True,
deps = [
"//executorch/backends/arm:tosa_quant_utils",
"//executorch/backends/arm:tosa_utils",
"//executorch/exir:lib",
],
)
4 changes: 3 additions & 1 deletion backends/arm/passes/annotate_channels_last_dim_order_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast

import torch
from executorch.backends.arm.tosa_quant_utils import dq_op
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
Expand All @@ -28,7 +30,7 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
if node.target != dq_op:
return False
prev_node = node.args[0]
if prev_node.op != "placeholder":
if cast(torch.fx.Node, prev_node).op != "placeholder":
return False
return is_consumer_node_depthwise_conv2d(node)
elif node.op == "placeholder":
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@

class ArmPassManager(PassManager):

def _transform(self, graph_module: torch.fx.Graph):
def _transform(self, graph_module: torch.fx.GraphModule):
return self(graph_module).graph_module

def transform_to_backend_pipeline(
self, graph_module: torch.fx.Graph, compile_spec: CompileSpec
self, graph_module: torch.fx.GraphModule, compile_spec: list[CompileSpec]
):
"""Apply passes before transforming program to backend"""
self.add_pass(SizeAdjustConv2DPass())
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/passes/convert_expand_copy_to_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast

import torch.fx
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
from executorch.exir.dialects._ops import ops as exir_ops
Expand Down Expand Up @@ -31,7 +33,7 @@ def call(self, graph_module: torch.fx.GraphModule):

expand_node = src_partition.nodes[0]
_, shape, _ = extract_tensor_meta(expand_node.all_input_nodes[0].meta)
multiples = expand_node.args[1]
multiples = cast(tuple[int], expand_node.args[1])
expanded_rank = len(multiples)

# Expanded shape is 'shape' front-padded with ones.
Expand Down
6 changes: 3 additions & 3 deletions backends/arm/passes/size_adjust_conv2d_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def call(self, graph_module: torch.fx.GraphModule):
input_node, weight, _, stride_hw, pad_hw, dilation_hw, _, _, _ = (
conv_node.args
)
weight_shape = weight.meta["val"].shape
input_shape = input_node.meta["val"].shape
weight_shape = cast(torch.fx.Node, weight).meta["val"].shape
input_shape = cast(torch.fx.Node, input_node).meta["val"].shape

slice_args = []
for stride, pad, dilation, dim in zip(
Expand Down Expand Up @@ -119,7 +119,7 @@ def call(self, graph_module: torch.fx.GraphModule):
last_node = dq_node
else:
last_node = slice_node
conv_node.replace_input_with(input_node, last_node)
conv_node.replace_input_with(cast(torch.fx.Node, input_node), last_node)
modified_graph = True

if modified_graph:
Expand Down
31 changes: 31 additions & 0 deletions backends/arm/quantizer/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

python_library(
name = "arm_quantizer",
srcs = ["arm_quantizer.py"],
typing = True,
deps = [
":arm_quantizer_utils",
"//caffe2:torch",
"//executorch/backends/arm/quantizer/quantization_annotation:quantization_annotation",
"//executorch/exir:lib",
],
)

python_library(
name = "quantization_config",
srcs = ["quantization_config.py"],
typing = True,
deps = [
"//caffe2:torch",
],
)

python_library(
name = "arm_quantizer_utils",
srcs = ["arm_quantizer_utils.py"],
typing = True,
deps = [
":quantization_config",
],
)
Loading

0 comments on commit a8c592e

Please sign in to comment.