Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ports the graph tracing to torch.fx #28

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nirtorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .from_nir import load # noqa F401
from .graph import extract_torch_graph # noqa F401
from .to_nir import extract_nir_graph # noqa F401
from .to_nir import extract_nir_graph, trace_nir_graph # noqa F401

__version__ = version = "1.0"
207 changes: 207 additions & 0 deletions nirtorch/graph_fx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
from typing import Any, Callable, Dict, Set, Tuple, TypeAlias, Optional

Check failure on line 1 in nirtorch/graph_fx.py

View workflow job for this annotation

GitHub Actions / Build on ubuntu-latest

Ruff (F401)

nirtorch/graph_fx.py:1:53: F401 `typing.TypeAlias` imported but unused
import operator

import numpy as np
from torch.nn.modules import Module

Check failure on line 5 in nirtorch/graph_fx.py

View workflow job for this annotation

GitHub Actions / Build on ubuntu-latest

Ruff (F401)

nirtorch/graph_fx.py:5:30: F401 `torch.nn.modules.Module` imported but unused

import nir
import torch
from torch.fx import Graph, GraphModule, Node, Tracer, Transformer

Check failure on line 9 in nirtorch/graph_fx.py

View workflow job for this annotation

GitHub Actions / Build on ubuntu-latest

Ruff (F401)

nirtorch/graph_fx.py:9:22: F401 `torch.fx.Graph` imported but unused
from torch.fx.passes import shape_prop

Check failure on line 10 in nirtorch/graph_fx.py

View workflow job for this annotation

GitHub Actions / Build on ubuntu-latest

Ruff (F401)

nirtorch/graph_fx.py:10:29: F401 `torch.fx.passes.shape_prop` imported but unused


class NIRTorchTracer(Tracer):

def __init__(self, custom_leaf_modules: Tuple[torch.nn.Module] = None, **kwargs):
"""Extends PyTorch's default symbolic tracing with a set of custom leaf nodes"""
super().__init__(**kwargs)
if custom_leaf_modules is not None and not isinstance(
custom_leaf_modules, tuple
):
custom_leaf_modules = tuple(custom_leaf_modules)
self.custom_leaf_modules = custom_leaf_modules

def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
"""
A method to specify whether a given ``nn.Module`` is a "leaf" module.
Leaf modules are the atomic units that appear in
the IR, referenced by ``call_module`` calls. By default,
Modules in the PyTorch standard library namespace (torch.nn)
are leaf modules. All other modules are traced through and
their constituent ops are recorded, unless specified otherwise
via this parameter.
Args:
m (Module): The module being queried about
module_qualified_name (str): The path to root of this module. For example,
if you have a module hierarchy where submodule ``foo`` contains
submodule ``bar``, which contains submodule ``baz``, that module will
appear with the qualified name ``foo.bar.baz`` here.
"""
# Tests that the module is in the list of custom leaves
if self.custom_leaf_modules and isinstance(m, self.custom_leaf_modules):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check for torch.nn.Sequential here?

return True

if hasattr(m, "_is_leaf_module") and m._is_leaf_module:
return True

return super().is_leaf_module(m, module_qualified_name)


class NIRTorchTransformer(Transformer):
def call_function(self, target: str, args: Tuple, kwargs: Dict) -> Any:
print("sup", target)
return super().call_function(target, args, kwargs)

def call_method(self, target: str, args: Tuple, kwargs: Dict) -> Any:
return super().call_method(target, args, kwargs)

def call_module(self, target, args, kwargs):
print("mod", target)
return super().call_module(target, args, kwargs)


def trace_torch_graph(
module: torch.nn.Module,
module_map: Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]],
default_dict: Optional[
Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]]
] = None,
) -> nir.NIRGraph:
"""
Traces a PyTorch module and converts it to a NIR graph using the specified module map.

Args:
module (torch.nn.Module): The module of interest
module_map (Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]]): A dictionary that maps
a given module type to a function that can convert the model to an NIRNode type
default_dict (Optional[Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]]]): An optional dictionary that maps
a given module type to a function that can convert the model to an NIRNode type. This dictionary is merged
with the module_map dictionary.
"""
# Merge the default dictionary, if it exists
if default_dict is not None:
module_map = module_map | default_dict

# Cover the edge case that the incoming module is a leaf node
if module.__class__ in module_map:
return module_map[module.__class__](module)

# Trace the graph
tracer = NIRTorchTracer(module_map.keys())
traced = tracer.trace(module)

if len(traced.nodes) == 2 and len(list(tracer.root.children())) == 0:
raise ValueError(
"The module is a leaf node, but does not appear in the module map. We cannot trace it further"
)

graph_module = GraphModule(tracer.root, traced)

# Create NIR nodes
nodes = {}
edges = []
ignored_nodes = set()
skipped_nodes = set()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does skipped_nodes means nodes we still need to process?


def _find_users(node: Node) -> Set[Node]:
"""
Finds all the users (outputs) of a given node, recursively if the node is registered as a skipped node
"""
nodes = set()
for user in node.users:
if user in ignored_nodes:
continue
elif user in skipped_nodes:
nodes |= _find_users(user)
else:
nodes.add(user)
return nodes

