Skip to content

Commit

Permalink
Allow parsing input tensors from torch ScriptModules (#57)
Browse files Browse the repository at this point in the history
* add ability to parse inputs from script module

* update torch dep to ^2.0 for consistency with ml4gw

* add onnx dep

* remove both tests

* remove both tests
  • Loading branch information
EthanMarx authored Feb 6, 2024
1 parent 59ff33d commit af7aa3b
Show file tree
Hide file tree
Showing 4 changed files with 338 additions and 40 deletions.
21 changes: 16 additions & 5 deletions hermes/quiver/exporters/torch_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@
from hermes.quiver.exporters import Exporter


def get_input_names_from_script_module(m):
graph = m.graph
input_names = [node.debugName().split(".")[0] for node in graph.inputs()]
if "self" in input_names:
input_names.remove("self")
return OrderedDict({name: name for name in input_names})


class TorchOnnxMeta(abc.ABCMeta):
@property
def handles(self):
Expand Down Expand Up @@ -55,11 +63,14 @@ def _get_output_shapes(self, model_fn, output_names):
# generate an input array of random data
input_tensors[input.name] = self._get_tensor(input.dims)

# use function signature from module.forward
# to figure out in which order to pass input
# tensors to the model_fn
signature = inspect.signature(model_fn.forward)
parameters = OrderedDict(signature.parameters)
# parse script module to figure out in which order
# to pass input tensors to the model_fn
if isinstance(model_fn, torch.jit.ScriptModule):
parameters = get_input_names_from_script_module(model_fn)
# otherwise use function signature from module.forward
else:
signature = inspect.signature(model_fn.forward)
parameters = OrderedDict(signature.parameters)

# make sure the number of inputs to
# the model_fn matches the number of
Expand Down
Loading

0 comments on commit af7aa3b

Please sign in to comment.