From b19fcc55ac52a3773003262c53d1b08cfb7fa348 Mon Sep 17 00:00:00 2001 From: Hao Yunfei Date: Tue, 9 Mar 2021 13:27:27 +0800 Subject: [PATCH] [webgpu] Add support for OnesLike (#4788) FEATURE --- tfjs-backend-webgpu/src/kernels/OnesLike.ts | 58 +++++++++++++++++++ .../src/register_all_kernels.ts | 2 + tfjs-backend-webgpu/src/setup_test.ts | 8 +++ 3 files changed, 68 insertions(+) create mode 100644 tfjs-backend-webgpu/src/kernels/OnesLike.ts diff --git a/tfjs-backend-webgpu/src/kernels/OnesLike.ts b/tfjs-backend-webgpu/src/kernels/OnesLike.ts new file mode 100644 index 00000000000..5c17436af44 --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/OnesLike.ts @@ -0,0 +1,58 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, KernelFunc, OnesLike, OnesLikeInputs, TensorInfo} from '@tensorflow/tfjs-core'; + +import {WebGPUBackend} from '../backend_webgpu'; + +import {complex} from './Complex'; +import {fill} from './Fill'; +import {imag} from './Imag'; +import {real} from './Real'; +import {zerosLike} from './ZerosLike'; + +export function onesLike( + args: {inputs: OnesLikeInputs, backend: WebGPUBackend}): TensorInfo { + const {inputs, backend} = args; + const {x} = inputs; + + if (x.dtype === 'string') { + throw new Error('onesLike is not supported under string dtype'); + } else if (x.dtype === 'complex64') { + const realPart = real({inputs: {input: x}, backend}); + const r = onesLike({inputs: {x: realPart}, backend}); + const imagPart = imag({inputs: {input: x}, backend}); + const i = zerosLike({inputs: {x: imagPart}, backend}); + + const result = complex({inputs: {real: r, imag: i}, backend}); + + backend.disposeData(realPart.dataId); + backend.disposeData(r.dataId); + backend.disposeData(imagPart.dataId); + backend.disposeData(i.dataId); + + return result; + } else { + return fill({attrs: {shape: x.shape, dtype: x.dtype, value: 1}, backend}); + } +} + +export const onesLikeConfig: KernelConfig = { + kernelName: OnesLike, + backendName: 'webgpu', + kernelFunc: onesLike as {} as KernelFunc +}; diff --git a/tfjs-backend-webgpu/src/register_all_kernels.ts b/tfjs-backend-webgpu/src/register_all_kernels.ts index 10f284f9a08..e692c9d2109 100644 --- a/tfjs-backend-webgpu/src/register_all_kernels.ts +++ b/tfjs-backend-webgpu/src/register_all_kernels.ts @@ -62,6 +62,7 @@ import {negConfig} from './kernels/Neg'; import {nonMaxSuppressionV3Config} from './kernels/NonMaxSuppressionV3'; import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5'; import {notEqualConfig} from './kernels/NotEqual'; +import {onesLikeConfig} from './kernels/OnesLike'; import {packConfig} from './kernels/Pack'; import {padV2Config} from './kernels/PadV2'; import {preluConfig} from './kernels/Prelu'; @@ -134,6 +135,7 @@ const kernelConfigs: KernelConfig[] = [ nonMaxSuppressionV3Config, nonMaxSuppressionV5Config, notEqualConfig, + onesLikeConfig, packConfig, padV2Config, preluConfig, diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index 3568f4bbfb0..630d97e870d 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -541,6 +541,14 @@ const TEST_FILTERS: TestFilter[] = [ '6D', // rank 6 is not yet supported. 'gradient' // gradient function not found. ] + }, + { + include: 'onesLike', + excludes: [ + '5D', // rank 5 is not yet supported. + '6D', // rank 6 is not yet supported. + 'gradient' // gradient function not found. + ] } ];