From 97fba27d982eee34045440ac545152edb261d31b Mon Sep 17 00:00:00 2001 From: AmosLewis Date: Wed, 16 Oct 2024 10:15:54 -0700 Subject: [PATCH] Fix the bug for truncated model --- alt_e2eshark/e2e_testing/storage.py | 6 +++++- alt_e2eshark/onnx_tests/helper_classes.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/alt_e2eshark/e2e_testing/storage.py b/alt_e2eshark/e2e_testing/storage.py index 943511d9..ef95d37c 100644 --- a/alt_e2eshark/e2e_testing/storage.py +++ b/alt_e2eshark/e2e_testing/storage.py @@ -17,6 +17,8 @@ def get_shape_string(torch_tensor): dtype = torch_tensor.dtype if dtype == torch.int64: input_shape_string += "xi64" + if dtype == torch.int32: + input_shape_string += "xi32" elif dtype == torch.float32 or dtype == torch.float: input_shape_string += "xf32" elif dtype == torch.bfloat16 or dtype == torch.float16 or dtype == torch.int16: @@ -142,7 +144,9 @@ class TestTensors: def __init__(self, data: Tuple): self.data = data - self.type = type(self.data[0]) + self.type = None + if len(data) > 0 : + self.type = type(self.data[0]) if not all([type(d) == self.type for d in data]): self.type == None diff --git a/alt_e2eshark/onnx_tests/helper_classes.py b/alt_e2eshark/onnx_tests/helper_classes.py index 13e82c28..84cac68a 100644 --- a/alt_e2eshark/onnx_tests/helper_classes.py +++ b/alt_e2eshark/onnx_tests/helper_classes.py @@ -75,6 +75,7 @@ def __init__(self, og_model_info_class: type, og_name: str, *args, **kwargs): run_dir = Path(self.model).parents[1] og_model_path = os.path.join(run_dir, og_name) self.sibling_inst = og_model_info_class(og_name, og_model_path) + self.opset_version = self.sibling_inst.opset_version def construct_model(self): if not os.path.exists(self.sibling_inst.model):