From fb14c2ef94fa4d41ac03a0d52b9505c1e3cef7fd Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sat, 2 Nov 2024 12:21:05 -0700 Subject: [PATCH] Add torch.export test file (#1211) --- source/python.js | 8 ++++++++ source/pytorch.js | 12 ++++++++++-- test/models.json | 7 +++++++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/source/python.js b/source/python.js index 8ee86a7d56..bdbdc9c9f4 100644 --- a/source/python.js +++ b/source/python.js @@ -8471,6 +8471,11 @@ python.Execution = class { .filter((s) => s.kind === torch.export.graph_signature.InputKind.BUFFER && s.arg instanceof torch.export.graph_signature.TensorArgument && typeof s.target === 'string') .map((s) => [s.arg.name, s.target])); } + inputs_to_lifted_tensor_constants() { + return new Map(this.input_specs + .filter((s) => s.kind === torch.export.graph_signature.InputKind.CONSTANT_TENSOR && s.arg instanceof torch.export.graph_signature.TensorArgument && typeof s.target === 'string') + .map((s) => [s.arg.name, s.target])); + } }); torch.export.graph_signature.InputKind = { USER_INPUT: 0, @@ -8563,6 +8568,9 @@ python.Execution = class { get state_dict() { return this._state_dict; } + get constants() { + return this._constants; + } }); this.registerType('torch.export.exported_program.ModuleCallEntry', class {}); this.registerType('torch.export.exported_program.ModuleCallSignature', class {}); diff --git a/source/pytorch.js b/source/pytorch.js index 3ff7e1fc08..1b8e72ed74 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -215,6 +215,7 @@ pytorch.Graph = class { const graph = exported_program.graph; const inputs_to_parameters = exported_program.graph_signature.inputs_to_parameters(); const inputs_to_buffers = exported_program.graph_signature.inputs_to_buffers(); + const inputs_to_lifted_tensor_constants = exported_program.graph_signature.inputs_to_lifted_tensor_constants(); const values = new Map(); values.map = (obj) => { if (!values.has(obj)) { @@ -241,8 +242,7 @@ pytorch.Graph = class { const value = new pytorch.Value(key, null, null, tensor); values.set(node, value); } - } - if (inputs_to_buffers.has(node.name)) { + } else if (inputs_to_buffers.has(node.name)) { const key = inputs_to_buffers.get(node.name); const buffer = exported_program.state_dict.get(key); if (buffer) { @@ -250,6 +250,14 @@ pytorch.Graph = class { const value = new pytorch.Value(key, null, null, tensor); values.set(node, value); } + } else if (inputs_to_lifted_tensor_constants.has(node.name)) { + const key = inputs_to_lifted_tensor_constants.get(node.name); + const constant = exported_program.constants.get(key); + if (exported_program) { + const tensor = new pytorch.Tensor(key, constant); + const value = new pytorch.Value(key, null, null, tensor); + values.set(node, value); + } } } } diff --git a/test/models.json b/test/models.json index 56b0e68bbe..bcd608aa91 100644 --- a/test/models.json +++ b/test/models.json @@ -5163,6 +5163,13 @@ "format": "TorchScript v1.6", "link": "https://github.com/lutzroeder/netron/issues/1067" }, + { + "type": "pytorch", + "target": "feature_embedding.pt2", + "source": "https://github.com/user-attachments/files/17608076/chai-1.zip[feature_embedding.pt2]", + "format": "PyTorch Export v5.1", + "link": "https://github.com/lutzroeder/netron/issues/1211" + }, { "type": "pytorch", "target": "checkpoints_auc_classifier.pth",