Skip to content

Commit

Permalink
[webgpu] Add unpack (#4789)
Browse files Browse the repository at this point in the history
FEATURE
  • Loading branch information
axinging authored Mar 9, 2021
1 parent ed1b9ae commit 85c9ae0
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tfjs-backend-webgpu/src/kernels/Pack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ export function pack(

const result = concat({inputs: expandedTensors, backend, attrs: {axis}});

intermediateTensorInfos.forEach(t => backend.disposeData(t));
intermediateTensorInfos.forEach(t => backend.disposeData(t.dataId));

return result;
}
Expand Down
73 changes: 73 additions & 0 deletions tfjs-backend-webgpu/src/kernels/Unpack.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {KernelConfig, KernelFunc, TensorInfo, Unpack, UnpackAttrs, UnpackInputs} from '@tensorflow/tfjs-core';

import {WebGPUBackend} from '../backend_webgpu';

import {reshape} from './Reshape';
import {slice} from './Slice';

export function unpack(
args:
{inputs: UnpackInputs, backend: WebGPUBackend, attrs: UnpackAttrs}):
TensorInfo[] {
const {inputs, backend, attrs} = args;
const {value} = inputs;
let {axis} = attrs;

if (axis < 0) {
axis += value.shape.length;
}

const x = value;
const xRank = x.shape.length;

const num = value.shape[axis];
const outShape: number[] = new Array(xRank - 1);
let outIndex = 0;
for (let i = 0; i < xRank; i++) {
if (i !== axis) {
outShape[outIndex++] = x.shape[i];
}
}

const toDispose = [];

const begin = new Array(xRank).fill(0);
const size = x.shape.slice();
size[axis] = 1;
const res: TensorInfo[] = new Array(num);
for (let i = 0; i < res.length; i++) {
begin[axis] = i;
const sliced = slice({inputs: {x}, backend, attrs: {begin, size}});
const reshaped =
reshape({inputs: {x: sliced}, backend, attrs: {shape: outShape}});
res[i] = reshaped;

toDispose.push(sliced);
}

toDispose.forEach(t => backend.disposeData(t.dataId));
return res;
}

export const unpackConfig: KernelConfig = {
kernelName: Unpack,
backendName: 'webgpu',
kernelFunc: unpack as {} as KernelFunc
};
2 changes: 2 additions & 0 deletions tfjs-backend-webgpu/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ import {subConfig} from './kernels/Sub';
import {sumConfig} from './kernels/Sum';
import {tanhConfig} from './kernels/Tanh';
import {transposeConfig} from './kernels/Transpose';
import {unpackConfig} from './kernels/Unpack';
import {zerosLikeConfig} from './kernels/ZerosLike';

// List all kernel configs here
Expand Down Expand Up @@ -158,6 +159,7 @@ const kernelConfigs: KernelConfig[] = [
sumConfig,
tanhConfig,
transposeConfig,
unpackConfig,
zerosLikeConfig
];

Expand Down
13 changes: 12 additions & 1 deletion tfjs-backend-webgpu/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,18 @@ const TEST_FILTERS: TestFilter[] = [
include: 'stack',
excludes: [
'accepts string',
'unstack',
'grad of unstack axis=0', // Remove this when grad is fixed in unstack.
'gradient with clones', // Remove this when grad is fixed in unstack.
'grad of unstack axis=1', // Remove this when grad is fixed in unstack.
]
},
{
include: 'unstack',
excludes: [
'accepts string',
'grad of unstack axis=0',
'gradient with clones',
'grad of unstack axis=1',
]
},
{
Expand Down

0 comments on commit 85c9ae0

Please sign in to comment.