Skip to content

Commit

Permalink
Add PyTorch test files (#720)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Mar 23, 2024
1 parent dcbb4a1 commit 7c771f8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
15 changes: 7 additions & 8 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -895,15 +895,18 @@ pytorch.Container.data_pkl = class extends pytorch.Container {
if (pytorch.Utility.isTensor(obj)) {
return new pytorch.Container.data_pkl('tensor', obj);
}
if (Array.isArray(obj) && obj.every((tensor) => pytorch.Utility.isTensor(tensor))) {
return new pytorch.Container.data_pkl('tensor', obj);
}
if (obj instanceof Map) {
const entries = Array.from(obj).filter(([name, value]) => name === '_metadata' || pytorch.Utility.isTensor(value));
if (entries.length > 0) {
return new pytorch.Container.data_pkl('tensor<>', obj);
return new pytorch.Container.data_pkl('tensor', obj);
}
} else if (!Array.isArray(obj)) {
const entries = Object.entries(obj).filter(([name, value]) => name === '_metadata' || pytorch.Utility.isTensor(value));
if (entries.length > 0) {
return new pytorch.Container.data_pkl('tensor<>', obj);
return new pytorch.Container.data_pkl('tensor', obj);
}
}
for (const key of ['', 'model', 'net']) {
Expand All @@ -924,12 +927,7 @@ pytorch.Container.data_pkl = class extends pytorch.Container {
}

get format() {
switch (this._type) {
case 'module': return 'PyTorch';
case 'tensor': return 'PyTorch Tensor';
case 'tensor<>': return 'PyTorch Pickle Weights';
default: return 'PyTorch Pickle';
}
return 'PyTorch Pickle';
}

get modules() {
Expand All @@ -945,6 +943,7 @@ pytorch.Container.data_pkl = class extends pytorch.Container {
return this._modules;
}
case 'tensor':
case 'tensor[]':
case 'tensor<>': {
if (this._data) {
this._modules = pytorch.Utility.findWeights(this._data);
Expand Down
22 changes: 18 additions & 4 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -4758,7 +4758,7 @@
"type": "pytorch",
"target": "densenet.data.pkl",
"source": "https://github.com/lutzroeder/netron/files/13064609/densenet.data.pkl.zip[densenet.data.pkl]",
"format": "PyTorch",
"format": "PyTorch Pickle",
"link": "https://github.com/lutzroeder/netron/issues/720"
},
{
Expand Down Expand Up @@ -4821,7 +4821,7 @@
"type": "pytorch",
"target": "fast.ai.data.pkl",
"source": "https://github.com/lutzroeder/netron/files/13064775/fast.ai.data.pkl.zip[fast.ai.data.pkl]",
"format": "PyTorch",
"format": "PyTorch Pickle",
"link": "https://github.com/lutzroeder/netron/issues/720"
},
{
Expand Down Expand Up @@ -5593,11 +5593,25 @@
"format": "TorchScript v1.0",
"link": "https://github.com/KinglittleQ/SuperPoint_SLAM"
},
{
"type": "pytorch",
"target": "tensor.pkl",
"source": "https://github.com/lutzroeder/netron/files/14733020/tensor.pkl.zip[tensor.pkl]",
"format": "PyTorch Pickle",
"link": "https://github.com/lutzroeder/netron/issues/720"
},
{
"type": "pytorch",
"target": "tensors.pkl",
"source": "https://github.com/lutzroeder/netron/files/14733021/tensors.pkl.zip[tensors.pkl]",
"format": "PyTorch Pickle",
"link": "https://github.com/lutzroeder/netron/issues/720"
},
{
"type": "pytorch",
"target": "tensors.data.pkl",
"source": "https://github.com/lutzroeder/netron/files/13061412/tensors.data.pkl.zip[tensors.data.pkl]",
"format": "PyTorch Pickle Weights",
"format": "PyTorch Pickle",
"link": "https://github.com/lutzroeder/netron/issues/720"
},
{
Expand Down Expand Up @@ -5836,7 +5850,7 @@
"type": "pytorch",
"target": "yolov5n.tensor.data.pkl",
"source": "https://github.com/lutzroeder/netron/files/13064842/yolov5n.tensor.data.pkl.zip[yolov5n.tensor.data.pkl]",
"format": "PyTorch Pickle Weights",
"format": "PyTorch Pickle",
"link": "https://github.com/lutzroeder/netron/issues/720"
},
{
Expand Down

0 comments on commit 7c771f8

Please sign in to comment.