diff --git a/source/python.js b/source/python.js index fa8ada0216..f0167d7cc7 100644 --- a/source/python.js +++ b/source/python.js @@ -8428,7 +8428,7 @@ python.Execution = class { insertNode(node) { return node.insertBefore(this._insert_before); } - insertConstant(val) { + insertConstant(val, loc) { const n = this.create('prim::Constant'); this.insertNode(n); let type = null; @@ -8442,22 +8442,29 @@ python.Execution = class { n.ss_('value', val); type = torch.ListType.create(torch.StringType.get()); } else if (typeof val === 'boolean') { - // return value; n.i_('value', val === true ? 1 : 0); type = torch.BoolType.get(); } else if (Number.isInteger(val)) { n.i_('value', val); type = torch.IntType.get(); } else if (typeof val === 'number') { - // return value; n.f_('value', val); type = torch.FloatType.get(); + } else if (val instanceof torch.Tensor) { + n.t_('value', val); + type = torch.TensorType.get(); + } else if (val instanceof torch.ScriptObject) { + n.ival_('value', val); + type = val.type(); } else { throw new python.Error(`Unsupported value type '${typeof value}'.`); } if (type) { n.output().setType(type); } + if (loc) { + n.setSourceRange(loc); + } return n.output(); } insertMethodCall(method_name, matched) { @@ -8768,6 +8775,12 @@ python.Execution = class { f(name) { return this._values.get(name)[0]; } + t_(name, value) { + this._values.set(name, [value, 't']); + } + t(name) { + return this._values.get(name)[0]; + } tys_(name, value) { this._values.set(name, [value, 'tys']); } @@ -8860,9 +8873,10 @@ python.Execution = class { this._type = type; } set value(value) { // remove - if (value instanceof torch.Value === false) { - this._value = value; + if (value instanceof torch.Value) { + throw new python.Error('Value cannot be a value.'); } + this._value = value; } get value() { // remove return this._value; diff --git a/source/pytorch.js b/source/pytorch.js index b15764a337..fe67c477e3 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -2914,9 +2914,13 @@ pytorch.Execution = class extends python.Execution { } throw new pytorch.Error(); } - const value = this.variable(v); - value.value = v; - node.addInput(value); + if (v instanceof torch.Value) { + node.addInput(v); + } else { + const value = this.variable(v); + value.value = v; + node.addInput(value); + } } } for (const arg of schema.returns) {