Skip to content

Commit

Permalink
Update python.js (#1211)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Sep 28, 2024
1 parent e7f61ca commit a916328
Show file tree
Hide file tree
Showing 4 changed files with 560 additions and 544 deletions.
33 changes: 18 additions & 15 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -4872,6 +4872,9 @@ python.Execution = class {
get args() {
return this._args;
}
get kwargs() {
return this._kwargs;
}
get next() {
return this._next;
}
Expand Down Expand Up @@ -4951,7 +4954,8 @@ python.Execution = class {
}
placeholder(name, type_expr /*, default_value */) {
const args = []; // () if default_value is inspect.Signature.empty else (default_value,)
return this.create_node('placeholder', name, args, type_expr);
const kwargs = new builtins.dict();
return this.create_node('placeholder', name, args, kwargs, type_expr);
}
create_node(op, target, args, kwargs, name, type_expr) {
args = args || new builtins.tuple();
Expand Down Expand Up @@ -6155,16 +6159,16 @@ python.Execution = class {
this.registerType('torch.TensorType', class extends torch.Type {});
this.registerType('torch.IntType', class extends torch.Type {});
this.registerType('torch.Argument', class {
constructor(name, type, real_type, N, default_value /*, alias_info, is_type_dispatched */) {
constructor(name, type, real_type, N, default_value, kwarg_only, alias_info) {
// torch/aten/src/ATen/core/function_schema.h
this.name = name;
this.type = type;
this.real_type = real_type;
this.N = N;
this.default_value = default_value;
// kwarg_only: bool
// is_out: bool
// alias_info: Optional[AliasInfo]
this.kwarg_only = kwarg_only;
const is_alias = alias_info && alias_info.isWrite();
this.is_out = this.kwarg_only && is_alias;
}
has_default_value() {
return this.default_value !== undefined;
Expand Down Expand Up @@ -6852,12 +6856,10 @@ python.Execution = class {
deserialize_node(serialized_node, target) {
let fx_node = null;
if (this._SYM_BOOL_OPS.has(target) || this._SYM_INT_OPS.has(target)) {
/*
const name = serialized_node.outputs[0].value.as_name;
const args = self.deserialize_sym_op_inputs(serialized_node.inputs);
fx_node = self.graph.create_node('call_function', target, args, {}, name);
self.deserialize_sym_op_outputs(serialized_node, fx_node);
*/
const args = this.deserialize_sym_op_inputs(serialized_node.inputs);
fx_node = this.graph.create_node('call_function', target, args, null, name);
this.deserialize_sym_op_outputs(serialized_node, fx_node);
} else if (builtins.isinstance(target, torch._ops.HigherOrderOperator)) {
// assert(len(serialized_node.outputs) === 1 && serialized_node.outputs[0].type in ('as_tensors', 'as_tensor')), 'Only single tensor output or list of tensor output is supported for higher order operators.')
const [output] = serialized_node.outputs;
Expand Down Expand Up @@ -7016,16 +7018,16 @@ python.Execution = class {
return inputs.map((input) => this.deserialize_input(input.arg));
}
deserialize_inputs(target, serialized_node) {
const schema_args = target._schema.arguments;
const schema_args = this._get_schema_from_target(target).arguments;
const actual_args = new Map(serialized_node.inputs.map((input) => [input.name, this.deserialize_input(input.arg)]));
const args = [];
const kwargs = {};
const args = new builtins.list();
const kwargs = new builtins.dict();
for (const schema_arg of schema_args) {
const is_positional = !schema_arg.has_default_value() && !schema_arg.kwarg_only;
if (is_positional) {
args.push(actual_args.get(schema_arg.name));
} else if (actual_args.has(schema_arg.name)) {
kwargs[schema_arg.name] = actual_args.get(schema_arg.name);
kwargs.set(schema_arg.name, actual_args.get(schema_arg.name));
}
}
return [args, kwargs];
Expand Down Expand Up @@ -7154,6 +7156,7 @@ python.Execution = class {
'call_function',
operator.getitem,
new builtins.tuple([fx_node, idx]),
null,
name,
);
this.sync_fx_node(name, individual_output);
Expand Down Expand Up @@ -7302,7 +7305,7 @@ python.Execution = class {
if (target instanceof torch._ops.OpOverload) {
return target._schema;
}
throw new python.Error(`Cannot find schema for ${target.name}`);
throw new python.Error(`Unsupported schema '${target.name}'.`);
}
_is_single_tensor_return(target) {
const schema = this._get_schema_from_target(target);
Expand Down
Loading

0 comments on commit a916328

Please sign in to comment.