Skip to content

Commit

Permalink
Add 'OneHot' op. (#36)
Browse files Browse the repository at this point in the history
* Add relu op.

* Add multiply.

* Change type op attr helper method to dtype.

* Cleanup.

* basic math.

* use expectArraysClose()

* Add reverse()

* Add neg - add comment about supporting concat.

* sum()

* More ops.

* Fixup concat()

* wip

* Some cleanup in the backend.

* min and minimum()

* Add basic pow and TODO for handling tensor upcasting of types.

* More single-input ops.

* wip

* cleanup
  • Loading branch information
nkreeger authored Mar 21, 2018
1 parent 84e50f7 commit 4ccb00e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
34 changes: 20 additions & 14 deletions src/nodejs_kernel_backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,8 @@ export class NodeJSKernelBackend implements KernelBackend {

slice<T extends Tensor<Rank>>(x: T, begin: number[], size: number[]): T {
const opAttrs = [
this.createTypeOpAttr('T', x.dtype), {
name: 'Index',
type: this.binding.TF_ATTR_TYPE,
value: this.binding.TF_INT32
}
this.createTypeOpAttr('T', x.dtype),
this.createTypeOpAttr('Index', 'int32')
];

// Bind tensor values
Expand Down Expand Up @@ -570,18 +567,15 @@ export class NodeJSKernelBackend implements KernelBackend {
}
pad<T extends Tensor<Rank>>(
x: T, paddings: Array<[number, number]>, constantValue: number): T {
const opAttrs = [
this.createTypeOpAttr('T', x.dtype), {
name: 'Tpaddings',
type: this.binding.TF_ATTR_TYPE,
value: this.binding.TF_INT32
}
];

// Bind tensor values
const paddingsTensor = tensor2d(paddings, [2, 2], 'int32');
const constantTensor = scalar(constantValue, x.dtype);

const opAttrs = [
this.createTypeOpAttr('T', x.dtype),
this.createTypeOpAttr('Tpaddings', paddingsTensor.dtype)
];

return this.execute(
'PadV2', opAttrs, [x, paddingsTensor, constantTensor]) as T;
}
Expand Down Expand Up @@ -613,7 +607,19 @@ export class NodeJSKernelBackend implements KernelBackend {
}
oneHot(indices: Tensor1D, depth: number, onValue: number, offValue: number):
Tensor2D {
throw new Error('Method not implemented.');
const depthTensor = scalar(depth, 'int32');
const onValueTensor = scalar(onValue, 'int32');
const offValueTensor = scalar(offValue, 'int32');

const opAttrs = [
{name: 'axis', type: this.binding.TF_ATTR_INT, value: -1},
this.createTypeOpAttr('T', indices.dtype),
this.createTypeOpAttr('TI', indices.dtype)
];

return this.execute('OneHot', opAttrs, [
indices, depthTensor, onValueTensor, offValueTensor
]) as Tensor2D;
}
dispose(): void {
throw new Error('Method not implemented.');
Expand Down
10 changes: 10 additions & 0 deletions src/nodejs_kernel_backend_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -489,3 +489,13 @@ describe('tanh', () => {
expectArraysClose(result, expected);
});
});

describe('oneHot', () => {
it('should work', () => {
const indices = dl.tensor1d([0, 1], 'int32');
const res = dl.oneHot(indices, 2);

expect(res.shape).toEqual([2, 2]);
expectArraysClose(res, [1, 0, 0, 1]);
});
});

0 comments on commit 4ccb00e

Please sign in to comment.