From 37cd97efdba700f3c5ea0be7f3f4d120675364f0 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sat, 26 Oct 2024 17:11:40 -0700 Subject: [PATCH] Update pytorch.js (#1061) --- source/pytorch.js | 643 ++++------------------------------------------ 1 file changed, 54 insertions(+), 589 deletions(-) diff --git a/source/pytorch.js b/source/pytorch.js index 37767f51ef..001a94f1c4 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -1926,49 +1926,53 @@ pytorch.Execution = class extends python.Execution { case '[]': { if (expression.arguments.type === 'list' && expression.arguments.value.length === 1) { const target = this.expression(expression.target, context); - if (target instanceof torch.Value && target.type() instanceof torch.ListType) { - let index = this.expression(expression.arguments.value[0], context); - const node = this._graph.create('aten::__getitem__.t'); - node.addInput(target); - if (Number.isInteger(index)) { - index = this.constant(index); + if (target instanceof torch.Value) { + let type = target.type(); + if (type instanceof torch.OptionalType) { + type = type.getElementType(); } - node.addInput(index); - const value = node.addOutput(); - value.setType(target.type().getElementType()); - return value; - } - if (target instanceof torch.Value && target.type() instanceof torch.DictType) { - let key = this.expression(expression.arguments.value[0], context); - const node = this._graph.create('aten::__getitem__.t'); - node.addInput(target); - if (target.type().getKeyType() instanceof torch.StringType && typeof key === 'string') { - const value = new torch.Value(node); - value.value = key; - key = value; - } else if (target.type().getKeyType() instanceof torch.StringType && key.type() instanceof torch.StringType) { - // continue - } else { - throw new pytorch.Error(`Unsupported dictionary key type.`); + if (type instanceof torch.ListType) { + let index = this.expression(expression.arguments.value[0], context); + const node = this._graph.create('aten::__getitem__.t'); + node.addInput(target); + if (Number.isInteger(index)) { + index = this.constant(index); + } + node.addInput(index); + const value = node.addOutput(); + value.setType(type.getElementType()); + return value; } - node.addInput(key); - const value = node.addOutput(); - value.setType(target.type().getValueType()); - return value; - } - if (target instanceof torch.Value && target.type() instanceof torch.TupleType) { - let index = this.expression(expression.arguments.value[0], context); - const node = this._graph.create('prim::TupleIndex'); - const value = node.addOutput(); - value.setType(target.type().elements()[index]); - node.addInput(target); - if (Number.isInteger(index)) { - const value = this.invoke('torch.Value', [node]); - value.value = index; - index = value; + if (type instanceof torch.DictType) { + let key = this.expression(expression.arguments.value[0], context); + const node = this._graph.create('aten::__getitem__.t'); + node.addInput(target); + if (type.getKeyType() instanceof torch.StringType && typeof key === 'string') { + const value = new torch.Value(node); + value.value = key; + key = value; + } else if (type.getKeyType() instanceof torch.StringType && key.type() instanceof torch.StringType) { + // continue + } else { + throw new pytorch.Error(`Unsupported dictionary key type.`); + } + node.addInput(key); + const value = node.addOutput(); + value.setType(type.getValueType()); + return value; + } + if (type instanceof torch.TupleType) { + let index = this.expression(expression.arguments.value[0], context); + const node = this._graph.create('prim::TupleIndex'); + const value = node.addOutput(); + value.setType(type.elements()[index]); + node.addInput(target); + if (Number.isInteger(index)) { + index = this.constant(index); + } + node.addInput(index); + return value; } - node.addInput(index); - return value; } } break; @@ -2118,11 +2122,7 @@ pytorch.Execution = class extends python.Execution { } case 'if': { const test = this.expression(statement.test, context); - if (test instanceof torch.Value) { - const node = this._graph.create('prim::If'); - node.addInput(test); - } - if (test === true || test) { + if (test === true || (!this.traceIf && test)) { const value = this.block(statement.body.statements, context); if (value !== undefined) { return value; @@ -2136,6 +2136,10 @@ pytorch.Execution = class extends python.Execution { } } return undefined; + } else if (test instanceof torch.Value && test.type() instanceof torch.BoolType) { + const node = this._graph.create('prim::If'); + node.addInput(test); + return undefined; } throw new python.Error("Unsupported condition."); } @@ -2420,40 +2424,13 @@ pytorch.Execution = class extends python.Execution { const type = arg.real_type; switch (type.str()) { case 'Tensor': { - const output = this.createTensorOutput(schema.name, evalArgs, i); + const output = new torch.Tensor(); output.__origin__ = schema.name; this.variable(output, node); result.push(output); break; } case 'Tensor[]': { - let count = 1; - switch (schema.name) { - case 'aten::chunk': - count = node.inputs()[1].value; - break; - case 'aten::meshgrid': { - const list = node.inputs()[0].node(); - if (list.kind() === 'prim::ListConstruct') { - count = list.inputs().length; - } - break; - } - case 'aten::unbind': - case 'aten::unbind.int': - count = args[0].__tuple__ || count; - break; - case 'aten::broadcast_tensors': - case 'aten::split': - case 'aten::split.Tensor': - case 'aten::split_with_sizes': - if (context.target.length > 0) { - count = context.target[context.target.length - 1].length; - } - break; - default: - break; - } const value = node.addOutput(); value.setType(torch.ListType.get(torch.TensorType.get())); result.push(value); @@ -2471,50 +2448,21 @@ pytorch.Execution = class extends python.Execution { result.push(value); break; } - case 'int': { - const value = this.variable(null, node); - value.__origin__ = schema.name; - value.setType(torch.IntType.get()); - switch (schema.name) { - case 'aten::div.int': value.value = torch.div(evalArgs[0], evalArgs[1]); break; - case 'aten::dim': value.value = torch.dim(evalArgs[0]); break; - case 'aten::len.t': value.value = torch.len(evalArgs[0]); break; - // case 'aten::size.int': value.value = torch.size(evalArgs[0], evalArgs[1]); break; - default: break; - } - result.push(value); - break; - } - case 'int[]': { - const value = this.variable(null, node); - value.__origin__ = schema.name; - value.setType(torch.ListType.get(torch.IntType.get())); - switch (schema.name) { - // case 'aten::size': value.value = torch.size(evalArgs[0], evalArgs[1]); break; - default: break; - } - result.push(value); - break; - } case 'Scalar': case 'Dict(str, Tensor)': + case 'int': + case 'int[]': case 'str': case 'str[]': case 'float': case 'float[]': case 'complex': case 'bool': - case 'bool[]': { - const value = this.variable(null, node); - value.__origin__ = schema.name; - value.setType(type); - result.push(value); - break; - } + case 'bool[]': case 'Device': { const value = this.variable(null, node); value.__origin__ = schema.name; - value.setType(torch.DeviceObjType.get()); + value.setType(type); result.push(value); break; } @@ -2577,207 +2525,6 @@ pytorch.Execution = class extends python.Execution { return result[0]; } - createTensorOutput(op_name, evalArgs, i) { - const torch = this.torch; - const output = new torch.Tensor(); - if (i === 0) { - switch (op_name) { - case 'aten::conv1d': - case 'aten::embedding': { - output.resize_([NaN, NaN, NaN]); - break; - } - case 'aten::cat': - case 'aten::conv2d': - case 'aten::dropout': - case 'aten::flatten': - case 'aten::flatten.named_out_dim': - case 'aten::max_pool2d': - case 'aten::adaptive_avg_pool2d': - case 'aten::avg_pool2d': - case 'aten::quantize_per_tensor': - case 'aten::relu_': - case 'aten::prelu': - case 'aten::hardtanh_': - case 'aten::upsample_bilinear2d': - case 'prepacked::conv2d_clamp_run': { - const [input] = evalArgs; - if (pytorch.Utility.isTensor(input) && input.size() === undefined) { - input.resize_([NaN, NaN, NaN, NaN]); - } - output.resize_([NaN, NaN, NaN, NaN]); - break; - } - case 'aten::slice': - case 'aten::slice.Tensor': { - const [input] = evalArgs; - if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { - const size = input.size(); - output.resize_(size); - } - break; - } - case 'aten::to': - case 'aten::to.device': - case 'aten::to.dtype': - case 'aten::to.dtype_layout': { - const [input] = evalArgs; - if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { - const size = input.size(); - output.resize_(size); - } - break; - } - case 'aten::conv3d': { - output.resize_([NaN, NaN, NaN, NaN, NaN]); - break; - } - case 'aten::roll': - case 'aten::detach': - case 'aten::mean': - case 'aten::mul': - case 'aten::mul.Scalar': - case 'aten::div': - case 'aten::div.Scalar': - case 'aten::batch_norm': - case 'aten::gelu': - case 'aten::relu': - case 'aten::clamp': - case 'aten::clamp_': - case 'aten::_add_relu_': - case 'aten::hardswish_': { - const [input] = evalArgs; - if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { - output.resize_(input.size()); - } - break; - } - case 'aten::add': - case 'aten::add.Scalar': - case 'aten::sub': - case 'aten::sub.Scalar': { - const [input] = evalArgs; - if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { - output.resize_(input.size()); - } else { - const [, other] = evalArgs; - if (pytorch.Utility.isTensor(other) && Array.isArray(other.size())) { - output.resize_(other.size()); - } - } - break; - } - case 'aten::select': - case 'aten::select.int': { - const [input] = evalArgs; - if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { - output.resize_(Array(input.size().length - 1).fill(NaN)); - } - break; - } - case 'aten::layer_norm': { - const [input, normalized_shape] = evalArgs; - if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { - const shape = input.size(); - if (Array.isArray(normalized_shape) && normalized_shape.length === 1) { - const [value] = normalized_shape; - shape[shape.length - 1] = value; - } - output.resize_(shape); - } - break; - } - case 'aten::empty': - case 'aten::ones': - case 'aten::zeros': - case 'aten::zeros_like': { - output.resize_(evalArgs[0]); - break; - } - case 'aten::view': - case 'aten::reshape': - case 'aten::new_full': { - output.resize_(evalArgs[1]); - break; - } - case 'aten::squeeze': - case 'aten::squeeze.dim': { - const [input] = evalArgs; - if (input instanceof torch.Value === false) { - const size = input.size(); - if (Array.isArray(size)) { - switch (evalArgs.length) { - case 1: { - output.resize_(size.filter((value) => value !== 1)); - break; - } - case 2: { - const [, dim] = evalArgs; - output.resize_(size.filter((value, index) => (value !== 1 && !isNaN(value)) || index !== dim)); - break; - } - default: { - break; - } - } - } - } - break; - } - case 'aten::unsqueeze': { - const [input, dim] = evalArgs; - if (pytorch.Utility.isTensor(input)) { - const size = input.size(); - if (Array.isArray(size) && dim !== undefined) { - const shape = size.slice(); - shape.splice(dim, 0, 1); - output.resize_(shape); - } else { - output.resize_([NaN, NaN, NaN, NaN]); - } - } - break; - } - case 'aten::transpose': - case 'aten::transpose.int': { - const [input, dim0, dim1] = evalArgs; - if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { - const size = input.size().slice(); - const d0 = dim0 >= 0 ? dim0 : size.length + dim0; - const d1 = dim1 >= 0 ? dim1 : size.length + dim1; - const value = size[dim0]; - /* eslint-disable prefer-destructuring */ - size[d0] = size[1]; - /* eslint-enable prefer-destructuring */ - size[d1] = value; - output.resize_(size); - } - break; - } - case 'aten::contiguous': { - const [source] = evalArgs; - output.__source__ = source; - break; - } - case 'quantized::cat': - case 'quantized::cat_relu': - case 'quantized::linear': - case 'quantized::conv2d': - case 'quantized::conv2d.new': - case 'quantized::conv2d_relu': - case 'quantized::conv2d_relu.new': - case 'quantized::add': - case 'quantized::add_relu': - output.resize_([NaN, NaN, NaN, NaN]); - output.__quantized__ = true; - break; - default: - break; - } - } - return output; - } - isType(obj, type, N) { const torch = this.torch; switch (type.str()) { @@ -3109,288 +2856,6 @@ pytorch.Execution = class extends python.Execution { } return [matches[0], evalArgs]; } - - block(statements, context) { - if (!this.traceIf) { - return super.block(statements, context); - } - statements = Array.prototype.slice.call(statements); - while (statements.length > 0) { - if (statements.length > 1) { - const [assign, condition] = statements; - // _x = torch.ne(torch.len(torch.size(input)), 5) - // if _x: - // ops.prim.RaiseException(...) - if (assign.type === '=' && - condition.type === 'if' && - pytorch.Utility.isEqual(assign.target, condition.test) && - pytorch.Utility.isCall(assign.expression, 'torch.ne', 2) && - pytorch.Utility.isCall(assign.expression.args[0], 'torch.len', 1) && - pytorch.Utility.isCall(assign.expression.args[0].args[0], 'torch.size', 1) && - condition.body.statements.length === 1 && - pytorch.Utility.isCall(condition.body.statements[0], 'ops.prim.RaiseException', 1)) { - const tensor = this.expression(assign.expression.args[0].args[0].args[0], context); - if (pytorch.Utility.isTensor(tensor) && tensor.size) { - const number = this.expression(assign.expression.args[1], context); - const size = tensor.size(); - if (number >= 3 && number <= 5) { - if (!Array.isArray(size) || size.length !== number) { - tensor.resize_(Array(number).fill(NaN)); - } - } - } - } - // _x = torch.ne(torch.dim(input), 5) - // if _x: - // ops.prim.RaiseException(...) - if (assign.type === '=' && - condition.type === 'if' && - pytorch.Utility.isEqual(assign.target, condition.test) && - pytorch.Utility.isCall(assign.expression, 'torch.ne', 2) && - pytorch.Utility.isCall(assign.expression.args[0], 'torch.dim', 1) && - condition.body.statements.length > 0 && - pytorch.Utility.isCall(condition.body.statements[condition.body.statements.length - 1], 'ops.prim.RaiseException', 1)) { - const tensor = this.expression(assign.expression.args[0].args[0], context); - if (pytorch.Utility.isTensor(tensor)) { - const size = this.expression(assign.expression.args[1], context); - tensor.resize_(Array(size).fill(NaN)); - } - } - // _0 = torch.eq(torch.len(torch.size(x)), 2) - // if _0: - // pass - // else: - // ops.prim.RaiseException("AssertionError: ") - if (assign.type === '=' && - condition.type === 'if' && - pytorch.Utility.isEqual(assign.target, condition.test) && - pytorch.Utility.isCall(assign.expression, 'torch.eq', 2) && - pytorch.Utility.isCall(assign.expression.args[0], 'torch.len', 1) && - pytorch.Utility.isCall(assign.expression.args[0].args[0], 'torch.size', 1) && - condition.orelse.statements.length === 1 && - pytorch.Utility.isCall(condition.orelse.statements[0], 'ops.prim.RaiseException', 1)) { - const tensor = this.expression(assign.expression.args[0].args[0].args[0], context); - if (pytorch.Utility.isTensor(tensor) && tensor.shape === undefined) { - const number = this.expression(assign.expression.args[1], context); - tensor.resize_(Array(number).fill(NaN)); - } - } - // val = torch.slice(torch.size(img), -2) - // if torch.eq(torch.len(val), 2): - // pass - // else: - // ops.prim.RaiseException("AssertionError: ") - if (assign.type === '=' && - condition.type === 'if' && - pytorch.Utility.isCall(assign.expression, 'torch.slice', 2) && - pytorch.Utility.isCall(assign.expression.args[0], 'torch.size', 1) && - pytorch.Utility.isCall(condition.test, 'torch.eq', 2) && - pytorch.Utility.isCall(condition.test.args[0], 'torch.len', 1) && - pytorch.Utility.isEqual(condition.test.args[0].args[0], assign.target) && - condition.orelse.statements.length === 1 && - pytorch.Utility.isCall(condition.orelse.statements[0], 'ops.prim.RaiseException', 1)) { - const tensor = this.expression(assign.expression.args[0].args[0], context); - if (pytorch.Utility.isTensor(tensor) && tensor.shape === undefined) { - const start = this.expression(assign.expression.args[1], context); - const value = this.expression(condition.test.args[1], context); - if (Number.isInteger(start) && start < 0 && Number.isInteger(value) && value > 0) { - tensor.resize_(Array(value - start).fill(NaN)); - } - } - } - } - if (statements.length > 1) { - // getattr_1 = torch.size(x) - // getitem = torch.slice(getattr_1, -2, 9223372036854775807, 1) - const [size, statement] = statements; - if (size.type === '=' && statement.type === '=' && - size.target.type === 'id' && - pytorch.Utility.isCall(size.expression, 'torch.size', 1) && - pytorch.Utility.isCall(statement.expression, 'torch.slice', 4) && - statement.expression.arguments[0].type === 'id' && size.target.value === statement.expression.arguments[0].value) { - const tensor = this.expression(size.expression.arguments[0], context); - if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) { - tensor.resize_([1, 3, 299, 299]); - } - } - } - if (statements.length > 1) { - // _0 = torch.split_with_sizes(...) - // a, a_1, a_2, = _0 - const [statement, tuple] = statements; - if (statement.type === '=' && statement.target.type === 'id' && statement.expression.type === 'call' && - tuple.type === '=' && tuple.target.type === 'tuple' && - tuple.target.value.every((item) => item.type === 'id') && - tuple.expression.value === statement.target.value) { - const containsVariableReference = (queue, value) => { - while (queue.length > 0) { - const obj = queue.shift(); - if (obj && obj.type === 'id' && obj.value === value) { - return true; - } else if (Array.isArray(obj)) { - for (const item of obj) { - if (Array.isArray(item) || (Object(item) === item && item.type)) { - queue.push(item); - } - } - } else if (Object(obj) === obj) { - for (const [key, value] of Object.entries(obj)) { - if (key !== 'identifier') { - if (Array.isArray(value)) { - for (const item of value) { - if (Array.isArray(item) || (Object(item) === item && item.type)) { - queue.push(item); - } - } - } else if (Object(value) === value && value.type) { - queue.push(value); - } - } - } - } - } - return false; - }; - if (!containsVariableReference(statements.slice(2, statements.length - 1), statement.target.value)) { - statements[0] = { ...statement }; - statements[0].target = tuple.target; - statements.splice(1, 1); - } - } - } - const statement = statements.shift(); - // input_shape = torch.slice(torch.size(x), -2, 9223372036854775807, 1) - if (statement.type === '=' && - pytorch.Utility.isCall(statement.expression, 'torch.slice', 4) && - pytorch.Utility.isCall(statement.expression.args[0], 'torch.size', 1)) { - const tensor = this.expression(statement.expression.args[0].args[0], context); - if (pytorch.Utility.isTensor(tensor) && tensor.shape === undefined) { - tensor.resize_([1, 3, 299, 299]); - } - } - // torch.slice(ops.prim.shape(input), 0, 2, 1) - if (statement.type === '=' && - pytorch.Utility.isCall(statement.expression, 'torch.slice', 4) && - pytorch.Utility.isCall(statement.expression.args[0], 'ops.prim.shape', 1)) { - const tensor = this.expression(statement.expression.args[0].args[0], context); - if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) { - tensor.resize_([NaN, NaN, NaN, NaN]); - } - } - // _3 = torch.le(xxxx, torch.dim(f0)) - if (statement.type === '=' && - pytorch.Utility.isCall(statement.expression, 'torch.le', 2) && - pytorch.Utility.isCall(statement.expression.args[1], 'torch.dim', 1)) { - const tensor = this.expression(statement.expression.args[1].args[0], context); - if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) { - tensor.resize_([NaN, NaN, NaN, NaN]); - } - } - // if torch.ne(torch.dim(image), 3): - // xxxx - // ops.prim.RaiseException(_7) - if (statement.type === 'if' && - pytorch.Utility.isCall(statement.test, 'torch.ne', 2) && - pytorch.Utility.isCall(statement.test.args[0], 'torch.dim', 1) && - statement.body.statements.length > 0 && - pytorch.Utility.isCall(statement.body.statements.slice(-1).pop(), 'ops.prim.RaiseException', 1)) { - const tensor = this.expression(statement.test.args[0].args[0], context); - const size = this.expression(statement.test.args[1], context); - if (pytorch.Utility.isTensor(tensor) && Number.isInteger(size) && size < 10) { - tensor.resize_(Array.isArray(tensor.shape) && tensor.shape.length > size ? tensor.shape.slice(-size) : Array(size).fill(NaN)); - } - } - // if torch.gt(torch.dim(x), 1): - // xxxx - // ops.prim.RaiseException(...) - if (statement.type === 'if' && - pytorch.Utility.isCall(statement.test, 'torch.gt', 2) && - pytorch.Utility.isCall(statement.test.args[0], 'torch.dim', 1) && - statement.body.statements.length > 0 && - pytorch.Utility.isCall(statement.body.statements.slice(-1).pop(), 'ops.prim.RaiseException')) { - const tensor = this.expression(statement.test.args[0].args[0], context); - const size = this.expression(statement.test.args[1], context); - if (pytorch.Utility.isTensor(tensor) && Number.isInteger(size) && size < 10) { - tensor.resize_(Array.isArray(tensor.shape) && tensor.shape.length > size ? tensor.shape.slice(-size) : Array(size).fill(NaN)); - } - } - // if bool(...): - // ops.prim.RaiseException(torch.format(_1, dtype)) - // else: - // pass - if (statement.type === 'if' && - pytorch.Utility.isCall(statement.test, 'bool', 1) && - statement.body.statements.length > 0 && - pytorch.Utility.isCall(statement.body.statements.slice(-1).pop(), 'ops.prim.RaiseException', 1)) { - statement.test = { type: 'id', value: 'False' }; - } - // dim = torch.sub(torch.dim(input), 2) - if (statement.type === '=' && - statement.target.type === 'id' && statement.target.value === 'dim' && - pytorch.Utility.isCall(statement.expression, 'torch.sub', 2) && - pytorch.Utility.isCall(statement.expression.args[0], 'torch.dim', 1)) { - const tensor = this.expression(statement.expression.args[0].args[0], context); - if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) { - tensor.resize_([NaN, NaN, NaN, NaN]); - } - } - // a, b = torch.unbind(size, 0) - if (statement.type === '=' && - statement.target.type === 'tuple' && - (pytorch.Utility.isCall(statement.expression, 'torch.unbind', 1) || - pytorch.Utility.isCall(statement.expression, 'torch.unbind', 2))) { - statement.expression.args[0].__tuple__ = statement.target.value.length; - } - // a, b, c = torch.size(input) - if (statement.type === '=' && - statement.target.type === 'tuple' && - pytorch.Utility.isCall(statement.expression, 'torch.size', 1)) { - const tensor = this.expression(statement.expression.args[0], context); - if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) { - const dim = statement.target.value.length; - tensor.resize_(Array(dim).fill(NaN)); - } - } - // x = torch.len(input) - if (statement.type === '=' && - statement.target.type === 'id' && - pytorch.Utility.isCall(statement.expression, 'torch.len', 1)) { - const tensor = this.expression(statement.expression.args[0], context); - if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) { - tensor.resize_([NaN, NaN, NaN, NaN]); - } - } - // x = _(torch.size(foo ,2)) - if (statement.type === '=' && - statement.expression.type === 'call' && statement.expression.args.length > 0 && - pytorch.Utility.isCall(statement.expression.args[0], 'torch.size', 2)) { - const tensor = this.expression(statement.expression.args[0].args[0], context); - const dim = this.expression(statement.expression.args[0].args[1], context); - if (pytorch.Utility.isTensor(tensor) && Number.isInteger(dim) && dim >= 0) { - if (tensor.shape === undefined) { - tensor.resize_(Array(dim + 1).fill(NaN)); - } else if (Array.isArray(tensor.shape) && tensor.shape.length <= dim) { - tensor.resize_(tensor.shape.concat(Array(dim + 1 - tensor.shape.length).fill(NaN))); - } - } - } - if (statement.type === '=' && statement.target.type === 'tuple' && - statement.expression.type === 'call' && statement.expression.args.length > 0 && - pytorch.Utility.isCall(statement.expression, 'torch.size', 1)) { - const tensor = this.expression(statement.expression.args[0], context); - if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input') { - if (tensor.shape === undefined) { - tensor.resize_(Array(statement.target.value.length).fill(NaN)); - } - } - } - const value = this.statement(statement, context); - if (value !== undefined) { - return value; - } - } - return undefined; - } }; pytorch.Container.Package = class extends pytorch.Container {