From 10b9a97e9af682dbfaf3361c9995317db446f3a8 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Tue, 2 Jul 2024 21:36:51 +0530 Subject: [PATCH] Add test case for onnx.ScatterND and onnx.Scan Signed-Off-by: Gaurav Shukla --- e2eshark/onnx/operators/Scan/model.py | 109 +++++++++++++++++++++ e2eshark/onnx/operators/ScatterND/model.py | 79 +++++++++++++++ 2 files changed, 188 insertions(+) create mode 100644 e2eshark/onnx/operators/Scan/model.py create mode 100644 e2eshark/onnx/operators/ScatterND/model.py diff --git a/e2eshark/onnx/operators/Scan/model.py b/e2eshark/onnx/operators/Scan/model.py new file mode 100644 index 000000000..d8792c837 --- /dev/null +++ b/e2eshark/onnx/operators/Scan/model.py @@ -0,0 +1,109 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# run.py creates runmodel.py by concatenating this file model.py +# and tools/stubs/onnxmodel.py +# Description: testing GatherND +# See https://onnx.ai/onnx/intro/python.html for intro on creating +# onnx model using python onnx API +import numpy, torch, sys +import onnxruntime +from onnx import numpy_helper, TensorProto, save_model +from onnx.helper import make_model, make_node, make_graph, make_tensor_value_info +from onnx.checker import check_model + +# import from e2eshark/tools to allow running in current dir, for run through +# run.pl, commutils is symbolically linked to allow any rundir to work +sys.path.insert(0, "../../../tools/stubs") +from commonutils import E2ESHARK_CHECK_DEF + +# Create an instance of it for this test +E2ESHARK_CHECK = dict(E2ESHARK_CHECK_DEF) + + # Given an input sequence [x1, ..., xN], sum up its elements using a scan +# returning the final state (x1+x2+...+xN) as well the scan_output +# [x1, x1+x2, ..., x1+x2+...+xN] +# +# create graph to represent scan body +sum_in = make_tensor_value_info( + "sum_in", TensorProto.FLOAT, [2] +) +next = make_tensor_value_info( # noqa: A001 + "next", TensorProto.FLOAT, [2] +) +sum_out = make_tensor_value_info( + "sum_out", TensorProto.FLOAT, [2] +) +scan_out = make_tensor_value_info( + "scan_out", TensorProto.FLOAT, [2] +) +add_node = make_node( + "Add", inputs=["sum_in", "next"], outputs=["sum_out"] +) +id_node = make_node( + "Identity", inputs=["sum_out"], outputs=["scan_out"] +) +scan_body = make_graph( + [add_node, id_node], "scan_body", [sum_in, next], [sum_out, scan_out] +) +# create scan op node +INIT = make_tensor_value_info("INIT", TensorProto.FLOAT, [2]) +X = make_tensor_value_info("X", TensorProto.FLOAT, [3, 2]) +Y = make_tensor_value_info("Y", TensorProto.FLOAT, [2]) +Z = make_tensor_value_info("Z", TensorProto.FLOAT, [3, 2]) +# no_sequence_lens = "" # optional input, not supplied +scan_node = make_node( + "Scan", + inputs=["INIT", "X"], + outputs=["Y", "Z"], + num_scan_inputs=1, + body=scan_body, +) + +# Create the graph (GraphProto) +graph = make_graph( + [scan_node], + "scan_graph", + [INIT, X], + [Y, Z], +) +# Create the model (ModelProto) +onnx_model = make_model(graph) +onnx_model.opset_import[0].version = 16 + +# Save the model +# save_model(onnx_model, "model.onnx") +with open("model.onnx", "wb") as f: + f.write(onnx_model.SerializeToString()) +session = onnxruntime.InferenceSession("model.onnx", None) +# create inputs for batch-size 1, sequence-length 3, inner dimension 2 +model_initial = numpy.array([0, 0]).astype(numpy.float32).reshape((2)) +model_x = numpy.array([1, 2, 3, 4, 5, 6]).astype(numpy.float32).reshape((3, 2)) +# final state computed = [1 + 3 + 5, 2 + 4 + 6] +# model_y = numpy.array([9, 12]).astype(np.float32).reshape((1, 2)) +# scan-output computed +# model_z = numpy.array([1, 2, 4, 6, 9, 12]).astype(np.float32).reshape((1, 3, 2)) + +# gets D in inputs[0] and I in inputs[1] +inputs = session.get_inputs() +# gets Z in outputs[0] +outputs = session.get_outputs() + +model_output = session.run( + [outputs[0].name, outputs[1].name], + {inputs[0].name: model_initial, + inputs[1].name: model_x}, +) + +# Moving to torch to handle bfloat16 as numpy does not support bfloat16 +E2ESHARK_CHECK["input"] = [ + torch.from_numpy(model_initial), + torch.from_numpy(model_x), +] +E2ESHARK_CHECK["output"] = [torch.from_numpy(arr) for arr in model_output] + +print("Input:", E2ESHARK_CHECK["input"]) +print("Output:", E2ESHARK_CHECK["output"]) diff --git a/e2eshark/onnx/operators/ScatterND/model.py b/e2eshark/onnx/operators/ScatterND/model.py new file mode 100644 index 000000000..d5402b365 --- /dev/null +++ b/e2eshark/onnx/operators/ScatterND/model.py @@ -0,0 +1,79 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# run.py creates runmodel.py by concatenating this file model.py +# and tools/stubs/onnxmodel.py +# Description: testing ScatterND +# See https://onnx.ai/onnx/intro/python.html for intro on creating +# onnx model using python onnx API +import numpy, torch, sys +import onnxruntime +from onnx import numpy_helper, TensorProto, save_model +from onnx.helper import make_model, make_node, make_graph, make_tensor_value_info +from onnx.checker import check_model + +# import from e2eshark/tools to allow running in current dir, for run through +# run.pl, commutils is symbolically linked to allow any rundir to work +sys.path.insert(0, "../../../tools/stubs") +from commonutils import E2ESHARK_CHECK_DEF + +# Create an instance of it for this test +E2ESHARK_CHECK = dict(E2ESHARK_CHECK_DEF) + +# Create an input (ValueInfoProto) +D = make_tensor_value_info("D", TensorProto.FLOAT, [4, 4, 4]) +I = make_tensor_value_info("I", TensorProto.INT64, [2, 1]) +U = make_tensor_value_info("U", TensorProto.FLOAT, [2, 4, 4]) + +# Create an output +Z = make_tensor_value_info("Z", TensorProto.FLOAT, [4, 4, 4]) + +# Create a node (NodeProto) +scatter_nd_node = make_node( + "ScatterND", ["D", "I", "U"], ["Z"], "scatter_nd_node" + #op_type # # inputs # outputs #node_name +) + +# Create the graph (GraphProto) +graph = make_graph( + [scatter_nd_node], + "scatter_nd_graph", + [D, I, U], + [Z], +) + +# Create the model (ModelProto) +onnx_model = make_model(graph) +onnx_model.opset_import[0].version = 16 + +# Save the model +with open("model.onnx", "wb") as f: + f.write(onnx_model.SerializeToString()) + + +session = onnxruntime.InferenceSession("model.onnx", None) +model_input_D = numpy.random.randn(4, 4, 4).astype(numpy.float32) +model_input_I = numpy.random.randint(2, size=(2, 1)).astype(numpy.int64) +model_input_U = numpy.random.randn(2, 4, 4).astype(numpy.float32) +inputs = session.get_inputs() +outputs = session.get_outputs() + +model_output = session.run( + [outputs[0].name], + {inputs[0].name: model_input_D, + inputs[1].name: model_input_I, + inputs[2].name: model_input_U}, +) + +E2ESHARK_CHECK["input"] = [ + torch.from_numpy(model_input_D), + torch.from_numpy(model_input_I), + torch.from_numpy(model_input_U), +] +E2ESHARK_CHECK["output"] = [torch.from_numpy(arr) for arr in model_output] + +print("Input:", E2ESHARK_CHECK["input"]) +print("Output:", E2ESHARK_CHECK["output"])