Skip to content

Commit

Permalink
Update python.js (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Dec 20, 2024
1 parent d774780 commit d66c5bf
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
24 changes: 19 additions & 5 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -8428,7 +8428,7 @@ python.Execution = class {
insertNode(node) {
return node.insertBefore(this._insert_before);
}
insertConstant(val) {
insertConstant(val, loc) {
const n = this.create('prim::Constant');
this.insertNode(n);
let type = null;
Expand All @@ -8442,22 +8442,29 @@ python.Execution = class {
n.ss_('value', val);
type = torch.ListType.create(torch.StringType.get());
} else if (typeof val === 'boolean') {
// return value;
n.i_('value', val === true ? 1 : 0);
type = torch.BoolType.get();
} else if (Number.isInteger(val)) {
n.i_('value', val);
type = torch.IntType.get();
} else if (typeof val === 'number') {
// return value;
n.f_('value', val);
type = torch.FloatType.get();
} else if (val instanceof torch.Tensor) {
n.t_('value', val);
type = torch.TensorType.get();
} else if (val instanceof torch.ScriptObject) {
n.ival_('value', val);
type = val.type();
} else {
throw new python.Error(`Unsupported value type '${typeof value}'.`);
}
if (type) {
n.output().setType(type);
}
if (loc) {
n.setSourceRange(loc);
}
return n.output();
}
insertMethodCall(method_name, matched) {
Expand Down Expand Up @@ -8768,6 +8775,12 @@ python.Execution = class {
f(name) {
return this._values.get(name)[0];
}
t_(name, value) {
this._values.set(name, [value, 't']);
}
t(name) {
return this._values.get(name)[0];
}
tys_(name, value) {
this._values.set(name, [value, 'tys']);
}
Expand Down Expand Up @@ -8860,9 +8873,10 @@ python.Execution = class {
this._type = type;
}
set value(value) { // remove
if (value instanceof torch.Value === false) {
this._value = value;
if (value instanceof torch.Value) {
throw new python.Error('Value cannot be a value.');
}
this._value = value;
}
get value() { // remove
return this._value;
Expand Down
10 changes: 7 additions & 3 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -2914,9 +2914,13 @@ pytorch.Execution = class extends python.Execution {
}
throw new pytorch.Error();
}
const value = this.variable(v);
value.value = v;
node.addInput(value);
if (v instanceof torch.Value) {
node.addInput(v);
} else {
const value = this.variable(v);
value.value = v;
node.addInput(value);
}
}
}
for (const arg of schema.returns) {
Expand Down

0 comments on commit d66c5bf

Please sign in to comment.