-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ports the graph tracing to torch.fx #28
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
from .from_nir import load # noqa F401 | ||
from .graph import extract_torch_graph # noqa F401 | ||
from .to_nir import extract_nir_graph # noqa F401 | ||
from .to_nir import extract_nir_graph, trace_nir_graph # noqa F401 | ||
|
||
__version__ = version = "1.0" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
from typing import Any, Callable, Dict, Set, Tuple, TypeAlias, Optional | ||
import operator | ||
|
||
import numpy as np | ||
from torch.nn.modules import Module | ||
|
||
import nir | ||
import torch | ||
from torch.fx import Graph, GraphModule, Node, Tracer, Transformer | ||
from torch.fx.passes import shape_prop | ||
|
||
|
||
class NIRTorchTracer(Tracer): | ||
|
||
def __init__(self, custom_leaf_modules: Tuple[torch.nn.Module] = None, **kwargs): | ||
"""Extends PyTorch's default symbolic tracing with a set of custom leaf nodes""" | ||
super().__init__(**kwargs) | ||
if custom_leaf_modules is not None and not isinstance( | ||
custom_leaf_modules, tuple | ||
): | ||
custom_leaf_modules = tuple(custom_leaf_modules) | ||
self.custom_leaf_modules = custom_leaf_modules | ||
|
||
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: | ||
""" | ||
A method to specify whether a given ``nn.Module`` is a "leaf" module. | ||
Leaf modules are the atomic units that appear in | ||
the IR, referenced by ``call_module`` calls. By default, | ||
Modules in the PyTorch standard library namespace (torch.nn) | ||
are leaf modules. All other modules are traced through and | ||
their constituent ops are recorded, unless specified otherwise | ||
via this parameter. | ||
Args: | ||
m (Module): The module being queried about | ||
module_qualified_name (str): The path to root of this module. For example, | ||
if you have a module hierarchy where submodule ``foo`` contains | ||
submodule ``bar``, which contains submodule ``baz``, that module will | ||
appear with the qualified name ``foo.bar.baz`` here. | ||
""" | ||
# Tests that the module is in the list of custom leaves | ||
if self.custom_leaf_modules and isinstance(m, self.custom_leaf_modules): | ||
return True | ||
|
||
if hasattr(m, "_is_leaf_module") and m._is_leaf_module: | ||
return True | ||
|
||
return super().is_leaf_module(m, module_qualified_name) | ||
|
||
|
||
class NIRTorchTransformer(Transformer): | ||
def call_function(self, target: str, args: Tuple, kwargs: Dict) -> Any: | ||
print("sup", target) | ||
return super().call_function(target, args, kwargs) | ||
|
||
def call_method(self, target: str, args: Tuple, kwargs: Dict) -> Any: | ||
return super().call_method(target, args, kwargs) | ||
|
||
def call_module(self, target, args, kwargs): | ||
print("mod", target) | ||
return super().call_module(target, args, kwargs) | ||
|
||
|
||
def trace_torch_graph( | ||
module: torch.nn.Module, | ||
module_map: Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]], | ||
default_dict: Optional[ | ||
Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]] | ||
] = None, | ||
) -> nir.NIRGraph: | ||
""" | ||
Traces a PyTorch module and converts it to a NIR graph using the specified module map. | ||
|
||
Args: | ||
module (torch.nn.Module): The module of interest | ||
module_map (Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]]): A dictionary that maps | ||
a given module type to a function that can convert the model to an NIRNode type | ||
default_dict (Optional[Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]]]): An optional dictionary that maps | ||
a given module type to a function that can convert the model to an NIRNode type. This dictionary is merged | ||
with the module_map dictionary. | ||
""" | ||
# Merge the default dictionary, if it exists | ||
if default_dict is not None: | ||
module_map = module_map | default_dict | ||
|
||
# Cover the edge case that the incoming module is a leaf node | ||
if module.__class__ in module_map: | ||
return module_map[module.__class__](module) | ||
|
||
# Trace the graph | ||
tracer = NIRTorchTracer(module_map.keys()) | ||
traced = tracer.trace(module) | ||
|
||
if len(traced.nodes) == 2 and len(list(tracer.root.children())) == 0: | ||
raise ValueError( | ||
"The module is a leaf node, but does not appear in the module map. We cannot trace it further" | ||
) | ||
|
||
graph_module = GraphModule(tracer.root, traced) | ||
|
||
# Create NIR nodes | ||
nodes = {} | ||
edges = [] | ||
ignored_nodes = set() | ||
skipped_nodes = set() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does |
||
|
||
def _find_users(node: Node) -> Set[Node]: | ||
""" | ||
Finds all the users (outputs) of a given node, recursively if the node is registered as a skipped node | ||
""" | ||
nodes = set() | ||
for user in node.users: | ||
if user in ignored_nodes: | ||
continue | ||
elif user in skipped_nodes: | ||
nodes |= _find_users(user) | ||
else: | ||
nodes.add(user) | ||
return nodes | ||
|
||
def _find_inputs(node: Node) -> Set[Node]: | ||
""" | ||
Finds all the inputs (inputs) of a given node, recursively if the node is registered as a skipped node | ||
""" | ||
nodes = set() | ||
for in_node in node.all_input_nodes: | ||
if in_node in ignored_nodes: | ||
continue | ||
elif in_node in skipped_nodes: | ||
nodes |= _find_inputs(in_node) | ||
else: | ||
nodes.add(in_node) | ||
return nodes | ||
|
||
for node in traced.nodes: | ||
# Add Node | ||
if node.op == "placeholder": | ||
if node.target == "input" or node.prev.op == "root": | ||
nodes[str(node.name)] = nir.Input(np.array([1])) | ||
else: | ||
ignored_nodes.add(node) | ||
continue | ||
elif node.op == "output": | ||
nodes[str(node.name)] = nir.Output(np.array([1])) | ||
elif node.op == "call_function": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add a check for other non-allowed functions (multiplies or divisions) and add an exception? |
||
# Ensure that we skip add nodes | ||
# TODO: Consider using transformations for this | ||
# https://pytorch.org/docs/stable/fx.html#torch.fx.Transformer | ||
if node.target == operator.add: | ||
skipped_nodes.add(node) | ||
# Raise a warning if we encounter other methods than addition | ||
else: | ||
raise ValueError( | ||
"The only supported function is addition. Please modify your model or raise an issue on GitHub" | ||
) | ||
elif node.op == "call_method": | ||
# Skip add methods | ||
if node.target == "add": | ||
skipped_nodes.add(node) | ||
else: | ||
raise ValueError( | ||
"The only supported method is addition. Please modify your model or raise an issue on GitHub" | ||
) | ||
elif node.op == "call_module": | ||
torch_module = graph_module.get_submodule(node.target) | ||
nir_module = module_map[torch_module.__class__](torch_module) | ||
nodes[str(node.name)] = nir_module | ||
elif node.op == "get_attr": | ||
# Skip attribute | ||
skipped_nodes.add(node) | ||
else: | ||
raise ValueError( | ||
f"Unsupported operation {node.op}. Please modify your model or raise an issue on GitHub" | ||
) | ||
|
||
# Create edges | ||
# - This is done in a separate loop to ensure that we correctly ignore the edges in case the nodes | ||
# are ignored out-of-order | ||
for node in traced.nodes: | ||
if node in ignored_nodes: | ||
continue | ||
|
||
# Add edges | ||
for in_node in node.all_input_nodes: | ||
if in_node in ignored_nodes or in_node in skipped_nodes: | ||
continue | ||
# If the function is set to be skipped, we simply forward the input to all the outputs | ||
if node in skipped_nodes: | ||
for next_node in _find_users(node): | ||
edges.append((in_node.name, next_node.name)) | ||
# Ignore additions as incoming edges | ||
elif in_node.op == "call_function" and in_node.target == operator.add: | ||
break | ||
# Otherwise, add an edge | ||
elif in_node not in ignored_nodes: | ||
edges.append((in_node.name, node.name)) | ||
graph = nir.NIRGraph(nodes=nodes, edges=edges) | ||
graph.infer_types() | ||
return graph | ||
|
||
|
||
if __name__ == "__main__": | ||
module = torch.nn.Sequential(torch.nn.Linear(2, 1), torch.nn.Linear(1, 1)) | ||
graph = trace_torch_graph(module) | ||
|
||
import pprint | ||
|
||
pprint.pprint(graph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to check for
torch.nn.Sequential
here?