diff --git a/source/python.js b/source/python.js index bdbdc9c9f4..b327f75756 100644 --- a/source/python.js +++ b/source/python.js @@ -6349,20 +6349,39 @@ python.Execution = class { } }); this.registerType('torch.TupleType', class extends torch.Type { - constructor(elements) { - super('TupleType'); + constructor(elements, annotation_str, schema) { + super('TupleType', annotation_str); this._elements = elements; + this._schema = schema; } static get(elements) { return new torch.TupleType(elements); } + static createNamed(qualified_name, field_names, field_types /*, field_defaults */) { + const args = []; + for (let i = 0; i < field_names.length; i++) { + const arg = new torch.Argument(field_names[i], field_types[i], field_types[i]); + args.push(arg); + } + const schema = new torch.FunctionSchema(qualified_name, args); + return new torch.TupleType(field_types, qualified_name, schema); + } elements() { return this._elements; } + schema() { + return this._schema; + } str() { + if (this._schema) { + return `NamedTuple(...)`; + } return `(${this.elements().map((elem) => elem.str()).join(', ')})`; } __str__() { + if (this.annotation_str) { + return this.annotation_str; + } return `Tuple[${this.elements().map((elem) => elem.__str__()).join(', ')}]`; } }); @@ -7074,11 +7093,11 @@ python.Execution = class { const index = name.indexOf('('); if (index === -1) { this._name = name; - this._overload_name = overload_name; - this._arguments = args; - this._returns = returns; - this._is_vararg = is_vararg; - this._is_varret = is_varret; + this._overload_name = overload_name || ''; + this._arguments = args || []; + this._returns = returns || []; + this._is_vararg = is_vararg || false; + this._is_varret = is_varret || false; } else { const value = name.substring(0, index).trim(); const dot = value.indexOf('.'); @@ -7689,22 +7708,32 @@ python.Execution = class { this.register('torch.jit._script'); this.register('torch.jit._trace'); this.registerType('torch.jit.Source', class { - constructor(text) { - this._text = text; + constructor(text_view, filename) { + this._text_view = text_view; + this._filename = filename; + } + text_str() { + return this._text_view; + } + filename() { + return this._filename; } }); - this.registerType('torch.jit.SourceLoader', class { - constructor(reader, code_prefix) { - this._reader = reader; - this._code_prefix = code_prefix; + this.registerType('torch.jit.QualifiedName', class { + constructor(name) { + const index = name.lastIndexOf('.'); + this._qualifiedName = name; + this._prefix = index === -1 ? '' : name.substring(0, index); + this._name = index === -1 ? name : name.substring(index + 1); } - loadSource(qualifier) { - const path = `${this._code_prefix}/${qualifier}.py`; - if (this._reader.has_record(path)) { - const data = this._reader.get_record(path); - return new torch.jit.Source(data); - } - return null; + qualifiedName() { + return this._qualifiedName; // "foo.bar.baz" + } + prefix() { + return this._prefix; // "foo.bar" + } + name() { + return this._name; // "baz" } }); this.registerType('torch.jit.SourceImporter', class { @@ -7713,17 +7742,93 @@ python.Execution = class { this._constant_table = constant_table; this._source_loader = source_loader; this._version = version; + this._loaded_sources = new Set(); + this._to_be_defined = new Map(); } loadType(/* name */) { // } resolveType(name) { - return this.findNamedType(new torch.jit.QualifiedName(name)); + name = new torch.jit.QualifiedName(name); + return this.findNamedType(name); } findNamedType(name) { + // if (auto custom_class = getCustomClass(name.qualifiedName())) { + // return custom_class; + // } this.parseSourceIfNeeded(name.prefix()); + const key = name.qualifiedName(); + const it = this._to_be_defined.get(name.qualifiedName()); + if (it && it.type === 'class') { + this._to_be_defined.delete(key); + this.importNamedType(name.prefix(), it); + } + return this._cu.get_type(name); + } + importNamedType(qualifier, class_def) { + const qualified_name = new torch.jit.QualifiedName(`${qualifier}.${class_def.name}`); + if (class_def.bases.length === 0) { + return this.importClass(qualified_name, class_def, false); + } + const superclass_name = class_def.bases[0].value; + if (superclass_name === 'Module') { + return this.importClass(qualified_name, class_def, true); + } else if (superclass_name === 'NamedTuple') { + return this.importNamedTuple(qualified_name, class_def); + } else if (superclass_name === 'Interface') { + // cu_->define_interface(qualified_name, class_def, shared_from_this(), is_module=false); + return null; + } else if (superclass_name === 'ModuleInterface') { + // cu_->define_interface(qualified_name, class_def, shared_from_this(), is_module=true); + return null; + } else if (superclass_name === 'Enum') { + // importEnum(qualified_name, class_def); + return null; + } + throw new python.Error('TorchScript does not support class inheritance.'); + } + importClass(/* qualified_name, class_def, is_module */) { + return null; + } + importNamedTuple(qualified_name, named_tuple_def) { + const field_names = []; + const field_types = []; + const field_defaults = []; + for (const statement of named_tuple_def.body.statements) { + if (statement.type !== 'var') { + throw new python.Error('Unexpected statement in NamedTuple body.'); + } + field_names.push(statement.name); + field_types.push(this._cu.execution.type(statement.variableType)); + } + const tt = torch.TupleType.createNamed(qualified_name.qualifiedName(), field_names, field_types, field_defaults); + this._cu.register_type(tt); } - parseSourceIfNeeded(/* qualifier */) { + parseSourceIfNeeded(qualifier) { + if (!qualifier || this._loaded_sources.has(qualifier)) { + return; + } + this._loaded_sources.add(qualifier); + const src = this._source_loader(qualifier); + if (!src) { + return; + } + const program = this._cu.execution.parse(src.filename(), src.text_str(), null); + for (const statement of program.body) { + switch (statement.type) { + case 'def': { + break; + } + case 'class': { + const name = `${qualifier}.${statement.name}`; + this._to_be_defined.set(name, statement); + break; + } + default: { + break; + } + } + } } }); this.registerType('torch.jit.ScriptModuleDeserializer', class { @@ -7734,12 +7839,15 @@ python.Execution = class { this._code_prefix = !pickle_dir_prefix && !tensor_dir_prefix ? 'code/' : '.data/ts_code/code/'; this._pickle_dir_prefix = pickle_dir_prefix || ''; this._tensor_dir_prefix = tensor_dir_prefix || ''; + const SourceLoader = (qualifier) => { + return this.findSourceInArchiveFromQualifier(this._reader, this._code_prefix, qualifier); + }; this._source_importer = new torch.jit.SourceImporter( - this._compilation_unit, this._constants_table, - new torch.jit.SourceLoader(this._reader, this._code_prefix), reader.version()); + this._compilation_unit, this._constants_table, SourceLoader, reader.version()); } deserialize() { const execution = this._compilation_unit.execution; + execution._resolver = this._source_importer; const code_prefix = this._code_prefix; for (const name of this._reader.get_all_records()) { if (name.startsWith(code_prefix) && name.endsWith('.py')) { @@ -7914,6 +8022,17 @@ python.Execution = class { }; return unpickler.load(); } + qualifierToArchivePath(qualifier, export_prefix) { + return `${export_prefix}${qualifier.replace(/\./g, '/')}.py`; + } + findSourceInArchiveFromQualifier(reader, export_prefix, qualifier) { + const path = this.qualifierToArchivePath(qualifier, export_prefix); + if (!reader.has_record(path)) { + return null; + } + const data = reader.get_record(path); + return new torch.jit.Source(data.peek(), path); + } }); this.registerType('torch.package.PackageImporter', class { constructor(reader) { @@ -8215,6 +8334,9 @@ python.Execution = class { this._functions = new Map(); this._classes = new Map(); } + register_type(namedType) { + this._classes.set(namedType.annotation_str, namedType); + } register_function(fn) { this._functions.set(fn.name, fn); } @@ -8228,14 +8350,11 @@ python.Execution = class { } } get_type(name) { - return this._classes.get(name); + return this._classes.get(name.qualifiedName()); } get_class(name) { return this.get_type(name); } - register_type(name, cls) { - this._classes.set(name, cls); - } }); this.registerType('torch.jit._script.ScriptModule', class extends torch.nn.modules.module.Module {}); this.registerType('torch.jit._trace.TracedModule', class extends torch.jit._script.ScriptModule {}); @@ -8399,7 +8518,7 @@ python.Execution = class { if (!cls) { const name = obj_type.type_name; if (name.startsWith('__torch__') || name.startsWith('torch.jit')) { - cls = this._cu.get_class(name); + cls = this._cu.get_class(new torch.jit.QualifiedName(name)); if (!cls) { const torch = this._torch; cls = new torch.ClassType(name, this._cu, true); @@ -10247,13 +10366,6 @@ python.Execution = class { return this._builtins; } - source(file) { - return this._sources.has(file) ? this._sources.get(file) : null; - } - - debug(/* file */) { - } - exec(code , context) { const reader = new python.Parser(code, '', null); const program = reader.parse(); @@ -10263,21 +10375,35 @@ python.Execution = class { this.block(program.body, context); } - parse(file) { + debug(/* file */) { + } + + source(file) { + if (this._sources.has(file)) { + return this._sources.get(file); + } + return null; + } + + read(file) { const buffer = this.source(file); if (buffer) { const debug = this.debug(file); - const code = this._utf8Decoder.decode(buffer); - const parser = new python.Parser(code, file, debug); - const program = parser.parse(); - if (!program) { - throw new python.Error(`Module '${file}' parse error.`); - } - return program; + return this.parse(file, buffer, debug); } return null; } + parse(file, buffer, debug) { + const code = this._utf8Decoder.decode(buffer); + const parser = new python.Parser(code, file, debug); + const program = parser.parse(); + if (!program) { + throw new python.Error(`Module '${file}' parse error.`); + } + return program; + } + import(name, current, level) { if (level) { let bits = current.split('.'); @@ -10303,7 +10429,7 @@ python.Execution = class { const path = name.split('.').join('/'); module.__path__ = [path]; const file = `${path}.py`; - const program = this.parse(file); + const program = this.read(file); if (program) { module.__file__ = file; for (const [name, value] of Object.entries(this.builtins)) { diff --git a/source/pytorch.js b/source/pytorch.js index 1b8e72ed74..f7ca46ba90 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -2639,6 +2639,24 @@ pytorch.Execution = class extends python.Execution { } return node.addOutput(); } + const prefix = pytorch.Utility.target(target); + if (prefix && prefix !== 'self' && !prefix.startsWith('self.') && prefix.indexOf('.') !== -1) { + const identifier = `${prefix}.${name}`; + const type = this._resolver.resolveType(identifier); + if (type instanceof torch.TupleType) { + const node = this._graph.create('prim::TupleConstruct'); + node.setSourceRange(location); + this.graph.insertNode(node); + const evalArgs = args.map((expression) => this.expression(expression, context)); + for (const arg of evalArgs) { + const value = this.variable(arg); + node.addInput(value); + } + const output = node.addOutput(); + output.setType(type); + return output; + } + } return super.call(target, name, args, context); } const [schema, evalArgs] = overload; diff --git a/test/models.json b/test/models.json index bcd608aa91..8daca558f1 100644 --- a/test/models.json +++ b/test/models.json @@ -5866,7 +5866,7 @@ "type": "pytorch", "target": "pyg_model.pt", "source": "https://github.com/lutzroeder/netron/files/10369483/pyg_model.zip[pyg_model.pt]", - "error": "Unknown function 'aten::linear'.", + "error": "Cannot read properties of undefined (reading 'str')", "link": "https://github.com/lutzroeder/netron/issues/546" }, {