Skip to content

Commit

Permalink
Update pytorch.js (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 11, 2024
1 parent 1e1018d commit 6217307
Showing 1 changed file with 136 additions and 66 deletions.
202 changes: 136 additions & 66 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,24 @@ python.Execution = class {
if (decorator_list) {
decorator_list = Array.from(decorator_list);
}
const bases = this._tokenizer.peek().type === '(' ? this._arguments() : [];
const bases = [];
if (this._tokenizer.eat('(')) {
while (!this._tokenizer.eat(')')) {
if (this._tokenizer.eat('\n')) {
continue;
}
const expression = this._expression(-1, [], false);
if (expression === null) {
throw new python.Error(`Expected expression ${this._tokenizer.location()}`);
}
bases.push(expression);
if (!this._tokenizer.eat(',')) {
this._tokenizer.eat('\n');
this._tokenizer.expect(')');
break;
}
}
}
this._tokenizer.expect(':');
const body = this._suite();
const node = new ast.ClassDef(name.id, bases, null, body, decorator_list, null);
Expand All @@ -776,7 +793,7 @@ python.Execution = class {
decorator_list = Array.from(decorator_list);
}
this._tokenizer.expect('(');
const args = this._parameters(')');
const args = this._arguments(')');
let returns = null;
if (this._tokenizer.eat('->')) {
returns = this._type();
Expand Down Expand Up @@ -1153,7 +1170,7 @@ python.Execution = class {
}
position = this._eat('id', 'lambda');
if (position) {
const args = this._parameters(':');
const args = this._arguments(':');
const body = this._expression(-1, terminal, false);
const node = new ast.Lambda(args, body);
this._mark(node, position);
Expand Down Expand Up @@ -1195,24 +1212,37 @@ python.Execution = class {
continue;
}
if (this._tokenizer.peek().type === '(') {
const position = this._position();
const args = [];
this._tokenizer.expect('(');
while (!this._tokenizer.eat(')')) {
if (this._tokenizer.eat('\n')) {
continue;
}
const expression = this._expression(-1, [], false);
if (expression === null) {
throw new python.Error(`Expected expression ${this._tokenizer.location()}`);
}
args.push(expression);
if (!this._tokenizer.eat(',')) {
this._tokenizer.eat('\n');
this._tokenizer.expect(')');
break;
}
}
if (stack.length === 0) {
const position = this._position();
const args = this._arguments();
if (args.length === 1) {
stack.push(args[0]);
[node] = args;
} else {
node = new ast.Tuple(args);
node = this._mark(node, position);
stack.push(node);
this._mark(node, position);
}
} else {
const location = this._tokenizer.location();
const func = stack.pop();
const args = this._arguments();
node = new ast.Call(func, args);
node.location = location;
stack.push(node);
this._mark(node, position);
}
stack.push(node);
continue;
}
if (this._tokenizer.peek().type === '[') {
Expand Down Expand Up @@ -1483,64 +1513,97 @@ python.Execution = class {
}
return null;
}
_parameters(terminal) {
const list = [];
_arguments(terminal) {
let posonlyargs = [];
let args = [];
let vararg = null;
const kwonlyargs = [];
const kw_defaults = [];
let kwarg = null;
const defaults = [];
let is_slash = false;
let is_vararg = false; // '*'
let is_kwarg = false; // '**'
while (!this._tokenizer.eat(terminal)) {
this._tokenizer.eat('\n');
if (this._tokenizer.eat('(')) {
list.push(this._parameters(')'));
} else {
const node = this._node('parameter');
if (this._tokenizer.eat('/')) {
node.name = '/';
} else {
if (this._tokenizer.eat('**')) {
node.annotation = '**';
}
if (this._tokenizer.eat('*')) {
node.annotation = '*';
}
const name = this._name();
if (name === null) {
throw new python.Error(`Expected parameter ${this._tokenizer.location()}`);
}
node.name = name.id;
if (terminal !== ':' && this._tokenizer.eat(':')) {
node.annotation = this._type();
if (this._tokenizer.eat('/')) {
if (is_slash || is_vararg || is_kwarg) {
throw new python.Error(`Invalid '/' in arguments ${this._tokenizer.location()}`);
}
is_slash = true;
if (!this._tokenizer.eat(',')) {
this._tokenizer.expect(terminal);
break;
}
continue;
}
if (this._tokenizer.eat('**')) {
if (is_kwarg) {
throw new python.Error(`Multiple '**' arguments ${this._tokenizer.location()}`);
}
is_kwarg = true;
const name = this._name(true);
const annotation = terminal !== ':' && this._tokenizer.eat(':') ? this._type() : null;
kwarg = new ast.arg(name.id, annotation, null);
if (!this._tokenizer.eat(',')) {
this._tokenizer.expect(terminal);
break;
}
continue;
}
if (this._tokenizer.eat('*')) {
if (is_vararg) {
throw new python.Error(`Multiple '*' arguments ${this._tokenizer.location()}`);
}
is_vararg = true;
const name = this._name(false);
if (name) {
const annotation = terminal !== ':' && this._tokenizer.eat(':') ? this._type() : null;
vararg = new ast.arg(name.id, annotation, null);
}
if (!this._tokenizer.eat(',')) {
this._tokenizer.expect(terminal);
break;
}
continue;
}
const name = this._name();
if (!name) {
this._tokenizer.expect(terminal);
break;
}
const annotation = terminal !== ':' && this._tokenizer.eat(':') ? this._type() : null;
const arg = new ast.arg(name.id, annotation, null);
const default_value = this._tokenizer.eat('=') ? this._expression() : null;
if (!is_vararg && !is_kwarg) {
if (is_slash) {
args.push(arg);
if (default_value !== null) {
defaults.push(default_value);
}
if (this._tokenizer.eat('=')) {
node.initializer = this._expression();
} else {
posonlyargs.push(arg);
if (default_value !== null) {
defaults.push(default_value);
}
}
list.push(node);
} else if (is_vararg && !is_kwarg) {
kwonlyargs.push(arg);
kw_defaults.push(default_value);
} else {
throw new python.Error(`Argument after '**' parameter ${this._tokenizer.location()}`);
}
this._tokenizer.eat('\n');
if (!this._tokenizer.eat(',')) {
this._tokenizer.expect(terminal);
break;
}
}
return list;
}
_arguments() {
const list = [];
this._tokenizer.expect('(');
while (!this._tokenizer.eat(')')) {
if (this._tokenizer.eat('\n')) {
continue;
}
const expression = this._expression(-1, [], false);
if (expression === null) {
throw new python.Error(`Expected expression ${this._tokenizer.location()}`);
}
list.push(expression);
if (!this._tokenizer.eat(',')) {
this._tokenizer.eat('\n');
this._tokenizer.expect(')');
break;
}
if (!is_slash) {
args = posonlyargs.concat(args);
posonlyargs = [];
}
return list;
return new ast.arguments(posonlyargs, args, vararg, kwonlyargs, kw_defaults, kwarg, defaults);
}
_node(type) {
const node = {};
Expand Down Expand Up @@ -8009,7 +8072,7 @@ python.Execution = class {
}
parseArgsFromDecl(decl, skip_self) {
const retval = [];
const params = skip_self ? decl.args.slice(1) : decl.args.slice();
const params = skip_self ? decl.args.args.slice(1) : decl.args.args.slice();
for (const decl_arg of params) {
const N = null;
const default_value = null;
Expand Down Expand Up @@ -9453,9 +9516,9 @@ python.Execution = class {
args.push(this.data); // self
}
if (this.data.forward.__code__ && this.data.forward.__code__.args) {
for (const arg of this.data.forward.__code__.args) {
if (execution.traceAttr || arg.name !== 'self') {
const value = execution.graph.addInput(arg.name);
for (const arg of this.data.forward.__code__.args.args) {
if (execution.traceAttr || arg.arg !== 'self') {
const value = execution.graph.addInput(arg.arg);
value.setType(execution.type(arg.annotation));
if (isTensor(value)) {
value.__variable__ = arg.name;
Expand Down Expand Up @@ -11932,12 +11995,19 @@ python.Execution = class {
apply(method, args, context) {
const locals = Array.prototype.slice.call(args);
context = new python.Execution.Context(context.globals, {});
for (const argument of method.args) {
let value = locals.shift();
if (value === undefined && argument.initializer) {
value = this.expression(argument.initializer, context);
args = method.args.posonlyargs.concat(method.args.args);
const default_pos = args.length - method.args.defaults.length;
for (let i = 0; i < method.args.args.length; i++) {
const arg = method.args.args[i];
let value = null;
if (locals.length > 0) {
value = locals.shift();
} else if (i >= default_pos) {
value = this.expression(method.args.defaults[i - default_pos], context);
} else {
throw new python.Error('Missing required positional argument.');
}
context.set(argument.name, value);
context.set(arg.arg, value);
}
return this.block(method.body, context);
}
Expand Down

0 comments on commit 6217307

Please sign in to comment.