Skip to content

Commit

Permalink
Address comments of adding more tests and enhancing readability
Browse files Browse the repository at this point in the history
  • Loading branch information
BruceDai committed Dec 13, 2023
1 parent cf4cae7 commit ff919e5
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 70 deletions.
2 changes: 2 additions & 0 deletions src/arg_max_min.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ export function argMaxMin(
keepDimensions = false,
selectLastIndex = false,
} = {}) {
// If axes doesn't present (defaulting to null), all dimensions are reduced.
// See https://webmachinelearning.github.io/webnn/#dom-mlargminmaxoptions-axes.
const inpAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i);
const outputShape = input.shape.slice();

Expand Down
94 changes: 94 additions & 0 deletions test/arg_max_min_test.js
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,53 @@ describe('test argMax and argMin', function() {
});
});

it('argMax 3D discontinuous axes=[0, 2]', function() {
testArgMaxMin(
{
shape: [2, 3, 4],
value: [
1, 2, 3, 3,
3, 4, 4, 9,
5, 9, 5, 0,
8, 7, 8, 6,
9, 4, 5, 2,
1, 9, 4, 3,
],
},
{
shape: [3],
value: [4, 3, 1],
},
argMax,
{
axes: [0, 2],
});
});

it('argMax 3D discontinuous axes=[2, 0] selectLastIndex=true', function() {
testArgMaxMin(
{
shape: [2, 3, 4],
value: [
1, 2, 3, 3,
3, 4, 4, 9,
5, 9, 5, 0,
8, 7, 8, 6,
9, 4, 5, 2,
1, 9, 4, 3,
],
},
{
shape: [3],
value: [6, 4, 5],
},
argMax,
{
axes: [2, 0],
selectLastIndex: true,
});
});

it('argMax 3D axes=[2, 1] selectLastIndex=true', function() {
testArgMaxMin(
{
Expand Down Expand Up @@ -477,6 +524,53 @@ describe('test argMax and argMin', function() {
});
});

it('argMin 3D discontinuous axes=[0, 2]', function() {
testArgMaxMin(
{
shape: [2, 3, 4],
value: [
1, 2, 3, 3,
3, 4, 2, 9,
5, 9, 5, 0,
8, 7, 1, 6,
9, 3, 5, 2,
1, 9, 0, 4,
],
},
{
shape: [3],
value: [0, 2, 3],
},
argMin,
{
axes: [0, 2],
});
});

it('argMin 3D discontinuous axes=[2, 0] selectLastIndex=true', function() {
testArgMaxMin(
{
shape: [2, 3, 4],
value: [
1, 2, 3, 3,
3, 4, 2, 9,
5, 9, 5, 0,
8, 7, 1, 6,
9, 3, 5, 2,
1, 9, 0, 4,
],
},
{
shape: [3],
value: [6, 7, 6],
},
argMin,
{
axes: [2, 0],
selectLastIndex: true,
});
});

it('argMin 3D axes=[1, 2]', function() {
testArgMaxMin(
{
Expand Down
Loading

0 comments on commit ff919e5

Please sign in to comment.