From eafb81664602a0ec514319e8a38483428eca1c13 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Wed, 25 Sep 2024 16:42:45 -0700 Subject: [PATCH] [alt] adds a numerics failure reproducer (#351) This particular Conv op signature was found to be causing numerics failures in a few models. E.g., "maxvit_rmlp_base_rw_224.sw_in12k" --- alt_e2eshark/onnx_tests/operators/conv.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/alt_e2eshark/onnx_tests/operators/conv.py b/alt_e2eshark/onnx_tests/operators/conv.py index 6c412bda..7f94084c 100644 --- a/alt_e2eshark/onnx_tests/operators/conv.py +++ b/alt_e2eshark/onnx_tests/operators/conv.py @@ -9,6 +9,21 @@ from ..helper_classes import BuildAModel from e2e_testing.registry import register_with_name, register_test +class ConvRepro(BuildAModel): + def construct_i_o_value_info(self): + self.input_vi = [ + make_tensor_value_info("X", TensorProto.FLOAT, [1,256,112,112]), + make_tensor_value_info("W", TensorProto.FLOAT, [256,1,3,3]), + make_tensor_value_info("B", TensorProto.FLOAT, [256]), + ] + self.output_vi = [make_tensor_value_info("Y", TensorProto.FLOAT, [1,256,56,56])] + + def construct_nodes(self): + app_node = self.get_app_node() + app_node("Conv",["X","W","B"],["Y"],group=256,kernel_shape=[3,3],pads=[1,1,1,1],strides=[2,2]) + +register_test(ConvRepro, "conv_depthwise_stride_2") + class QConvModelBase(BuildAModel): def __init__(self, specs, *args, **kwargs): (self.N, self.Cin, self.Hin, self.Win, self.Cout, self.groups, self.Hker, self.Wker, self.pads, self.dilations, self.strides) = specs