Skip to content

Commit

Permalink
Update tensor formatter (#961)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Sep 10, 2022
1 parent ab16ea9 commit 47a4d61
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 56 deletions.
55 changes: 34 additions & 21 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,38 @@ pytorch.Tensor = class {
}
return this._data instanceof Uint8Array ? this._data : this._data.peek();
}

decode() {
if (this._layout !== '<' && this._layout !== '>') {
throw new pytorch.Error("Tensor layout '" + this._layout + "' not implemented.");
}
const littleEndian = this._littleEndian;
const type = this._type;
const data = this.values;
const view = new DataView(data.buffer, data.byteOffset, data.byteLength);
switch (type.dataType) {
case 'int16': {
const array = new Uint16Array(data.length >> 1);
for (let i = 0; i < array.length; i++) {
array[i] = view.getInt16(i << 1, littleEndian);
}
return array;
}
case 'int64': {
const array = new Uint32Array(data.length >> 3);
for (let i = 0; i < array.length; i++) {
array[i] = view.getUint32(i << 3, littleEndian);
if (view.getUint32((i << 3) + 4, littleEndian) !== 0) {
throw new pytorch.Error('Signed 64-bit value exceeds 32-bit range.');
}
}
return array;
}
default: {
throw new pytorch.Error("Tensor data type '" + type.dataType + "' not implemented.");
}
}
}
};

