Skip to content

Commit

Permalink
Add PyTorch test file (#720)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jul 14, 2022
1 parent d3b635f commit 44abcdc
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 31 deletions.
107 changes: 76 additions & 31 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -678,9 +678,11 @@ pytorch.Tensor = class {
case 'float32':
case 'float64':
case 'bfloat16':
case 'complex64':
case 'complex128':
break;
default:
context.state = "Tensor data type '" + this._type.dataType + "' is not supported.";
context.state = "Tensor data type '" + this._type.dataType + "' is not implemented.";
return context;
}
if (!this._type.shape) {
Expand All @@ -702,7 +704,7 @@ pytorch.Tensor = class {

context.dataType = this._type.dataType;
context.dimensions = this._type.shape.dimensions;
context.dataView = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength);
context.view = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength);
return context;
}

Expand All @@ -718,56 +720,66 @@ pytorch.Tensor = class {
}
switch (context.dataType) {
case 'boolean':
results.push(context.dataView.getUint8(context.index) === 0 ? false : true);
results.push(context.view.getUint8(context.index) === 0 ? false : true);
context.index++;
context.count++;
break;
case 'uint8':
results.push(context.dataView.getUint8(context.index));
results.push(context.view.getUint8(context.index));
context.index++;
context.count++;
break;
case 'qint8':
case 'int8':
results.push(context.dataView.getInt8(context.index));
results.push(context.view.getInt8(context.index));
context.index++;
context.count++;
break;
case 'int16':
results.push(context.dataView.getInt16(context.index, this._littleEndian));
results.push(context.view.getInt16(context.index, this._littleEndian));
context.index += 2;
context.count++;
break;
case 'int32':
results.push(context.dataView.getInt32(context.index, this._littleEndian));
results.push(context.view.getInt32(context.index, this._littleEndian));
context.index += 4;
context.count++;
break;
case 'int64':
results.push(context.dataView.getInt64(context.index, this._littleEndian));
results.push(context.view.getInt64(context.index, this._littleEndian));
context.index += 8;
context.count++;
break;
case 'float16':
results.push(context.dataView.getFloat16(context.index, this._littleEndian));
results.push(context.view.getFloat16(context.index, this._littleEndian));
context.index += 2;
context.count++;
break;
case 'float32':
results.push(context.dataView.getFloat32(context.index, this._littleEndian));
results.push(context.view.getFloat32(context.index, this._littleEndian));
context.index += 4;
context.count++;
break;
case 'float64':
results.push(context.dataView.getFloat64(context.index, this._littleEndian));
results.push(context.view.getFloat64(context.index, this._littleEndian));
context.index += 8;
context.count++;
break;
case 'bfloat16':
results.push(context.dataView.getBfloat16(context.index, this._littleEndian));
results.push(context.view.getBfloat16(context.index, this._littleEndian));
context.index += 2;
context.count++;
break;
case 'complex64':
results.push(context.view.getComplex64(i << 3, this._littleEndian));
context.index += 8;
context.count++;
break;
case 'complex128':
results.push(context.view.getComplex128(i << 4, this._littleEndian));
context.index += 16;
context.count++;
break;
default:
throw new pytorch.Error("Unsupported tensor data type '" + context.dataType + "'.");
}
Expand Down Expand Up @@ -799,22 +811,26 @@ pytorch.Tensor = class {
result.push(indentation + ']');
return result.join('\n');
}
if (value && (value instanceof base.Int64 || value instanceof base.Uint64)) {
return indentation + value.toString();
}
if (typeof value == 'string') {
return indentation + value;
}
if (value == Infinity) {
return indentation + 'Infinity';
}
if (value == -Infinity) {
return indentation + '-Infinity';
}
if (isNaN(value)) {
return indentation + 'NaN';
switch (typeof value) {
case 'string':
return indentation + value;
case 'number':
if (value == Infinity) {
return indentation + 'Infinity';
}
if (value == -Infinity) {
return indentation + '-Infinity';
}
if (isNaN(value)) {
return indentation + 'NaN';
}
return indentation + value.toString();
default:
if (value && value.toString) {
return indentation + value.toString();
}
return indentation + '(undefined)';
}
return indentation + value.toString();
}
};

Expand Down Expand Up @@ -1833,20 +1849,20 @@ pytorch.Execution = class extends python.Execution {
this._device = null;
}
get device() {
return null;
return this._device;
}
get dtype() {
return this._dtype;
}
get data() {
return this._cdata;
}
element_size() {
return this._dtype.element_size;
}
size() {
return this._size;
}
get data() {
return this._cdata;
}
_set_cdata(data) {
const length = this.size() * this.dtype.itemsize();
if (length !== data.length) {
Expand Down Expand Up @@ -1876,6 +1892,33 @@ pytorch.Execution = class extends python.Execution {
return storage;
}
});
this.registerType('torch.storage._UntypedStorage', class extends torch_storage._StorageBase {
constructor() {
super();
throw new python.Error('_UntypedStorage not implemented.');
}
});
this.registerType('torch.storage._TypedStorage', class {
constructor() {
throw new python.Error('_TypedStorage not implemented.');
}
});
this.registerType('torch.storage._LegacyStorage', class extends torch_storage._TypedStorage {
constructor() {
super();
throw new python.Error('_LegacyStorage not implemented.');
}
});
this.registerType('torch.ComplexFloatStorage', class extends torch_storage._StorageBase {
constructor(size) {
super(size, torch.complex64);
}
});
this.registerType('torch.ComplexDoubleStorage', class extends torch_storage._StorageBase {
constructor(size) {
super(size, torch.complex128);
}
});
this.registerType('torch.BoolStorage', class extends torch_storage._StorageBase {
constructor(size) {
super(size, torch.bool);
Expand Down Expand Up @@ -2058,6 +2101,8 @@ pytorch.Execution = class extends python.Execution {
this.registerType('torch.HalfTensor', class extends torch.Tensor {});
this.registerType('torch.FloatTensor', class extends torch.Tensor {});
this.registerType('torch.DoubleTensor', class extends torch.Tensor {});
this.registerType('torch.ComplexFloatTensor', class extends torch.Tensor {});
this.registerType('torch.ComplexDoubleTensor', class extends torch.Tensor {});
this.registerType('torch.QInt8Tensor', class extends torch.Tensor {});
this.registerType('torch.QUInt8Tensor', class extends torch.Tensor {});
this.registerType('torch.QInt32Tensor', class extends torch.Tensor {});
Expand Down
7 changes: 7 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -4155,6 +4155,13 @@
"format": "TorchScript v1.0",
"link": "https://github.com/ApolloAuto/apollo"
},
{
"type": "pytorch",
"target": "complex_tensor.pt",
"source": "https://github.com/lutzroeder/netron/files/9108149/complex_tensor.pt.zip[complex_tensor.pt]",
"format": "PyTorch v1.6",
"link": "https://github.com/lutzroeder/netron/issues/720"
},
{
"type": "pytorch",
"target": "d2go.pt",
Expand Down

0 comments on commit 44abcdc

Please sign in to comment.