Skip to content

Commit

Permalink
WebNN: Implement scatterElements operator in DirectML backend
Browse files Browse the repository at this point in the history
The `scatterElements` operator is proposed by WebML WG [1] for
supporting popular transformer-based models.

This CL adds the IDL and mojo definitions of scatterElements, and
implements it in the DirectML backend by mapping to
`DML_OPERATOR_SCATTER` [2].

This CL also adds the `scatterElements` validation and conformance tests
into WPT.

[1]: webmachinelearning/webnn#375 (comment)
[2]: https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_scatter_operator_desc

Bug: 370536101,370538328
Change-Id: Ifb73bed5eb05cb919b106b4aaea5127ec099edb2
Cq-Include-Trybots: luci.chromium.try:win11-blink-rel, mac14.arm64-blink-rel, mac14-blink-rel, mac15.arm64-blink-rel, mac15-blink-rel, linux-blink-rel
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5921136
Reviewed-by: Alex Gough <ajgo@chromium.org>
Reviewed-by: Weizhong Xia <weizhong@google.com>
Auto-Submit: ningxin hu <ningxin.hu@intel.com>
Commit-Queue: ningxin hu <ningxin.hu@intel.com>
Commit-Queue: Weizhong Xia <weizhong@google.com>
Reviewed-by: Rafael Cintron <rafael.cintron@microsoft.com>
Reviewed-by: Austin Sullivan <asully@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1368312}
  • Loading branch information
