From 73069a934a12690d7b7d7a6a3e99d9140c7a7b6a Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sun, 8 Sep 2024 09:36:00 -0700 Subject: [PATCH] Update mxnet.js --- source/mxnet-metadata.json | 10 ++++----- source/mxnet.js | 42 ++++++++++++++++---------------------- 2 files changed, 23 insertions(+), 29 deletions(-) diff --git a/source/mxnet-metadata.json b/source/mxnet-metadata.json index 51be355051..37e001f87e 100644 --- a/source/mxnet-metadata.json +++ b/source/mxnet-metadata.json @@ -75,7 +75,7 @@ { "name": "_Plus", "inputs": [ - { "name": "inputs", "option": "variadic" } + { "name": "inputs", "type": "Tensor[]" } ], "outputs": [ { "name": "output" } @@ -191,7 +191,7 @@ { "visible": false, "name": "num_args" } ], "inputs": [ - { "name": "inputs", "option": "variadic" } + { "name": "inputs", "type": "Tensor[]" } ], "outputs": [ { "name": "output" } @@ -215,7 +215,7 @@ "inputs": [ { "name": "input" }, { "name": "weight" }, - { "name": "bias", "option": "optional" } + { "name": "bias", "optional": true } ], "outputs": [ { "name": "output" } @@ -267,7 +267,7 @@ "name": "ElementWiseSum", "category": "Normalization", "inputs": [ - { "name": "inputs", "option": "variadic" } + { "name": "inputs", "type": "Tensor[]" } ], "outputs": [ { "name": "output" } @@ -504,7 +504,7 @@ { "name": "inputs" } ], "outputs": [ - { "name": "outputs", "option": "variadic" } + { "name": "outputs", "type": "Tensor[]" } ] }, { diff --git a/source/mxnet.js b/source/mxnet.js index 58e16ce03f..42364af15d 100644 --- a/source/mxnet.js +++ b/source/mxnet.js @@ -10,8 +10,8 @@ mxnet.ModelFactory = class { const extension = identifier.split('.').pop().toLowerCase(); if (extension === 'json') { const obj = context.peek('json'); - if (obj && Array.isArray(obj.nodes) && Array.isArray(obj.arg_nodes) && Array.isArray(obj.heads) /* && - !obj.nodes.some((node) => node && node.op === 'tvm_op') */) { + if (obj && Array.isArray(obj.nodes) && Array.isArray(obj.arg_nodes) && Array.isArray(obj.heads) && + !obj.nodes.some((node) => node && node.op === 'tvm_op')) { context.type = 'mxnet.json'; context.target = obj; return; @@ -302,12 +302,12 @@ mxnet.Graph = class { if (symbol) { const nodes = symbol.nodes; const inputs = {}; + const outputs = {}; if (manifest && manifest.signature && manifest.signature.inputs) { for (const input of manifest.signature.inputs) { inputs[input.data_name] = input; } } - const outputs = {}; if (manifest && manifest.signature && manifest.signature.outputs) { for (const output of manifest.signature.outputs) { outputs[output.data_name] = output; @@ -320,17 +320,11 @@ mxnet.Graph = class { node.inputs = node.inputs || []; node.inputs = node.inputs.map((input) => updateOutput(nodes, input)); } - const outputCountMap = {}; - for (const node of nodes) { - for (const output of node.outputs) { - outputCountMap[output] = (outputCountMap[output] || 0) + 1; - } - } const arg_nodes = new Map(symbol.arg_nodes.map((index) => [index, index < nodes.length ? nodes[index] : null])); for (let i = 0; i < symbol.heads.length; i++) { const head = symbol.heads[i]; const identifier = updateOutput(nodes, head); - const name = nodes[identifier[0]] ? nodes[identifier[0]].name : (`output${(i === 0) ? '' : (i + 1)}`); + const name = `output${(i === 0) ? '' : (i + 1)}`; const signature = outputs[name]; const type = signature && signature.data_shape ? new mxnet.TensorType(-1, new mxnet.TensorShape(signature.data_shape)) : null; const value = values.map(`[${identifier.join(',')}]`, type); @@ -564,13 +558,13 @@ mxnet.Node = class { let inputIndex = 0; if (this.type && this.type.inputs) { for (const inputDef of this.type.inputs) { - if (inputIndex < inputs.length || inputDef.option !== 'optional') { - const count = (inputDef.option === 'variadic') ? (inputs.length - inputIndex) : 1; + if (inputIndex < inputs.length || inputDef.optional !== true) { + const count = (inputDef.type === 'Tensor[]') ? (inputs.length - inputIndex) : 1; const list = []; for (const input of inputs.slice(inputIndex, inputIndex + count)) { const identifier = `[${input.join(',')}]`; - if (identifier !== '' || inputDef.option !== 'optional') { - const value = values.map(identifier, inputDef.type, initializers.get(identifier)); + if (identifier !== '' || (inputDef.optional !== true || inputDef.type === 'Tensor[]')) { + const value = values.map(identifier, null, initializers.get(identifier)); list.push(value); } } @@ -594,9 +588,9 @@ mxnet.Node = class { let outputIndex = 0; if (this.type && this.type.outputs) { for (const outputDef of this.type.outputs) { - if (outputIndex < outputs.length || outputDef.option !== 'optional') { + if (outputIndex < outputs.length || outputDef.optional !== true) { const list = []; - const count = (outputDef.option === 'variadic') ? (outputs.length - outputIndex) : 1; + const count = (outputDef.type === 'Tensor[]') ? (outputs.length - outputIndex) : 1; for (const output of outputs.slice(outputIndex, outputIndex + count)) { const value = values.map(`[${output.join(',')}]`); list.push(value); @@ -678,7 +672,7 @@ mxnet.ndarray = class { static load(reader) { // NDArray::Load(dmlc::Stream* fi, std::vector* data, std::vector* keys) - const map = new Map(); + const params = new Map(); reader = new mxnet.BinaryReader(reader); if (reader.uint64().toNumber() !== 0x112) { // kMXAPINDArrayListMagic throw new mxnet.Error('Invalid signature.'); @@ -686,9 +680,9 @@ mxnet.ndarray = class { if (reader.uint64().toNumber() !== 0) { throw new mxnet.Error('Invalid reserved block.'); } - const data = new Array(reader.uint64().toNumber()); - for (let i = 0; i < data.length; i++) { - data[i] = new mxnet.ndarray.NDArray(reader); + const values = new Array(reader.uint64().toNumber()); + for (let i = 0; i < values.length; i++) { + values[i] = new mxnet.ndarray.NDArray(reader); } const decoder = new TextDecoder('ascii'); const names = new Array(reader.uint64().toNumber()); @@ -697,13 +691,13 @@ mxnet.ndarray = class { const buffer = reader.read(size); names[i] = decoder.decode(buffer); } - if (names.length !== data.length) { - throw new mxnet.Error('Label count mismatch.'); + if (names.length !== values.length) { + throw new mxnet.Error('Invalid parameters.'); } for (let i = 0; i < names.length; i++) { - map.set(names[i], data[i]); + params.set(names[i], values[i]); } - return map; + return params; } };