def _find_inputs(node: Node) -> Set[Node]:
"""
Finds all the inputs (inputs) of a given node, recursively if the node is registered as a skipped node
"""
nodes = set()
for in_node in node.all_input_nodes:
if in_node in ignored_nodes:
continue
elif in_node in skipped_nodes:
nodes |= _find_inputs(in_node)
else:
nodes.add(in_node)
return nodes

for node in traced.nodes:
# Add Node
if node.op == "placeholder":
if node.target == "input" or node.prev.op == "root":
nodes[str(node.name)] = nir.Input(np.array([1]))
else:
ignored_nodes.add(node)
continue
elif node.op == "output":
nodes[str(node.name)] = nir.Output(np.array([1]))
elif node.op == "call_function":

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a check for other non-allowed functions (multiplies or divisions) and add an exception?

# Ensure that we skip add nodes
# TODO: Consider using transformations for this
# https://pytorch.org/docs/stable/fx.html#torch.fx.Transformer
if node.target == operator.add:
skipped_nodes.add(node)
# Raise a warning if we encounter other methods than addition
else:
raise ValueError(
"The only supported function is addition. Please modify your model or raise an issue on GitHub"
)
elif node.op == "call_method":
# Skip add methods
if node.target == "add":
skipped_nodes.add(node)
else:
raise ValueError(
"The only supported method is addition. Please modify your model or raise an issue on GitHub"
)
elif node.op == "call_module":
torch_module = graph_module.get_submodule(node.target)
nir_module = module_map[torch_module.__class__](torch_module)
nodes[str(node.name)] = nir_module
elif node.op == "get_attr":
# Skip attribute
skipped_nodes.add(node)
else:
raise ValueError(
f"Unsupported operation {node.op}. Please modify your model or raise an issue on GitHub"
)

# Create edges
# - This is done in a separate loop to ensure that we correctly ignore the edges in case the nodes
# are ignored out-of-order
for node in traced.nodes:
if node in ignored_nodes:
continue

# Add edges
for in_node in node.all_input_nodes:
if in_node in ignored_nodes or in_node in skipped_nodes:
continue
# If the function is set to be skipped, we simply forward the input to all the outputs
if node in skipped_nodes:
for next_node in _find_users(node):
edges.append((in_node.name, next_node.name))
# Ignore additions as incoming edges
elif in_node.op == "call_function" and in_node.target == operator.add:
break
# Otherwise, add an edge
elif in_node not in ignored_nodes:
edges.append((in_node.name, node.name))
graph = nir.NIRGraph(nodes=nodes, edges=edges)
graph.infer_types()
return graph


if __name__ == "__main__":
module = torch.nn.Sequential(torch.nn.Linear(2, 1), torch.nn.Linear(1, 1))
graph = trace_torch_graph(module)

import pprint

pprint.pprint(graph)
44 changes: 43 additions & 1 deletion nirtorch/to_nir.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import logging
from typing import Any, Callable, Optional, Sequence
from typing import Any, Callable, Dict, Optional, Sequence
import warnings

import nir
import numpy as np
import torch.nn as nn

from .graph import extract_torch_graph
from .graph_fx import trace_torch_graph

Check failure on line 10 in nirtorch/to_nir.py

View workflow job for this annotation

GitHub Actions / Build on ubuntu-latest

Ruff (F401)

nirtorch/to_nir.py:10:23: F401 `.graph_fx.trace_torch_graph` imported but unused


DEFAULT_MAPS: Dict[nn.Module, Callable[[nn.Module], nir.NIRNode]] = {
nn.Linear: (
lambda module: nir.Affine(
module.weight.detach().numpy(), module.bias.detach().numpy()
)
)
}


def extract_nir_graph(
Expand Down Expand Up @@ -39,6 +50,11 @@
Returns:
nir.NIR: Returns the generated NIR graph representation.
"""
warnings.warn(
"extract_nir_graph is deprecated, use trace_nir_graph instead",
DeprecationWarning,
stacklevel=2,
)

if len(list(model.children())):
# If the model has submodules, ignore the top level module
Expand Down Expand Up @@ -152,3 +168,29 @@
nir_edges.remove(edge)

return nir.NIRGraph(nir_nodes, nir_edges)


def trace_nir_graph(
model: nn.Module,
model_map: Dict[nn.Module, Callable[[nn.Module], nir.NIRNode]],
default_map: Optional[Dict[nn.Module, Callable[[nn.Module], nir.NIRNode]]] = DEFAULT_MAPS,
model_name: Optional[str] = "model",
ignore_submodules_of=None,
model_fwd_args=[],
ignore_dims: Optional[Sequence[int]] = None,
) -> nir.NIRNode:
"""
Given a PyTorch `model`, we trace it and recreate a NIR graph using the specified `model_map`.

Args:
model (nn.Module): The model of interest
model_map (Dict[nn.Module, Callable[[nn.Module], nir.NIRNode]]): A dictionary that maps
a given module type to a function that can convert the model to an NIRNode type
model_name (Optional[str], optional): The name of the top level module.
Defaults to "model".
ignore_submodules_of (Optional[Sequence[nn.Module]]): If specified,
the corresponding module's children will not be traversed for graph.
ignore_dims (Optional[Sequence[int]]): Dimensions of data to be ignored for
type/shape inference. Typically the dimensions that you will want to ignore
are for batch and time.
"""
Loading
Loading