huningxin authored and Chromium LUCI CQ committed Oct 14, 2024
1 parent a01467c commit 48716cb
Show file tree
Hide file tree
Showing 44 changed files with 921 additions and 1 deletion.
4 changes: 4 additions & 0 deletions services/webnn/coreml/graph_builder_coreml.cc
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,9 @@ ContextProperties GraphBuilderCoreml::GetContextProperties() {
// corresponding BOOL type. See docs here:
// https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.tensor_transformation.reshape
/*reshape_input=*/kFloatsAndInt32,
// TODO(crbug.com/370535834): Implement ScatterElements.
/*scatter_elements_input=*/{},
/*scatter_elements_indices=*/{},
// TODO(crbug.com/363544348): Implement ScatterND.
/*scatter_nd_input=*/{},
/*scatter_nd_indices=*/{},
Expand Down Expand Up @@ -1042,6 +1045,7 @@ GraphBuilderCoreml::BuildCoreMLModel() {
case mojom::Operation::Tag::kLstmCell:
case mojom::Operation::Tag::kPrelu:
case mojom::Operation::Tag::kQuantizeLinear:
case mojom::Operation::Tag::kScatterElements:
case mojom::Operation::Tag::kScatterNd:
case mojom::Operation::Tag::kTriangular:
return NewNotSupportedError(NotSupportedOperatorError(*operation));
Expand Down
6 changes: 6 additions & 0 deletions services/webnn/dml/context_impl_dml.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,10 @@ ContextProperties ContextImplDml::GetProperties(
// Reshape is emulated by identity.
/*reshape_input=*/kFloat16To32Ints8To32,

// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_scatter_operator_desc#tensor-support
/*scatter_elements_input=*/kFloat16To32Ints8To32,
/*scatter_elements_indices=*/kGatherScatterIndicesSupportedDataTypes,

// https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_scatter_nd_operator_desc#tensor-support
/*scatter_nd_input=*/kFloat16To32Ints8To32,
/*scatter_nd_indices=*/kGatherScatterIndicesSupportedDataTypes,
Expand Down Expand Up @@ -372,6 +376,8 @@ ContextProperties ContextImplDml::GetProperties(
SupportedDataTypes::All();
properties.data_type_limits.gather_nd_input = SupportedDataTypes::All();
properties.data_type_limits.reshape_input = SupportedDataTypes::All();
properties.data_type_limits.scatter_elements_input =
SupportedDataTypes::All();
properties.data_type_limits.scatter_nd_input = SupportedDataTypes::All();
properties.data_type_limits.sign_input =
DataTypeConstraint::kFloat16To32Int8To64;
Expand Down
61 changes: 61 additions & 0 deletions services/webnn/dml/graph_impl_dml.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,14 @@ void RetrieveOperationConnectivity(
output_ids = {reshape->output_operand_id};
break;
}
case Operation::Tag::kScatterElements: {
const auto& scatter_elements = operation->get_scatter_elements();
input_ids = {scatter_elements->input_operand_id,
scatter_elements->indices_operand_id,
scatter_elements->updates_operand_id};
output_ids = {scatter_elements->output_operand_id};
break;
}
case Operation::Tag::kScatterNd: {
const auto& scatter_nd = operation->get_scatter_nd();
input_ids = {scatter_nd->input_operand_id, scatter_nd->indices_operand_id,
Expand Down Expand Up @@ -2326,6 +2334,52 @@ void CreateOperatorNodeForPrelu(const ContextProperties context_properties,
CHECK(id_to_node_output_map.try_emplace(output_id, node_output).second);
}

void CreateOperatorNodeForScatterElements(
const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::ScatterElementsPtr& scatter_elements,
GraphBuilderDml& graph_builder,
IdToNodeOutputMap& id_to_node_output_map) {
const NodeOutput* input = GetNodeOutputForOperand(
id_to_node_output_map, scatter_elements->input_operand_id);
TensorDesc input_tensor_desc = input->GetTensorDesc();
CHECK(context_properties.data_type_limits.scatter_elements_input.Has(
DmlDataTypeToOperand(input_tensor_desc.GetDataType())));

const NodeOutput* indices = GetNodeOutputForOperand(
id_to_node_output_map, scatter_elements->indices_operand_id);
TensorDesc indices_tensor_desc = indices->GetTensorDesc();
CHECK(context_properties.data_type_limits.scatter_elements_indices.Has(
DmlDataTypeToOperand(indices_tensor_desc.GetDataType())));

const NodeOutput* updates = GetNodeOutputForOperand(
id_to_node_output_map, scatter_elements->updates_operand_id);
TensorDesc updates_tensor_desc = updates->GetTensorDesc();
CHECK(context_properties.data_type_limits.scatter_elements_input.Has(
DmlDataTypeToOperand(updates_tensor_desc.GetDataType())));

uint64_t output_id = scatter_elements->output_operand_id;
const TensorDesc output_tensor_desc =
CreateOutputTensorDesc(id_to_operand_map, output_id);

DML_SCATTER_OPERATOR_DESC scatter_elements_desc{
.InputTensor = &input_tensor_desc.GetDMLTensorDesc(),
.IndicesTensor = &indices_tensor_desc.GetDMLTensorDesc(),
.UpdatesTensor = &updates_tensor_desc.GetDMLTensorDesc(),
.OutputTensor = &output_tensor_desc.GetDMLTensorDesc(),
.Axis = scatter_elements->axis};

std::array<const NodeOutput*, 3> inputs = {input, indices, updates};
const GraphNode* node = graph_builder.CreateOperatorNode(
DML_OPERATOR_SCATTER, &scatter_elements_desc, inputs,
scatter_elements->label);

const NodeOutput* output =
graph_builder.CreateNodeOutput(node, std::move(output_tensor_desc), 0);
// The output id must be unique in the map.
CHECK(id_to_node_output_map.try_emplace(output_id, output).second);
}

void CreateOperatorNodeForScatterND(const ContextProperties& context_properties,
const IdToOperandMap& id_to_operand_map,
const mojom::ScatterNDPtr& scatter_nd,
Expand Down Expand Up @@ -6236,6 +6290,13 @@ base::expected<void, mojom::ErrorPtr> GraphImplDml::CreateAndBuildInternal(
id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kScatterElements: {
CreateOperatorNodeForScatterElements(
context_properties, id_to_operand_map,
operation->get_scatter_elements(), graph_builder,
id_to_node_output_map);
break;
}
case mojom::Operation::Tag::kScatterNd: {
CreateOperatorNodeForScatterND(context_properties, id_to_operand_map,
operation->get_scatter_nd(),
Expand Down
4 changes: 4 additions & 0 deletions services/webnn/public/cpp/data_type_limits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ DataTypeLimits::DataTypeLimits(SupportedDataTypes input,
SupportedDataTypes relu_input,
SupportedDataTypes resample2d_input,
SupportedDataTypes reshape_input,
SupportedDataTypes scatter_elements_input,
SupportedDataTypes scatter_elements_indices,
SupportedDataTypes scatter_nd_input,
SupportedDataTypes scatter_nd_indices,
SupportedDataTypes sigmoid_input,
Expand Down Expand Up @@ -192,6 +194,8 @@ DataTypeLimits::DataTypeLimits(SupportedDataTypes input,
relu_input(relu_input),
resample2d_input(resample2d_input),
reshape_input(reshape_input),
scatter_elements_input(scatter_elements_input),
scatter_elements_indices(scatter_elements_indices),
scatter_nd_input(scatter_nd_input),
scatter_nd_indices(scatter_nd_indices),
sigmoid_input(sigmoid_input),
Expand Down
6 changes: 6 additions & 0 deletions services/webnn/public/cpp/data_type_limits.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) DataTypeLimits {
SupportedDataTypes relu_input,
SupportedDataTypes resample2d_input,
SupportedDataTypes reshape_input,
SupportedDataTypes scatter_elements_input,
SupportedDataTypes scatter_elements_indices,
SupportedDataTypes scatter_nd_input,
SupportedDataTypes scatter_nd_indices,
SupportedDataTypes sigmoid_input,
Expand Down Expand Up @@ -208,6 +210,8 @@ struct COMPONENT_EXPORT(WEBNN_PUBLIC_CPP) DataTypeLimits {
SupportedDataTypes relu_input;
SupportedDataTypes resample2d_input;
SupportedDataTypes reshape_input;
SupportedDataTypes scatter_elements_input;
SupportedDataTypes scatter_elements_indices;
SupportedDataTypes scatter_nd_input;
SupportedDataTypes scatter_nd_indices;
SupportedDataTypes sigmoid_input;
Expand Down Expand Up @@ -311,6 +315,8 @@ inline bool operator==(const DataTypeLimits& lhs, const DataTypeLimits& rhs) {
lhs.relu_input == rhs.relu_input &&
lhs.resample2d_input == rhs.resample2d_input &&
lhs.reshape_input == rhs.reshape_input &&
lhs.scatter_elements_input == rhs.scatter_elements_input &&
lhs.scatter_elements_indices == rhs.scatter_elements_indices &&
lhs.scatter_nd_input == rhs.scatter_nd_input &&
lhs.scatter_nd_indices == rhs.scatter_nd_indices &&
lhs.sigmoid_input == rhs.sigmoid_input &&
Expand Down
72 changes: 72 additions & 0 deletions services/webnn/public/cpp/graph_validation_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2494,6 +2494,78 @@ base::expected<OperandDescriptor, std::string> ValidateReduceAndInferOutput(
return OperandDescriptor::Create(input.data_type(), output_shape);
}

base::expected<OperandDescriptor, std::string>
ValidateScatterElementsAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& indices,
const OperandDescriptor& updates,
const uint32_t axis,
std::string_view label) {
if (!context_properties.data_type_limits.scatter_elements_input.Has(
input.data_type())) {
return base::unexpected(ErrorWithLabel(
label,
NotSupportedInputArgumentTypeError(
input.data_type(),
context_properties.data_type_limits.scatter_elements_input)));
}

static constexpr char kIndicesParam[] = "indices";
if (!context_properties.data_type_limits.scatter_elements_indices.Has(
indices.data_type())) {
return base::unexpected(ErrorWithLabel(
label,
NotSupportedArgumentTypeError(
kIndicesParam, indices.data_type(),
context_properties.data_type_limits.scatter_elements_indices)));
}

if (input.data_type() != updates.data_type()) {
return base::unexpected(
ErrorWithLabel(label,
"The updates tensor data type should be the same as "
"input data type."));
}

if (input.Rank() == 0) {
return base::unexpected(
ErrorWithLabel(label, "The input should not be a scalar."));
}

if (input.Rank() <= axis) {
return base::unexpected(ErrorWithLabel(
label,
"The axis must be in the range [0, N-1] where N is the rank of input "
"tensor."));
}

if (indices.Rank() != input.Rank()) {
return base::unexpected(ErrorWithLabel(
label, "The indices and input tensors should have the same rank."));
}

for (uint32_t i = 0; i < input.Rank(); ++i) {
if (i == axis) {
continue;
}
if (input.shape()[i] != indices.shape()[i]) {
return base::unexpected(
ErrorWithLabel(label,
"Except on the axis dimension, the input and indices "
"tensor must have the same dimension size."));
}
}

if (indices.shape() != updates.shape()) {
return base::unexpected(ErrorWithLabel(
label, "The updates and indices tensors should have the same shape."));
}

// The output tensor has the same data type and shape as input's.
return input;
}

base::expected<OperandDescriptor, std::string> ValidateScatterNDAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
Expand Down
12 changes: 12 additions & 0 deletions services/webnn/public/cpp/graph_validation_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,18 @@ base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
base::span<const uint32_t> axes,
bool keepDimensions = false);

// Validate and infer output information of scatterElements operator defined in
// WebIDL here https://www.w3.org/TR/webnn/#api-mlgraphbuilder-scatterelements
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
WEBNN_PUBLIC_CPP)
ValidateScatterElementsAndInferOutput(
const ContextProperties& context_properties,
const OperandDescriptor& input,
const OperandDescriptor& indices,
const OperandDescriptor& updates,
uint32_t axis,
std::string_view label);

// Validate and infer output information of scatterND operator defined in
// WebIDL here https://www.w3.org/TR/webnn/#api-mlgraphbuilder-scatternd
base::expected<OperandDescriptor, std::string> COMPONENT_EXPORT(
Expand Down
1 change: 1 addition & 0 deletions services/webnn/public/cpp/webnn_errors.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ inline constexpr char kQuantizeLinear[] = "quantizeLinear";
inline constexpr char kRelu[] = "relu";
inline constexpr char kResample2d[] = "resample2d";
inline constexpr char kReshape[] = "reshape";
inline constexpr char kScatterElements[] = "scatterElements";
inline constexpr char kScatterND[] = "scatterND";
inline constexpr char kSigmoid[] = "sigmoid";
inline constexpr char kSlice[] = "slice";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ TEST(ContextPropertiesMojomTraitsTest, Basic) {
webnn::SupportedDataTypes::All(),
{webnn::OperandDataType::kFloat16, webnn::OperandDataType::kFloat32},
{webnn::OperandDataType::kInt32, webnn::OperandDataType::kUint32},
{webnn::OperandDataType::kFloat16, webnn::OperandDataType::kFloat32},
{webnn::OperandDataType::kInt32, webnn::OperandDataType::kUint32},
webnn::SupportedDataTypes::All(),
webnn::SupportedDataTypes::All(),
{webnn::OperandDataType::kFloat16, webnn::OperandDataType::kUint8},
Expand All @@ -121,7 +123,7 @@ TEST(ContextPropertiesMojomTraitsTest, Basic) {
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {},
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {},
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {},
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}});
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}});

EXPECT_TRUE(
mojo::test::SerializeAndDeserialize<webnn::mojom::ContextProperties>(
Expand Down
10 changes: 10 additions & 0 deletions services/webnn/public/mojom/data_type_limits_mojom_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,14 @@ struct StructTraits<webnn::mojom::DataTypeLimitsDataView,
const webnn::DataTypeLimits& data_type_limits) {
return data_type_limits.reshape_input;
}
static webnn::SupportedDataTypes scatter_elements_input(
const webnn::DataTypeLimits& data_type_limits) {
return data_type_limits.scatter_elements_input;
}
static webnn::SupportedDataTypes scatter_elements_indices(
const webnn::DataTypeLimits& data_type_limits) {
return data_type_limits.scatter_elements_indices;
}
static webnn::SupportedDataTypes scatter_nd_input(
const webnn::DataTypeLimits& data_type_limits) {
return data_type_limits.scatter_nd_input;
Expand Down Expand Up @@ -498,6 +506,8 @@ struct StructTraits<webnn::mojom::DataTypeLimitsDataView,
data.ReadReluInput(&out->relu_input) &&
data.ReadResample2dInput(&out->resample2d_input) &&
data.ReadReshapeInput(&out->reshape_input) &&
data.ReadScatterElementsInput(&out->scatter_elements_input) &&
data.ReadScatterElementsIndices(&out->scatter_elements_indices) &&
data.ReadScatterNdInput(&out->scatter_nd_input) &&
data.ReadScatterNdIndices(&out->scatter_nd_indices) &&
data.ReadSigmoidInput(&out->sigmoid_input) &&
Expand Down
4 changes: 4 additions & 0 deletions services/webnn/public/mojom/webnn_context_properties.mojom
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ struct DataTypeLimits {
SupportedDataTypes resample2d_input;
SupportedDataTypes reshape_input;

// ScatterElements.
SupportedDataTypes scatter_elements_input;
SupportedDataTypes scatter_elements_indices;

// ScatterND.
SupportedDataTypes scatter_nd_input;
SupportedDataTypes scatter_nd_indices;
Expand Down
60 changes: 60 additions & 0 deletions services/webnn/public/mojom/webnn_graph.mojom
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,65 @@ struct Reshape {
string label;
};

// Represents ScatterElements operation that first copies the values of `input`
// tensor to `output` tensor, and then overwrites the values of `output` tensor
// to values specified by `updates` tensor at specific index positions specified
// by `indices` tensor along `axis` dimension.
//
// Its calculation follows the pseudocode:
// ```
// output = input
// // `input`, `indices` and `output` have the same rank, denoted as `r`.
// output[i_{0}, ..., i_{axis-1},
// indices[i_{0}, ..., i_{r-1}],
// i_{axis+1}, ..., i_{r-1}] = updates[i_{0}, ..., i_{r-1}]
// ```
// The index value must be within the bounds [`-output.shape[axis]`,
// `output.shape[axis]` - 1] and a negative index means indexing from the end of
// the dimension.
//
// Example 1: Scatter elements along axis 0
// input = [[0.0, 0.0, 0.0],
// [0.0, 0.0, 0.0],
// [0.0, 0.0, 0.0]]
// indices = [[1, 0, 2],
// [0, 2, 1]]
// updates = [[1.0, 1.1, 1.2],
// [2.0, 2.1, 2.2]]
// output = [[2.0, 1.1, 0.0]
// [1.0, 0.0, 2.2]
// [0.0, 2.1, 1.2]]
//
// Example 2: Scatter elements along axis 1
// input = [[1.0, 2.0, 3.0, 4.0, 5.0]]
// indices = [[1, 3]]
// updates = [[1.1, 2.1]]
// output = [[1.0, 1.1, 3.0, 2.1, 5.0]]
//
// The values of `indices` tensor aren't known until graph execution, and may
// cause out-of-bounds write issue. A backend implementation must guarantee the
// index values do not cause invalid writes outside the output tensor, either by
// the platform API (e.g. DML) or by introducing a `clamp` operator in the
// compiled graph.
// The out-of-bounds `indices` handling is being discussed in WG:
// https://github.com/webmachinelearning/webnn/issues/486
struct ScatterElements {
// The operand id is used to get the `Operand` description from
// `GraphInfo.id_to_operand_map`.
// The `input`, `indices` and `updates` are ScatterElements operator input
// operands.
// Their id must be distinct from the output operand id.
uint64 input_operand_id;
uint64 indices_operand_id;
uint64 updates_operand_id;
uint64 output_operand_id;

// The axis of the output tensor to scatter on.
uint32 axis = 0;
// User defined label from MLOperatorOptions.
string label;
};

// Represents ScatterND operation that first copies the values of `input` to
// `output` tensor, and then overwrites the values of `output` tensor to values
// specified by `updates` tensor at specific index positions specified by
Expand Down Expand Up @@ -1317,6 +1376,7 @@ union Operation {
Relu relu;
Resample2d resample2d;
Reshape reshape;
ScatterElements scatter_elements;
ScatterND scatter_nd;
Sigmoid sigmoid;
Slice slice;
Expand Down
Loading

0 comments on commit 48716cb

Please sign in to comment.