Skip to content

Commit

Permalink
[alt] adds a numerics failure reproducer (#351)
Browse files Browse the repository at this point in the history
This particular Conv op signature was found to be causing numerics
failures in a few models. E.g., "maxvit_rmlp_base_rw_224.sw_in12k"
  • Loading branch information
zjgarvey authored Sep 25, 2024
1 parent 34dc084 commit eafb816
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions alt_e2eshark/onnx_tests/operators/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,21 @@
from ..helper_classes import BuildAModel
from e2e_testing.registry import register_with_name, register_test

class ConvRepro(BuildAModel):
def construct_i_o_value_info(self):
self.input_vi = [
make_tensor_value_info("X", TensorProto.FLOAT, [1,256,112,112]),
make_tensor_value_info("W", TensorProto.FLOAT, [256,1,3,3]),
make_tensor_value_info("B", TensorProto.FLOAT, [256]),
]
self.output_vi = [make_tensor_value_info("Y", TensorProto.FLOAT, [1,256,56,56])]

def construct_nodes(self):
app_node = self.get_app_node()
app_node("Conv",["X","W","B"],["Y"],group=256,kernel_shape=[3,3],pads=[1,1,1,1],strides=[2,2])

register_test(ConvRepro, "conv_depthwise_stride_2")

class QConvModelBase(BuildAModel):
def __init__(self, specs, *args, **kwargs):
(self.N, self.Cin, self.Hin, self.Win, self.Cout, self.groups, self.Hker, self.Wker, self.pads, self.dilations, self.strides) = specs
Expand Down

0 comments on commit eafb816

Please sign in to comment.