diff --git a/tfjs-backend-webgpu/src/kernels/Pack.ts b/tfjs-backend-webgpu/src/kernels/Pack.ts index 9ebe618d51a..8e6c7250f67 100644 --- a/tfjs-backend-webgpu/src/kernels/Pack.ts +++ b/tfjs-backend-webgpu/src/kernels/Pack.ts @@ -54,7 +54,7 @@ export function pack( const result = concat({inputs: expandedTensors, backend, attrs: {axis}}); - intermediateTensorInfos.forEach(t => backend.disposeData(t)); + intermediateTensorInfos.forEach(t => backend.disposeData(t.dataId)); return result; } diff --git a/tfjs-backend-webgpu/src/kernels/Unpack.ts b/tfjs-backend-webgpu/src/kernels/Unpack.ts new file mode 100644 index 00000000000..130bc5219b5 --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/Unpack.ts @@ -0,0 +1,73 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, KernelFunc, TensorInfo, Unpack, UnpackAttrs, UnpackInputs} from '@tensorflow/tfjs-core'; + +import {WebGPUBackend} from '../backend_webgpu'; + +import {reshape} from './Reshape'; +import {slice} from './Slice'; + +export function unpack( + args: + {inputs: UnpackInputs, backend: WebGPUBackend, attrs: UnpackAttrs}): + TensorInfo[] { + const {inputs, backend, attrs} = args; + const {value} = inputs; + let {axis} = attrs; + + if (axis < 0) { + axis += value.shape.length; + } + + const x = value; + const xRank = x.shape.length; + + const num = value.shape[axis]; + const outShape: number[] = new Array(xRank - 1); + let outIndex = 0; + for (let i = 0; i < xRank; i++) { + if (i !== axis) { + outShape[outIndex++] = x.shape[i]; + } + } + + const toDispose = []; + + const begin = new Array(xRank).fill(0); + const size = x.shape.slice(); + size[axis] = 1; + const res: TensorInfo[] = new Array(num); + for (let i = 0; i < res.length; i++) { + begin[axis] = i; + const sliced = slice({inputs: {x}, backend, attrs: {begin, size}}); + const reshaped = + reshape({inputs: {x: sliced}, backend, attrs: {shape: outShape}}); + res[i] = reshaped; + + toDispose.push(sliced); + } + + toDispose.forEach(t => backend.disposeData(t.dataId)); + return res; +} + +export const unpackConfig: KernelConfig = { + kernelName: Unpack, + backendName: 'webgpu', + kernelFunc: unpack as {} as KernelFunc +}; diff --git a/tfjs-backend-webgpu/src/register_all_kernels.ts b/tfjs-backend-webgpu/src/register_all_kernels.ts index e692c9d2109..3fe1c64d86c 100644 --- a/tfjs-backend-webgpu/src/register_all_kernels.ts +++ b/tfjs-backend-webgpu/src/register_all_kernels.ts @@ -85,6 +85,7 @@ import {subConfig} from './kernels/Sub'; import {sumConfig} from './kernels/Sum'; import {tanhConfig} from './kernels/Tanh'; import {transposeConfig} from './kernels/Transpose'; +import {unpackConfig} from './kernels/Unpack'; import {zerosLikeConfig} from './kernels/ZerosLike'; // List all kernel configs here @@ -158,6 +159,7 @@ const kernelConfigs: KernelConfig[] = [ sumConfig, tanhConfig, transposeConfig, + unpackConfig, zerosLikeConfig ]; diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index 630d97e870d..c73a97c8f3c 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -523,7 +523,18 @@ const TEST_FILTERS: TestFilter[] = [ include: 'stack', excludes: [ 'accepts string', - 'unstack', + 'grad of unstack axis=0', // Remove this when grad is fixed in unstack. + 'gradient with clones', // Remove this when grad is fixed in unstack. + 'grad of unstack axis=1', // Remove this when grad is fixed in unstack. + ] + }, + { + include: 'unstack', + excludes: [ + 'accepts string', + 'grad of unstack axis=0', + 'gradient with clones', + 'grad of unstack axis=1', ] }, {