From 7c771f8f3b7aff6384e43684d506643b7794e6f0 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sat, 23 Mar 2024 10:10:26 -0700 Subject: [PATCH] Add PyTorch test files (#720) --- source/pytorch.js | 15 +++++++-------- test/models.json | 22 ++++++++++++++++++---- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/source/pytorch.js b/source/pytorch.js index 085c1cfa5c..923e2ebef0 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -895,15 +895,18 @@ pytorch.Container.data_pkl = class extends pytorch.Container { if (pytorch.Utility.isTensor(obj)) { return new pytorch.Container.data_pkl('tensor', obj); } + if (Array.isArray(obj) && obj.every((tensor) => pytorch.Utility.isTensor(tensor))) { + return new pytorch.Container.data_pkl('tensor', obj); + } if (obj instanceof Map) { const entries = Array.from(obj).filter(([name, value]) => name === '_metadata' || pytorch.Utility.isTensor(value)); if (entries.length > 0) { - return new pytorch.Container.data_pkl('tensor<>', obj); + return new pytorch.Container.data_pkl('tensor', obj); } } else if (!Array.isArray(obj)) { const entries = Object.entries(obj).filter(([name, value]) => name === '_metadata' || pytorch.Utility.isTensor(value)); if (entries.length > 0) { - return new pytorch.Container.data_pkl('tensor<>', obj); + return new pytorch.Container.data_pkl('tensor', obj); } } for (const key of ['', 'model', 'net']) { @@ -924,12 +927,7 @@ pytorch.Container.data_pkl = class extends pytorch.Container { } get format() { - switch (this._type) { - case 'module': return 'PyTorch'; - case 'tensor': return 'PyTorch Tensor'; - case 'tensor<>': return 'PyTorch Pickle Weights'; - default: return 'PyTorch Pickle'; - } + return 'PyTorch Pickle'; } get modules() { @@ -945,6 +943,7 @@ pytorch.Container.data_pkl = class extends pytorch.Container { return this._modules; } case 'tensor': + case 'tensor[]': case 'tensor<>': { if (this._data) { this._modules = pytorch.Utility.findWeights(this._data); diff --git a/test/models.json b/test/models.json index 8b4fcf9c3a..e631c47606 100644 --- a/test/models.json +++ b/test/models.json @@ -4758,7 +4758,7 @@ "type": "pytorch", "target": "densenet.data.pkl", "source": "https://github.com/lutzroeder/netron/files/13064609/densenet.data.pkl.zip[densenet.data.pkl]", - "format": "PyTorch", + "format": "PyTorch Pickle", "link": "https://github.com/lutzroeder/netron/issues/720" }, { @@ -4821,7 +4821,7 @@ "type": "pytorch", "target": "fast.ai.data.pkl", "source": "https://github.com/lutzroeder/netron/files/13064775/fast.ai.data.pkl.zip[fast.ai.data.pkl]", - "format": "PyTorch", + "format": "PyTorch Pickle", "link": "https://github.com/lutzroeder/netron/issues/720" }, { @@ -5593,11 +5593,25 @@ "format": "TorchScript v1.0", "link": "https://github.com/KinglittleQ/SuperPoint_SLAM" }, + { + "type": "pytorch", + "target": "tensor.pkl", + "source": "https://github.com/lutzroeder/netron/files/14733020/tensor.pkl.zip[tensor.pkl]", + "format": "PyTorch Pickle", + "link": "https://github.com/lutzroeder/netron/issues/720" + }, + { + "type": "pytorch", + "target": "tensors.pkl", + "source": "https://github.com/lutzroeder/netron/files/14733021/tensors.pkl.zip[tensors.pkl]", + "format": "PyTorch Pickle", + "link": "https://github.com/lutzroeder/netron/issues/720" + }, { "type": "pytorch", "target": "tensors.data.pkl", "source": "https://github.com/lutzroeder/netron/files/13061412/tensors.data.pkl.zip[tensors.data.pkl]", - "format": "PyTorch Pickle Weights", + "format": "PyTorch Pickle", "link": "https://github.com/lutzroeder/netron/issues/720" }, { @@ -5836,7 +5850,7 @@ "type": "pytorch", "target": "yolov5n.tensor.data.pkl", "source": "https://github.com/lutzroeder/netron/files/13064842/yolov5n.tensor.data.pkl.zip[yolov5n.tensor.data.pkl]", - "format": "PyTorch Pickle Weights", + "format": "PyTorch Pickle", "link": "https://github.com/lutzroeder/netron/issues/720" }, {