Skip to content

Commit

Permalink
Update Pickle test files (#901)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Aug 29, 2024
1 parent fb810eb commit cc7c063
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 10 deletions.
53 changes: 46 additions & 7 deletions source/pickle.js
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,48 @@ pickle.Value = class {

pickle.Tensor = class {

constructor(array) {
this.type = new pickle.TensorType(array.dtype.__name__, new pickle.TensorShape(array.shape));
this.stride = Array.isArray(array.strides) ? array.strides.map((stride) => stride / array.itemsize) : null;
this.encoding = this.type.dataType === 'string' || this.type.dataType === 'object' ? '|' : array.dtype.byteorder;
this.values = this.type.dataType === 'string' || this.type.dataType === 'object' || this.type.dataType === 'void' ? array.flatten().tolist() : array.tobytes();
constructor(obj) {
if (obj.__class__ && (obj.__class__.__module__ === 'torch' || obj.__class__.__module__ === 'torch.nn.parameter')) {
// PyTorch tensor
const tensor = obj.__class__.__module__ === 'torch.nn.parameter' && obj.__class__.__name__ === 'Parameter' ? obj.data : obj;
const layout = tensor.layout ? tensor.layout.__str__() : null;
const storage = tensor.storage();
const size = tensor.size() || [];
if (!layout || layout === 'torch.strided') {
this.type = new pickle.TensorType(storage.dtype.__reduce__(), new pickle.TensorShape(size));
this.values = storage.data;
this.encoding = '<';
this.indices = null;
this.stride = tensor.stride();
const stride = this.stride;
const offset = tensor.storage_offset();
let length = 0;
if (!Array.isArray(stride)) {
length = storage.size();
} else if (size.every((v) => v !== 0)) {
length = size.reduce((a, v, i) => a + stride[i] * (v - 1), 1);
}
if (offset !== 0 || length !== storage.size()) {
const itemsize = storage.dtype.itemsize();
const stream = this.values;
const position = stream.position;
stream.seek(itemsize * offset);
this.values = stream.peek(itemsize * length);
stream.seek(position);
} else {
this.values = this.values.peek();
}
} else {
throw new pickle.Error(`Unsupported tensor layout '${layout}'.`);
}
} else {
// NumPy array
const array = obj;
this.type = new pickle.TensorType(array.dtype.__name__, new pickle.TensorShape(array.shape));
this.stride = Array.isArray(array.strides) ? array.strides.map((stride) => stride / array.itemsize) : null;
this.encoding = this.type.dataType === 'string' || this.type.dataType === 'object' ? '|' : array.dtype.byteorder;
this.values = this.type.dataType === 'string' || this.type.dataType === 'object' || this.type.dataType === 'void' ? array.flatten().tolist() : array.tobytes();
}
}
};

Expand Down Expand Up @@ -264,10 +301,12 @@ pickle.Utility = class {
}

static isTensor(obj) {
return obj && obj.__class__ &&
return obj && obj.__class__ && obj.__class__.__name__ &&
((obj.__class__.__module__ === 'numpy' && obj.__class__.__name__ === 'ndarray') ||
(obj.__class__.__module__ === 'numpy' && obj.__class__.__name__ === 'matrix') ||
(obj.__class__.__module__ === 'jax' && obj.__class__.__name__ === 'Array'));
(obj.__class__.__module__ === 'jax' && obj.__class__.__name__ === 'Array') ||
(obj.__class__.__module__ === 'torch.nn.parameter' && obj.__class__.__name__ === 'Parameter') ||
(obj.__class__.__module__ === 'torch' && obj.__class__.__name__.endsWith('Tensor')));
}

static weights(obj) {
Expand Down
13 changes: 10 additions & 3 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -4793,6 +4793,13 @@
"error": "Unknown type name '__builtin__.__main__'.\nUnsupported Pickle type '__main__.CustomClass'.",
"link": "https://github.com/lutzroeder/netron/issues/901"
},
{
"type": "pickle",
"target": "easy-khair-180-gpc0.8-trans10-025000.pkl.zip",
"source": "https://github.com/user-attachments/files/16783181/easy-khair-180-gpc0.8-trans10-025000.pkl.zip",
"format": "Pickle",
"link": "https://github.com/lutzroeder/netron/issues/901"
},
{
"type": "pickle",
"target": "file.pickle",
Expand Down Expand Up @@ -4855,11 +4862,11 @@
},
{
"type": "pickle",
"target": "R-50.pkl",
"source": "https://dl.fbaipublicfiles.com/detectron/ImageNetPretrained/MSRA/R-50.pkl",
"target": "R-50.pkl.zip",
"source": "https://github.com/user-attachments/files/16783222/R-50.pkl.zip",
"format": "Pickle",
"assert": "model.graphs[0].nodes.length == 107",
"link": "https://github.com/facebookresearch/Detectron"
"link": "https://github.com/lutzroeder/netron/issues/901"
},
{
"type": "pickle",
Expand Down

0 comments on commit cc7c063

Please sign in to comment.