From 3556f090908767215efd65f2c4171ed2d85fa9f2 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sun, 17 Jun 2018 02:26:00 -0700 Subject: [PATCH] ONNX loader shared edge objects (#71) --- src/onnx-model.js | 85 ++++++++++++++++++++++---------------------- tools/onnx-converter | 3 ++ 2 files changed, 46 insertions(+), 42 deletions(-) diff --git a/src/onnx-model.js b/src/onnx-model.js index 54f9861bb90..95baf57297a 100644 --- a/src/onnx-model.js +++ b/src/onnx-model.js @@ -87,7 +87,7 @@ class OnnxModel { if (this._modelVersion) { results.push({ name: 'Version', value: this._modelVersion }); } - if (this._model.docString) { + if (this._docString) { results.push({ name: 'Description', value: this._docString }); } var metadata = {}; @@ -140,8 +140,8 @@ class OnnxGraph { this._nodes = []; if (graph) { - var initializerMap = []; - var valueInfoMap = []; + this._initializerMap = {}; + this._connectionMap = {}; this._name = graph.name || ('(' + index.toString() + ')'); this._description = graph.docString || ''; @@ -160,7 +160,7 @@ class OnnxGraph { if (outputCountMap[name] == 1) { var attribute = node.attribute.find((attribute) => { return attribute.name == 'value' && attribute.t; }); if (attribute) { - initializerMap[name] = new OnnxTensor(attribute.t, name, 'Constant'); + this._initializerMap[name] = new OnnxTensor(attribute.t, name, 'Constant'); initializerNode = true; } } @@ -171,33 +171,25 @@ class OnnxGraph { }); graph.initializer.forEach((tensor) => { - initializerMap[tensor.name] = new OnnxTensor(tensor, tensor.name, 'Initializer'); + this._initializerMap[tensor.name] = new OnnxTensor(tensor, tensor.name, 'Initializer'); }); graph.valueInfo.forEach((valueInfo) => { - valueInfoMap[valueInfo.name] = valueInfo; + this._connection(valueInfo.name, valueInfo.type, valueInfo.docString); }); this._inputs = []; graph.input.forEach((valueInfo) => { - if (!initializerMap[valueInfo.name]) { - this._inputs.push({ - id: valueInfo.name, - name: valueInfo.name, - description: valueInfo.docString, - type: OnnxTensor.formatType(valueInfo.type) - }); - valueInfoMap[valueInfo.name] = valueInfo; + if (!this._initializerMap[valueInfo.name]) { + var connection = this._connection(valueInfo.name, valueInfo.type, valueInfo.docString); + connection.name = valueInfo.name; + this._inputs.push(connection); } }); - - this._outputs = graph.output.map((valueInfo) => { - valueInfoMap[valueInfo.name] = valueInfo; - return { - id: valueInfo.name, - name: valueInfo.name, - description: valueInfo.docString, - type: OnnxTensor.formatType(valueInfo.type) - }; + this._outputs = []; + graph.output.map((valueInfo) => { + var connection = this._connection(valueInfo.name, valueInfo.type, valueInfo.docString); + connection.name = valueInfo.name; + this._outputs.push(connection); }); nodes.forEach((node) => { @@ -205,19 +197,8 @@ class OnnxGraph { if (node.input) { inputs = this._metadata.getInputs(node.opType, node.input); inputs.forEach((input) => { - input.connections.forEach((connection) => { - var initializer = initializerMap[connection.id]; - if (initializer) { - connection.initializer = initializer; - connection.type = initializer.type; - } - else { - var valueInfo = valueInfoMap[connection.id]; - if (valueInfo) { - connection.type = OnnxTensor.formatType(valueInfo.type); - connection.description = valueInfo.docString; - } - } + input.connections = input.connections.map((connection) => { + return this._connection(connection.id); }); }); } @@ -225,17 +206,16 @@ class OnnxGraph { if (node.output) { outputs = this._metadata.getOutputs(node.opType, node.output); outputs.forEach((output) => { - output.connections.forEach((connection) => { - var valueInfo = valueInfoMap[connection.id]; - if (valueInfo) { - connection.type = OnnxTensor.formatType(valueInfo.type); - connection.description = valueInfo.docString; - } + output.connections = output.connections.map((connection) => { + return this._connection(connection.id); }); }); } this._nodes.push(new OnnxNode(this, node.opType, node.domain, node.name, node.docString, node.attribute, inputs, outputs)); }); + + delete this._initializerMap; + delete this._connectionMap; } } @@ -266,6 +246,27 @@ class OnnxGraph { get metadata() { return this._metadata; } + + _connection(name, type, docString) { + var connection = this._connectionMap[name]; + if (!connection) { + connection = {}; + connection.id = name; + var initializer = this._initializerMap[name]; + if (initializer) { + connection.initializer = initializer; + connection.type = initializer.type; + } + if (type) { + connection.type = OnnxTensor.formatType(type); + } + if (docString) { + connection.description = docString; + } + this._connectionMap[name] = connection; + } + return connection; + } } class OnnxNode { diff --git a/tools/onnx-converter b/tools/onnx-converter index 8e638dc3e26..6d4d89d29b2 100755 --- a/tools/onnx-converter +++ b/tools/onnx-converter @@ -23,6 +23,9 @@ fi export PYTHONUSERBASE=${build}/third_party/pypi/${identifier} export PATH=$PATH:${PYTHONUSERBASE}/bin rm -rf ${PYTHONUSERBASE} +pip install --quiet --user coremltools +pip install --quiet --user tensorflow +pip install --quiet --user keras pip install --quiet --user ${third_party}/${identifier} python ${tools}/onnx-converter.py $@