From 982ff6079c98a50366cad1421ec4c569b65dd99e Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sat, 14 Dec 2024 17:19:47 -0500 Subject: [PATCH] Update python.js (#1061) --- source/python.js | 157 +++++++++++++++++++++++++++++++++++++++++----- source/pytorch.js | 29 +++++---- 2 files changed, 157 insertions(+), 29 deletions(-) diff --git a/source/python.js b/source/python.js index 19c33d8710..d8f00cb305 100644 --- a/source/python.js +++ b/source/python.js @@ -125,6 +125,7 @@ python.Execution = class { this.register('argparse'); this._enum = this.register('enum'); this.register('collections'); + const copy = this.register('copy'); this.register('copy_reg'); const ast = this.register('ast'); this.ast = ast; @@ -4330,6 +4331,7 @@ python.Execution = class { this.registerFunction('collections.defaultdict', (/* default_factory */) => { return {}; }); + this.registerFunction('copy.deepcopy'); this.registerFunction('copy_reg._reconstructor', (cls, base, state) => { // copyreg._reconstructor in Python 3 if (base === '__builtin__.object' || base === builtins.object) { @@ -6980,13 +6982,16 @@ python.Execution = class { }); this.registerType('torch.ClassType', class extends torch.Type { constructor(qualified_name, cu, is_module) { - super('ClassType', qualified_name); + super('ClassType', typeof qualified_name === 'string' ? qualified_name : qualified_name.qualifiedName()); this._is_module = is_module; this._attributes = new Map(); this._methods = new Map(); this._staticmethods = new Map(); this._constants = new Map(); } + static create(qualifiedName, cu, is_module /*, doc_string, unresolved_class_attributes */) { + return new torch.ClassType(qualifiedName, cu, is_module); + } qualified_name() { return this.annotation_str; } @@ -7655,7 +7660,7 @@ python.Execution = class { while (L.eat('.')) { name = `${name}.${L.expect('id')}`; } - real_value = new torch.ClassType(name); // getCustomClass + real_value = torch.ClassType.create(name); // getCustomClass fake_value = real_value; } else { real_value = this.parseBaseType(); @@ -8805,6 +8810,8 @@ python.Execution = class { let name = null; if (args.length === 1 && typeof args[0] === 'string') { [name] = args; + } else if (args.length === 1 && Array.isArray(args[0]) && args[0].every((arg) => typeof arg === 'string')) { + name = args[0].join('.'); } else { name = `${args[0].qualifiedName()}.${args[1]}`; } @@ -8822,6 +8829,9 @@ python.Execution = class { name() { return this._name; // "baz" } + atoms() { + return this._qualifiedName.split('.'); + } }); this.registerType('torch.jit.SourceImporter', class { constructor(cu, constant_table, source_loader, version) { @@ -8887,7 +8897,7 @@ python.Execution = class { const pre_hook_def_map = new Map(); const hook_names = new Set(); const hook_def_map = new Map(); - const class_type = new torch.ClassType(qualified_classname.qualifiedName(), this._cu, is_module); + const class_type = torch.ClassType.create(qualified_classname.qualifiedName(), this._cu, is_module); for (const stmt of class_def.body) { if (stmt instanceof ast.Assign || stmt instanceof ast.AnnAssign) { let target = null; @@ -8940,8 +8950,9 @@ python.Execution = class { break; } } - } else if (target instanceof ast.Subscript) { - // not implemented + } else if (target instanceof ast.Subscript && target.value instanceof ast.Name && target.value.id === '__annotations__') { + const name = target.slice.elts[0].value; + attributes.push({ name, value, annotation: stmt.value }); continue; } else { throw new python.Error('Unexpected statement kind in module metadata.'); @@ -9020,11 +9031,11 @@ python.Execution = class { this._code_prefix = !pickle_dir_prefix && !tensor_dir_prefix ? 'code/' : '.data/ts_code/code/'; this._pickle_dir_prefix = pickle_dir_prefix || ''; this._tensor_dir_prefix = tensor_dir_prefix || ''; + this._constant_table = []; const SourceLoader = (qualifier) => { return this.findSourceInArchiveFromQualifier(this._reader, this._code_prefix, qualifier); }; - this._source_importer = new torch.jit.SourceImporter( - this._compilation_unit, this._constants_table, SourceLoader, reader.version()); + this._source_importer = new torch.jit.SourceImporter(this._compilation_unit, this._constant_table, SourceLoader, reader.version()); } deserialize() { const execution = this._compilation_unit.execution; @@ -9061,7 +9072,7 @@ python.Execution = class { ]; for (const known_type of known_types) { const prefix = new torch.jit.QualifiedName(known_type.name); - const type = new torch.ClassType(known_type.name, this._compilation_unit, false); + const type = torch.ClassType.create(known_type.name, this._compilation_unit, false); for (const known_method of known_type.methods || []) { const schema = new torch.FunctionSchema(known_method); const name = new torch.jit.QualifiedName(prefix, schema.name); @@ -9078,7 +9089,8 @@ python.Execution = class { execution.builtins.CONSTANTS[`c${i}`] = constants[i]; } const module = this.readArchive('data'); - const type = new torch.ClassType(`${module.__class__.__module__}.${module.__class__.__name__}`, null, true); + const name = `${module.__class__.__module__}.${module.__class__.__name__}`; + const type = torch.ClassType.create(name, null, true); const result = new torch.ScriptModule(type); result.data = module; return result; @@ -9101,7 +9113,7 @@ python.Execution = class { ['INT32', 'Int'], ['INT64', 'Long'] ]); - const constants = (model.tensors || []).map((constant) => { + const tensor_table = (model.tensors || []).map((constant) => { const key = constant.data.key; if (!tensorTypeMap.has(constant.dataType)) { throw new python.Error(`Unsupported tensor data type '${constant.dataType}'.`); @@ -9126,8 +9138,8 @@ python.Execution = class { return tensor; }); execution.builtins.CONSTANTS = {}; - for (let i = 0; i < constants.length; i++) { - execution.builtins.CONSTANTS[`c${i}`] = constants[i]; + for (let i = 0; i < tensor_table.length; i++) { + execution.builtins.CONSTANTS[`c${i}`] = tensor_table[i]; } const attributes = []; if (this._reader.has_record('attributes.pkl')) { @@ -9137,6 +9149,14 @@ python.Execution = class { const obj = unpickler.load(); attributes.push(...obj); } + + this._LEGACY_moduleStack = ['__torch__']; + // const module_def = model.mainModule; + for (const tensor of tensor_table) { + this._constant_table.push(tensor); + } + // this.LEGACY_convertModule(module_def); + while (queue.length > 0) { const module = queue.shift(); module.__class__ = module.__class__ || { __module__: 'torch.nn.modules.module', __name__: 'Module' }; @@ -9161,7 +9181,7 @@ python.Execution = class { delete module.arguments; } for (const parameter of parameters) { - const tensor = constants[parameter.tensorId]; + const tensor = tensor_table[parameter.tensorId]; module[parameter.name] = tensor; parameter.__class__ = parameter.__class__ || { __module__: 'torch', __name__: 'Tensor' }; } @@ -9182,11 +9202,71 @@ python.Execution = class { data.forward = module.forward; } } - const class_type = new torch.ClassType(data.name); + const class_type = torch.ClassType.create(data.name); const result = new torch.ScriptModule(class_type); result.data = data; return result; } + LEGACY_convertModule(module_def) { + const atoms = new torch.jit.QualifiedName(module_def.name).atoms(); + const numPushed = atoms.length; + for (const atom of atoms) { + const sanitized = /^\d+$/.test(atom) ? `_${atom}` : atom; + this._LEGACY_moduleStack.push(sanitized); + } + const module = new torch.ScriptModule(new torch.jit.QualifiedName(this._LEGACY_moduleStack), this._compilation_unit); + for (const sub_def of module_def.submodules || []) { + const submodule = this.LEGACY_convertModule(sub_def); + module.register_module(sub_def.name, submodule); + } + for (const param_def of module_def.parameters || []) { + const tensor = this._constant_table[Number(param_def.tensorId)]; + if (param_def.isBuffer) { + module.register_buffer(param_def.name, tensor); + } else { + module.register_parameter(param_def.name, tensor, false); + } + } + // const typeParser = new torch.jit.ScriptTypeParser(this._source_importer); + for (const attr_def of module_def.attributes || []) { + if (module.hasattr(attr_def.name)) { + continue; + } + // IValue ivalue; + // if (attr_def.id() >= 0) { + // ivalue = LEGACY_pickled_ivalues_.at(attr_def.id()); + // } + // module.register_attribute(attr_def.name(), typeParser.parseType(attr_def.type()), ivalue); + } + /* + std::shared_ptr gen_ranges = nullptr; + if (module_def.has_torchscript_debug_arena()) { + auto [data, size] = reader_->getRecord(module_def.torchscript_debug_arena().key()); + gen_ranges = std::make_shared(std::move(data), size); + } + if (module_def.has_torchscript_arena()) { + auto [data, size] = + reader_->getRecord(module_def.torchscript_arena().key()); + std::string data_str(static_cast(data.get()), size); + auto src = std::make_shared(std::string(static_cast(data.get()), size), module_def.torchscript_arena().key(), 1, std::move(gen_ranges)); + source_importer_.LEGACY_import_methods(module, src); + } + if (module_def.has_get_state_attribute_id()) { + LEGACY_moduleSetState(module, LEGACY_pickled_ivalues_.at(module_def.get_state_attribute_id())); + } + const ClassTypePtr& module_type = module._ivalue()->type(); + for (size_t i = 0, N = module_type->numAttributes(); i < N; ++i) { + const IValue& v = module._ivalue()->getSlot(i); + if (module_type->getAttribute(i)->kind() != TypeKind::OptionalType) { + TORCH_CHECK(!v.isNone(), "The field '", module_type->getAttributeName(i), "' was left unitialized after __setstate__, but expected a ", "value of type '", v.type()->repr_str(), "'"); + } + } + */ + for (let i = 0; i < numPushed; i++) { + this._LEGACY_moduleStack.pop(); + } + return module; + } readArchive(archive_name) { const type_resolver = null; const obj_loader = null; @@ -9478,6 +9558,14 @@ python.Execution = class { } }); this.registerType('torch.ScriptModule', class extends torch.ScriptObject { + constructor(...args) { + if (args[0] instanceof torch.jit.QualifiedName && args[1] instanceof torch.jit.CompilationUnit) { + const [class_name, cu, shouldMangle] = args; + super(...torch.ScriptModule.create_module_object(class_name, cu, shouldMangle)); + } else { + super(...args); + } + } get qualified_name() { return this._type.qualified_name(); } @@ -9579,6 +9667,24 @@ python.Execution = class { } return this._graph; } + static create_module_object(class_name, cu, shouldMangle) { + shouldMangle = shouldMangle || false; + if (!class_name.prefix()) { + class_name = new torch.jit.QualifiedName('__torch__', class_name.name()); + } + if (shouldMangle && cu.get_class(class_name)) { + class_name = cu.mangle(class_name); + } + const cls = torch.ClassType.create(class_name, cu, true); + cu.register_type(cls); + return [cls, cu]; + } + register_module(/* name, module */) { + } + register_buffer(/* name, buffer */) { + } + register_parameter(/* name, parameter, is_buffer */) { + } }); this.registerType('torch.ModuleDict', class { constructor(module) { @@ -9659,6 +9765,7 @@ python.Execution = class { const graph = new torch.Graph(); graph.set_op_version(operator_set_version); const fn = new torch.jit.GraphFunction(name, graph, creator); + fn.__ast__ = def; if (shouldMangle && this.find_function(name)) { // name = mangle(name); } @@ -9845,7 +9952,7 @@ python.Execution = class { cls = this._cu.get_class(new torch.jit.QualifiedName(name)); if (!cls) { const torch = this._torch; - cls = new torch.ClassType(name, this._cu, true); + cls = torch.ClassType.create(name, this._cu, true); this._cu.register_type(cls); } } else { @@ -9862,7 +9969,20 @@ python.Execution = class { return cls; } }); - this.registerType('torch.export.unflatten.UnflattenedModule', class extends torch.nn.modules.module.Module {}); + this.registerType('torch.export.UnflattenedModule', class extends torch.nn.modules.module.Module { + constructor(export_module, flat_args_adapter) { + super(); + const export_graph = copy.deepcopy(export_module.graph); + self.graph_signature = copy.deepcopy(export_module.graph_signature); + this.graph = torch.fx.Graph(); + this.graph.owning_module = this; + this.module_call_graph = copy.deepcopy(export_module.module_call_graph); + this.flat_args_adapter = flat_args_adapter; + this.adapted = false; + // this._run_with_interpreter = RUN_WITH_INTERPRETER + this._inplace_buffer_mutations(export_graph, this.graph_signature); + } + }); this.registerType('torch.export.graph_signature.ExportGraphSignature', class { constructor(input_specs, output_specs) { this.input_specs = input_specs; @@ -10017,7 +10137,10 @@ python.Execution = class { }); this.registerType('torch.export.exported_program.ModuleCallEntry', class {}); this.registerType('torch.export.exported_program.ModuleCallSignature', class {}); - this.registerFunction('torch.export.unflatten'); + this.registerFunction('torch.export.unflatten', (module, flat_args_adapter) => { + module = torch.export._remove_effect_tokens(module); + return new torch.export.UnflattenedModule(module, flat_args_adapter); + }); this.registerFunction('torch._export.exported_program._create_graph_module_for_export', (root, graph) => { return new torch.fx.graph_module.GraphModule(root, graph); }); diff --git a/source/pytorch.js b/source/pytorch.js index d828eb9228..6d2848f581 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -1792,6 +1792,12 @@ pytorch.Execution = class extends python.Execution { const ast = this.ast; const torch = this.torch; switch (expr.__class__.__name__) { + case 'Name': { + if (this.traceAttr && expr.id === 'self') { + return context.get('self'); + } + break; + } case 'Constant': { if (expr.value === true || expr.value === false) { return this._graph.insertConstant(expr.value); @@ -2530,7 +2536,7 @@ pytorch.Execution = class extends python.Execution { super.statement(stmt, context); /* const value = context.get(stmt.name); - const type = new torch.ClassType(`${value.__module__}.${value.__name__}`); + const type = torch.ClassType.create(`${value.__module__}.${value.__name__}`); for (const entry of stmt.body) { if (entry instanceof ast.AnnAssign) { const target = this.identifier(entry.target); @@ -2660,23 +2666,22 @@ pytorch.Execution = class extends python.Execution { const moduleTarget = this.target(target, context); if (moduleTarget instanceof torch.Value && moduleTarget.type() instanceof torch.ClassType) { const class_type = moduleTarget.type().expect(torch.ClassType); - const method_name = name; - const method = class_type.getMethod(method_name); - const return_type = method.getSchema().returns[0].real_type; - const node = this._graph.create('prim::CallMethod'); - this._graph.insertNode(node); - node.s_('name', name); - const inputs = []; + const method = class_type.getMethod(name); const evalArgs = args.map((expression) => this.expression(expression, context)); + if (this.traceAttr && method.__ast__) { + return this.apply(method.__ast__, [moduleTarget].concat(evalArgs), context); + } + const schema = method.getSchema(); + const return_field_names = [schema.returns[0].name]; + const return_types = [schema.returns[0].real_type]; + const inputs = [moduleTarget]; for (const arg of evalArgs) { const value = this.variable(arg); inputs.push(value); - node.addInput(value); } - node.output().setType(return_type); + const matchedSchema = new torch.jit.MatchedSchema(inputs, return_types, return_field_names, name); + const node = this._graph.insertMethodCall(name, matchedSchema); return node.output(); - // const matchedSchema = new torch.jit.MatchedSchema(inputs, return_types, return_field_names, schema_name) - // return this._graph.insertMethodCall(name, matchedSchema); } const prefix = this.identifier(target); if (prefix && prefix !== 'self' && !prefix.startsWith('self.') && prefix.indexOf('.') !== -1) {