Skip to content

Commit

Permalink
Merge pull request #71 from JDAI-CV/fix_exception_in_ort
Browse files Browse the repository at this point in the history
catch exceptions in GetSupportedNodes
  • Loading branch information
daquexian authored Jan 28, 2020
2 parents bac3cce + 4efa112 commit e17f11e
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 26 deletions.
2 changes: 1 addition & 1 deletion ci/onnxruntime_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pool:
steps:
- checkout: self
submodules: true
- script: git clone --recursive --branch fix_android_build https://github.com/daquexian/onnxruntime $(Agent.HomeDirectory)/onnxruntime
- script: git clone --recursive --branch android https://github.com/daquexian/onnxruntime $(Agent.HomeDirectory)/onnxruntime
displayName: Clone ONNX Runtime
- script: rm -rf $(Agent.HomeDirectory)/onnxruntime/cmake/external/DNNLibrary && cp -r $(Build.SourcesDirectory) $(Agent.HomeDirectory)/onnxruntime/cmake/external/DNNLibrary
displayName: Copy latest DNNLibrary
Expand Down
2 changes: 1 addition & 1 deletion include/tools/onnx2daq/OnnxConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ class OnnxConverter {
void Clear();

public:
std::vector<std::vector<int>> GetSupportedNodes(
expected<std::vector<std::vector<int>>, std::string> GetSupportedNodes(
ONNX_NAMESPACE::ModelProto model_proto);
void Convert(const std::string &model_str, const std::string &filepath,
const std::string &table_file = "");
Expand Down
12 changes: 10 additions & 2 deletions tools/getsupportednodes/getsupportednodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ int main(int argc, char *argv[])
// FIXME: Handle the return value
model_proto.ParseFromString(ss.str());
dnn::OnnxConverter converter;
PNT(converter.GetSupportedNodes(model_proto));
return 0;
const auto nodes = converter.GetSupportedNodes(model_proto);
if (nodes) {
const auto &supported_ops = nodes.value();
PNT(supported_ops);
return 0;
} else {
const auto &error = nodes.error();
PNT(error);
return 1;
}
}
50 changes: 28 additions & 22 deletions tools/onnx2daq/OnnxConverter.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <common/data_types.h>
#include <common/Shaper.h>
#include <common/StrKeyMap.h>
#include <common/helper.h>
Expand Down Expand Up @@ -271,8 +272,9 @@ void OnnxConverter::HandleInitializer() {
ONNX_NAMESPACE::TensorProto_DataType_INT64) {
// TODO: shape of reshape layer
} else {
PNT(tensor.name(), tensor.data_type());
DNN_ASSERT(false, "");
DNN_ASSERT(false, "The data type \"" + std::to_string(tensor.data_type()) +
"\" of tensor \"" +
tensor.name() + "\" is not supported");
}
operands_.push_back(name);
}
Expand Down Expand Up @@ -630,34 +632,38 @@ bool IsValidSupportedNodesVec(const std::vector<int> &supported_node_vec,
return false;
}

std::vector<std::vector<int>> OnnxConverter::GetSupportedNodes(
expected<std::vector<std::vector<int>>, std::string> OnnxConverter::GetSupportedNodes(
ONNX_NAMESPACE::ModelProto model_proto) {
GOOGLE_PROTOBUF_VERIFY_VERSION;
ONNX_NAMESPACE::shape_inference::InferShapes(model_proto);
model_proto_ = model_proto;
HandleInitializer();
try {
HandleInitializer();

std::vector<std::vector<int>> supported_node_vecs;
std::vector<int> supported_node_vec;
for (int i = 0; i < model_proto.graph().node_size(); i++) {
bool supported;
std::string error_msg;
std::tie(supported, error_msg) =
IsNodeSupported(model_proto, model_proto.graph().node(i));
if (supported) {
supported_node_vec.push_back(i);
} else {
if (IsValidSupportedNodesVec(supported_node_vec, model_proto)) {
supported_node_vecs.push_back(supported_node_vec);
supported_node_vec.clear();
std::vector<std::vector<int>> supported_node_vecs;
std::vector<int> supported_node_vec;
for (int i = 0; i < model_proto.graph().node_size(); i++) {
bool supported;
std::string error_msg;
std::tie(supported, error_msg) =
IsNodeSupported(model_proto, model_proto.graph().node(i));
if (supported) {
supported_node_vec.push_back(i);
} else {
if (IsValidSupportedNodesVec(supported_node_vec, model_proto)) {
supported_node_vecs.push_back(supported_node_vec);
supported_node_vec.clear();
}
}
}
if (IsValidSupportedNodesVec(supported_node_vec, model_proto)) {
supported_node_vecs.push_back(supported_node_vec);
}
Clear();
return supported_node_vecs;
} catch (std::exception &e) {
return make_unexpected(e.what());
}
if (IsValidSupportedNodesVec(supported_node_vec, model_proto)) {
supported_node_vecs.push_back(supported_node_vec);
}
Clear();
return supported_node_vecs;
}

void OnnxConverter::Convert(const ONNX_NAMESPACE::ModelProto &model_proto,
Expand Down

0 comments on commit e17f11e

Please sign in to comment.