Skip to content

Commit

Permalink
Add torch.export test file (#1211)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Nov 2, 2024
1 parent 00f6c80 commit fb14c2e
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
8 changes: 8 additions & 0 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -8471,6 +8471,11 @@ python.Execution = class {
.filter((s) => s.kind === torch.export.graph_signature.InputKind.BUFFER && s.arg instanceof torch.export.graph_signature.TensorArgument && typeof s.target === 'string')
.map((s) => [s.arg.name, s.target]));
}
inputs_to_lifted_tensor_constants() {
return new Map(this.input_specs
.filter((s) => s.kind === torch.export.graph_signature.InputKind.CONSTANT_TENSOR && s.arg instanceof torch.export.graph_signature.TensorArgument && typeof s.target === 'string')
.map((s) => [s.arg.name, s.target]));
}
});
torch.export.graph_signature.InputKind = {
USER_INPUT: 0,
Expand Down Expand Up @@ -8563,6 +8568,9 @@ python.Execution = class {
get state_dict() {
return this._state_dict;
}
get constants() {
return this._constants;
}
});
this.registerType('torch.export.exported_program.ModuleCallEntry', class {});
this.registerType('torch.export.exported_program.ModuleCallSignature', class {});
Expand Down
12 changes: 10 additions & 2 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ pytorch.Graph = class {
const graph = exported_program.graph;
const inputs_to_parameters = exported_program.graph_signature.inputs_to_parameters();
const inputs_to_buffers = exported_program.graph_signature.inputs_to_buffers();
const inputs_to_lifted_tensor_constants = exported_program.graph_signature.inputs_to_lifted_tensor_constants();
const values = new Map();
values.map = (obj) => {
if (!values.has(obj)) {
Expand All @@ -241,15 +242,22 @@ pytorch.Graph = class {
const value = new pytorch.Value(key, null, null, tensor);
values.set(node, value);
}
}
if (inputs_to_buffers.has(node.name)) {
} else if (inputs_to_buffers.has(node.name)) {
const key = inputs_to_buffers.get(node.name);
const buffer = exported_program.state_dict.get(key);
if (buffer) {
const tensor = new pytorch.Tensor(key, buffer);
const value = new pytorch.Value(key, null, null, tensor);
values.set(node, value);
}
} else if (inputs_to_lifted_tensor_constants.has(node.name)) {
const key = inputs_to_lifted_tensor_constants.get(node.name);
const constant = exported_program.constants.get(key);
if (exported_program) {
const tensor = new pytorch.Tensor(key, constant);
const value = new pytorch.Value(key, null, null, tensor);
values.set(node, value);
}
}
}
}
Expand Down
7 changes: 7 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -5163,6 +5163,13 @@
"format": "TorchScript v1.6",
"link": "https://github.com/lutzroeder/netron/issues/1067"
},
{
"type": "pytorch",
"target": "feature_embedding.pt2",
"source": "https://github.com/user-attachments/files/17608076/chai-1.zip[feature_embedding.pt2]",
"format": "PyTorch Export v5.1",
"link": "https://github.com/lutzroeder/netron/issues/1211"
},
{
"type": "pytorch",
"target": "checkpoints_auc_classifier.pth",
Expand Down

0 comments on commit fb14c2e

Please sign in to comment.