diff --git a/source/pytorch.js b/source/pytorch.js index fd367e9796..3ff7e1fc08 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -695,14 +695,6 @@ pytorch.Node = class { const value = metadata.type(key); this.type = value ? { ...value } : { name: type }; this.type.identifier = type; - if (this.type.name.indexOf('(') !== -1) { - throw new Error(); - } - if (this.type.name.indexOf('::') !== -1) { - throw new Error(); - } - // [name] = this.type.name.split('('); - // this.type.name = name.indexOf('::') === -1 ? name : name.split('::').pop().split('.')[0]; } stack = stack || new Set(); const weights = pytorch.Utility.weights(obj); @@ -729,6 +721,8 @@ pytorch.Node = class { continue; } else if (pytorch.Utility.isInstance(value, 'builtins.list') && Array.isArray(value) && value.length === 0) { continue; + } else if (pytorch.Utility.isInstance(value, 'torch.Size') && Array.isArray(value) && value.length === 0) { + continue; } const parameters = new Map(); if ((name === '_parameters' || name === '_buffers') && value instanceof Map && value.size > 0) { @@ -3492,9 +3486,9 @@ pytorch.Utility = class { if (obj instanceof Map) { const entries = Array.from(obj).filter(([name]) => name !== '_metadata'); const names = entries.filter(([name]) => typeof name === 'string' && (name.indexOf('.') !== -1 || name.indexOf('|') !== -1)); - if (names.length > 1 && - (names.length / entries.length) >= 0.8 && - entries.every(([, value]) => !pytorch.Utility.isInstance(value, 'builtins.dict') || Array.from(value.values()).every((value) => !pytorch.Utility.isTensor(value)))) { + if (names.length > 1 && (names.length / entries.length) >= 0.8 && + (entries.every(([, value]) => !pytorch.Utility.isInstance(value, 'builtins.dict') || Array.from(value.values()).every((value) => !pytorch.Utility.isTensor(value)))) && + (!entries.every(([, value]) => Array.isArray(value)))) { const modules = new Map(); for (const [name, value] of entries) { const separator = name.indexOf('.') === -1 && name.indexOf('|') !== -1 ? '|' : '.'; diff --git a/test/models.json b/test/models.json index 55824a7cb9..56b0e68bbe 100644 --- a/test/models.json +++ b/test/models.json @@ -5429,6 +5429,13 @@ "format": "PyTorch v1.6", "link": "https://github.com/lutzroeder/netron/issues/133" }, + { + "type": "pytorch", + "target": "InternVideo2-stage2_1b-224p-f4.pt", + "source": "https://github.com/user-attachments/files/17607734/InternVideo2-stage2_1b-224p-f4.pt.zip[InternVideo2-stage2_1b-224p-f4.pt]", + "format": "PyTorch v1.6", + "link": "https://github.com/lutzroeder/netron/issues/720" + }, { "type": "pytorch", "target": "iv3_pertensor.pt",