From ca23af430048cad87d0fc1da219d20191ce43479 Mon Sep 17 00:00:00 2001 From: Chi_Liu <22491986+AmosLewis@users.noreply.github.com> Date: Wed, 16 Oct 2024 10:36:53 -0700 Subject: [PATCH] Fix the bug for truncated model (#372) Find the bug when debug https://github.com/nod-ai/SHARK-TestSuite/pull/371. --- alt_e2eshark/e2e_testing/storage.py | 6 +++++- alt_e2eshark/onnx_tests/helper_classes.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/alt_e2eshark/e2e_testing/storage.py b/alt_e2eshark/e2e_testing/storage.py index 943511d9..ef95d37c 100644 --- a/alt_e2eshark/e2e_testing/storage.py +++ b/alt_e2eshark/e2e_testing/storage.py @@ -17,6 +17,8 @@ def get_shape_string(torch_tensor): dtype = torch_tensor.dtype if dtype == torch.int64: input_shape_string += "xi64" + if dtype == torch.int32: + input_shape_string += "xi32" elif dtype == torch.float32 or dtype == torch.float: input_shape_string += "xf32" elif dtype == torch.bfloat16 or dtype == torch.float16 or dtype == torch.int16: @@ -142,7 +144,9 @@ class TestTensors: def __init__(self, data: Tuple): self.data = data - self.type = type(self.data[0]) + self.type = None + if len(data) > 0 : + self.type = type(self.data[0]) if not all([type(d) == self.type for d in data]): self.type == None diff --git a/alt_e2eshark/onnx_tests/helper_classes.py b/alt_e2eshark/onnx_tests/helper_classes.py index 13e82c28..84cac68a 100644 --- a/alt_e2eshark/onnx_tests/helper_classes.py +++ b/alt_e2eshark/onnx_tests/helper_classes.py @@ -75,6 +75,7 @@ def __init__(self, og_model_info_class: type, og_name: str, *args, **kwargs): run_dir = Path(self.model).parents[1] og_model_path = os.path.join(run_dir, og_name) self.sibling_inst = og_model_info_class(og_name, og_model_path) + self.opset_version = self.sibling_inst.opset_version def construct_model(self): if not os.path.exists(self.sibling_inst.model):