From 00f527c2579187535c8c077f1e935f987fd2ad8c Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sun, 11 Sep 2022 11:23:27 -0700 Subject: [PATCH] Update tensor formatter (#961) --- source/view-sidebar.js | 210 +++++++++++++++++++++-------------------- 1 file changed, 108 insertions(+), 102 deletions(-) diff --git a/source/view-sidebar.js b/source/view-sidebar.js index 30dc41dc2fb..2602802b970 100644 --- a/source/view-sidebar.js +++ b/source/view-sidebar.js @@ -1437,7 +1437,7 @@ sidebar.Tensor = class { const dataType = this._type.dataType; const context = {}; context.layout = this._layout; - context.dimensions = this._type.shape.dimensions; + context.dimensions = this._type.shape.dimensions.map((value) => !Number.isInteger(value) && value.toNumber ? value.toNumber() : value); context.dataType = dataType; const size = context.dimensions.reduce((a, b) => a * b, 1); switch (this._layout) { @@ -1446,18 +1446,20 @@ sidebar.Tensor = class { context.data = (this._data instanceof Uint8Array || this._data instanceof Int8Array) ? this._data : this._data.peek(); context.view = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength); if (sidebar.Tensor.dataTypes.has(dataType)) { - const itemsize = sidebar.Tensor.dataTypes.get(dataType); - if (this._data.length < (itemsize * size)) { + context.itemsize = sidebar.Tensor.dataTypes.get(dataType); + if (this._data.length < (context.itemsize * size)) { throw new Error('Invalid tensor data size.'); } } else if (dataType.startsWith('uint') && !isNaN(parseInt(dataType.substring(4), 10))) { context.dataType = 'uint'; context.bits = parseInt(dataType.substring(4), 10); + context.itemsize = 1; } else if (dataType.startsWith('int') && !isNaN(parseInt(dataType.substring(3), 10))) { context.dataType = 'int'; context.bits = parseInt(dataType.substring(3), 10); + context.itemsize = 1; } else { throw new Error("Tensor data type '" + dataType + "' is not implemented."); @@ -1546,106 +1548,110 @@ sidebar.Tensor = class { const dataType = context.dataType; const view = context.view; if (dimension == dimensions.length - 1) { - for (let i = 0; i < size; i++) { - if (context.count > context.limit) { - results.push('...'); - return results; - } - switch (dataType) { - case 'boolean': - results.push(view.getUint8(context.index) === 0 ? false : true); - context.index++; - context.count++; - break; - case 'qint8': - case 'int8': - results.push(view.getInt8(context.index)); - context.index++; - context.count++; - break; - case 'qint16': - case 'int16': - results.push(view.getInt16(context.index, this._littleEndian)); - context.index += 2; - context.count++; - break; - case 'qint32': - case 'int32': - results.push(view.getInt32(context.index, this._littleEndian)); - context.index += 4; - context.count++; - break; - case 'int64': - results.push(view.getInt64(context.index, this._littleEndian)); - context.index += 8; - context.count++; - break; - case 'int': - results.push(view.getIntBits(context.index, context.bits)); - context.index++; - context.count++; - break; - case 'quint8': - case 'uint8': - results.push(view.getUint8(context.index)); - context.index++; - context.count++; - break; - case 'quint16': - case 'uint16': - results.push(view.getUint16(context.index, true)); - context.index += 2; - context.count++; - break; - case 'quint32': - case 'uint32': - results.push(view.getUint32(context.index, true)); - context.index += 4; - context.count++; - break; - case 'uint64': - results.push(view.getUint64(context.index, true)); - context.index += 8; - context.count++; - break; - case 'uint': - results.push(view.getUintBits(context.index, context.bits)); - context.index++; - context.count++; - break; - case 'float16': - results.push(view.getFloat16(context.index, this._littleEndian)); - context.index += 2; - context.count++; - break; - case 'float32': - results.push(view.getFloat32(context.index, this._littleEndian)); - context.index += 4; - context.count++; - break; - case 'float64': - results.push(view.getFloat64(context.index, this._littleEndian)); - context.index += 8; - context.count++; - break; - case 'bfloat16': - results.push(view.getBfloat16(context.index, this._littleEndian)); - context.index += 2; - context.count++; - break; - case 'complex64': - results.push(view.getComplex64(i << 3, this._littleEndian)); + const ellipsis = (context.count + size) > context.limit; + const length = ellipsis ? context.limit - context.count : size; + let i = context.index; + const max = i + (length * context.itemsize); + switch (dataType) { + case 'boolean': + for (; i < max; i += 1) { + results.push(view.getUint8(i) === 0 ? false : true); + } + break; + case 'qint8': + case 'int8': + for (; i < max; i++) { + results.push(view.getInt8(i)); + } + break; + case 'qint16': + case 'int16': + for (; i < max; i += 2) { + results.push(view.getInt16(i, this._littleEndian)); + } + break; + case 'qint32': + case 'int32': + for (; i < max; i += 4) { + results.push(view.getInt32(i, this._littleEndian)); + } + break; + case 'int64': + for (; i < max; i += 8) { + results.push(view.getInt64(i, this._littleEndian)); + } + break; + case 'int': + for (; i < size; i++) { + results.push(view.getIntBits(i, context.bits)); + } + break; + case 'quint8': + case 'uint8': + for (; i < max; i++) { + results.push(view.getUint8(i)); + } + break; + case 'quint16': + case 'uint16': + for (; i < max; i += 2) { + results.push(view.getUint16(i, true)); + } + break; + case 'quint32': + case 'uint32': + for (; i < max; i += 4) { + results.push(view.getUint32(i, true)); + } + break; + case 'uint64': + for (; i < max; i += 8) { + results.push(view.getUint64(i, true)); + } + break; + case 'uint': + for (; i < max; i++) { + results.push(view.getUintBits(i, context.bits)); + } + break; + case 'float16': + for (; i < max; i += 2) { + results.push(view.getFloat16(i, this._littleEndian)); + } + break; + case 'float32': + for (; i < max; i += 4) { + results.push(view.getFloat32(i, this._littleEndian)); + } + break; + case 'float64': + for (; i < max; i += 8) { + results.push(view.getFloat64(i, this._littleEndian)); + } + break; + case 'bfloat16': + for (; i < max; i += 2) { + results.push(view.getBfloat16(i, this._littleEndian)); + } + break; + case 'complex64': + for (; i < max; i += 8) { + results.push(view.getComplex64(i, this._littleEndian)); context.index += 8; - context.count++; - break; - case 'complex128': - results.push(view.getComplex128(i << 4, this._littleEndian)); - context.index += 16; - context.count++; - break; - default: - throw new Error("Unsupported tensor data type '" + dataType + "'."); - } + } + break; + case 'complex128': + for (; i < size; i += 16) { + results.push(view.getComplex128(i, this._littleEndian)); + } + break; + default: + throw new Error("Unsupported tensor data type '" + dataType + "'."); + } + context.index = i; + context.count += length; + if (ellipsis) { + results.push('...'); } } else {