diff --git a/tfjs-backend-webgpu/src/kernels/Elu.ts b/tfjs-backend-webgpu/src/kernels/Elu.ts new file mode 100644 index 00000000000..3aab05bd278 --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/Elu.ts @@ -0,0 +1,28 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Elu, KernelConfig} from '@tensorflow/tfjs-core'; +import {unaryKernelFunc} from '../kernel_utils/kernel_funcs_utils'; +import {ELU} from './unary_op_webgpu'; + +export const elu = unaryKernelFunc({opSnippet: ELU}); + +export const eluConfig: KernelConfig = { + kernelName: Elu, + backendName: 'webgpu', + kernelFunc: elu +}; diff --git a/tfjs-backend-webgpu/src/register_all_kernels.ts b/tfjs-backend-webgpu/src/register_all_kernels.ts index 199bc1d59ad..10f284f9a08 100644 --- a/tfjs-backend-webgpu/src/register_all_kernels.ts +++ b/tfjs-backend-webgpu/src/register_all_kernels.ts @@ -33,6 +33,7 @@ import {concatConfig} from './kernels/Concat'; import {conv2DConfig} from './kernels/Conv2D'; import {cropAndResizeConfig} from './kernels/CropAndResize'; import {depthwiseConv2dNativeConfig} from './kernels/DepthwiseConv2dNative'; +import {eluConfig} from './kernels/Elu'; import {expConfig} from './kernels/Exp'; import {expandDimsConfig} from './kernels/ExpandDims'; import {expm1Config} from './kernels/Expm1'; @@ -104,6 +105,7 @@ const kernelConfigs: KernelConfig[] = [ conv2DConfig, cropAndResizeConfig, depthwiseConv2dNativeConfig, + eluConfig, expandDimsConfig, expConfig, expm1Config, diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index b8e5bf99fd1..78d48dd060a 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -206,6 +206,14 @@ const TEST_FILTERS: TestFilter[] = [ 'leakyrelu' // Not yet implemented. ] }, + { + include: 'elu', + excludes: [ + 'selu', // Not yet implemented. + 'derivative', // gradient function not found. + 'gradient' // gradient function not found. + ] + }, { include: 'resizeBilinear', excludes: [