Skip to content

Commit

Permalink
[onnx] Fix test generator for better type information (#316)
Browse files Browse the repository at this point in the history
Test generator does not handle all type imports. Fixed the test
generator so that we actually convert the types correctly.
  • Loading branch information
rsuderman authored Aug 5, 2024
1 parent 31261b3 commit 987f05e
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 26 deletions.
3 changes: 1 addition & 2 deletions iree_tests/configs/onnx_cpu_llvm_sync.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"onnx/node/generated/test_group_normalization_epsilon",
"onnx/node/generated/test_group_normalization_example",
"onnx/node/generated/test_group_normalization_example_expanded",
"onnx/node/generated/test_group_normalization_epsilon_expanded",
"onnx/node/generated/test_group_normalization_epsilon_expanded"
],
"skip_run_tests": [],
"expected_compile_failures": [
Expand Down Expand Up @@ -491,7 +491,6 @@
"onnx/node/generated/test_wrap_pad"
],
"expected_run_failures": [
"onnx/node/generated/test_asin",
"onnx/node/generated/test_averagepool_3d_dilations_large_count_include_pad_is_0_ceil_mode_is_True",
"onnx/node/generated/test_bernoulli",
"onnx/node/generated/test_bernoulli_double",
Expand Down
20 changes: 6 additions & 14 deletions iree_tests/onnx/import_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import subprocess
import numpy as np
import sys
from import_tests_utils import get_shape_string, write_io_bin
from import_tests_utils import get_io_proto_type, write_io_bin

THIS_DIR = Path(__file__).parent
REPO_ROOT = THIS_DIR.parent.parent
Expand Down Expand Up @@ -140,32 +140,26 @@ def import_onnx_files(test_dir_path, imported_dir_path):
for i in range(len(test_inputs)):
test_input = test_inputs[i]
t = convert_io_proto(test_input, model.graph.input[i].type)
ty = get_io_proto_type(model.graph.input[i].type)
if t is None:
return False
input_path = (imported_dir_path / test_input.stem).with_suffix(".npy")
np.save(input_path, t) # Only for ref, actual comparison with .bin
ss = get_shape_string(t)
input_path_bin = (imported_dir_path / test_input.stem).with_suffix(".bin")
write_io_bin(t, input_path_bin)
test_data_flagfile_lines.append(
f"--input={ss}=@{input_path_bin.name}\n"
)
test_data_flagfile_lines.append(f"--input={ty}=@{input_path_bin.name}\n")
for i in range(len(test_outputs)):
test_output = test_outputs[i]
t = convert_io_proto(test_output, model.graph.output[i].type)
ty = get_io_proto_type(model.graph.output[i].type)
if t is None:
return False
output_path = (imported_dir_path / test_output.stem).with_suffix(".npy")
np.save(output_path, t) # Only for ref, actual comparison with .bin
ss = get_shape_string(t)
# required for signless output comparision
if "xsi" in ss or "xui" in ss:
ss = ss.replace("xsi", "xi")
ss = ss.replace("xui", "xi")
output_path_bin = (imported_dir_path / test_output.stem).with_suffix(".bin")
write_io_bin(t, output_path_bin)
test_data_flagfile_lines.append(
f"--expected_output={ss}=@{output_path_bin.name}\n"
f"--expected_output={ty}=@{output_path_bin.name}\n"
)

with open(test_data_flagfile_path, "wt") as f:
Expand Down Expand Up @@ -197,9 +191,7 @@ def import_onnx_files(test_dir_path, imported_dir_path):
passed_imports = []
failed_imports = []
with Pool(args.jobs) as pool:
results = pool.imap_unordered(
import_onnx_files_with_cleanup, test_dir_paths
)
results = pool.imap_unordered(import_onnx_files_with_cleanup, test_dir_paths)
for result in results:
if result[1]:
passed_imports.append(result[0])
Expand Down
68 changes: 58 additions & 10 deletions iree_tests/onnx/import_tests_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import struct
import numpy as np
import onnx

# map numpy dtype -> (iree dtype, struct.pack format str)
dtype_map = {
Expand All @@ -24,17 +25,64 @@
}


def get_shape_string(torchtensor):
inputshape = list(torchtensor.shape)
inputshapestring = "x".join([str(item) for item in inputshape])
dtype = torchtensor.dtype
if dtype in dtype_map:
inputshapestring += f"x{dtype_map[dtype][0]}"
def convert_proto_etype(etype):
if etype == onnx.TensorProto.FLOAT:
return "f32"
if etype == onnx.TensorProto.UINT8:
return "i8"
if etype == onnx.TensorProto.INT8:
return "i8"
if etype == onnx.TensorProto.UINT16:
return "i16"
if etype == onnx.TensorProto.INT16:
return "i16"
if etype == onnx.TensorProto.INT32:
return "i32"
if etype == onnx.TensorProto.INT64:
return "i64"
if etype == onnx.TensorProto.BOOL:
return "i1"
if etype == onnx.TensorProto.FLOAT16:
return "f16"
if etype == onnx.TensorProto.DOUBLE:
return "f64"
if etype == onnx.TensorProto.UINT32:
return "i32"
if etype == onnx.TensorProto.UINT64:
return "i64"
if etype == onnx.TensorProto.COMPLEX64:
return "complex<f32>"
if etype == onnx.TensorProto.COMPLEX128:
return "complex<f64>"
if etype == onnx.TensorProto.BFLOAT16:
return "bf16"
if etype == onnx.TensorProto.FLOAT8E4M3FN:
return "f8e4m3fn"
if etype == onnx.TensorProto.FLOAT8E4M3FNUZ:
return "f8e4m3fnuz"
if etype == onnx.TensorProto.FLOAT8E5M2:
return "f8e5m2"
if etype == onnx.TensorProto.FLOAT8E5M2FNUZ:
return "f8e5m2fnuz"
if etype == onnx.TensorProto.UINT4:
return "i4"
if etype == onnx.TensorProto.INT4:
return "i4"
return ""


def get_io_proto_type(type_proto):
if type_proto.HasField("tensor_type"):
tensor_type = type_proto.tensor_type
shape = tensor_type.shape
shape = "x".join([str(d.dim_value) for d in shape.dim])
dtype = convert_proto_etype(tensor_type.elem_type)
if shape == "":
return dtype
return f"{shape}x{dtype}"
else:
print(
f"WARNING: unsupported data type in get_shape_string() : '{dtype}'"
)
return inputshapestring
print(f"Unsupported proto type: {type_proto}")
return None


def pack_np_arr(arr):
Expand Down

0 comments on commit 987f05e

Please sign in to comment.