Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request the decomposition for gatherElements, scatterElements and scatterND #767

Open
fujunwei opened this issue Oct 12, 2024 · 3 comments

Comments

@fujunwei
Copy link

Operator Notes TFLite
gatherElements Gathers values from the input using indices along the given axis.
Related: mil.ops.defs.iOS15.scatter_gather.gather_along_axis, tf.experimental.numpy.take_along_axis?, DML_GATHER_ELEMENTS
Decomposition: NA
Data Types: input (*), indices (int32, uint32, int64)
scatterElements Corollary for pair-completeness with gatherElements.
Related: mil.ops.defs.iOS17.scatter_gather.scatter_along_axis, TF=?, DML_SCATTER_ELEMENTS
Decomposition: NA
Data Types: input (*), updates (same as input), indices (int32, uint32, int64)
scatterND Scatter values using multidimensional indices. This is also useful for improving performance of ferrying partial MLBuffer transputs between iterations (transformer key-value reuse).
Related: tf.scatter_nd, mil.ops.defs.iOS15.scatter_gather.scatter_nd, DML_SCATTER_ND
Decomposition: NA
Data Types: input (*), updates (same as input), indices (int32, uint32, int64)
tf.scatter_nd only scatter zero-initialized tensor with the third argument that only specifies shape not value like webnn's input
@fujunwei fujunwei changed the title TFLite requires the decomposition for gatherElements, scatterElements and gatherND Request the decomposition for gatherElements, scatterElements and gatherND Oct 12, 2024
@huningxin
Copy link
Contributor

One possible emulation of scatterND by using tf.scatter_nd:

# Make an all True tensor in updates.shape which will be scattered to the condition tensor.
trues = tf.ones(updates.shape, tf.dtypes.bool)
# Scatter the True values into a zero (False) initialized tensor according to indices.
condition = tf.scatter_nd(indices, trues, input.shape)
# Scatter the values of updates into another zero-initialized tensor according to indices.
scatter = tf.scatter_nd(indices, updates, input.shape)
# Select scattered value or input value based on condition.
output = tf.where(condition, scatter, input)

Test case

>>> input = tf.constant([[[1, 2, 3, 4], [5, 6, 7, 8], [8, 7, 6, 5], [4, 3, 2, 1]],
                         [[1, 2, 3, 4], [5, 6, 7, 8], [8, 7, 6, 5], [4, 3, 2, 1]],
                         [[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]],
                         [[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]]])
>>> indices = tf.constant([[0], [2]])
>>> updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
                           [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]]])
# Run above emulation
>>> print(output)
tf.Tensor(
[[[5 5 5 5]
  [6 6 6 6]
  [7 7 7 7]
  [8 8 8 8]]

 [[1 2 3 4]
  [5 6 7 8]
  [8 7 6 5]
  [4 3 2 1]]

 [[1 1 1 1]
  [2 2 2 2]
  [3 3 3 3]
  [4 4 4 4]]

 [[8 7 6 5]
  [4 3 2 1]
  [1 2 3 4]
  [5 6 7 8]]], shape=(4, 4, 4), dtype=int32)

@anssiko anssiko changed the title Request the decomposition for gatherElements, scatterElements and gatherND Request the decomposition for gatherElements, scatterElements and scatterND Oct 31, 2024
@fujunwei
Copy link
Author

fujunwei commented Nov 6, 2024

Maybe the gatherElements can be supported with tfl.gather_nd, the test cases and the doc show the gather_nd can gether not only slices but also elements, but gather_nd has no axis argument, so the indices need to be converted, for example

an indices = [[1, 0],
              [2, 1],                                                      
              [0, 2]] with shape (3, 2),

Convert to

indices =      [[0 ,1]
                [0, 0],
                [1, 2]
                [1, 1],                                                      
                [2, 0] 
                [2, 2]] with shape (6, 2),

So the indices must be constant operand and insert the location with loop, the two dimensions of input is required at current stage.

@huningxin
Copy link
Contributor

huningxin commented Nov 15, 2024

@fujunwei , thanks for sharing the idea of emulating gatherElements by gatherND!

the two dimensions of input is required at current stage.

I suppose the emulation could support N dimensions input as follows:

// Generate gatherND's indicesND from gatherElements' indices and axis.
// indicesND.rank == 2 which can be treated as an array of locations.
// indicesND.shape[0] == the number of elements of indices. indicesND.shape[1] == indices.rank. 
let indicesND = [];

for (let i = 0; i < indices.numberOfElements(); ++i) {
  let location = indices.getLocationFromIndex(i);
  location[axis] = indices.getValueByLocation(location);
  indicesND.push(location);
}

// For gatherElements, output.shape == indices.shape
reshape(gatherND(input, indicesND), indices.shape);

webnn-baseline has a reference implementation of getLocationFromIndex() and getValueByLocation().

Please note this emulation only works when indices is a constant operand.

Revisions:
11/18/2024: provided a non-nested loop version

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants