diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index af7348dba532f..4bcfae183ffcb 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -89,9 +89,10 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | ScatterElements | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterElements | ✗ | ✓ | Only supports 'reduction' == 'none' | | ScatterND | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterND | ✗ | ✓ | Only supports 'reduction' == 'none' | | Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice | ✓ | ✓ | | -| SimplifiedLayerNormalization | ai.onnx(1+) | pow + reduceMean + add + sqrt + div + mul | ✓ | ✓ | | +| SimplifiedLayerNormalization | ai.onnx(1+) | pow, reduceMean, add, sqrt, div, mul | ✓ | ✓ | | | Sigmoid | ai.onnx(7-12, 13+) | sigmoid | ✓ | ✓ | | | Sign | ai.onnx(9-12, 13+) | sign | ✓ | ✓ | | +| SkipSimplifiedLayerNormalization | com.microsoft(1+) | pow, reduceMean, add, sqrt, div, mul | ✓ | ✓ | | | Softplus | ai.onnx(7+) | softplus | ✓ | ✓ | | | Softsign | ai.onnx(7+) | softsign | ✓ | ✓ | | | Sin | ai.onnx(7+) | sin | ✓ | ✓ | | diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index a06f46f1bdf0a..cf80eeef3418b 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -181,6 +181,10 @@ inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::s return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; }); } +inline bool TensorExists(const ConstPointerContainer>& defs, size_t tensor_index) noexcept { + return tensor_index < defs.size() && defs[tensor_index]->Exists(); +} + bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name, const logging::Logger& logger, bool allow_empty_input = false); @@ -278,6 +282,7 @@ static const InlinedHashMap op_map = { {"Softplus", "softplus"}, {"Softsign", "softsign"}, {"Sin", "sin"}, + {"SkipSimplifiedLayerNormalization", "layerNormalization"}, {"Slice", "slice"}, {"Softmax", "softmax"}, {"Split", "split"}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 81e688ea4f8ea..548e718b8774e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -405,8 +405,8 @@ bool ConvOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initia int32_t input1_type; // weight data type int32_t input2_type; // bias or x_zero_point data type int32_t input3_type; // w_zero_point data type - bool has_input2 = input_defs.size() > 2 && input_defs[2]->Exists(); - bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists(); + bool has_input2 = TensorExists(input_defs, 2); + bool has_input3 = TensorExists(input_defs, 3); if (!GetType(*input_defs[0], input0_type, logger) || !GetType(*input_defs[1], input1_type, logger) || diff --git a/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc index ef713f48b8135..e16ecad99b993 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc @@ -742,7 +742,7 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* init const auto& op_type = node.OpType(); int32_t input0_type; int32_t input1_type; - bool has_input1 = input_defs.size() > 1 && input_defs[1]->Exists(); + bool has_input1 = TensorExists(input_defs, 1); if (!GetType(*input_defs[0], input0_type, logger) || (has_input1 && !GetType(*input_defs[1], input1_type, logger))) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 5f4e6de8fda98..b0ddd97443984 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -223,8 +223,8 @@ bool GemmOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initia int32_t input1_type; // B data type int32_t input2_type; // C or a_zero_point data type int32_t input3_type; // b_zero_point data type - bool has_input2 = input_defs.size() > 2 && input_defs[2]->Exists(); - bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists(); + bool has_input2 = TensorExists(input_defs, 2); + bool has_input3 = TensorExists(input_defs, 3); if (!GetType(*input_defs[0], input0_type, logger) || !GetType(*input_defs[1], input1_type, logger) || diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc index b240e30d38b22..6b3b5b4b48add 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -33,7 +33,7 @@ class GruOpBuilder : public BaseOpBuilder { }; void GruOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { - if (node.InputDefs().size() > 4 && node.InputDefs()[4]->Exists()) { + if (TensorExists(node.InputDefs(), 4)) { model_builder.AddInitializerToSkip(node.InputDefs()[4]->Name()); // sequence_lens model_builder.AddInputToSkip(node.InputDefs()[4]->Name()); } @@ -56,7 +56,7 @@ Status GruOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No options.set("label", node.Name()); options.set("layout", emscripten::val("zrn")); - if (input_defs.size() > 3 && input_defs[3]->Exists()) { + if (TensorExists(input_defs, 3)) { emscripten::val bias = model_builder.GetOperand(input_defs[3]->Name()); emscripten::val split_options = emscripten::val::object(); split_options.set("label", node.Name() + "_split"); @@ -68,7 +68,7 @@ Status GruOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No options.set("recurrentBias", splitted_biases[1]); } - if (input_defs.size() > 5 && input_defs[5]->Exists()) { + if (TensorExists(input_defs, 5)) { options.set("initialHiddenState", model_builder.GetOperand(input_defs[5]->Name())); } @@ -76,8 +76,8 @@ Status GruOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No options.set("resetAfter", linear_before_reset); const auto& output_defs = node.OutputDefs(); - bool has_Y = output_defs.size() > 0 && output_defs[0]->Exists(); - bool has_Y_h = output_defs.size() > 1 && output_defs[1]->Exists(); + bool has_Y = TensorExists(output_defs, 0); + bool has_Y_h = TensorExists(output_defs, 1); options.set("returnSequence", has_Y); std::string direction = helper.Get("direction", "forward"); @@ -134,7 +134,7 @@ bool GruOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c } int32_t steps = static_cast(input_shape[0]); - if (input_defs.size() > 4 && input_defs[4]->Exists()) { + if (TensorExists(input_defs, 4)) { if (!Contains(initializers, input_defs[4]->Name())) { LOGS(logger, ERROR) << "GRU: sequence_lens must be constant"; return false; @@ -196,8 +196,8 @@ bool GruOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initial int32_t input_R_type = 0; // recurrent weight data type int32_t input_B_type = 0; // bias data type int32_t input_initial_h_type = 0; // initial hidden state data type - bool has_input_B = input_defs.size() > 3 && input_defs[3]->Exists(); - bool has_input_initial_h = input_defs.size() > 5 && input_defs[5]->Exists(); + bool has_input_B = TensorExists(input_defs, 3); + bool has_input_initial_h = TensorExists(input_defs, 5); if (!GetType(*input_defs[0], input_X_type, logger) || !GetType(*input_defs[1], input_W_type, logger) || @@ -229,8 +229,8 @@ bool GruOpBuilder::HasSupportedOutputsImpl(const Node& node, const auto& op_type = node.OpType(); int32_t Y_type = 0; int32_t Y_h_type = 0; - bool has_Y = output_defs.size() > 0 && output_defs[0]->Exists(); - bool has_Y_h = output_defs.size() > 1 && output_defs[1]->Exists(); + bool has_Y = TensorExists(output_defs, 0); + bool has_Y_h = TensorExists(output_defs, 1); bool Y_supported = has_Y && GetType(*output_defs[0], Y_type, logger); bool Y_h_supported = has_Y_h && GetType(*output_defs[1], Y_h_type, logger); diff --git a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc index 33ba22ac3fb5b..30abbd117fb66 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc @@ -32,7 +32,7 @@ class LstmOpBuilder : public BaseOpBuilder { }; void LstmOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { - if (node.InputDefs().size() > 4 && node.InputDefs()[4]->Exists()) { + if (TensorExists(node.InputDefs(), 4)) { model_builder.AddInitializerToSkip(node.InputDefs()[4]->Name()); // sequence_lens model_builder.AddInputToSkip(node.InputDefs()[4]->Name()); } @@ -56,7 +56,7 @@ Status LstmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N options.set("label", node.Name()); options.set("layout", emscripten::val("iofg")); - if (input_defs.size() > 3 && input_defs[3]->Exists()) { + if (TensorExists(input_defs, 3)) { emscripten::val bias = model_builder.GetOperand(input_defs[3]->Name()); emscripten::val split_options = emscripten::val::object(); split_options.set("axis", 1); @@ -67,13 +67,13 @@ Status LstmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N options.set("bias", splitted_biases[0]); options.set("recurrentBias", splitted_biases[1]); } - if (input_defs.size() > 5 && input_defs[5]->Exists()) { + if (TensorExists(input_defs, 5)) { options.set("initialHiddenState", model_builder.GetOperand(input_defs[5]->Name())); } - if (input_defs.size() > 6 && input_defs[6]->Exists()) { + if (TensorExists(input_defs, 6)) { options.set("initialCellState", model_builder.GetOperand(input_defs[6]->Name())); } - if (input_defs.size() > 7 && input_defs[7]->Exists()) { + if (TensorExists(input_defs, 7)) { options.set("peepholeWeight", model_builder.GetOperand(input_defs[7]->Name())); } @@ -87,9 +87,9 @@ Status LstmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N } const auto& output_defs = node.OutputDefs(); - bool has_Y = output_defs.size() > 0 && output_defs[0]->Exists(); - bool has_Y_h = output_defs.size() > 1 && output_defs[1]->Exists(); - bool has_Y_c = output_defs.size() > 2 && output_defs[2]->Exists(); + bool has_Y = TensorExists(output_defs, 0); + bool has_Y_h = TensorExists(output_defs, 1); + bool has_Y_c = TensorExists(output_defs, 2); options.set("returnSequence", has_Y); if (helper.HasAttr("activations")) { @@ -140,7 +140,7 @@ bool LstmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } int32_t steps = static_cast(input_shape[0]); - if (input_defs.size() > 4 && input_defs[4]->Exists()) { + if (TensorExists(input_defs, 4)) { if (!Contains(initializers, input_defs[4]->Name())) { LOGS(logger, ERROR) << "LSTM: sequence_lens must be constant"; return false; @@ -210,10 +210,10 @@ bool LstmOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initia int32_t input5_type = 0; // initialHiddenState data type int32_t input6_type = 0; // initialCellState data type int32_t input7_type = 0; // peepholeWeight data type - bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists(); - bool has_input5 = input_defs.size() > 5 && input_defs[5]->Exists(); - bool has_input6 = input_defs.size() > 6 && input_defs[6]->Exists(); - bool has_input7 = input_defs.size() > 7 && input_defs[7]->Exists(); + bool has_input3 = TensorExists(input_defs, 3); + bool has_input5 = TensorExists(input_defs, 5); + bool has_input6 = TensorExists(input_defs, 6); + bool has_input7 = TensorExists(input_defs, 7); if (!GetType(*input_defs[0], input0_type, logger) || !GetType(*input_defs[1], input1_type, logger) || @@ -253,9 +253,9 @@ bool LstmOpBuilder::HasSupportedOutputsImpl(const Node& node, int32_t Y_type = 0; int32_t Y_h_type = 0; int32_t Y_c_type = 0; - bool has_Y = output_defs.size() > 0 && output_defs[0]->Exists(); - bool has_Y_h = output_defs.size() > 1 && output_defs[1]->Exists(); - bool has_Y_c = output_defs.size() > 2 && output_defs[2]->Exists(); + bool has_Y = TensorExists(output_defs, 0); + bool has_Y_h = TensorExists(output_defs, 1); + bool has_Y_c = TensorExists(output_defs, 2); if (has_Y && GetType(*output_defs[0], Y_type, logger)) { return IsDataTypeSupportedByOp(op_type, Y_type, wnn_limits, "outputs", "Y", logger); diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 50e49884bdfa9..d1c0f598b79f4 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -34,6 +34,7 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder const logging::Logger& logger) const { const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); + const auto& output_defs = node.OutputDefs(); ORT_RETURN_IF_NOT(input_defs.size() >= 2, op_type, " requires at least two inputs."); emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); @@ -45,7 +46,8 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder options.set("label", node.Name()); std::vector scale_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[1], scale_shape, logger), "Cannot get scale shape"); + const size_t scale_input_index = op_type == "SkipSimplifiedLayerNormalization" ? 2 : 1; + ORT_RETURN_IF_NOT(GetShape(*input_defs[scale_input_index], scale_shape, logger), "Cannot get scale shape"); const auto scale_size = scale_shape.size(); // Except LayerNormalization, other normalization ops' scale input should be 1-D. if (op_type == "LayerNormalization") { @@ -55,19 +57,17 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder ORT_RETURN_IF_NOT(scale_size == 1, "The scale size should be one."); } - if (input_defs.size() >= 3 && !input_defs[2]->Name().empty()) { + emscripten::val scale = model_builder.GetOperand(input_defs[scale_input_index]->Name()); + options.set("scale", scale); + + const size_t bias_input_index = op_type == "SkipSimplifiedLayerNormalization" ? 3 : 2; + emscripten::val bias = emscripten::val::undefined(); + if (TensorExists(input_defs, bias_input_index)) { // Bias input exists, and bias's shape should be the same as scale's shape. std::vector bias_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[2], bias_shape, logger), "Cannot get bias shape"); + ORT_RETURN_IF_NOT(GetShape(*input_defs[bias_input_index], bias_shape, logger), "Cannot get bias shape"); ORT_RETURN_IF_NOT(bias_shape == scale_shape, "The bias' shape should be equal to scale's shape."); - } - - emscripten::val scale = model_builder.GetOperand(input_defs[1]->Name()); - options.set("scale", scale); - - if (input_defs.size() >= 3 && !input_defs[2]->Name().empty()) { - // Bias input exists, and bias's shape is the same as scale's shape. - emscripten::val bias = model_builder.GetOperand(input_defs[2]->Name()); + bias = model_builder.GetOperand(input_defs[bias_input_index]->Name()); options.set("bias", bias); } @@ -76,6 +76,8 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder options.set("epsilon", epsilon); emscripten::val output = emscripten::val::undefined(); + // SkipSimplifiedLayerNormalization's output: input_skip_bias_sum. + emscripten::val input_skip_bias_sum = emscripten::val::undefined(); if (op_type == "BatchNormalization") { ORT_RETURN_IF_NOT(input_defs.size() == 5, "BatchNormalization requires five inputs."); emscripten::val mean = model_builder.GetOperand(input_defs[3]->Name()); @@ -85,7 +87,9 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder } output = model_builder.GetBuilder().call("batchNormalization", input, mean, variance, options); - } else if (op_type == "LayerNormalization" || op_type == "SimplifiedLayerNormalization") { + } else if (op_type == "LayerNormalization" || + op_type == "SimplifiedLayerNormalization" || + op_type == "SkipSimplifiedLayerNormalization") { int64_t axis = helper.Get("axis", -1); axis = HandleNegativeAxis(axis, rank); std::vector axes(rank - SafeInt(axis)); @@ -94,13 +98,17 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder if (op_type == "LayerNormalization") { options.set("axes", emscripten::val::array(axes)); output = model_builder.GetBuilder().call("layerNormalization", input, options); - } else { // SimplifiedLayerNormalization + } else { // SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization /** - WebNN doesn't support SimplifiedLayerNormalization. So decompose it into a series of ops: - X --> Pow --> ReduceMean --> Add --> Sqrt --> Div -> Mul - ^ ^ ^ ^ ^ - | | | | | - Y:2 axis B:epsilon A:X A:scale + WebNN doesn't support SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization. + So decompose it into a series of ops: + X --> Pow --> ReduceMean --> Add --> Sqrt --> Div -> Mul -> Add (optional) + ^ ^ ^ ^ ^ ^ + | | | | | | + Y:2 axis B:epsilon A:X A:scale B:bias + + If it is SkipSimplifiedLayerNormalization and its output input_skip_bias_sum exists, + input_skip_bias_sum = X + skip + bias (if it exists) */ int32_t input_type; @@ -137,6 +145,25 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder // Mul common_options.set("label", node.Name() + "_mul"); output = model_builder.GetBuilder().call("mul", scale, div, common_options); + + // Add (if bias exits) + if (!bias.isUndefined()) { + common_options.set("label", node.Name() + "_add_bias"); + output = model_builder.GetBuilder().call("add", output, bias, common_options); + } + + // SkipSimplifiedLayerNormalization's output input_skip_bias_sum is the sum of input, skip, and bias. + if (op_type == "SkipSimplifiedLayerNormalization" && TensorExists(output_defs, 3)) { + emscripten::val skip = model_builder.GetOperand(input_defs[1]->Name()); + common_options.set("label", node.Name() + "_add_skip"); + input_skip_bias_sum = model_builder.GetBuilder().call("add", input, skip, common_options); + if (!bias.isUndefined()) { + common_options.set("label", node.Name() + "_add_skip_bias"); + input_skip_bias_sum = model_builder.GetBuilder().call( + "add", input_skip_bias_sum, bias, common_options); + } + model_builder.AddOperand(output_defs[3]->Name(), std::move(input_skip_bias_sum)); + } } } else if (op_type == "InstanceNormalization") { // WebNN spec only supports 4D input for instanceNormalization. @@ -188,7 +215,7 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported normalization op: ", op_type); } - model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + model_builder.AddOperand(output_defs[0]->Name(), std::move(output)); return Status::OK(); } @@ -215,9 +242,21 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initi } const auto& output_defs = node.OutputDefs(); - if (output_defs.size() != 1) { - LOGS(logger, VERBOSE) << op_type << " output count must be one."; - return false; + if (op_type == "SkipSimplifiedLayerNormalization") { + if (output_defs.size() > 4) { + LOGS(logger, VERBOSE) << "SkipSimplifiedLayerNormalization output count must not exceed 4."; + return false; + } + if (TensorExists(output_defs, 1) || TensorExists(output_defs, 2)) { + // Output mean and inv_std_var are used for training mode, which is not supported. + LOGS(logger, VERBOSE) << "SkipSimplifiedLayerNormalization's output mean and inv_std_var are not supported."; + return false; + } + } else { + if (output_defs.size() != 1) { + LOGS(logger, VERBOSE) << op_type << " output count must be one."; + return false; + } } if (op_type == "BatchNormalization" && helper.Get("training_mode", 0)) { @@ -238,9 +277,9 @@ bool NormalizationOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& int32_t input2_type; // B data type int32_t input3_type; // mean data type int32_t input4_type; // var data type - bool has_input2 = input_defs.size() > 2 && input_defs[2]->Exists(); - bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists(); - bool has_input4 = input_defs.size() > 3 && input_defs[4]->Exists(); + bool has_input2 = TensorExists(input_defs, 2); + bool has_input3 = TensorExists(input_defs, 3); + bool has_input4 = TensorExists(input_defs, 4); if (!GetType(*input_defs[0], input0_type, logger) || !GetType(*input_defs[1], input1_type, logger) || @@ -277,6 +316,7 @@ void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrat "InstanceNormalization", "LayerNormalization", "SimplifiedLayerNormalization", + "SkipSimplifiedLayerNormalization", }; op_registrations.builders.push_back(std::make_unique()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc index bd7c23d75eba4..ca7b1c3f08a71 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc @@ -51,7 +51,7 @@ Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val scale = model_builder.GetOperand(input_defs[1]->Name()); emscripten::val zero_point = emscripten::val::null(); - if (input_defs.size() == 3 && input_defs[2]->Exists()) { + if (TensorExists(input_defs, 2)) { zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); has_zero_point = true; } else { @@ -159,7 +159,7 @@ bool QDQOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initial int32_t input0_type = 0; // input data type int32_t input1_type = 0; // x_scale data type int32_t input2_type = 0; // x_zero_point data type - bool has_input2 = input_defs.size() > 2 && input_defs[2]->Exists(); + bool has_input2 = TensorExists(input_defs, 2); if (!GetType(*input_defs[0], input0_type, logger) || !GetType(*input_defs[1], input1_type, logger) || diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index 41c66038c2694..44f929a9f1ac0 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -173,7 +173,7 @@ bool SliceOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& initiali return false; // If there is step < 0, check data type support of reverse. - if (input_defs.size() > 4 && input_defs[4]->Exists()) { + if (TensorExists(input_defs, 4)) { std::vector steps; if (!ReadIntArrayFrom1DTensor(*initializers.at(input_defs[4]->Name()), steps, logger)) return false; diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 6d1c572128b93..e0ca50a36dbf9 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -159,6 +159,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateNormalizationOpBuilder("InstanceNormalization", op_registrations); CreateNormalizationOpBuilder("LayerNormalization", op_registrations); CreateNormalizationOpBuilder("SimplifiedLayerNormalization", op_registrations); + CreateNormalizationOpBuilder("SkipSimplifiedLayerNormalization", op_registrations); } { // Pad