Skip to content

Commit

Permalink
Onnx.if op test (#176)
Browse files Browse the repository at this point in the history
Simple op test for if. Also includes a nifty little helper function that
would reduce error & duplication by avoiding re-specifying dtypes and
shapes

```python
inputs = session.get_inputs()
# gets Z in outputs[0]
outputs = session.get_outputs()

def generate_input_from_node(node: onnxruntime.capi.onnxruntime_pybind11_state.NodeArg):
    if node.type == "tensor(float)":
        return numpy.random.randn(*node.shape).astype(numpy.float32)
    if node.type == "tensor(int)":
        return numpy.random.randint(0, 10000, size=node.shape).astype(numpy.int32)
    if node.type == "tensor(bool)":
        return numpy.random.randint(0, 2, size=node.shape).astype(bool)
    
input_dict = {
    node.name: generate_input_from_node(node)
    for node in inputs
}

output_list = [
    node.name
    for node in outputs
]
```
  • Loading branch information
renxida authored Apr 18, 2024
2 parents 3e78a90 + f924a75 commit 1d77dcf
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions e2eshark/onnx/operators/If/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@

# condition has to be a float tensor
condition = make_tensor_value_info('condition', TensorProto.BOOL, [1])
input1 = make_tensor_value_info('input1', TensorProto.FLOAT, [1])
input2 = make_tensor_value_info('input2', TensorProto.FLOAT, [1])
output = make_tensor_value_info('output', TensorProto.FLOAT, [1])
output_then = make_tensor_value_info('output_then', TensorProto.FLOAT, [1])
output_else = make_tensor_value_info('output_else', TensorProto.FLOAT, [1])
input1 = make_tensor_value_info('input1', TensorProto.FLOAT, [2,3])
input2 = make_tensor_value_info('input2', TensorProto.FLOAT, [2,3])
output = make_tensor_value_info('output', TensorProto.FLOAT, [2,3])
output_then = make_tensor_value_info('output_then', TensorProto.FLOAT, [2,3])
output_else = make_tensor_value_info('output_else', TensorProto.FLOAT, [2,3])

then_branch = make_graph(
nodes=[
Expand Down

0 comments on commit 1d77dcf

Please sign in to comment.