Skip to content

Commit

Permalink
Modularize OneHot. (#4142)
Browse files Browse the repository at this point in the history
DEV
  • Loading branch information
lina128 authored Oct 27, 2020
1 parent 282ac81 commit af6e1d6
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 16 deletions.
16 changes: 0 additions & 16 deletions tfjs-backend-cpu/src/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1532,22 +1532,6 @@ export class MathBackendCPU extends KernelBackend {
return res;
}

oneHot(indices: Tensor1D, depth: number, onValue: number, offValue: number):
Tensor2D {
assertNotComplex(indices, 'oneHot');

const res = new Float32Array(indices.size * depth);
res.fill(offValue);
const indicesVal = this.readSync(indices.dataId) as TypedArray;

for (let event = 0; event < indices.size; ++event) {
if (indicesVal[event] >= 0 && indicesVal[event] < depth) {
res[event * depth + indicesVal[event]] = onValue;
}
}
return tf.tensor2d(res, [indices.size, depth], 'int32');
}

nonMaxSuppression(
boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number,
iouThreshold: number, scoreThreshold: number): Tensor1D {
Expand Down
51 changes: 51 additions & 0 deletions tfjs-backend-cpu/src/kernels/OneHot.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {KernelConfig, KernelFunc, OneHot, OneHotAttrs, OneHotInputs, TensorInfo, TypedArray, util} from '@tensorflow/tfjs-core';

import {MathBackendCPU} from '../backend_cpu';
import {assertNotComplex} from '../cpu_util';

export function oneHot(
args: {inputs: OneHotInputs, backend: MathBackendCPU, attrs: OneHotAttrs}):
TensorInfo {
const {inputs, backend, attrs} = args;
const {indices} = inputs;
const {depth, onValue, offValue} = attrs;

assertNotComplex(indices, 'oneHot');

const indicesSize = util.sizeFromShape(indices.shape);

const res = new Float32Array(indicesSize * depth);
res.fill(offValue);
const indicesVal = backend.data.get(indices.dataId).values as TypedArray;

for (let event = 0; event < indicesSize; ++event) {
if (indicesVal[event] >= 0 && indicesVal[event] < depth) {
res[event * depth + indicesVal[event]] = onValue;
}
}

return backend.makeTensorInfo([...indices.shape, depth], 'int32', res);
}

export const oneHotConfig: KernelConfig = {
kernelName: OneHot,
backendName: 'cpu',
kernelFunc: oneHot as {} as KernelFunc
};
2 changes: 2 additions & 0 deletions tfjs-backend-cpu/src/register_all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ import {multiplyConfig} from './kernels/Multiply';
import {nonMaxSuppressionV4Config} from './kernels/NonMaxSuppressionV4';
import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5';
import {notEqualConfig} from './kernels/NotEqual';
import {oneHotConfig} from './kernels/OneHot';
import {padV2Config} from './kernels/PadV2';
import {preluConfig} from './kernels/Prelu';
import {realConfig} from './kernels/Real';
Expand Down Expand Up @@ -171,6 +172,7 @@ const kernelConfigs: KernelConfig[] = [
nonMaxSuppressionV4Config,
nonMaxSuppressionV5Config,
notEqualConfig,
oneHotConfig,
padV2Config,
preluConfig,
realConfig,
Expand Down

0 comments on commit af6e1d6

Please sign in to comment.