From 47a4d61cb8e659672a68d45d6004d825487df8ed Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sat, 10 Sep 2022 16:04:16 -0700 Subject: [PATCH] Update tensor formatter (#961) --- source/pytorch.js | 55 +++++++++++++++--------- source/view-sidebar.js | 97 +++++++++++++++++++++++++++--------------- test/models.js | 2 +- test/models.json | 7 +++ 4 files changed, 105 insertions(+), 56 deletions(-) diff --git a/source/pytorch.js b/source/pytorch.js index b7227e2cab..d2f08f3c73 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -663,6 +663,38 @@ pytorch.Tensor = class { } return this._data instanceof Uint8Array ? this._data : this._data.peek(); } + + decode() { + if (this._layout !== '<' && this._layout !== '>') { + throw new pytorch.Error("Tensor layout '" + this._layout + "' not implemented."); + } + const littleEndian = this._littleEndian; + const type = this._type; + const data = this.values; + const view = new DataView(data.buffer, data.byteOffset, data.byteLength); + switch (type.dataType) { + case 'int16': { + const array = new Uint16Array(data.length >> 1); + for (let i = 0; i < array.length; i++) { + array[i] = view.getInt16(i << 1, littleEndian); + } + return array; + } + case 'int64': { + const array = new Uint32Array(data.length >> 3); + for (let i = 0; i < array.length; i++) { + array[i] = view.getUint32(i << 3, littleEndian); + if (view.getUint32((i << 3) + 4, littleEndian) !== 0) { + throw new pytorch.Error('Signed 64-bit value exceeds 32-bit range.'); + } + } + return array; + } + default: { + throw new pytorch.Error("Tensor data type '" + type.dataType + "' not implemented."); + } + } + } }; pytorch.TensorType = class { @@ -751,7 +783,7 @@ pytorch.Execution = class extends python.Execution { const tensors = state[1]; const opt_tensors = state[2]; const packed_config_tensor = new pytorch.Tensor('', tensors[0], true); - const packed_config = pytorch.Utility.values(packed_config_tensor); + const packed_config = packed_config_tensor.decode(); this.weight = tensors[1]; this.bias = opt_tensors[0]; this.stride = [ packed_config[1], packed_config[2] ]; @@ -770,7 +802,7 @@ pytorch.Execution = class extends python.Execution { const tensors = state[1]; const opt_tensors = state[2]; const packed_config_tensor = new pytorch.Tensor('', tensors[0], true); - const packed_config = pytorch.Utility.values(packed_config_tensor); + const packed_config = packed_config_tensor.decode(); this.weight = tensors[1]; this.bias = opt_tensors[0]; this.stride = [ packed_config[1], packed_config[2] ]; @@ -2566,25 +2598,6 @@ pytorch.Utility = class { return null; } - static values(tensor) { - const type = tensor.type; - const data = tensor.values; - if (type && data) { - switch (type.dataType) { - case 'int16': { - if (tensor.layout === '<') { - return new Uint16Array(data); - } - break; - } - default: { - break; - } - } - } - throw new pytorch.Error("Tensor data type '" + type.dataType + "' not implemented."); - } - static isTensor(obj) { const name = obj && obj.__class__ ? obj.__class__.__module__ : null; switch (name) { diff --git a/source/view-sidebar.js b/source/view-sidebar.js index ef0a304b69..31fc6d6d2f 100644 --- a/source/view-sidebar.js +++ b/source/view-sidebar.js @@ -423,7 +423,7 @@ sidebar.ValueView = class extends sidebar.Control { this._bold('layout', layouts.get(layout)); } } - if (tensor.layout !== '<' && tensor.layout !== '>' && tensor.layout !== '|' && tensor.layout !== 'sparse') { + if (tensor.layout !== '<' && tensor.layout !== '>' && tensor.layout !== '|' && tensor.layout !== 'sparse' && tensor.layout !== 'sparse.coo') { contentLine.innerHTML = "Tensor layout '" + tensor.layout + "' is not implemented."; } else if (tensor.empty) { @@ -1341,6 +1341,12 @@ sidebar.Tensor = class { this._layout = 'sparse'; break; } + case 'sparse.coo': { + this._indices = this._tensor.indices; + this._values = this._tensor.values; + this._layout = 'sparse.coo'; + break; + } default: { this._layout = tensor.layout; break; @@ -1378,7 +1384,8 @@ sidebar.Tensor = class { case '|': { return !(Array.isArray(this._values) || ArrayBuffer.isView(this._values)) || this._values.length === 0; } - case 'sparse': { + case 'sparse': + case 'sparse.coo': { return !this._values || this.indices || this._values.values.length === 0; } default: { @@ -1424,19 +1431,19 @@ sidebar.Tensor = class { } _context() { - if (this._layout !== '<' && this._layout !== '>' && this._layout !== '|' && this._layout !== 'sparse') { + if (this._layout !== '<' && this._layout !== '>' && this._layout !== '|' && this._layout !== 'sparse' && this._layout !== 'sparse.coo') { throw new Error("Tensor layout '" + this._layout + "' is not supported."); } + const dataType = this._type.dataType; const context = {}; context.layout = this._layout; context.dimensions = this._type.shape.dimensions; - const dataType = this._type.dataType; - const size = this._type.shape.dimensions.reduce((a, b) => a * b, 1); + context.dataType = dataType; + const size = context.dimensions.reduce((a, b) => a * b, 1); switch (this._layout) { case '<': case '>': { context.data = (this._data instanceof Uint8Array || this._data instanceof Int8Array) ? this._data : this._data.peek(); - context.dataType = dataType; context.view = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength); if (sidebar.Tensor.dataTypes.has(dataType)) { const itemsize = sidebar.Tensor.dataTypes.get(dataType); @@ -1459,7 +1466,6 @@ sidebar.Tensor = class { } case '|': { context.data = this._values; - context.dataType = dataType; if (!sidebar.Tensor.dataTypes.has(dataType) && dataType !== 'string' && dataType !== 'object') { throw new Error("Tensor data type '" + dataType + "' is not implemented."); } @@ -1469,36 +1475,32 @@ sidebar.Tensor = class { break; } case 'sparse': { - context.dataType = dataType; - const size = context.dimensions.reduce((a, b) => a * b, 1); - const indices = this._indices.values; - const values = this._values.values; - const array = new values.constructor(size); - switch (context.dataType) { - case 'boolean': - array.fill(false); - break; - case 'int64': - case 'uint64': - break; - default: - break; - } - if (indices.length > 0) { - if (Object.prototype.hasOwnProperty.call(indices[0], 'low')) { - for (let i = 0; i < indices.length; i++) { - const index = indices[i]; - array[index.high === 0 ? index.low : index.toNumber()] = values[i]; - } - } - else { - for (let i = 0; i < indices.length; i++) { - array[indices[i]] = values[i]; - } + const indices = new sidebar.Tensor(this._indices).value; + const values = new sidebar.Tensor(this._values).value; + context.data = this._decodeSparse(dataType, context.dimensions, indices, values); + context.layout = '|'; + break; + } + case 'sparse.coo': { + const values = new sidebar.Tensor(this._values).value; + const data = new sidebar.Tensor(this._indices).value; + const dimensions = context.dimensions.length; + let stride = 1; + const strides = context.dimensions.slice().reverse().map((dim) => { + const value = stride; + stride *= dim; + return value; + }).reverse(); + const indices = new Uint32Array(values.length); + for (let i = 0; i < dimensions; i++) { + const stride = strides[i]; + const dimension = data[i]; + for (let i = 0; i < indices.length; i++) { + indices[i] += dimension[i] * stride; } } + context.data = this._decodeSparse(dataType, context.dimensions, indices, values); context.layout = '|'; - context.data = array; break; } default: { @@ -1510,6 +1512,33 @@ sidebar.Tensor = class { return context; } + _decodeSparse(dataType, dimensions, indices, values) { + const size = dimensions.reduce((a, b) => a * b, 1); + const array = new Array(size); + switch (dataType) { + case 'boolean': + array.fill(false); + break; + default: + array.fill(0); + break; + } + if (indices.length > 0) { + if (Object.prototype.hasOwnProperty.call(indices[0], 'low')) { + for (let i = 0; i < indices.length; i++) { + const index = indices[i]; + array[index.high === 0 ? index.low : index.toNumber()] = values[i]; + } + } + else { + for (let i = 0; i < indices.length; i++) { + array[indices[i]] = values[i]; + } + } + } + return array; + } + _decodeData(context, dimension) { const results = []; const dimensions = (context.dimensions.length == 0) ? [ 1 ] : context.dimensions; diff --git a/test/models.js b/test/models.js index bc37edc857..d0bfecfff5 100755 --- a/test/models.js +++ b/test/models.js @@ -671,7 +671,7 @@ const loadModel = (target, item) => { // console.log(' ' + message); }; const tensor = new sidebar.Tensor(argument.initializer); - if (tensor.layout !== '<' && tensor.layout !== '>' && tensor.layout !== '|' && tensor.layout !== 'sparse') { + if (tensor.layout !== '<' && tensor.layout !== '>' && tensor.layout !== '|' && tensor.layout !== 'sparse' && tensor.layout !== 'sparse.coo') { log("Tensor layout '" + tensor.layout + "' is not implemented."); } else if (tensor.empty) { diff --git a/test/models.json b/test/models.json index 4a006b4b9e..3c279b7327 100644 --- a/test/models.json +++ b/test/models.json @@ -4823,6 +4823,13 @@ "format": "TorchScript v1.6", "link": "https://github.com/lutzroeder/netron/issues/842" }, + { + "type": "pytorch", + "target": "sparse_coo.pth", + "source": "https://github.com/lutzroeder/netron/files/9541426/sparse_coo.pth.zip[sparse_coo.pth]", + "format": "PyTorch v1.6", + "link": "https://github.com/lutzroeder/netron/issues/720" + }, { "type": "pytorch", "target": "sparsified.pth",