Skip to content

Commit

Permalink
Update pytorch.js
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 26, 2024
1 parent 7015d48 commit e59686b
Showing 1 changed file with 45 additions and 50 deletions.
95 changes: 45 additions & 50 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -359,18 +359,6 @@ pytorch.Node = class {
this.inputs = [];
this.outputs = [];
this.metadata = [];
const createType = (metadata, name) => {
if (name instanceof pytorch.nnapi.Graph) {
return name;
}
const key = name.startsWith('__torch__.') ? name.substring(10) : name;
const value = metadata.type(key);
const type = value ? { ...value } : { name };
type.identifier = name;
[name] = type.name.split('(');
type.name = name.indexOf('::') === -1 ? name : name.split('::').pop().split('.')[0];
return type;
};
let module = null;
if (pytorch.Utility.isInstance(obj, 'torch.Node')) {
const node = obj;
Expand Down Expand Up @@ -590,16 +578,23 @@ pytorch.Node = class {
}
} else if (pytorch.Utility.isInstance(obj, 'torch.fx.node.Node')) {
if (obj.op === 'call_function') {
this.type = createType(metadata, obj.target.name);
const name = obj.target.name;
this.type = {
identifier: name,
name: name.indexOf('::') === -1 ? name : name.split('::').pop().split('.')[0]
};
const schema = obj.target._schema;
if (schema && schema.category) {
this.type.category = schema.category;
}
let args = obj.args.map((arg, index) => {
const name = schema && Array.isArray(schema.arguments) ? schema.arguments[index].name : '';
return [name, arg];
});
const inputs = new Map((this.type.inputs || []).map((input) => [input.name, input]));
const inputs = new Map((schema ? schema.arguments : []).map((arg) => [arg.name, arg]));
args = args.concat(Array.from(obj.kwargs));
for (const [name, arg] of args) {
const type = inputs.has(name) ? inputs.get(name).type : null;
const type = inputs.has(name) ? pytorch.Utility.toType(inputs.get(name).real_type) : null;
if (pytorch.Utility.isInstance(arg, 'torch.fx.node.Node')) {
const value = values.map(arg);
const argument = new pytorch.Argument(name, [value]);
Expand Down Expand Up @@ -662,7 +657,7 @@ pytorch.Node = class {
}
}
} else if (obj.op === 'placeholder') {
this.type = createType(metadata, 'placeholder');
this.type = { name: obj.op };
const value = values.map(obj);
const argument = new pytorch.Argument('value', [value]);
this.inputs.push(argument);
Expand All @@ -682,7 +677,22 @@ pytorch.Node = class {
type = 'builtins.object';
}
}
this.type = createType(metadata, type);
if (type instanceof pytorch.nnapi.Graph) {
this.type = type;
} else {
const key = type.startsWith('__torch__.') ? type.substring(10) : type;
const value = metadata.type(key);
this.type = value ? { ...value } : { name: type };
this.type.identifier = type;
if (this.type.name.indexOf('(') !== -1) {
throw new Error();
}
if (this.type.name.indexOf('::') !== -1) {
throw new Error();
}
// [name] = this.type.name.split('(');
// this.type.name = name.indexOf('::') === -1 ? name : name.split('::').pop().split('.')[0];
}
stack = stack || new Set();
const weights = pytorch.Utility.weights(obj);
if (weights) {
Expand Down Expand Up @@ -3497,40 +3507,25 @@ pytorch.Utility = class {
}

static toType(type) {
if (pytorch.Utility.isInstance(type, 'torch.OptionalType')) {
return `${pytorch.Utility.toType(type.getElementType())}?`;
}
if (pytorch.Utility.isInstance(type, 'torch.ListType')) {
return `${pytorch.Utility.toType(type.getElementType())}[]`;
}
if (pytorch.Utility.isInstance(type, 'torch.IntType')) {
return `int64`;
}
if (pytorch.Utility.isInstance(type, 'torch.FloatType')) {
return `float32`;
}
if (pytorch.Utility.isInstance(type, 'torch.StringType')) {
return `string`;
}
if (pytorch.Utility.isInstance(type, 'torch.ComplexType')) {
return `complex`;
}
if (pytorch.Utility.isInstance(type, 'torch.BoolType')) {
return `boolean`;
}
if (pytorch.Utility.isInstance(type, 'torch.TensorType')) {
return `tensor`;
}
if (pytorch.Utility.isInstance(type, 'torch.TupleType')) {
return `(${type.elements().map((type) => pytorch.Utility.toType(type)).join(', ')})`;
}
if (pytorch.Utility.isInstance(type, 'torch.DeviceObjType')) {
return `Device`;
}
if (pytorch.Utility.isInstance(type, 'torch.DictType')) {
return `{}`;
switch (type.kind()) {
case 'OptionalType': return `${pytorch.Utility.toType(type.getElementType())}?`;
case 'ListType': return `${pytorch.Utility.toType(type.getElementType())}[]`;
case 'BoolType': return `boolean`;
case 'IntType': return `int64`;
case 'FloatType': return `float32`;
case 'StringType': return `string`;
case 'ComplexType': return `complex`;
case 'NumberType': return `scalar`;
case 'TensorType': return `tensor`;
case 'TupleType': return `tuple<${type.elements().map((type) => pytorch.Utility.toType(type)).join(', ')}>`;
case 'DictType': return `map<${pytorch.Utility.toType(type.getKeyType())}, ${pytorch.Utility.toType(type.getValueType())}>`;
case 'DeviceObjType': return `device`;
case 'SymIntType': return `SymInt`;
case 'ScalarTypeType': return `ScalarType`;
case 'MemoryFormat': return `MemoryFormat`;
case 'Layout': return `Layout`;
default: throw new pytorch.Error(`Unsupported type '${type.kind()}'.`);
}
throw new pytorch.Error(`Unsupported type '${type.kind()}'.`);
}

static isObjectType(type) {
Expand Down

0 comments on commit e59686b

Please sign in to comment.