From 09ea70c161eba01463ddc1d10df6f6d672656e32 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Wed, 17 Apr 2024 17:49:51 +0000 Subject: [PATCH 1/2] add onnx if operator test and modify run.py to support bool types --- e2eshark/onnx/operators/If/model.py | 110 ++++++++++++++++++++++++++++ e2eshark/run.py | 4 + 2 files changed, 114 insertions(+) create mode 100644 e2eshark/onnx/operators/If/model.py diff --git a/e2eshark/onnx/operators/If/model.py b/e2eshark/onnx/operators/If/model.py new file mode 100644 index 000000000..849052857 --- /dev/null +++ b/e2eshark/onnx/operators/If/model.py @@ -0,0 +1,110 @@ +# Copyright 2024 Advanced Micro Devices +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# run.py creates runmodel.py by concatenating this file model.py +# and tools/stubs/onnxmodel.py +# Description: testing add +# See https://onnx.ai/onnx/intro/python.html for intro on creating +# onnx model using python onnx API +import numpy, torch, sys +import onnxruntime +from onnx import numpy_helper, TensorProto, save_model +from onnx.helper import make_model, make_node, make_graph, make_tensor_value_info +from onnx.checker import check_model + +# import from e2eshark/tools to allow running in current dir, for run through +# run.pl, commutils is symbolically linked to allow any rundir to work +sys.path.insert(0, "../../../tools/stubs") +from commonutils import E2ESHARK_CHECK_DEF + +# Create an instance of it for this test +E2ESHARK_CHECK = dict(E2ESHARK_CHECK_DEF) + + +# condition has to be a float tensor +condition = make_tensor_value_info('condition', TensorProto.BOOL, [1]) +input1 = make_tensor_value_info('input1', TensorProto.FLOAT, [1]) +input2 = make_tensor_value_info('input2', TensorProto.FLOAT, [1]) +output = make_tensor_value_info('output', TensorProto.FLOAT, [1]) +output_then = make_tensor_value_info('output_then', TensorProto.FLOAT, [1]) +output_else = make_tensor_value_info('output_else', TensorProto.FLOAT, [1]) + +then_branch = make_graph( + nodes=[ + make_node('Add', ['input1', 'input2'], ['output_then']) + ], + name='then_branch', + inputs=[], + outputs=[output_then] +) + +else_branch = make_graph( + nodes=[ + make_node('Sub', ['input1', 'input2'], ['output_else']) + ], + name='else_branch', + inputs=[], + outputs=[output_else] +) + +graph = make_graph( + nodes=[ + make_node('If', ['condition'], ['output'], then_branch=then_branch, else_branch=else_branch) + ], + name='if_example', + inputs=[condition, input1, input2], + outputs=[output] +) + +onnx_model = make_model(graph, producer_name='conditional_example') + +onnx_model.opset_import[0].version = 19 + +# Save the model +# save_model(onnx_model, "model.onnx") +with open("model.onnx", "wb") as f: + f.write(onnx_model.SerializeToString()) + + +session = onnxruntime.InferenceSession("model.onnx", None) +# gets X in inputs[0] and Y in inputs[1] +inputs = session.get_inputs() +# gets Z in outputs[0] +outputs = session.get_outputs() + +def generate_input_from_node(node: onnxruntime.capi.onnxruntime_pybind11_state.NodeArg): + if node.type == "tensor(float)": + return numpy.random.randn(*node.shape).astype(numpy.float32) + if node.type == "tensor(int)": + return numpy.random.randint(0, 10000, size=node.shape).astype(numpy.int32) + if node.type == "tensor(bool)": + return numpy.random.randint(0, 2, size=node.shape).astype(bool) + +input_dict = { + node.name: generate_input_from_node(node) + for node in inputs +} + +output_list = [ + node.name + for node in outputs +] + +model_output = session.run( + output_list, + input_dict, +) + +# Moving to torch to handle bfloat16 as numpy does not support bfloat16 +E2ESHARK_CHECK["input"] = [ + torch.from_numpy(input_dict["condition"]), + torch.from_numpy(input_dict["input1"]), + torch.from_numpy(input_dict["input2"]), +] +E2ESHARK_CHECK["output"] = [torch.from_numpy(arr) for arr in model_output] + +print("Input:", E2ESHARK_CHECK["input"]) +print("Output:", E2ESHARK_CHECK["output"]) diff --git a/e2eshark/run.py b/e2eshark/run.py index dbb6f28be..84cfdfa23 100644 --- a/e2eshark/run.py +++ b/e2eshark/run.py @@ -202,6 +202,8 @@ def getShapeString(torchtensor): inputshapestring += "xbf16" elif dtype == torch.int8: inputshapestring += "xi8" + elif dtype == torch.bool: + inputshapestring += "xi1" else: print("In getShapeString, found an unsupported data type", dtype) return inputshapestring @@ -271,6 +273,8 @@ def packTensor(modelinput): bytearr = struct.pack("%sh" % len(mylist), *mylist) elif dtype == torch.int8: bytearr = struct.pack("%sb" % len(mylist), *mylist) + elif dtype == torch.bool: + bytearr = struct.pack("%s?" % len(mylist), *mylist) else: print("In packTensor, found an unsupported data type", dtype) return bytearr From 2d0cc307f9a19a8f275880b3d389a64715e4aaac Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Thu, 18 Apr 2024 16:43:56 +0000 Subject: [PATCH 2/2] add onnx.If operator test --- e2eshark/onnx/operators/If/model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/e2eshark/onnx/operators/If/model.py b/e2eshark/onnx/operators/If/model.py index 849052857..abddd5743 100644 --- a/e2eshark/onnx/operators/If/model.py +++ b/e2eshark/onnx/operators/If/model.py @@ -26,11 +26,11 @@ # condition has to be a float tensor condition = make_tensor_value_info('condition', TensorProto.BOOL, [1]) -input1 = make_tensor_value_info('input1', TensorProto.FLOAT, [1]) -input2 = make_tensor_value_info('input2', TensorProto.FLOAT, [1]) -output = make_tensor_value_info('output', TensorProto.FLOAT, [1]) -output_then = make_tensor_value_info('output_then', TensorProto.FLOAT, [1]) -output_else = make_tensor_value_info('output_else', TensorProto.FLOAT, [1]) +input1 = make_tensor_value_info('input1', TensorProto.FLOAT, [2,3]) +input2 = make_tensor_value_info('input2', TensorProto.FLOAT, [2,3]) +output = make_tensor_value_info('output', TensorProto.FLOAT, [2,3]) +output_then = make_tensor_value_info('output_then', TensorProto.FLOAT, [2,3]) +output_else = make_tensor_value_info('output_else', TensorProto.FLOAT, [2,3]) then_branch = make_graph( nodes=[