Skip to content

Commit

Permalink
Update python.js (#1211)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 4, 2024
1 parent 0a90132 commit 9463df3
Showing 1 changed file with 31 additions and 15 deletions.
46 changes: 31 additions & 15 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -6080,6 +6080,13 @@ python.Execution = class {
}
throw new python.Error("Unsupported 'torch.sub' expression type.");
});
this.registerFunction('torch.sym_int');
this.registerFunction('torch.sym_ite');
this.registerFunction('torch.sym_max');
this.registerFunction('torch.sym_min');
this.registerFunction('torch.sym_not');
this.registerFunction('torch.sym_sqrt');
this.registerFunction('torch.sym_sqrt');
this.registerFunction('torch.functional.einsum');
this.registerFunction('torch.functional.norm');
this.registerFunction('torch.functional.split');
Expand Down Expand Up @@ -6433,6 +6440,14 @@ python.Execution = class {
this.example_inputs = example_inputs;
}
});
torch._export.serde.serialize._SYM_INT_OPS = new Set([
operator.mul, operator.add, operator.sub, operator.floordiv, operator.mod,
torch.sym_sqrt, torch.sym_int, torch.sym_ite, torch.sym_max, torch.sym_min, torch.sym_sqrt
]);
torch._export.serde.serialize._SYM_BOOL_OPS = new Set([
operator.eq, operator.ne, operator.le, operator.ge, operator.lt, operator.gt,
torch.sym_not
]);
this.registerType('torch._export.serde.union._Union', class {
constructor(obj) {
if (obj.$type) {
Expand Down Expand Up @@ -6818,14 +6833,6 @@ python.Execution = class {
this.serialized_name_to_meta = new Map();
this.graph = new torch.fx.Graph();
this.module = new torch.nn.Module();
this._SYM_INT_OPS = new Set([
operator.mul, operator.add, operator.sub, operator.floordiv, operator.mod,
torch.sym_sqrt, torch.sym_int, torch.sym_ite, torch.sym_max, torch.sym_min, torch.sym_sqrt
]);
this._SYM_BOOL_OPS = new Set([
operator.eq, operator.ne, operator.le, operator.ge, operator.lt, operator.gt,
torch.sym_not
]);
}
deserialize_graph_output(output) {
if (output.type === 'as_tensor') {
Expand Down Expand Up @@ -6919,7 +6926,7 @@ python.Execution = class {
}
deserialize_node(serialized_node, target) {
let fx_node = null;
if (this._SYM_BOOL_OPS.has(target) || this._SYM_INT_OPS.has(target)) {
if (torch._export.serde.serialize._SYM_BOOL_OPS.has(target) || torch._export.serde.serialize._SYM_INT_OPS.has(target)) {
const name = serialized_node.outputs[0].value.as_name;
const args = this.deserialize_sym_op_inputs(serialized_node.inputs);
fx_node = this.graph.create_node('call_function', target, args, null, name);
Expand All @@ -6942,10 +6949,11 @@ python.Execution = class {
fx_node = this.graph.create_node('call_function', target, args, kwargs, name);
this.deserialize_outputs(serialized_node, fx_node);
} else {
// throw new python.Error(`Unsupported target type '${target}'.`);
throw new python.Error(`Unsupported node target type '${target}'.`);
}
if (fx_node) {
Object.assign(fx_node.meta, this.deserialize_metadata(serialized_node.metadata));
fx_node.meta.update(this.deserialize_metadata(serialized_node.metadata));
if (fx_node.op !== 'placeholder' && fx_node.op !== 'output' && !fx_node.meta.has('nn_module_stack')) {
fx_node.meta.set('nn_module_stack', new builtins.dict());
}
}
deserialize_input_spec(i) {
Expand Down Expand Up @@ -7263,7 +7271,7 @@ python.Execution = class {
}
deserialize_metadata(metadata) {
const ret = new builtins.dict();
const stack_trace = metadata.stack_trace;
const stack_trace = metadata.get('stack_trace');
if (stack_trace) {
ret.set('stack_trace', stack_trace);
}
Expand All @@ -7288,15 +7296,15 @@ python.Execution = class {
}
return target;
};
const nn_module_stack_str = metadata.nn_module_stack;
const nn_module_stack_str = metadata.get('nn_module_stack');
if (nn_module_stack_str) {
const import_nn_module_stack = (key, path, ty) => {
return [key, [path, ty]];
};
const nn_module_stack = new Map(nn_module_stack_str.split(';').map((item) => import_nn_module_stack(...item.split(','))));
ret.set('nn_module_stack', nn_module_stack);
}
const source_fn_st_str = metadata.source_fn_stack;
const source_fn_st_str = metadata.get('source_fn_stack');
if (source_fn_st_str) {
const source_fn_st = [];
for (const source_fn_str of source_fn_st_str.split(';')) {
Expand All @@ -7305,6 +7313,14 @@ python.Execution = class {
}
ret.set('source_fn_stack', source_fn_st);
}
const torch_fn = metadata.get('torch_fn');
if (torch_fn) {
ret.set('torch_fn', new builtins.tuple(torch_fn.split(';')));
}
const custom_str = metadata.get('custom');
if (custom_str) {
ret.set('custom', JSON.parse(custom_str));
}
return ret;
}
deserialize_argument_spec(x) {
Expand Down

0 comments on commit 9463df3

Please sign in to comment.