Skip to content

Commit 4ccb00e

Browse files
authored
Add 'OneHot' op. (#36)
* Add relu op. * Add multiply. * Change type op attr helper method to dtype. * Cleanup. * basic math. * use expectArraysClose() * Add reverse() * Add neg - add comment about supporting concat. * sum() * More ops. * Fixup concat() * wip * Some cleanup in the backend. * min and minimum() * Add basic pow and TODO for handling tensor upcasting of types. * More single-input ops. * wip * cleanup
1 parent 84e50f7 commit 4ccb00e

File tree

2 files changed

+30
-14
lines changed

2 files changed

+30
-14
lines changed

src/nodejs_kernel_backend.ts

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,8 @@ export class NodeJSKernelBackend implements KernelBackend {
127127

128128
slice<T extends Tensor<Rank>>(x: T, begin: number[], size: number[]): T {
129129
const opAttrs = [
130-
this.createTypeOpAttr('T', x.dtype), {
131-
name: 'Index',
132-
type: this.binding.TF_ATTR_TYPE,
133-
value: this.binding.TF_INT32
134-
}
130+
this.createTypeOpAttr('T', x.dtype),
131+
this.createTypeOpAttr('Index', 'int32')
135132
];
136133

137134
// Bind tensor values
@@ -570,18 +567,15 @@ export class NodeJSKernelBackend implements KernelBackend {
570567
}
571568
pad<T extends Tensor<Rank>>(
572569
x: T, paddings: Array<[number, number]>, constantValue: number): T {
573-
const opAttrs = [
574-
this.createTypeOpAttr('T', x.dtype), {
575-
name: 'Tpaddings',
576-
type: this.binding.TF_ATTR_TYPE,
577-
value: this.binding.TF_INT32
578-
}
579-
];
580-
581570
// Bind tensor values
582571
const paddingsTensor = tensor2d(paddings, [2, 2], 'int32');
583572
const constantTensor = scalar(constantValue, x.dtype);
584573

574+
const opAttrs = [
575+
this.createTypeOpAttr('T', x.dtype),
576+
this.createTypeOpAttr('Tpaddings', paddingsTensor.dtype)
577+
];
578+
585579
return this.execute(
586580
'PadV2', opAttrs, [x, paddingsTensor, constantTensor]) as T;
587581
}
@@ -613,7 +607,19 @@ export class NodeJSKernelBackend implements KernelBackend {
613607
}
614608
oneHot(indices: Tensor1D, depth: number, onValue: number, offValue: number):
615609
Tensor2D {
616-
throw new Error('Method not implemented.');
610+
const depthTensor = scalar(depth, 'int32');
611+
const onValueTensor = scalar(onValue, 'int32');
612+
const offValueTensor = scalar(offValue, 'int32');
613+
614+
const opAttrs = [
615+
{name: 'axis', type: this.binding.TF_ATTR_INT, value: -1},
616+
this.createTypeOpAttr('T', indices.dtype),
617+
this.createTypeOpAttr('TI', indices.dtype)
618+
];
619+
620+
return this.execute('OneHot', opAttrs, [
621+
indices, depthTensor, onValueTensor, offValueTensor
622+
]) as Tensor2D;
617623
}
618624
dispose(): void {
619625
throw new Error('Method not implemented.');

src/nodejs_kernel_backend_test.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,3 +489,13 @@ describe('tanh', () => {
489489
expectArraysClose(result, expected);
490490
});
491491
});
492+
493+
describe('oneHot', () => {
494+
it('should work', () => {
495+
const indices = dl.tensor1d([0, 1], 'int32');
496+
const res = dl.oneHot(indices, 2);
497+
498+
expect(res.shape).toEqual([2, 2]);
499+
expectArraysClose(res, [1, 0, 0, 1]);
500+
});
501+
});

0 commit comments

Comments
 (0)