Skip to content

Commit

Permalink
[webgpu] Add support for OnesLike (#4788)
Browse files Browse the repository at this point in the history
FEATURE
  • Loading branch information
haoyunfeix authored Mar 9, 2021
1 parent f21938f commit b19fcc5
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 0 deletions.
58 changes: 58 additions & 0 deletions tfjs-backend-webgpu/src/kernels/OnesLike.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {KernelConfig, KernelFunc, OnesLike, OnesLikeInputs, TensorInfo} from '@tensorflow/tfjs-core';

import {WebGPUBackend} from '../backend_webgpu';

import {complex} from './Complex';
import {fill} from './Fill';
import {imag} from './Imag';
import {real} from './Real';
import {zerosLike} from './ZerosLike';

export function onesLike(
args: {inputs: OnesLikeInputs, backend: WebGPUBackend}): TensorInfo {
const {inputs, backend} = args;
const {x} = inputs;

if (x.dtype === 'string') {
throw new Error('onesLike is not supported under string dtype');
} else if (x.dtype === 'complex64') {
const realPart = real({inputs: {input: x}, backend});
const r = onesLike({inputs: {x: realPart}, backend});
const imagPart = imag({inputs: {input: x}, backend});
const i = zerosLike({inputs: {x: imagPart}, backend});

const result = complex({inputs: {real: r, imag: i}, backend});

backend.disposeData(realPart.dataId);
backend.disposeData(r.dataId);
backend.disposeData(imagPart.dataId);
backend.disposeData(i.dataId);

return result;
} else {
return fill({attrs: {shape: x.shape, dtype: x.dtype, value: 1}, backend});
}
}

export const onesLikeConfig: KernelConfig = {
kernelName: OnesLike,
backendName: 'webgpu',
kernelFunc: onesLike as {} as KernelFunc
};
2 changes: 2 additions & 0 deletions tfjs-backend-webgpu/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ import {negConfig} from './kernels/Neg';
import {nonMaxSuppressionV3Config} from './kernels/NonMaxSuppressionV3';
import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5';
import {notEqualConfig} from './kernels/NotEqual';
import {onesLikeConfig} from './kernels/OnesLike';
import {packConfig} from './kernels/Pack';
import {padV2Config} from './kernels/PadV2';
import {preluConfig} from './kernels/Prelu';
Expand Down Expand Up @@ -134,6 +135,7 @@ const kernelConfigs: KernelConfig[] = [
nonMaxSuppressionV3Config,
nonMaxSuppressionV5Config,
notEqualConfig,
onesLikeConfig,
packConfig,
padV2Config,
preluConfig,
Expand Down
8 changes: 8 additions & 0 deletions tfjs-backend-webgpu/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,14 @@ const TEST_FILTERS: TestFilter[] = [
'6D', // rank 6 is not yet supported.
'gradient' // gradient function not found.
]
},
{
include: 'onesLike',
excludes: [
'5D', // rank 5 is not yet supported.
'6D', // rank 6 is not yet supported.
'gradient' // gradient function not found.
]
}
];

Expand Down

0 comments on commit b19fcc5

Please sign in to comment.