Skip to content

Commit

Permalink
Add PyTorch test file (#720)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Aug 26, 2022
1 parent eecd2a2 commit 65bcf81
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 11 deletions.
82 changes: 76 additions & 6 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -4083,6 +4083,20 @@ python.Execution = class {
}
throw new python.Error('Unsupported function range(' + JSON.stringify(start) + ', ' + JSON.stringify(stop) + ', ' + JSON.stringify(step) + ')');
});
this.registerFunction('torch._utils._rebuild_sparse_tensor', function(layout, data) {
if (layout === torch.sparse_coo) {
return self.invoke('torch._sparse_coo_tensor_unsafe', data);
}
throw new python.Error("Unsupported sparse tensor layout '" + (layout ? layout.__str__() : '') + "'.");
});
this.registerFunction('torch._sparse_coo_tensor_unsafe', function(indices, values, size) {
const tensor = self.invoke('torch.Tensor', []);
tensor._layout = torch.sparse_coo;
tensor._indices = indices;
tensor._values = values;
tensor._shape = size;
return tensor;
});
this.registerFunction('torch._utils._rebuild_tensor', function (storage, storage_offset, size, stride) {
const name = storage.__class__.__module__ + '.' + storage.__class__.__name__.replace('Storage', 'Tensor');
const tensor = self.invoke(name, []);
Expand Down Expand Up @@ -4298,6 +4312,11 @@ python.Execution = class {
});
this.registerFunction('torch.set_grad_enabled', function(/* value */) {
});

this.registerFunction('torch.serialization._get_layout', function(name) {
const value = name.startsWith('torch.') ? torch[name.split('.')[1]] : null;
return value instanceof torch.layout ? value : null;
});
this.registerFunction('torch.jit._pickle.build_boollist', function(data) {
return data;
});
Expand Down Expand Up @@ -4514,8 +4533,32 @@ python.Execution = class {
__str__() {
return 'torch.' + this._data.name;
}
toString() {
return this.__str__();
}
});
this.registerType('torch.layout', class {
constructor(name) {
this._name = name;
}
__str__() {
return this._name;
}
toString() {
return this.__str__();
}
});
this.registerType('torch.qscheme', class {
constructor(name) {
this._name = name;
}
__str__() {
return this._name;
}
toString() {
return this.__str__();
}
});
this.registerType('torch.qscheme', class {});
this.registerType('torch.utils.hooks.RemovableHandle', class {
__setstate__(state) {
this.hooks_dict_ref = state[0] || new Map();
Expand Down Expand Up @@ -4680,16 +4723,36 @@ python.Execution = class {
});
this.registerType('torch.Tensor', class {
constructor() {
this._layout = torch.strided;
}
get device() {
return this.storage().device;
}
get dtype() {
if (this._layout === torch.sparse_coo) {
return this._values.dtype();
}
return this.storage().dtype;
}
get shape() {
return this._shape;
}
get layout() {
return this._layout;
}
get values() {
if (this._layout === torch.sparse_coo) {
return this._values;
}
throw new python.Error("Unsupported values in layout'" + this._layout.__str__() + "'.");
}
get indices() {
if (this._indices === torch.sparse_coo) {
return this._indices;
}
throw new python.Error("Unsupported indices in layout'" + this._indices.__str__() + "'.");
}

size() {
return this._shape;
}
Expand Down Expand Up @@ -4809,11 +4872,18 @@ python.Execution = class {
torch.qint32 = torch.QInt32Storage.dtype = new torch.dtype({ type: 14, name: 'qint32', itemsize: 4 });
torch.bfloat16 = torch.BFloat16Storage.dtype = new torch.dtype({ type: 15, name: 'bfloat16', itemsize: 2 });
torch.quint4x2 = new torch.dtype({ type: 16, name: 'quint4x2' });
torch.per_tensor_affine = new torch.qscheme();
torch.per_channel_affine = new torch.qscheme();
torch.per_tensor_symmetric = new torch.qscheme();
torch.per_channel_symmetric = new torch.qscheme();
torch.per_channel_affine_float_qparams = new torch.qscheme();
torch.strided = new torch.layout('torch.strided');
torch.sparse_coo = new torch.layout('torch.sparse_coo');
torch.sparse_csr = new torch.layout('torch.sparse_csr');
torch.sparse_csc = new torch.layout('torch.sparse_csc');
torch.sparse_bsr = new torch.layout('torch.sparse_bsr');
torch.sparse_bsc = new torch.layout('torch.sparse_bsc');
torch._mkldnn = new torch.layout('torch._mkldnn');
torch.per_tensor_affine = new torch.qscheme('torch.per_tensor_affine');
torch.per_channel_affine = new torch.qscheme('torch.per_channel_affine');
torch.per_tensor_symmetric = new torch.qscheme('torch.per_tensor_symmetric');
torch.per_channel_symmetric = new torch.qscheme('torch.per_channel_symmetric');
torch.per_channel_affine_float_qparams = new torch.qscheme('torch.per_channel_affine_float_qparams');
torch.inf = this.register('math').inf;
}

Expand Down
17 changes: 12 additions & 5 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -614,14 +614,18 @@ pytorch.Attribute = class {

pytorch.Tensor = class {

constructor(name, type, data, littleEndian) {
constructor(name, type, layout, data, littleEndian) {
this._name = name || '';
this._type = type;
this._layout = layout;
this._data = data;
this._littleEndian = littleEndian;
}

get kind() {
if (this._layout === 'torch.sparse_coo') {
return 'Sparse Tensor';
}
return 'Tensor';
}

Expand Down Expand Up @@ -662,6 +666,10 @@ pytorch.Tensor = class {
context.index = 0;
context.count = 0;

if (this._layout !== null && this._layout !== 'torch.strided') {
context.state = "Tensor layout '" + this._layout + "' is not supported.";
return context;
}
if (!this._type.dataType) {
context.state = 'Tensor has no data type.';
return context;
Expand Down Expand Up @@ -693,15 +701,13 @@ pytorch.Tensor = class {
context.state = 'Tensor data is empty.';
return context;
}

try {
context.data = this._data instanceof Uint8Array ? this._data : this._data.peek();
}
catch (err) {
context.state = err.message;
return context;
}

context.dataType = this._type.dataType;
context.dimensions = this._type.shape.dimensions;
context.view = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength);
Expand Down Expand Up @@ -2765,7 +2771,8 @@ pytorch.Utility = class {
const storage = tensor.storage();
const size = tensor.size();
const type = new pytorch.TensorType(storage.dtype.__reduce__(), new pytorch.TensorShape(size));
return new pytorch.Tensor(name || '', type, storage.data, littleEndian);
const layout = tensor.layout ? tensor.layout.__str__() : null;
return new pytorch.Tensor(name || '', type, layout, storage.data, littleEndian);
}

static getType(value) {
Expand Down Expand Up @@ -3529,7 +3536,7 @@ pytorch.nnapi.Argument = class {
this._name = operand.index.toString();
const shape = new pytorch.TensorShape(operand.dimensions);
this._type = new pytorch.TensorType(operand.data_type.replace('[]', ''), shape);
this._initializer = operand.data ? new pytorch.Tensor(this._name, this._type, operand.data, true) : null;
this._initializer = operand.data ? new pytorch.Tensor(this._name, this._type, null, operand.data, true) : null;
this._scale = operand.scale;
this._zeroPoint = operand.zero_point;
}
Expand Down
7 changes: 7 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -4813,6 +4813,13 @@
"format": "TorchScript v1.6",
"link": "https://github.com/lutzroeder/netron/issues/842"
},
{
"type": "pytorch",
"target": "sparsified.pth",
"source": "https://github.com/lutzroeder/netron/files/9433521/sparsified.pth.zip[sparsified.pth]",
"format": "PyTorch v1.6",
"link": "https://github.com/lutzroeder/netron/issues/720"
},
{
"type": "pytorch",
"target": "superpoint_v1.pth",
Expand Down

0 comments on commit 65bcf81

Please sign in to comment.