diff --git a/e2eshark/onnx/operators/expand/model.py b/e2eshark/onnx/operators/expand/model.py new file mode 100644 index 000000000..95c190ffa --- /dev/null +++ b/e2eshark/onnx/operators/expand/model.py @@ -0,0 +1,94 @@ +# Copyright 2024 Advanced Micro Devices +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# run.py creates runmodel.py by concatenating this file model.py +# and tools/stubs/onnxmodel.py +# Description: testing Expand +# See https://onnx.ai/onnx/intro/python.html for intro on creating +# onnx model using python onnx API +import numpy, torch, sys +import onnxruntime +from onnx import numpy_helper, TensorProto, save_model +from onnx.helper import ( + make_model, + make_node, + make_graph, + make_tensor_value_info, + make_tensor, +) +from onnx.checker import check_model + +# import from e2eshark/tools to allow running in current dir, for run through +# run.pl, commutils is symbolically linked to allow any rundir to work +sys.path.insert(0, "../../../tools/stubs") +from commonutils import E2ESHARK_CHECK_DEF + +# Create an instance of it for this test +E2ESHARK_CHECK = dict(E2ESHARK_CHECK_DEF) + +# Create an input (ValueInfoProto) +X = make_tensor_value_info("X", TensorProto.FLOAT, [4, 5]) + +shape_tensor = make_tensor( + name="shape", + data_type=TensorProto.INT64, + dims=(3,), + vals=[3, 4, 5], +) + +shape_node = make_node( + "Constant", + inputs=[], + outputs=["shape"], + value=shape_tensor, +) + +# Create an output +Z = make_tensor_value_info("Z", TensorProto.FLOAT, [3, 4, 5]) + +# Create a node (NodeProto) +expand_node = make_node( + "Expand", + ["X", "shape"], + ["Z"], + "expand_node", # node name # inputs # outputs +) + +# Create the graph (GraphProto) +graph = make_graph( + [shape_node, expand_node], + "expand_graph", + [X], + [Z], +) + +# Create the model (ModelProto) +onnx_model = make_model(graph) +onnx_model.opset_import[0].version = 19 + +# Save the model +with open("model.onnx", "wb") as f: + f.write(onnx_model.SerializeToString()) + +session = onnxruntime.InferenceSession("model.onnx", None) +model_input_X = numpy.random.randn(4, 5).astype(numpy.float32) +inputs = session.get_inputs() +outputs = session.get_outputs() + +model_output = session.run( + [outputs[0].name], + {inputs[0].name: model_input_X}, +) + +print("Input shape:", model_input_X.shape) +print("Output shape:", numpy.array(model_output[0]).shape) + +# Moving to torch to handle bfloat16 as numpy does not support bfloat16 +E2ESHARK_CHECK["input"] = [torch.from_numpy(model_input_X)] +E2ESHARK_CHECK["output"] = [torch.from_numpy(arr) for arr in model_output] + +print("Input:", E2ESHARK_CHECK["input"]) +print("Output:", E2ESHARK_CHECK["output"])