Skip to content

Commit

Permalink
Update mxnet.js
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Sep 8, 2024
1 parent 995b1f2 commit 73069a9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 29 deletions.
10 changes: 5 additions & 5 deletions source/mxnet-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
{
"name": "_Plus",
"inputs": [
{ "name": "inputs", "option": "variadic" }
{ "name": "inputs", "type": "Tensor[]" }
],
"outputs": [
{ "name": "output" }
Expand Down Expand Up @@ -191,7 +191,7 @@
{ "visible": false, "name": "num_args" }
],
"inputs": [
{ "name": "inputs", "option": "variadic" }
{ "name": "inputs", "type": "Tensor[]" }
],
"outputs": [
{ "name": "output" }
Expand All @@ -215,7 +215,7 @@
"inputs": [
{ "name": "input" },
{ "name": "weight" },
{ "name": "bias", "option": "optional" }
{ "name": "bias", "optional": true }
],
"outputs": [
{ "name": "output" }
Expand Down Expand Up @@ -267,7 +267,7 @@
"name": "ElementWiseSum",
"category": "Normalization",
"inputs": [
{ "name": "inputs", "option": "variadic" }
{ "name": "inputs", "type": "Tensor[]" }
],
"outputs": [
{ "name": "output" }
Expand Down Expand Up @@ -504,7 +504,7 @@
{ "name": "inputs" }
],
"outputs": [
{ "name": "outputs", "option": "variadic" }
{ "name": "outputs", "type": "Tensor[]" }
]
},
{
Expand Down
42 changes: 18 additions & 24 deletions source/mxnet.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ mxnet.ModelFactory = class {
const extension = identifier.split('.').pop().toLowerCase();
if (extension === 'json') {
const obj = context.peek('json');
if (obj && Array.isArray(obj.nodes) && Array.isArray(obj.arg_nodes) && Array.isArray(obj.heads) /* &&
!obj.nodes.some((node) => node && node.op === 'tvm_op') */) {
if (obj && Array.isArray(obj.nodes) && Array.isArray(obj.arg_nodes) && Array.isArray(obj.heads) &&
!obj.nodes.some((node) => node && node.op === 'tvm_op')) {
context.type = 'mxnet.json';
context.target = obj;
return;
Expand Down Expand Up @@ -302,12 +302,12 @@ mxnet.Graph = class {
if (symbol) {
const nodes = symbol.nodes;
const inputs = {};
const outputs = {};
if (manifest && manifest.signature && manifest.signature.inputs) {
for (const input of manifest.signature.inputs) {
inputs[input.data_name] = input;
}
}
const outputs = {};
if (manifest && manifest.signature && manifest.signature.outputs) {
for (const output of manifest.signature.outputs) {
outputs[output.data_name] = output;
Expand All @@ -320,17 +320,11 @@ mxnet.Graph = class {
node.inputs = node.inputs || [];
node.inputs = node.inputs.map((input) => updateOutput(nodes, input));
}
const outputCountMap = {};
for (const node of nodes) {
for (const output of node.outputs) {
outputCountMap[output] = (outputCountMap[output] || 0) + 1;
}
}
const arg_nodes = new Map(symbol.arg_nodes.map((index) => [index, index < nodes.length ? nodes[index] : null]));
for (let i = 0; i < symbol.heads.length; i++) {
const head = symbol.heads[i];
const identifier = updateOutput(nodes, head);
const name = nodes[identifier[0]] ? nodes[identifier[0]].name : (`output${(i === 0) ? '' : (i + 1)}`);
const name = `output${(i === 0) ? '' : (i + 1)}`;
const signature = outputs[name];
const type = signature && signature.data_shape ? new mxnet.TensorType(-1, new mxnet.TensorShape(signature.data_shape)) : null;
const value = values.map(`[${identifier.join(',')}]`, type);
Expand Down Expand Up @@ -564,13 +558,13 @@ mxnet.Node = class {
let inputIndex = 0;
if (this.type && this.type.inputs) {
for (const inputDef of this.type.inputs) {
if (inputIndex < inputs.length || inputDef.option !== 'optional') {
const count = (inputDef.option === 'variadic') ? (inputs.length - inputIndex) : 1;
if (inputIndex < inputs.length || inputDef.optional !== true) {
const count = (inputDef.type === 'Tensor[]') ? (inputs.length - inputIndex) : 1;
const list = [];
for (const input of inputs.slice(inputIndex, inputIndex + count)) {
const identifier = `[${input.join(',')}]`;
if (identifier !== '' || inputDef.option !== 'optional') {
const value = values.map(identifier, inputDef.type, initializers.get(identifier));
if (identifier !== '' || (inputDef.optional !== true || inputDef.type === 'Tensor[]')) {
const value = values.map(identifier, null, initializers.get(identifier));
list.push(value);
}
}
Expand All @@ -594,9 +588,9 @@ mxnet.Node = class {
let outputIndex = 0;
if (this.type && this.type.outputs) {
for (const outputDef of this.type.outputs) {
if (outputIndex < outputs.length || outputDef.option !== 'optional') {
if (outputIndex < outputs.length || outputDef.optional !== true) {
const list = [];
const count = (outputDef.option === 'variadic') ? (outputs.length - outputIndex) : 1;
const count = (outputDef.type === 'Tensor[]') ? (outputs.length - outputIndex) : 1;
for (const output of outputs.slice(outputIndex, outputIndex + count)) {
const value = values.map(`[${output.join(',')}]`);
list.push(value);
Expand Down Expand Up @@ -678,17 +672,17 @@ mxnet.ndarray = class {

static load(reader) {
// NDArray::Load(dmlc::Stream* fi, std::vector<NDArray>* data, std::vector<std::string>* keys)
const map = new Map();
const params = new Map();
reader = new mxnet.BinaryReader(reader);
if (reader.uint64().toNumber() !== 0x112) { // kMXAPINDArrayListMagic
throw new mxnet.Error('Invalid signature.');
}
if (reader.uint64().toNumber() !== 0) {
throw new mxnet.Error('Invalid reserved block.');
}
const data = new Array(reader.uint64().toNumber());
for (let i = 0; i < data.length; i++) {
data[i] = new mxnet.ndarray.NDArray(reader);
const values = new Array(reader.uint64().toNumber());
for (let i = 0; i < values.length; i++) {
values[i] = new mxnet.ndarray.NDArray(reader);
}
const decoder = new TextDecoder('ascii');
const names = new Array(reader.uint64().toNumber());
Expand All @@ -697,13 +691,13 @@ mxnet.ndarray = class {
const buffer = reader.read(size);
names[i] = decoder.decode(buffer);
}
if (names.length !== data.length) {
throw new mxnet.Error('Label count mismatch.');
if (names.length !== values.length) {
throw new mxnet.Error('Invalid parameters.');
}
for (let i = 0; i < names.length; i++) {
map.set(names[i], data[i]);
params.set(names[i], values[i]);
}
return map;
return params;
}
};

Expand Down

0 comments on commit 73069a9

Please sign in to comment.