diff --git a/source/pytorch.js b/source/pytorch.js index b64376422b..37767f51ef 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -359,18 +359,6 @@ pytorch.Node = class { this.inputs = []; this.outputs = []; this.metadata = []; - const createType = (metadata, name) => { - if (name instanceof pytorch.nnapi.Graph) { - return name; - } - const key = name.startsWith('__torch__.') ? name.substring(10) : name; - const value = metadata.type(key); - const type = value ? { ...value } : { name }; - type.identifier = name; - [name] = type.name.split('('); - type.name = name.indexOf('::') === -1 ? name : name.split('::').pop().split('.')[0]; - return type; - }; let module = null; if (pytorch.Utility.isInstance(obj, 'torch.Node')) { const node = obj; @@ -590,16 +578,23 @@ pytorch.Node = class { } } else if (pytorch.Utility.isInstance(obj, 'torch.fx.node.Node')) { if (obj.op === 'call_function') { - this.type = createType(metadata, obj.target.name); + const name = obj.target.name; + this.type = { + identifier: name, + name: name.indexOf('::') === -1 ? name : name.split('::').pop().split('.')[0] + }; const schema = obj.target._schema; + if (schema && schema.category) { + this.type.category = schema.category; + } let args = obj.args.map((arg, index) => { const name = schema && Array.isArray(schema.arguments) ? schema.arguments[index].name : ''; return [name, arg]; }); - const inputs = new Map((this.type.inputs || []).map((input) => [input.name, input])); + const inputs = new Map((schema ? schema.arguments : []).map((arg) => [arg.name, arg])); args = args.concat(Array.from(obj.kwargs)); for (const [name, arg] of args) { - const type = inputs.has(name) ? inputs.get(name).type : null; + const type = inputs.has(name) ? pytorch.Utility.toType(inputs.get(name).real_type) : null; if (pytorch.Utility.isInstance(arg, 'torch.fx.node.Node')) { const value = values.map(arg); const argument = new pytorch.Argument(name, [value]); @@ -662,7 +657,7 @@ pytorch.Node = class { } } } else if (obj.op === 'placeholder') { - this.type = createType(metadata, 'placeholder'); + this.type = { name: obj.op }; const value = values.map(obj); const argument = new pytorch.Argument('value', [value]); this.inputs.push(argument); @@ -682,7 +677,22 @@ pytorch.Node = class { type = 'builtins.object'; } } - this.type = createType(metadata, type); + if (type instanceof pytorch.nnapi.Graph) { + this.type = type; + } else { + const key = type.startsWith('__torch__.') ? type.substring(10) : type; + const value = metadata.type(key); + this.type = value ? { ...value } : { name: type }; + this.type.identifier = type; + if (this.type.name.indexOf('(') !== -1) { + throw new Error(); + } + if (this.type.name.indexOf('::') !== -1) { + throw new Error(); + } + // [name] = this.type.name.split('('); + // this.type.name = name.indexOf('::') === -1 ? name : name.split('::').pop().split('.')[0]; + } stack = stack || new Set(); const weights = pytorch.Utility.weights(obj); if (weights) { @@ -3497,40 +3507,25 @@ pytorch.Utility = class { } static toType(type) { - if (pytorch.Utility.isInstance(type, 'torch.OptionalType')) { - return `${pytorch.Utility.toType(type.getElementType())}?`; - } - if (pytorch.Utility.isInstance(type, 'torch.ListType')) { - return `${pytorch.Utility.toType(type.getElementType())}[]`; - } - if (pytorch.Utility.isInstance(type, 'torch.IntType')) { - return `int64`; - } - if (pytorch.Utility.isInstance(type, 'torch.FloatType')) { - return `float32`; - } - if (pytorch.Utility.isInstance(type, 'torch.StringType')) { - return `string`; - } - if (pytorch.Utility.isInstance(type, 'torch.ComplexType')) { - return `complex`; - } - if (pytorch.Utility.isInstance(type, 'torch.BoolType')) { - return `boolean`; - } - if (pytorch.Utility.isInstance(type, 'torch.TensorType')) { - return `tensor`; - } - if (pytorch.Utility.isInstance(type, 'torch.TupleType')) { - return `(${type.elements().map((type) => pytorch.Utility.toType(type)).join(', ')})`; - } - if (pytorch.Utility.isInstance(type, 'torch.DeviceObjType')) { - return `Device`; - } - if (pytorch.Utility.isInstance(type, 'torch.DictType')) { - return `{}`; + switch (type.kind()) { + case 'OptionalType': return `${pytorch.Utility.toType(type.getElementType())}?`; + case 'ListType': return `${pytorch.Utility.toType(type.getElementType())}[]`; + case 'BoolType': return `boolean`; + case 'IntType': return `int64`; + case 'FloatType': return `float32`; + case 'StringType': return `string`; + case 'ComplexType': return `complex`; + case 'NumberType': return `scalar`; + case 'TensorType': return `tensor`; + case 'TupleType': return `tuple<${type.elements().map((type) => pytorch.Utility.toType(type)).join(', ')}>`; + case 'DictType': return `map<${pytorch.Utility.toType(type.getKeyType())}, ${pytorch.Utility.toType(type.getValueType())}>`; + case 'DeviceObjType': return `device`; + case 'SymIntType': return `SymInt`; + case 'ScalarTypeType': return `ScalarType`; + case 'MemoryFormat': return `MemoryFormat`; + case 'Layout': return `Layout`; + default: throw new pytorch.Error(`Unsupported type '${type.kind()}'.`); } - throw new pytorch.Error(`Unsupported type '${type.kind()}'.`); } static isObjectType(type) {