pytorch.TensorType = class {
Expand Down Expand Up @@ -751,7 +783,7 @@ pytorch.Execution = class extends python.Execution {
const tensors = state[1];
const opt_tensors = state[2];
const packed_config_tensor = new pytorch.Tensor('', tensors[0], true);
const packed_config = pytorch.Utility.values(packed_config_tensor);
const packed_config = packed_config_tensor.decode();
this.weight = tensors[1];
this.bias = opt_tensors[0];
this.stride = [ packed_config[1], packed_config[2] ];
Expand All @@ -770,7 +802,7 @@ pytorch.Execution = class extends python.Execution {
const tensors = state[1];
const opt_tensors = state[2];
const packed_config_tensor = new pytorch.Tensor('', tensors[0], true);
const packed_config = pytorch.Utility.values(packed_config_tensor);
const packed_config = packed_config_tensor.decode();
this.weight = tensors[1];
this.bias = opt_tensors[0];
this.stride = [ packed_config[1], packed_config[2] ];
Expand Down Expand Up @@ -2566,25 +2598,6 @@ pytorch.Utility = class {
return null;
}

static values(tensor) {
const type = tensor.type;
const data = tensor.values;
if (type && data) {
switch (type.dataType) {
case 'int16': {
if (tensor.layout === '<') {
return new Uint16Array(data);
}
break;
}
default: {
break;
}
}
}
throw new pytorch.Error("Tensor data type '" + type.dataType + "' not implemented.");
}

static isTensor(obj) {
const name = obj && obj.__class__ ? obj.__class__.__module__ : null;
switch (name) {
Expand Down
97 changes: 63 additions & 34 deletions source/view-sidebar.js
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ sidebar.ValueView = class extends sidebar.Control {
this._bold('layout', layouts.get(layout));
}
}
if (tensor.layout !== '<' && tensor.layout !== '>' && tensor.layout !== '|' && tensor.layout !== 'sparse') {
if (tensor.layout !== '<' && tensor.layout !== '>' && tensor.layout !== '|' && tensor.layout !== 'sparse' && tensor.layout !== 'sparse.coo') {
contentLine.innerHTML = "Tensor layout '" + tensor.layout + "' is not implemented.";
}
else if (tensor.empty) {
Expand Down Expand Up @@ -1341,6 +1341,12 @@ sidebar.Tensor = class {
this._layout = 'sparse';
break;
}
case 'sparse.coo': {
this._indices = this._tensor.indices;
this._values = this._tensor.values;
this._layout = 'sparse.coo';
break;
}
default: {
this._layout = tensor.layout;
break;
Expand Down Expand Up @@ -1378,7 +1384,8 @@ sidebar.Tensor = class {
case '|': {
return !(Array.isArray(this._values) || ArrayBuffer.isView(this._values)) || this._values.length === 0;
}
case 'sparse': {
case 'sparse':
case 'sparse.coo': {
return !this._values || this.indices || this._values.values.length === 0;
}
default: {
Expand Down Expand Up @@ -1424,19 +1431,19 @@ sidebar.Tensor = class {
}

_context() {
if (this._layout !== '<' && this._layout !== '>' && this._layout !== '|' && this._layout !== 'sparse') {
if (this._layout !== '<' && this._layout !== '>' && this._layout !== '|' && this._layout !== 'sparse' && this._layout !== 'sparse.coo') {
throw new Error("Tensor layout '" + this._layout + "' is not supported.");
}
const dataType = this._type.dataType;
const context = {};
context.layout = this._layout;
context.dimensions = this._type.shape.dimensions;
const dataType = this._type.dataType;
const size = this._type.shape.dimensions.reduce((a, b) => a * b, 1);
context.dataType = dataType;
const size = context.dimensions.reduce((a, b) => a * b, 1);
switch (this._layout) {
case '<':
case '>': {
context.data = (this._data instanceof Uint8Array || this._data instanceof Int8Array) ? this._data : this._data.peek();
context.dataType = dataType;
context.view = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength);
if (sidebar.Tensor.dataTypes.has(dataType)) {
const itemsize = sidebar.Tensor.dataTypes.get(dataType);
Expand All @@ -1459,7 +1466,6 @@ sidebar.Tensor = class {
}
case '|': {
context.data = this._values;
context.dataType = dataType;
if (!sidebar.Tensor.dataTypes.has(dataType) && dataType !== 'string' && dataType !== 'object') {
throw new Error("Tensor data type '" + dataType + "' is not implemented.");
}
Expand All @@ -1469,36 +1475,32 @@ sidebar.Tensor = class {
break;
}
case 'sparse': {
context.dataType = dataType;
const size = context.dimensions.reduce((a, b) => a * b, 1);
const indices = this._indices.values;
const values = this._values.values;
const array = new values.constructor(size);
switch (context.dataType) {
case 'boolean':
array.fill(false);
break;
case 'int64':
case 'uint64':
break;
default:
break;
}
if (indices.length > 0) {
if (Object.prototype.hasOwnProperty.call(indices[0], 'low')) {
for (let i = 0; i < indices.length; i++) {
const index = indices[i];
array[index.high === 0 ? index.low : index.toNumber()] = values[i];
}
}
else {
for (let i = 0; i < indices.length; i++) {
array[indices[i]] = values[i];
}
const indices = new sidebar.Tensor(this._indices).value;
const values = new sidebar.Tensor(this._values).value;
context.data = this._decodeSparse(dataType, context.dimensions, indices, values);
context.layout = '|';
break;
}
case 'sparse.coo': {
const values = new sidebar.Tensor(this._values).value;
const data = new sidebar.Tensor(this._indices).value;
const dimensions = context.dimensions.length;
let stride = 1;
const strides = context.dimensions.slice().reverse().map((dim) => {
const value = stride;
stride *= dim;
return value;
}).reverse();
const indices = new Uint32Array(values.length);
for (let i = 0; i < dimensions; i++) {
const stride = strides[i];
const dimension = data[i];
for (let i = 0; i < indices.length; i++) {
indices[i] += dimension[i] * stride;
}
}
context.data = this._decodeSparse(dataType, context.dimensions, indices, values);
context.layout = '|';
context.data = array;
break;
}
default: {
Expand All @@ -1510,6 +1512,33 @@ sidebar.Tensor = class {
return context;
}

_decodeSparse(dataType, dimensions, indices, values) {
const size = dimensions.reduce((a, b) => a * b, 1);
const array = new Array(size);
switch (dataType) {
case 'boolean':
array.fill(false);
break;
default:
array.fill(0);
break;
}
if (indices.length > 0) {
if (Object.prototype.hasOwnProperty.call(indices[0], 'low')) {
for (let i = 0; i < indices.length; i++) {
const index = indices[i];
array[index.high === 0 ? index.low : index.toNumber()] = values[i];
}
}
else {
for (let i = 0; i < indices.length; i++) {
array[indices[i]] = values[i];
}
}
}
return array;
}

_decodeData(context, dimension) {
const results = [];
const dimensions = (context.dimensions.length == 0) ? [ 1 ] : context.dimensions;
Expand Down
2 changes: 1 addition & 1 deletion test/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ const loadModel = (target, item) => {
// console.log(' ' + message);
};
const tensor = new sidebar.Tensor(argument.initializer);
if (tensor.layout !== '<' && tensor.layout !== '>' && tensor.layout !== '|' && tensor.layout !== 'sparse') {
if (tensor.layout !== '<' && tensor.layout !== '>' && tensor.layout !== '|' && tensor.layout !== 'sparse' && tensor.layout !== 'sparse.coo') {
log("Tensor layout '" + tensor.layout + "' is not implemented.");
}
else if (tensor.empty) {
Expand Down
7 changes: 7 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -4823,6 +4823,13 @@
"format": "TorchScript v1.6",
"link": "https://github.com/lutzroeder/netron/issues/842"
},
{
"type": "pytorch",
"target": "sparse_coo.pth",
"source": "https://github.com/lutzroeder/netron/files/9541426/sparse_coo.pth.zip[sparse_coo.pth]",
"format": "PyTorch v1.6",
"link": "https://github.com/lutzroeder/netron/issues/720"
},
{
"type": "pytorch",
"target": "sparsified.pth",
Expand Down

0 comments on commit 47a4d61

Please sign in to comment.