diff --git a/source/pytorch.js b/source/pytorch.js index 825c1cfebe..1b0f097de4 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -107,16 +107,13 @@ pytorch.Graph = class { const queue = [module.data]; while (queue.length > 0) { const module = queue.shift(); - if (pytorch.Utility.isInstance(module, '__torch__.torch.classes._nnapi.Compilation')) { - continue; - } for (const [key, obj] of Object.entries(module)) { if (key !== '__module__' && key !== '__name__' && key !== '__class__' && key !== '__parent__') { if (!Array.isArray(obj) && obj === Object(obj)) { if (pytorch.Utility.isTensor(obj)) { const parameter = obj; parameter.__parent__ = module; - if (parameter.storage()) { + if (parameter.storage() && !parameter.__origin__) { if (parameter.__count__ === undefined || parameter.__count__ === 1) { initializers.set(parameter, new pytorch.Tensor(parameter.name, parameter)); } @@ -161,7 +158,7 @@ pytorch.Graph = class { node.outputs().every((output) => pytorch.Utility.isTensor(output.value))) { continue; } - this.nodes.push(new pytorch.Node(metadata, node, initializers, values)); + this.nodes.push(new pytorch.Node(metadata, null, null, node, initializers, values)); } if (module) { const queue = [module.data]; @@ -169,8 +166,12 @@ pytorch.Graph = class { const module = queue.pop(); if (module && !pytorch.Utility.isObject(module)) { if (!module.__hide__ && pytorch.Graph._getParameters(module).size > 0) { - const item = { module }; - const node = new pytorch.Node(metadata, item, initializers, values); + for (const [name, obj] of Object.entries(module)) { + if ((obj && obj.__hide__) || (obj !== null && !pytorch.Utility.isTensor(obj)) && typeof obj !== 'boolean' && typeof obj !== 'number' && typeof obj !== 'string') { + delete module[name]; + } + } + const node = new pytorch.Node(metadata, null, null, module, initializers, values); this.nodes.push(node); } const modules = []; @@ -186,32 +187,20 @@ pytorch.Graph = class { } } } else if (pytorch.Utility.isTensor(module)) { - const item = { type, obj: { value: module } }; - const node = new pytorch.Node(metadata, item); + const node = new pytorch.Node(metadata, null, type, { value: module }); this.nodes.push(node); } else { const weights = this.type === 'weights' ? module : pytorch.Utility.weights(module); if (weights) { for (const [name, module] of weights) { - const item = { name, type: 'Weights', obj: module }; - const node = new pytorch.Node(metadata, item); + const node = new pytorch.Node(metadata, name, 'Weights', module); this.nodes.push(node); } } else { const modules = Array.isArray(module) && module.every((module) => module && !pytorch.Utility.isTensor(module) && (module._modules !== undefined || module.__class__)) ? module : [module]; for (const module of modules) { - let type = module.__class__ && module.__class__.__module__ && module.__class__.__name__ ? `${module.__class__.__module__}.${module.__class__.__name__}` : null; - if (type === 'torch.jit._script.RecursiveScriptModule' && module._c && module._c.qualified_name) { - type = module._c.qualified_name; - } - if (!type) { - type = this.type === 'weights' ? 'Weights' : 'builtins.dict'; - } - const item = { - type, - obj: module - }; - const node = new pytorch.Node(metadata, item, {}, values); + const type = this.type === 'weights' ? 'Weights' : null; + const node = new pytorch.Node(metadata, null, type, module, null, values); this.nodes.push(node); } } @@ -260,13 +249,13 @@ pytorch.Value = class { pytorch.Node = class { - constructor(metadata, item, initializers, values, stack) { - this.name = item.name || ''; + constructor(metadata, name, type, obj, initializers, values, stack) { + this.name = name || ''; this.nodes = []; this.attributes = []; this.inputs = []; this.outputs = []; - const type = (metadata, name) => { + const createType = (metadata, name) => { if (name instanceof pytorch.nnapi.Graph) { return name; } @@ -305,9 +294,9 @@ pytorch.Node = class { return new pytorch.Argument(name, value, type, visible); }; let module = null; - if (pytorch.Utility.isInstance(item, 'torch.Node')) { - const node = item; - this.type = type(metadata, node.kind()); + if (pytorch.Utility.isInstance(obj, 'torch.Node')) { + const node = obj; + this.type = createType(metadata, node.kind()); let match = true; let count = 0; for (const input of node.inputs()) { @@ -366,10 +355,10 @@ pytorch.Node = class { if (pytorch.Utility.isObjectType(type)) { const obj = input.value; if (!array && initializers.has(obj)) { - const node = new pytorch.Node(metadata, { name, type, obj }, initializers, values); + const node = new pytorch.Node(metadata, name, type, obj, initializers, values); argument = new pytorch.Argument(name, node, 'object'); } else if (array && Array.isArray(obj) && obj.every((obj) => initializers.has(obj))) { - const node = obj.map((obj) => new pytorch.Node(metadata, { name, type, obj }, initializers, values)); + const node = obj.map((obj) => new pytorch.Node(metadata, name, type, obj, initializers, values)); argument = new pytorch.Argument(name, node, 'object[]'); } else { const identifier = input.unique().toString(); @@ -425,30 +414,20 @@ pytorch.Node = class { const argument = new pytorch.Argument(name, args); this.outputs.push(argument); } - } else if (item.module) { - module = item.module; - const type = module.__class__ ? `${module.__class__.__module__}.${module.__class__.__name__}` : 'torch.nn.modules.module.Module'; - this.type = { name: type }; - for (const [name, tensor] of pytorch.Graph._getParameters(module)) { - const initializer = initializers.get(tensor) || (tensor ? new pytorch.Tensor('', tensor) : null); - const value = values.map('', null, initializer || null); - this.inputs.push(new pytorch.Argument(name, [value])); - if (tensor.__variable__) { - const value = values.map(tensor.__variable__); - const argument = new pytorch.Argument(name, [value]); - this.outputs.push(argument); + } else { + if (!type) { + if (pytorch.Utility.isInstance(obj, 'torch.jit._script.RecursiveScriptModule') && obj._c && obj._c.qualified_name) { + type = obj._c.qualified_name; + } else if (pytorch.Utility.isInstance(obj, 'builtins.function')) { + type = `${obj.__module__}.${obj.__name__}`; + obj = {}; + } else if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) { + type = `${obj.__class__.__module__}.${obj.__class__.__name__}`; + } else { + type = 'builtins.object'; } } - } else { - this.type = type(metadata, item.type); - let obj = item.obj; - if (pytorch.Utility.isInstance(obj, 'builtins.function')) { - this.type = { name: `${obj.__module__}.${obj.__name__}` }; - obj = {}; - } - const parameters = new Map(); - const entries = []; - const attributes = new Map(); + this.type = createType(metadata, type); stack = stack || new Set(); const weights = pytorch.Utility.weights(obj); if (weights) { @@ -457,124 +436,117 @@ pytorch.Node = class { this.type.name = type; } else if (obj && pytorch.Utility.isInstance(obj, 'fastai.data.core.DataLoaders')) { // continue - } else if (obj && item.type === 'builtins.bytearray') { + } else if (obj && pytorch.Utility.isInstance(obj, '__torch__.torch.classes._nnapi.Compilation')) { + // continue + } else if (obj && type === 'builtins.bytearray') { const argument = new pytorch.Argument('value', Array.from(obj), 'byte[]'); this.inputs.push(argument); } else if (obj) { + const inputs = new Map(Array.isArray(this.type.inputs) ? this.type.inputs.map((input) => [input.name, input]) : []); const list = obj instanceof Map ? Array.from(obj) : Object.entries(obj); for (const [name, value] of list) { - if (name === '__class__' || name === '__hide__') { + if (name === '__class__' || name === '__parent__' || name === '__name__') { continue; - } else if (name === '_parameters' && value instanceof Map) { - for (const [name, parameter] of Array.from(value)) { - parameters.set(name, parameter); - } - } else if (name === '_buffers' && value instanceof Map) { - for (const [name, buffer] of Array.from(value)) { - parameters.set(name, buffer); + } else if (pytorch.Utility.isInstance(value, 'collections.OrderedDict') && value instanceof Map && value.size === 0) { + continue; + } else if (pytorch.Utility.isInstance(value, 'builtins.set') && value instanceof Set && value.size === 0) { + continue; + } else if (pytorch.Utility.isInstance(value, 'builtins.list') && Array.isArray(value) && value.length === 0) { + continue; + } + const parameters = new Map(); + if ((name === '_parameters' || name === '_buffers') && value instanceof Map && value.size > 0) { + for (const [name, obj] of Array.from(value)) { + parameters.set(name, obj); } } else if (Array.isArray(value) && value.every((tensor) => pytorch.Utility.isTensor(tensor))) { parameters.set(name, value); } else if (pytorch.Utility.isTensor(value)) { parameters.set(name, value); - } else if (value && value.__class__ && value.__class__.__module__ === 'collections' && value.__class__.__name__ === 'OrderedDict' && - value instanceof Map && value.size === 0) { - continue; - } else if (value && value.__class__ && value.__class__.__module__ === 'builtins' && value.__class__.__name__ === 'set' && - value instanceof Set && value.size === 0) { - continue; - } else if (value && value.__class__ && value.__class__.__module__ === 'builtins' && value.__class__.__name__ === 'list' && - Array.isArray(value) && value.length === 0) { + } + if (parameters.size > 0) { + for (const [name, value] of parameters) { + const list = Array.isArray(value) ? value.map((item) => pytorch.Utility.toTensor(item)) : [pytorch.Utility.toTensor(value)]; + const visible = inputs.has(name) ? inputs.get(name).visible || true : true; + const args = list.filter((value) => value !== null && !value.__origin__).map((value) => { + const name = value && value.name ? value.name : ''; + const identifier = list.length === 1 && value && value.__name__ ? value.__name__ : name; + let tensor = null; + if (initializers && initializers.has(value)) { + tensor = initializers.get(value); + } else { + value = value.__source__ ? value.__source__ : value; + tensor = value ? new pytorch.Tensor(identifier, value) : null; + } + return new pytorch.Value(identifier, null, null, tensor); + }); + const argument = new pytorch.Argument(name, args, null, visible); + this.inputs.push(argument); + if (value && value.__variable__) { + const argument = new pytorch.Argument(name, [values.map(value.__variable__)]); + this.outputs.push(argument); + } + } continue; - } else { - entries.push([name, value]); } - } - } - for (const [name, value] of entries) { - if (!parameters.has(name)) { - attributes.set(name, value); - } - } - const inputs = new Map(Array.isArray(this.type.inputs) ? this.type.inputs.map((input) => [input.name, input]) : []); - for (const [name, value] of parameters) { - const list = Array.isArray(value) ? value.map((item) => pytorch.Utility.toTensor(item)) : [pytorch.Utility.toTensor(value)]; - const visible = inputs.has(name) ? inputs.get(name).visible || true : true; - const values = list.filter((value) => value !== null).map((value) => { - const name = value && value.name ? value.name : ''; - const identifier = list.length === 1 && value && value.__name__ ? value.__name__ : name; - const tensor = value ? new pytorch.Tensor(identifier, value) : null; - return new pytorch.Value(identifier, null, null, tensor); - }); - const argument = new pytorch.Argument(name, values, null, visible); - this.inputs.push(argument); - } - for (const [name, value] of attributes) { - const type = this.type.identifier; - if (pytorch.Utility.isTensor(value)) { - const tensor = new pytorch.Tensor('', value); - const argument = new pytorch.Argument(name, tensor, 'tensor'); - this.inputs.push(argument); - } else if (value && pytorch.Utility.isInstance(value, 'torch.dtype')) { - const node = new pytorch.Node(metadata, { type: value.toString() }); - const argument = new pytorch.Argument(name, node, 'object'); - this.inputs.push(argument); - } else if (Array.isArray(value) && value.some((value) => pytorch.Utility.isTensor(value)) && value.every((value) => pytorch.Utility.isTensor(value) || value === null)) { - const tensors = value.map((value) => value === null ? value : new pytorch.Tensor('', value)); - const argument = new pytorch.Argument(name, tensors, 'tensor[]'); - this.inputs.push(argument); - } else if (pytorch.Utility.isInstance(value, 'numpy.ndarray') || pytorch.Utility.isInstance(value, 'numpy.matrix')) { - const tensor = new numpy.Tensor(value); - const argument = new pytorch.Argument(name, tensor, 'tensor'); - this.inputs.push(argument); - } else if (Array.isArray(value) && value.every((value) => typeof value === 'string')) { - const argument = new pytorch.Argument(name, value, 'string[]'); - this.inputs.push(argument); - } else if (Array.isArray(value) && value.every((value) => typeof value === 'number')) { - const argument = new pytorch.Argument(name, value, 'attribute'); - this.inputs.push(argument); - } else if (name === '_modules' && value && value.__class__ && value.__class__.__module__ === 'collections' && value.__class__.__name__ === 'OrderedDict' && - value instanceof Map && Array.from(value).every(([, value]) => value === null || value.__class__)) { - const values = Array.from(value).filter(([, value]) => !stack.has(value)).map(([name, obj]) => { - stack.add(value); - const type = obj === null ? 'builtins.NoneType' : `${obj.__class__.__module__}.${obj.__class__.__name__}`; - const node = new pytorch.Node(metadata, { name, type, obj }); - stack.delete(value); - return node; - }); - const argument = new pytorch.Argument(name, values, 'object[]'); - this.inputs.push(argument); - } else if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => Array.isArray(obj) && obj.every((item) => typeof item === 'string' || typeof item === 'number'))) { - const argument = new pytorch.Argument(name, value, 'attribute'); - this.inputs.push(argument); - } else if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => obj && (obj.__class__ || obj === Object(obj)))) { - const values = value.filter((value) => !stack.has(value)); - const nodes = values.map((value) => { + const type = this.type.identifier; + if (pytorch.Utility.isTensor(value)) { + const tensor = new pytorch.Tensor('', value); + const argument = new pytorch.Argument(name, tensor, 'tensor'); + this.inputs.push(argument); + } else if (value && pytorch.Utility.isInstance(value, 'torch.dtype')) { + const node = new pytorch.Node(metadata, null, value.toString(), {}); + const argument = new pytorch.Argument(name, node, 'object'); + this.inputs.push(argument); + } else if (Array.isArray(value) && value.some((value) => pytorch.Utility.isTensor(value)) && value.every((value) => pytorch.Utility.isTensor(value) || value === null)) { + const tensors = value.map((value) => value === null ? value : new pytorch.Tensor('', value)); + const argument = new pytorch.Argument(name, tensors, 'tensor[]'); + this.inputs.push(argument); + } else if (pytorch.Utility.isInstance(value, 'numpy.ndarray') || pytorch.Utility.isInstance(value, 'numpy.matrix')) { + const tensor = new numpy.Tensor(value); + const argument = new pytorch.Argument(name, tensor, 'tensor'); + this.inputs.push(argument); + } else if (Array.isArray(value) && value.every((value) => typeof value === 'string')) { + const argument = new pytorch.Argument(name, value, 'string[]'); + this.inputs.push(argument); + } else if (Array.isArray(value) && value.every((value) => typeof value === 'number')) { + const argument = new pytorch.Argument(name, value, 'attribute'); + this.inputs.push(argument); + } else if (name === '_modules' && pytorch.Utility.isInstance(value, 'collections.OrderedDict') && + value instanceof Map && Array.from(value).every(([, value]) => value === null || value.__class__)) { + const values = Array.from(value).filter(([, value]) => !stack.has(value)).map(([name, obj]) => { + stack.add(value); + const type = obj === null ? 'builtins.NoneType' : `${obj.__class__.__module__}.${obj.__class__.__name__}`; + const node = new pytorch.Node(metadata, name, type, obj); + stack.delete(value); + return node; + }); + const argument = new pytorch.Argument(name, values, 'object[]'); + this.inputs.push(argument); + } else if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => Array.isArray(obj) && obj.every((item) => typeof item === 'string' || typeof item === 'number'))) { + const argument = new pytorch.Argument(name, value, 'attribute'); + this.inputs.push(argument); + } else if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => obj && (obj.__class__ || obj === Object(obj)))) { + const list = value.filter((value) => !stack.has(value)); + const nodes = list.map((value) => { + stack.add(value); + const node = new pytorch.Node(metadata, null, null, value, initializers, values, stack); + stack.delete(value); + return node; + }); + const argument = new pytorch.Argument(name, nodes, 'object[]'); + this.inputs.push(argument); + } else if (value && (value.__class__ || typeof value === 'object') && !stack.has(value)) { stack.add(value); - const item = { - type: value.__class__ ? `${value.__class__.__module__}.${value.__class__.__name__}` : 'builtins.object', - obj: value - }; - const node = new pytorch.Node(metadata, item, initializers, values, stack); + const node = new pytorch.Node(metadata, null, null, value, initializers, values, stack); stack.delete(value); - return node; - }); - const argument = new pytorch.Argument(name, nodes, 'object[]'); - this.inputs.push(argument); - } else if (value && (value.__class__ || typeof value === 'object') && !stack.has(value)) { - stack.add(value); - const item = { - type: value.__class__ ? `${value.__class__.__module__}.${value.__class__.__name__}` : 'builtins.object', - obj: value - }; - const node = new pytorch.Node(metadata, item, initializers, values, stack); - stack.delete(value); - const visible = name === '_metadata' && pytorch.Utility.isMetadataObject(value) ? false : true; - const argument = new pytorch.Argument(name, node, 'object', visible); - this.inputs.push(argument); - } else { - const argument = createAttribute(metadata.attribute(type, name), name, value); - this.inputs.push(argument); + const visible = name === '_metadata' && pytorch.Utility.isMetadataObject(value) ? false : true; + const argument = new pytorch.Argument(name, node, 'object', visible); + this.inputs.push(argument); + } else { + const argument = createAttribute(metadata.attribute(type, name), name, value); + this.inputs.push(argument); + } } } } @@ -2941,7 +2913,7 @@ pytorch.jit.ScriptModuleDeserializer = class { execution.builtins.CONSTANTS[`c${i}`] = constants[i]; } const module = this.readArchive('data'); - const result = new torch.ScriptModule(); + const result = new torch.ScriptModule(`${module.__class__.__module__}.${module.__class__.__name__}`); result.data = module; return result; } diff --git a/test/models.json b/test/models.json index 6cced7ce04..189a7863eb 100644 --- a/test/models.json +++ b/test/models.json @@ -4923,6 +4923,7 @@ "target": "bad-base_libri.pt", "source": "https://github.com/user-attachments/files/16401716/base_libri.pt.zip[base_libri.pt]", "format": "PyTorch v1.6", + "assert": "model.graphs[0].nodes[0].inputs[7].value.inputs[0].value.type.name == 'builtins.object' == null", "link": "https://github.com/lutzroeder/netron/issues/543" }, { @@ -5372,7 +5373,7 @@ "target": "mnist_linear_dynamic_quantized.pt", "source": "https://github.com/lutzroeder/netron/files/4774023/mnist_linear_dynamic_quantized.zip[mnist_linear_dynamic_quantized.pt]", "format": "PyTorch v0.1.10", - "assert": "model.graphs[0].nodes[0].inputs[3].value.type.name == 'torch.qint8'", + "assert": "model.graphs[0].nodes[0].inputs[2].value.type.name == 'torch.qint8'", "link": "https://github.com/lutzroeder/netron/issues/519" }, { @@ -5879,7 +5880,7 @@ "target": "rng_state.pth", "source": "https://github.com/user-attachments/files/16401709/rng_state.pth.zip[rng_state.pth]", "format": "PyTorch v1.6", - "assert": "model.graphs[0].nodes[0].inputs[3].name == 'numpy'", + "assert": "model.graphs[0].nodes[0].inputs[1].name == 'numpy'", "link": "https://github.com/lutzroeder/netron/issues/543" }, { @@ -6266,7 +6267,6 @@ "target": "v1_lj_8000.jit", "source": "https://github.com/user-attachments/files/16041474/v1_lj_8000.jit.zip[v1_lj_8000.jit]", "format": "TorchScript v1.6", - "error": "Cannot read properties of undefined (reading 'position')", "link": "https://github.com/lutzroeder/netron/issues/1061" }, {