diff --git a/nirtorch/__init__.py b/nirtorch/__init__.py index a896416..fd5e954 100644 --- a/nirtorch/__init__.py +++ b/nirtorch/__init__.py @@ -1,5 +1,5 @@ from .from_nir import load # noqa F401 from .graph import extract_torch_graph # noqa F401 -from .to_nir import extract_nir_graph # noqa F401 +from .to_nir import extract_nir_graph, trace_nir_graph # noqa F401 __version__ = version = "1.0" diff --git a/nirtorch/graph_fx.py b/nirtorch/graph_fx.py new file mode 100644 index 0000000..2e76c07 --- /dev/null +++ b/nirtorch/graph_fx.py @@ -0,0 +1,207 @@ +from typing import Any, Callable, Dict, Set, Tuple, TypeAlias, Optional +import operator + +import numpy as np +from torch.nn.modules import Module + +import nir +import torch +from torch.fx import Graph, GraphModule, Node, Tracer, Transformer +from torch.fx.passes import shape_prop + + +class NIRTorchTracer(Tracer): + + def __init__(self, custom_leaf_modules: Tuple[torch.nn.Module] = None, **kwargs): + """Extends PyTorch's default symbolic tracing with a set of custom leaf nodes""" + super().__init__(**kwargs) + if custom_leaf_modules is not None and not isinstance( + custom_leaf_modules, tuple + ): + custom_leaf_modules = tuple(custom_leaf_modules) + self.custom_leaf_modules = custom_leaf_modules + + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + """ + A method to specify whether a given ``nn.Module`` is a "leaf" module. + Leaf modules are the atomic units that appear in + the IR, referenced by ``call_module`` calls. By default, + Modules in the PyTorch standard library namespace (torch.nn) + are leaf modules. All other modules are traced through and + their constituent ops are recorded, unless specified otherwise + via this parameter. + Args: + m (Module): The module being queried about + module_qualified_name (str): The path to root of this module. For example, + if you have a module hierarchy where submodule ``foo`` contains + submodule ``bar``, which contains submodule ``baz``, that module will + appear with the qualified name ``foo.bar.baz`` here. + """ + # Tests that the module is in the list of custom leaves + if self.custom_leaf_modules and isinstance(m, self.custom_leaf_modules): + return True + + if hasattr(m, "_is_leaf_module") and m._is_leaf_module: + return True + + return super().is_leaf_module(m, module_qualified_name) + + +class NIRTorchTransformer(Transformer): + def call_function(self, target: str, args: Tuple, kwargs: Dict) -> Any: + print("sup", target) + return super().call_function(target, args, kwargs) + + def call_method(self, target: str, args: Tuple, kwargs: Dict) -> Any: + return super().call_method(target, args, kwargs) + + def call_module(self, target, args, kwargs): + print("mod", target) + return super().call_module(target, args, kwargs) + + +def trace_torch_graph( + module: torch.nn.Module, + module_map: Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]], + default_dict: Optional[ + Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]] + ] = None, +) -> nir.NIRGraph: + """ + Traces a PyTorch module and converts it to a NIR graph using the specified module map. + + Args: + module (torch.nn.Module): The module of interest + module_map (Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]]): A dictionary that maps + a given module type to a function that can convert the model to an NIRNode type + default_dict (Optional[Dict[torch.nn.Module, Callable[[torch.nn.Module], nir.NIRNode]]]): An optional dictionary that maps + a given module type to a function that can convert the model to an NIRNode type. This dictionary is merged + with the module_map dictionary. + """ + # Merge the default dictionary, if it exists + if default_dict is not None: + module_map = module_map | default_dict + + # Cover the edge case that the incoming module is a leaf node + if module.__class__ in module_map: + return module_map[module.__class__](module) + + # Trace the graph + tracer = NIRTorchTracer(module_map.keys()) + traced = tracer.trace(module) + + if len(traced.nodes) == 2 and len(list(tracer.root.children())) == 0: + raise ValueError( + "The module is a leaf node, but does not appear in the module map. We cannot trace it further" + ) + + graph_module = GraphModule(tracer.root, traced) + + # Create NIR nodes + nodes = {} + edges = [] + ignored_nodes = set() + skipped_nodes = set() + + def _find_users(node: Node) -> Set[Node]: + """ + Finds all the users (outputs) of a given node, recursively if the node is registered as a skipped node + """ + nodes = set() + for user in node.users: + if user in ignored_nodes: + continue + elif user in skipped_nodes: + nodes |= _find_users(user) + else: + nodes.add(user) + return nodes + + def _find_inputs(node: Node) -> Set[Node]: + """ + Finds all the inputs (inputs) of a given node, recursively if the node is registered as a skipped node + """ + nodes = set() + for in_node in node.all_input_nodes: + if in_node in ignored_nodes: + continue + elif in_node in skipped_nodes: + nodes |= _find_inputs(in_node) + else: + nodes.add(in_node) + return nodes + + for node in traced.nodes: + # Add Node + if node.op == "placeholder": + if node.target == "input" or node.prev.op == "root": + nodes[str(node.name)] = nir.Input(np.array([1])) + else: + ignored_nodes.add(node) + continue + elif node.op == "output": + nodes[str(node.name)] = nir.Output(np.array([1])) + elif node.op == "call_function": + # Ensure that we skip add nodes + # TODO: Consider using transformations for this + # https://pytorch.org/docs/stable/fx.html#torch.fx.Transformer + if node.target == operator.add: + skipped_nodes.add(node) + # Raise a warning if we encounter other methods than addition + else: + raise ValueError( + "The only supported function is addition. Please modify your model or raise an issue on GitHub" + ) + elif node.op == "call_method": + # Skip add methods + if node.target == "add": + skipped_nodes.add(node) + else: + raise ValueError( + "The only supported method is addition. Please modify your model or raise an issue on GitHub" + ) + elif node.op == "call_module": + torch_module = graph_module.get_submodule(node.target) + nir_module = module_map[torch_module.__class__](torch_module) + nodes[str(node.name)] = nir_module + elif node.op == "get_attr": + # Skip attribute + skipped_nodes.add(node) + else: + raise ValueError( + f"Unsupported operation {node.op}. Please modify your model or raise an issue on GitHub" + ) + + # Create edges + # - This is done in a separate loop to ensure that we correctly ignore the edges in case the nodes + # are ignored out-of-order + for node in traced.nodes: + if node in ignored_nodes: + continue + + # Add edges + for in_node in node.all_input_nodes: + if in_node in ignored_nodes or in_node in skipped_nodes: + continue + # If the function is set to be skipped, we simply forward the input to all the outputs + if node in skipped_nodes: + for next_node in _find_users(node): + edges.append((in_node.name, next_node.name)) + # Ignore additions as incoming edges + elif in_node.op == "call_function" and in_node.target == operator.add: + break + # Otherwise, add an edge + elif in_node not in ignored_nodes: + edges.append((in_node.name, node.name)) + graph = nir.NIRGraph(nodes=nodes, edges=edges) + graph.infer_types() + return graph + + +if __name__ == "__main__": + module = torch.nn.Sequential(torch.nn.Linear(2, 1), torch.nn.Linear(1, 1)) + graph = trace_torch_graph(module) + + import pprint + + pprint.pprint(graph) diff --git a/nirtorch/to_nir.py b/nirtorch/to_nir.py index 6a6de28..e37edb9 100644 --- a/nirtorch/to_nir.py +++ b/nirtorch/to_nir.py @@ -1,11 +1,22 @@ import logging -from typing import Any, Callable, Optional, Sequence +from typing import Any, Callable, Dict, Optional, Sequence +import warnings import nir import numpy as np import torch.nn as nn from .graph import extract_torch_graph +from .graph_fx import trace_torch_graph + + +DEFAULT_MAPS: Dict[nn.Module, Callable[[nn.Module], nir.NIRNode]] = { + nn.Linear: ( + lambda module: nir.Affine( + module.weight.detach().numpy(), module.bias.detach().numpy() + ) + ) +} def extract_nir_graph( @@ -39,6 +50,11 @@ def extract_nir_graph( Returns: nir.NIR: Returns the generated NIR graph representation. """ + warnings.warn( + "extract_nir_graph is deprecated, use trace_nir_graph instead", + DeprecationWarning, + stacklevel=2, + ) if len(list(model.children())): # If the model has submodules, ignore the top level module @@ -152,3 +168,29 @@ def extract_nir_graph( nir_edges.remove(edge) return nir.NIRGraph(nir_nodes, nir_edges) + + +def trace_nir_graph( + model: nn.Module, + model_map: Dict[nn.Module, Callable[[nn.Module], nir.NIRNode]], + default_map: Optional[Dict[nn.Module, Callable[[nn.Module], nir.NIRNode]]] = DEFAULT_MAPS, + model_name: Optional[str] = "model", + ignore_submodules_of=None, + model_fwd_args=[], + ignore_dims: Optional[Sequence[int]] = None, +) -> nir.NIRNode: + """ + Given a PyTorch `model`, we trace it and recreate a NIR graph using the specified `model_map`. + + Args: + model (nn.Module): The model of interest + model_map (Dict[nn.Module, Callable[[nn.Module], nir.NIRNode]]): A dictionary that maps + a given module type to a function that can convert the model to an NIRNode type + model_name (Optional[str], optional): The name of the top level module. + Defaults to "model". + ignore_submodules_of (Optional[Sequence[nn.Module]]): If specified, + the corresponding module's children will not be traversed for graph. + ignore_dims (Optional[Sequence[int]]): Dimensions of data to be ignored for + type/shape inference. Typically the dimensions that you will want to ignore + are for batch and time. + """ diff --git a/tests/test_graph_fx.py b/tests/test_graph_fx.py new file mode 100644 index 0000000..c30f831 --- /dev/null +++ b/tests/test_graph_fx.py @@ -0,0 +1,190 @@ +import pytest + +import nir +import numpy as np +import torch + +from nirtorch.to_nir import DEFAULT_MAPS +from nirtorch.graph_fx import trace_torch_graph + + +def _filter_edges(graph, t1, t2): + return [ + e + for e in graph.edges + if graph.nodes[e[0]].__class__ == t1 and graph.nodes[e[1]].__class__ == t2 + ] + + +def _filter_nodes(graph, t): + return {k: v for k, v in graph.nodes.items() if v.__class__ == t} + + +def test_trace_unknown_leaf(): + class MyModule(torch.nn.Module): + def forward(self, x): + return x + + model = MyModule() + with pytest.raises(ValueError): + trace_torch_graph(model, {}) + + +def test_trace_default_linear(): + model = torch.nn.Linear(1, 1) + graph = trace_torch_graph(model, DEFAULT_MAPS) + assert graph.__class__ == nir.Affine + + +def test_trace_mapped_module(): + class MyModule(torch.nn.Module): + def forward(self, x): + return x + + def map_my_module(module): + return nir.Linear(np.array([[1]])) + + model = MyModule() + graph = trace_torch_graph(model, {MyModule: map_my_module}) + assert graph.__class__ == nir.Linear + + +def test_trace_mapped_module_stateless(): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.lin = torch.nn.Linear(1, 1) + + def forward(self, x, state): + return self.lin(x) + state + + model = MyModule() + graph = trace_torch_graph(model, DEFAULT_MAPS) + assert len(graph.nodes) == 3 + assert len(graph.edges) == 2 + assert len(_filter_edges(graph, nir.Input, nir.Affine)) == 1 + assert len(_filter_edges(graph, nir.Affine, nir.Output)) == 1 + + +def test_trace_mapped_module_stateful(): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.lin = torch.nn.Linear(1, 1) + self.state = torch.tensor([1.0]) + + def forward(self, x): + return self.lin(x) + self.state + + model = MyModule() + graph = trace_torch_graph(model, DEFAULT_MAPS) + assert len(graph.nodes) == 3 + assert len(graph.edges) == 2 + assert len(_filter_edges(graph, nir.Input, nir.Affine)) == 1 + assert len(_filter_edges(graph, nir.Affine, nir.Output)) == 1 + + +def test_trace_addition(): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.lin = torch.nn.Linear(1, 1) + + def forward(self, x): + return self.lin(x) + self.lin(x) + + model = MyModule() + graph = trace_torch_graph(model, DEFAULT_MAPS) + assert len(graph.nodes) == 4 + assert len(graph.edges) == 4 + assert len(_filter_edges(graph, nir.Input, nir.Affine)) == 2 + assert len(_filter_edges(graph, nir.Affine, nir.Output)) == 2 + + +def test_trace_sequential(): + model = torch.nn.Sequential(torch.nn.Linear(2, 1), torch.nn.Linear(1, 1)) + graph = trace_torch_graph(model, DEFAULT_MAPS) + assert graph.__class__ == nir.NIRGraph + assert len(graph.nodes) == 4 + assert len(graph.edges) == 3 + nodes = list(graph.nodes.items()) + assert len(_filter_nodes(graph, nir.Input)) == 1, "We require one input node" + assert len(_filter_nodes(graph, nir.Output)) == 1, "We require one output node" + assert len(_filter_nodes(graph, nir.Affine)) == 2, "We require two affine nodes" + + assert len(_filter_edges(graph, nir.Input, nir.Affine)) == 1 + assert len(_filter_edges(graph, nir.Affine, nir.Affine)) == 1 + assert len(_filter_edges(graph, nir.Affine, nir.Output)) == 1 + + +def test_trace_submodule(): + class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + self.linear = torch.nn.Linear(1, 1) + + def forward(self, x): + return self.linear(x) + + model = MyModule() + graph = trace_torch_graph(model, DEFAULT_MAPS) + assert graph.__class__ == nir.NIRGraph + assert len(graph.nodes) == 3 + assert len(graph.edges) == 2 + nodes = list(graph.nodes.items()) + assert len(_filter_nodes(graph, nir.Input)) == 1, "We require one input node" + assert len(_filter_nodes(graph, nir.Output)) == 1, "We require one output node" + assert len(_filter_nodes(graph, nir.Affine)) == 1, "We require one affine node" + + assert len(_filter_edges(graph, nir.Input, nir.Affine)) == 1 + assert len(_filter_edges(graph, nir.Affine, nir.Output)) == 1 + + +def test_trace_nested_submodule(): + class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + self.linear = torch.nn.Linear(1, 1) + + def forward(self, x): + return self.linear(x) + + class MyModule2(torch.nn.Module): + def __init__(self): + super(MyModule2, self).__init__() + self.module = MyModule() + self.linear = torch.nn.Linear(1, 1) + + def forward(self, x): + return self.linear(self.module(x)) + + model = MyModule2() + graph = trace_torch_graph(model, DEFAULT_MAPS) + assert graph.__class__ == nir.NIRGraph + assert len(graph.nodes) == 4 + assert len(graph.edges) == 3 + + assert len(_filter_nodes(graph, nir.Input)) == 1, "We require one input node" + assert len(_filter_nodes(graph, nir.Output)) == 1, "We require one output node" + assert len(_filter_nodes(graph, nir.Affine)) == 2, "We require two affine nodes" + + assert len(_filter_edges(graph, nir.Input, nir.Affine)) == 1 + assert len(_filter_edges(graph, nir.Affine, nir.Affine)) == 1 + assert len(_filter_edges(graph, nir.Affine, nir.Output)) == 1 + + +def test_recursive_stateful(): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(1, 1) + self.state = torch.tensor([1.0]) + + def forward(self, x): + out = self.linear(x + self.state) + self.state = out + return out + + model = MyModule() + graph = trace_torch_graph(model, DEFAULT_MAPS) + assert graph.__class__ == nir.NIRGraph