-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WebNN: Add IDL and mojo definitions for
scatterND
operator
The `scatterND` operator is proposed by WebML WG [1] for improving performance of passing partial MLBuffer/MLTensor between transformers decoder iterations for key-value reuse. This CL also implements data type limits, inputs validation and adds validation tests for `scatterND` operator. [1]: webmachinelearning/webnn#375 (comment) Bug: 363677531 Change-Id: Ib68db63c7b51b99f9976b1eb3c06f8b7f0de9f97 Cq-Include-Trybots: luci.chromium.try:win11-blink-rel, mac14.arm64-blink-rel, mac14-blink-rel, mac15.arm64-blink-rel, mac15-blink-rel, linux-blink-rel Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5829103 Reviewed-by: Alex Gough <[email protected]> Reviewed-by: Weizhong Xia <[email protected]> Reviewed-by: Reilly Grant <[email protected]> Commit-Queue: ningxin hu <[email protected]> Cr-Commit-Position: refs/heads/main@{#1353781}
- Loading branch information
1 parent
e943689
commit 49e4115
Showing
1 changed file
with
102 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
// META: title=validation tests for WebNN API scatterND operation | ||
// META: global=window,dedicatedworker | ||
// META: variant=?cpu | ||
// META: variant=?gpu | ||
// META: variant=?npu | ||
// META: script=../resources/utils_validation.js | ||
|
||
'use strict'; | ||
|
||
const tests = [ | ||
{ | ||
name: '[scatterND] Test scatterND with valid tensors', | ||
input: {dataType: 'float32', dimensions: [4, 4, 4]}, | ||
indices: {dataType: 'int32', dimensions: [2, 1]}, | ||
updates: {dataType: 'float32', dimensions: [2, 4, 4]}, | ||
output: {dataType: 'float32', dimensions: [4, 4, 4]} | ||
}, | ||
{ | ||
name: | ||
'[scatterND] Throw if updates tensor data type is not the same as input data type', | ||
input: {dataType: 'float32', dimensions: [4, 4, 4]}, | ||
indices: {dataType: 'int32', dimensions: [2, 1]}, | ||
updates: {dataType: 'float16', dimensions: [2, 4, 4]}, | ||
}, | ||
{ | ||
name: '[scatterND] Throw if input is a scalar', | ||
input: {dataType: 'float32', dimensions: []}, | ||
indices: {dataType: 'int32', dimensions: [2, 1]}, | ||
updates: {dataType: 'float32', dimensions: [2, 4, 4]}, | ||
}, | ||
{ | ||
name: '[scatterND] Throw if indices is a scalar', | ||
input: {dataType: 'float32', dimensions: [4, 4, 4]}, | ||
indices: {dataType: 'int32', dimensions: []}, | ||
updates: {dataType: 'float32', dimensions: [2, 4, 4]}, | ||
}, | ||
{ | ||
name: | ||
'[scatterND] Throw if the size of last dimension of indices tensor is greater than input rank', | ||
input: {dataType: 'float32', dimensions: [4, 4, 4]}, | ||
indices: {dataType: 'int32', dimensions: [2, 4]}, | ||
updates: {dataType: 'float32', dimensions: [2, 4, 4]}, | ||
}, | ||
{ | ||
name: '[scatterND] Throw if updates tensor shape is invalid.', | ||
input: {dataType: 'float32', dimensions: [4, 4, 4]}, | ||
indices: {dataType: 'int32', dimensions: [2, 1]}, | ||
// Updates tensor shape should be [2, 4, 4]. | ||
updates: {dataType: 'float32', dimensions: [2, 3, 4]}, | ||
} | ||
]; | ||
|
||
tests.forEach(test => promise_test(async t => { | ||
const builder = new MLGraphBuilder(context); | ||
const input = builder.input('input', test.input); | ||
const indices = builder.input('indices', test.indices); | ||
const updates = builder.input('updates', test.updates); | ||
|
||
if (test.output) { | ||
const output = builder.scatterND(input, indices, updates); | ||
assert_equals(output.dataType(), test.output.dataType); | ||
assert_array_equals(output.shape(), test.output.dimensions); | ||
} else { | ||
const label = 'a_scatter_nd' | ||
const options = {label}; | ||
const regexp = new RegExp('\\[' + label + '\\]'); | ||
assert_throws_with_label( | ||
() => builder.scatterND(input, indices, updates, options), | ||
regexp); | ||
} | ||
}, test.name)); | ||
|
||
multi_builder_test(async (t, builder, otherBuilder) => { | ||
const input = | ||
otherBuilder.input('input', {dataType: 'float32', dimensions: [8]}); | ||
const indices = | ||
builder.input('indices', {dataType: 'int32', dimensions: [4, 1]}); | ||
const updates = | ||
builder.input('indices', {dataType: 'int32', dimensions: [4]}); | ||
|
||
assert_throws_js(TypeError, () => builder.scatterND(input, indices, updates)); | ||
}, '[scatterND] Throw if input is from another builder'); | ||
|
||
multi_builder_test(async (t, builder, otherBuilder) => { | ||
const input = builder.input('input', {dataType: 'float32', dimensions: [8]}); | ||
const indices = | ||
otherBuilder.input('indices', {dataType: 'int32', dimensions: [4, 1]}); | ||
const updates = | ||
builder.input('indices', {dataType: 'int32', dimensions: [4]}); | ||
|
||
assert_throws_js(TypeError, () => builder.scatterND(input, indices, updates)); | ||
}, '[scatterND] Throw if indcies is from another builder'); | ||
|
||
multi_builder_test(async (t, builder, otherBuilder) => { | ||
const input = builder.input('input', {dataType: 'float32', dimensions: [8]}); | ||
const indices = | ||
builder.input('indices', {dataType: 'int32', dimensions: [4, 1]}); | ||
const updates = | ||
otherBuilder.input('indices', {dataType: 'int32', dimensions: [4]}); | ||
|
||
assert_throws_js(TypeError, () => builder.scatterND(input, indices, updates)); | ||
}, '[scatterND] Throw if updates is from another builder'); |