Skip to content

Commit

Permalink
[WebNN EP] Support axes and fix some validation for Resize (microsoft…
Browse files Browse the repository at this point in the history
…#21952)

- Supports arbitrary axes for Resize opset 18+
- Check all inputs and attributes more carefully

---------

Co-authored-by: Dwayne Robinson <fdwr@hotmail.com>
  • Loading branch information
2 people authored and Ishwar Raut committed Nov 19, 2024
1 parent 1ebb908 commit cecce2d
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 109 deletions.
2 changes: 1 addition & 1 deletion js/web/docs/webnn-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| ReduceSumSquare | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceSumSquare ||| Input 'axes' if present should be a constant |
| Relu | ai.onnx(7-12, 13, 14+) | relu ||| |
| Reshape | ai.onnx(7-12, 13, 14-18, 19-20, 21+) | reshape ||| Input 'shape' should be a constant, 0 dimension value in 'shape' is not supported |
| Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d ||| Only supports 4-D input, exclude_outside != 0, input 'scales' and 'sizes' if present must be a constant, 'linear' and 'nearest' modes |
| Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d ||| Only supports 4-D input, antialias == 0, coordinate_transformation_mode == 'half_pixel', exclude_outside == 0, keep_aspect_ratio_policy == 'stretch', 'linear' and 'nearest' modes, input 'scales' and 'sizes' if present must be a constant |
| Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice ||| |
| Sigmoid | ai.onnx(7-12, 13+) | sigmoid ||| |
| Softplus | ai.onnx(7+) | softplus ||| |
Expand Down
36 changes: 36 additions & 0 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,31 @@ WebnnDeviceType DeviceTypeFromString(const std::string_view& device_type);
// Collects all the initializer tensors in the subGraph and its ancestor graphs.
InitializedTensorSet CollectAllInitializedTensors(const GraphViewer& graph_viewer);

inline std::vector<int64_t> convertAxesFromNCHWtoNHWC(const std::vector<int64_t>& axes) {
constexpr std::array<int64_t, 4> nchw_to_nhwc = {0, 3, 1, 2};
std::vector<int64_t> new_axes;
new_axes.reserve(axes.size());
for (int64_t axis : axes) {
if (axis >= nchw_to_nhwc.size()) {
ORT_THROW("Invalid axis value: ", axis);
}
new_axes.push_back(nchw_to_nhwc[static_cast<size_t>(axis)]);
}
return new_axes;
}

inline std::vector<int64_t> HandleNegativeAxes(const std::vector<int64_t>& axes, size_t input_size) {
std::vector<int64_t> new_axes(axes.size());
for (size_t i = 0; i < axes.size(); ++i) {
new_axes[i] = HandleNegativeAxis(axes[i], input_size);
}
return new_axes;
}

inline std::vector<int64_t> GetResolvedAxes(const NodeAttrHelper& helper, size_t input_size) {
return HandleNegativeAxes(helper.Get("axes", std::vector<int64_t>{}), input_size);
}

bool GetShape(const NodeArg& node_arg, std::vector<int64_t>& shape, const logging::Logger& logger);

template <typename T>
Expand Down Expand Up @@ -144,6 +169,17 @@ inline bool ReadScalarTensorData(const onnx::TensorProto& tensor, emscripten::va
return true;
}

inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::string& name) {
if (name.empty() || !Contains(initializers, name)) {
return true;
}

const auto& tensor = *initializers.at(name);
const auto dims = tensor.dims();
// An empty tensor contains a 0 in the dimensions list.
return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; });
}

bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger);

// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
Expand Down
Loading

0 comments on commit cecce2d

Please sign in to comment.