From 44abcdc28a2a20da81278802a96c86805b582519 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Wed, 13 Jul 2022 20:22:32 -0700 Subject: [PATCH] Add PyTorch test file (#720) --- source/pytorch.js | 107 ++++++++++++++++++++++++++++++++-------------- test/models.json | 7 +++ 2 files changed, 83 insertions(+), 31 deletions(-) diff --git a/source/pytorch.js b/source/pytorch.js index 86f0989eed..17caa37458 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -678,9 +678,11 @@ pytorch.Tensor = class { case 'float32': case 'float64': case 'bfloat16': + case 'complex64': + case 'complex128': break; default: - context.state = "Tensor data type '" + this._type.dataType + "' is not supported."; + context.state = "Tensor data type '" + this._type.dataType + "' is not implemented."; return context; } if (!this._type.shape) { @@ -702,7 +704,7 @@ pytorch.Tensor = class { context.dataType = this._type.dataType; context.dimensions = this._type.shape.dimensions; - context.dataView = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength); + context.view = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength); return context; } @@ -718,56 +720,66 @@ pytorch.Tensor = class { } switch (context.dataType) { case 'boolean': - results.push(context.dataView.getUint8(context.index) === 0 ? false : true); + results.push(context.view.getUint8(context.index) === 0 ? false : true); context.index++; context.count++; break; case 'uint8': - results.push(context.dataView.getUint8(context.index)); + results.push(context.view.getUint8(context.index)); context.index++; context.count++; break; case 'qint8': case 'int8': - results.push(context.dataView.getInt8(context.index)); + results.push(context.view.getInt8(context.index)); context.index++; context.count++; break; case 'int16': - results.push(context.dataView.getInt16(context.index, this._littleEndian)); + results.push(context.view.getInt16(context.index, this._littleEndian)); context.index += 2; context.count++; break; case 'int32': - results.push(context.dataView.getInt32(context.index, this._littleEndian)); + results.push(context.view.getInt32(context.index, this._littleEndian)); context.index += 4; context.count++; break; case 'int64': - results.push(context.dataView.getInt64(context.index, this._littleEndian)); + results.push(context.view.getInt64(context.index, this._littleEndian)); context.index += 8; context.count++; break; case 'float16': - results.push(context.dataView.getFloat16(context.index, this._littleEndian)); + results.push(context.view.getFloat16(context.index, this._littleEndian)); context.index += 2; context.count++; break; case 'float32': - results.push(context.dataView.getFloat32(context.index, this._littleEndian)); + results.push(context.view.getFloat32(context.index, this._littleEndian)); context.index += 4; context.count++; break; case 'float64': - results.push(context.dataView.getFloat64(context.index, this._littleEndian)); + results.push(context.view.getFloat64(context.index, this._littleEndian)); context.index += 8; context.count++; break; case 'bfloat16': - results.push(context.dataView.getBfloat16(context.index, this._littleEndian)); + results.push(context.view.getBfloat16(context.index, this._littleEndian)); context.index += 2; context.count++; break; + case 'complex64': + results.push(context.view.getComplex64(i << 3, this._littleEndian)); + context.index += 8; + context.count++; + break; + case 'complex128': + results.push(context.view.getComplex128(i << 4, this._littleEndian)); + context.index += 16; + context.count++; + break; default: throw new pytorch.Error("Unsupported tensor data type '" + context.dataType + "'."); } @@ -799,22 +811,26 @@ pytorch.Tensor = class { result.push(indentation + ']'); return result.join('\n'); } - if (value && (value instanceof base.Int64 || value instanceof base.Uint64)) { - return indentation + value.toString(); - } - if (typeof value == 'string') { - return indentation + value; - } - if (value == Infinity) { - return indentation + 'Infinity'; - } - if (value == -Infinity) { - return indentation + '-Infinity'; - } - if (isNaN(value)) { - return indentation + 'NaN'; + switch (typeof value) { + case 'string': + return indentation + value; + case 'number': + if (value == Infinity) { + return indentation + 'Infinity'; + } + if (value == -Infinity) { + return indentation + '-Infinity'; + } + if (isNaN(value)) { + return indentation + 'NaN'; + } + return indentation + value.toString(); + default: + if (value && value.toString) { + return indentation + value.toString(); + } + return indentation + '(undefined)'; } - return indentation + value.toString(); } }; @@ -1833,20 +1849,20 @@ pytorch.Execution = class extends python.Execution { this._device = null; } get device() { - return null; + return this._device; } get dtype() { return this._dtype; } - get data() { - return this._cdata; - } element_size() { return this._dtype.element_size; } size() { return this._size; } + get data() { + return this._cdata; + } _set_cdata(data) { const length = this.size() * this.dtype.itemsize(); if (length !== data.length) { @@ -1876,6 +1892,33 @@ pytorch.Execution = class extends python.Execution { return storage; } }); + this.registerType('torch.storage._UntypedStorage', class extends torch_storage._StorageBase { + constructor() { + super(); + throw new python.Error('_UntypedStorage not implemented.'); + } + }); + this.registerType('torch.storage._TypedStorage', class { + constructor() { + throw new python.Error('_TypedStorage not implemented.'); + } + }); + this.registerType('torch.storage._LegacyStorage', class extends torch_storage._TypedStorage { + constructor() { + super(); + throw new python.Error('_LegacyStorage not implemented.'); + } + }); + this.registerType('torch.ComplexFloatStorage', class extends torch_storage._StorageBase { + constructor(size) { + super(size, torch.complex64); + } + }); + this.registerType('torch.ComplexDoubleStorage', class extends torch_storage._StorageBase { + constructor(size) { + super(size, torch.complex128); + } + }); this.registerType('torch.BoolStorage', class extends torch_storage._StorageBase { constructor(size) { super(size, torch.bool); @@ -2058,6 +2101,8 @@ pytorch.Execution = class extends python.Execution { this.registerType('torch.HalfTensor', class extends torch.Tensor {}); this.registerType('torch.FloatTensor', class extends torch.Tensor {}); this.registerType('torch.DoubleTensor', class extends torch.Tensor {}); + this.registerType('torch.ComplexFloatTensor', class extends torch.Tensor {}); + this.registerType('torch.ComplexDoubleTensor', class extends torch.Tensor {}); this.registerType('torch.QInt8Tensor', class extends torch.Tensor {}); this.registerType('torch.QUInt8Tensor', class extends torch.Tensor {}); this.registerType('torch.QInt32Tensor', class extends torch.Tensor {}); diff --git a/test/models.json b/test/models.json index 85e1a7251c..33c14a9434 100644 --- a/test/models.json +++ b/test/models.json @@ -4155,6 +4155,13 @@ "format": "TorchScript v1.0", "link": "https://github.com/ApolloAuto/apollo" }, + { + "type": "pytorch", + "target": "complex_tensor.pt", + "source": "https://github.com/lutzroeder/netron/files/9108149/complex_tensor.pt.zip[complex_tensor.pt]", + "format": "PyTorch v1.6", + "link": "https://github.com/lutzroeder/netron/issues/720" + }, { "type": "pytorch", "target": "d2go.pt",