From 9463df3fa3c62ab448d68327faa6af7d3c503490 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Thu, 3 Oct 2024 18:22:44 -0700 Subject: [PATCH] Update python.js (#1211) --- source/python.js | 46 +++++++++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/source/python.js b/source/python.js index 6fd9618eb9..8ebced31f6 100644 --- a/source/python.js +++ b/source/python.js @@ -6080,6 +6080,13 @@ python.Execution = class { } throw new python.Error("Unsupported 'torch.sub' expression type."); }); + this.registerFunction('torch.sym_int'); + this.registerFunction('torch.sym_ite'); + this.registerFunction('torch.sym_max'); + this.registerFunction('torch.sym_min'); + this.registerFunction('torch.sym_not'); + this.registerFunction('torch.sym_sqrt'); + this.registerFunction('torch.sym_sqrt'); this.registerFunction('torch.functional.einsum'); this.registerFunction('torch.functional.norm'); this.registerFunction('torch.functional.split'); @@ -6433,6 +6440,14 @@ python.Execution = class { this.example_inputs = example_inputs; } }); + torch._export.serde.serialize._SYM_INT_OPS = new Set([ + operator.mul, operator.add, operator.sub, operator.floordiv, operator.mod, + torch.sym_sqrt, torch.sym_int, torch.sym_ite, torch.sym_max, torch.sym_min, torch.sym_sqrt + ]); + torch._export.serde.serialize._SYM_BOOL_OPS = new Set([ + operator.eq, operator.ne, operator.le, operator.ge, operator.lt, operator.gt, + torch.sym_not + ]); this.registerType('torch._export.serde.union._Union', class { constructor(obj) { if (obj.$type) { @@ -6818,14 +6833,6 @@ python.Execution = class { this.serialized_name_to_meta = new Map(); this.graph = new torch.fx.Graph(); this.module = new torch.nn.Module(); - this._SYM_INT_OPS = new Set([ - operator.mul, operator.add, operator.sub, operator.floordiv, operator.mod, - torch.sym_sqrt, torch.sym_int, torch.sym_ite, torch.sym_max, torch.sym_min, torch.sym_sqrt - ]); - this._SYM_BOOL_OPS = new Set([ - operator.eq, operator.ne, operator.le, operator.ge, operator.lt, operator.gt, - torch.sym_not - ]); } deserialize_graph_output(output) { if (output.type === 'as_tensor') { @@ -6919,7 +6926,7 @@ python.Execution = class { } deserialize_node(serialized_node, target) { let fx_node = null; - if (this._SYM_BOOL_OPS.has(target) || this._SYM_INT_OPS.has(target)) { + if (torch._export.serde.serialize._SYM_BOOL_OPS.has(target) || torch._export.serde.serialize._SYM_INT_OPS.has(target)) { const name = serialized_node.outputs[0].value.as_name; const args = this.deserialize_sym_op_inputs(serialized_node.inputs); fx_node = this.graph.create_node('call_function', target, args, null, name); @@ -6942,10 +6949,11 @@ python.Execution = class { fx_node = this.graph.create_node('call_function', target, args, kwargs, name); this.deserialize_outputs(serialized_node, fx_node); } else { - // throw new python.Error(`Unsupported target type '${target}'.`); + throw new python.Error(`Unsupported node target type '${target}'.`); } - if (fx_node) { - Object.assign(fx_node.meta, this.deserialize_metadata(serialized_node.metadata)); + fx_node.meta.update(this.deserialize_metadata(serialized_node.metadata)); + if (fx_node.op !== 'placeholder' && fx_node.op !== 'output' && !fx_node.meta.has('nn_module_stack')) { + fx_node.meta.set('nn_module_stack', new builtins.dict()); } } deserialize_input_spec(i) { @@ -7263,7 +7271,7 @@ python.Execution = class { } deserialize_metadata(metadata) { const ret = new builtins.dict(); - const stack_trace = metadata.stack_trace; + const stack_trace = metadata.get('stack_trace'); if (stack_trace) { ret.set('stack_trace', stack_trace); } @@ -7288,7 +7296,7 @@ python.Execution = class { } return target; }; - const nn_module_stack_str = metadata.nn_module_stack; + const nn_module_stack_str = metadata.get('nn_module_stack'); if (nn_module_stack_str) { const import_nn_module_stack = (key, path, ty) => { return [key, [path, ty]]; @@ -7296,7 +7304,7 @@ python.Execution = class { const nn_module_stack = new Map(nn_module_stack_str.split(';').map((item) => import_nn_module_stack(...item.split(',')))); ret.set('nn_module_stack', nn_module_stack); } - const source_fn_st_str = metadata.source_fn_stack; + const source_fn_st_str = metadata.get('source_fn_stack'); if (source_fn_st_str) { const source_fn_st = []; for (const source_fn_str of source_fn_st_str.split(';')) { @@ -7305,6 +7313,14 @@ python.Execution = class { } ret.set('source_fn_stack', source_fn_st); } + const torch_fn = metadata.get('torch_fn'); + if (torch_fn) { + ret.set('torch_fn', new builtins.tuple(torch_fn.split(';'))); + } + const custom_str = metadata.get('custom'); + if (custom_str) { + ret.set('custom', JSON.parse(custom_str)); + } return ret; } deserialize_argument_spec(x) {