diff --git a/torchlens/constants.py b/torchlens/constants.py
index bcde820..278e8fb 100644
--- a/torchlens/constants.py
+++ b/torchlens/constants.py
@@ -105,6 +105,8 @@
"elapsed_time_total",
"elapsed_time_function_calls",
"elapsed_time_torchlens_logging",
+ # Lookup info
+ "func_argnames"
]
TENSOR_LOG_ENTRY_FIELD_ORDER = [
@@ -158,6 +160,7 @@
"func_call_stack",
"func_time_elapsed",
"func_rng_states",
+ "func_argnames",
"num_func_args_total",
"num_position_args",
"num_keyword_args",
diff --git a/torchlens/model_history.py b/torchlens/model_history.py
index b0583d5..35d2ea6 100644
--- a/torchlens/model_history.py
+++ b/torchlens/model_history.py
@@ -137,6 +137,7 @@ def __init__(self, fields_dict: Dict):
self.func_call_stack = fields_dict["func_call_stack"]
self.func_time_elapsed = fields_dict["func_time_elapsed"]
self.func_rng_states = fields_dict["func_rng_states"]
+ self.func_argnames = fields_dict["func_argnames"]
self.num_func_args_total = fields_dict["num_func_args_total"]
self.num_position_args = fields_dict["num_position_args"]
self.num_keyword_args = fields_dict["num_keyword_args"]
@@ -862,6 +863,9 @@ def __init__(
self.elapsed_time_function_calls: float = 0
self.elapsed_time_torchlens_logging: float = 0
+ # Reference info
+ self.func_argnames: Dict[str, tuple] = defaultdict(lambda: tuple([]))
+
# ********************************************
# ********** User-Facing Functions ***********
# ********************************************
@@ -1092,6 +1096,8 @@ def decorate_pytorch(
if not hasattr(local_func_namespace, func_name):
continue
orig_func = getattr(local_func_namespace, func_name)
+ if func_name not in self.func_argnames:
+ self.get_func_argnames(orig_func, func_name)
if getattr(orig_func, "__name__", False) == "wrapped_func":
continue
new_func = self.torch_func_decorator(orig_func)
@@ -1185,6 +1191,38 @@ def collect_orig_func_defs(
orig_func = getattr(local_func_namespace, func_name)
orig_func_defs.append((namespace_name, func_name, orig_func))
+ # TODO: hard-code some of the arg names; for example truediv, getitem, etc. Can crawl through and see what isn't working
+ def get_func_argnames(self, orig_func: Callable, func_name: str):
+ """Attempts to get the argument names for a function, first by checking the signature, then
+ by checking the documentation. Adds these names to func_argnames if it can find them,
+ doesn't do anything if it can't."""
+ try:
+ argnames = list(inspect.signature(orig_func).parameters.keys())
+ argnames = tuple([arg.replace('*', '') for arg in argnames if arg not in ['cls', 'self']])
+ self.func_argnames[func_name] = argnames
+ return
+ except ValueError:
+ pass
+
+ docstring = orig_func.__doc__
+ if (type(docstring) is not str) or (len(docstring) == 0): # if docstring missing, skip it
+ return
+
+ open_ind, close_ind = docstring.find('('), docstring.find(')')
+ argstring = docstring[open_ind + 1: close_ind]
+ arg_list = argstring.split(',')
+ arg_list = [arg.strip(' ') for arg in arg_list]
+ argnames = []
+ for arg in arg_list:
+ argname = arg.split('=')[0]
+ if argname in ['*', '/']:
+ continue
+ argname = argname.replace('*', '')
+ argnames.append(argname)
+ argnames = tuple([arg for arg in argnames if arg not in ['self', 'cls']])
+ self.func_argnames[func_name] = argnames
+ return
+
###########################
##### Model Functions #####
###########################
@@ -1786,17 +1824,27 @@ def run_and_log_inputs_through_model(
time.time() - self.pass_start_time - self.elapsed_time_setup
)
self.track_tensors = False
- output_tensors_w_addresses = get_vars_of_type_from_obj(
+ output_tensors_w_addresses_all = get_vars_of_type_from_obj(
outputs,
torch.Tensor,
search_depth=5,
return_addresses=True,
allow_repeats=True,
)
+ # Remove duplicate addresses
+ addresses_used = []
+ output_tensors_w_addresses = []
+ for entry in output_tensors_w_addresses_all:
+ if entry[1] in addresses_used:
+ continue
+ output_tensors_w_addresses.append(entry)
+ addresses_used.append(entry[1])
+
output_tensors = [t for t, _, _ in output_tensors_w_addresses]
output_tensor_addresses = [
addr for _, addr, _ in output_tensors_w_addresses
]
+
for t in output_tensors:
self.output_layers.append(t.tl_tensor_label_raw)
self.raw_tensor_dict[t.tl_tensor_label_raw].is_output_parent = True
@@ -2016,6 +2064,7 @@ def log_source_tensor_exhaustive(
"func_call_stack": self._get_call_stack_dicts(),
"func_time_elapsed": 0,
"func_rng_states": log_current_rng_states(),
+ "func_argnames": tuple([]),
"num_func_args_total": 0,
"num_position_args": 0,
"num_keyword_args": 0,
@@ -2254,6 +2303,7 @@ def log_function_output_tensors_exhaustive(
fields_dict["func_call_stack"] = self._get_call_stack_dicts()
fields_dict["func_time_elapsed"] = func_time_elapsed
fields_dict["func_rng_states"] = func_rng_states
+ fields_dict["func_argnames"] = self.func_argnames[func_name.strip('_')]
fields_dict["num_func_args_total"] = len(args) + len(kwargs)
fields_dict["num_position_args"] = len(args)
fields_dict["num_keyword_args"] = len(kwargs)
@@ -3304,6 +3354,7 @@ def _add_output_layers(
new_output_node.func_call_stack = self._get_call_stack_dicts()
new_output_node.func_time_elapsed = 0
new_output_node.func_rng_states = log_current_rng_states()
+ new_output_node.func_argnames = tuple([])
new_output_node.num_func_args_total = 0
new_output_node.num_position_args = 0
new_output_node.num_keyword_args = 0
@@ -4906,6 +4957,12 @@ def render_graph(
vis_opt: str = "unrolled",
vis_nesting_depth: int = 1000,
vis_outpath: str = "modelgraph",
+ vis_graph_overrides: Dict = None,
+ vis_node_overrides: Dict = None,
+ vis_nested_node_overrides: Dict = None,
+ vis_edge_overrides: Dict = None,
+ vis_gradient_edge_overrides: Dict = None,
+ vis_module_overrides: Dict = None,
save_only: bool = False,
vis_fileformat: str = "pdf",
show_buffer_layers: bool = False,
@@ -4924,6 +4981,19 @@ def render_graph(
direction: which way the graph should go: either 'bottomup', 'topdown', or 'leftright'
"""
+ if vis_graph_overrides is None:
+ vis_graph_overrides = {}
+ if vis_node_overrides is None:
+ vis_node_overrides = {}
+ if vis_nested_node_overrides is None:
+ vis_nested_node_overrides = {}
+ if vis_edge_overrides is None:
+ vis_edge_overrides = {}
+ if vis_gradient_edge_overrides is None:
+ vis_gradient_edge_overrides = {}
+ if vis_module_overrides is None:
+ vis_module_overrides = {}
+
if not self.all_layers_logged:
raise ValueError(
"Must have all layers logged in order to render the graph; either save all layers,"
@@ -4976,16 +5046,21 @@ def render_graph(
comment="Computational graph for the feedforward sweep",
format=vis_fileformat,
)
- dot.graph_attr.update(
- {
- "rankdir": rankdir,
- "label": graph_caption,
- "labelloc": "t",
- "labeljust": "left",
- "ordering": "out",
- }
- )
- dot.node_attr.update({"shape": "box", "ordering": "out"})
+
+ graph_args = {'rankdir': rankdir,
+ 'label': graph_caption,
+ 'labelloc': 't',
+ 'labeljust': 'left',
+ 'ordering': 'out'}
+
+ for arg_name, arg_val in vis_graph_overrides.items():
+ if callable(arg_val):
+ graph_args[arg_name] = str(arg_val(self))
+ else:
+ graph_args[arg_name] = str(arg_val)
+
+ dot.graph_attr.update(graph_args)
+ dot.node_attr.update({"ordering": "out"})
# list of edges for each subgraph; subgraphs will be created at the end.
module_cluster_dict = defaultdict(
@@ -5006,16 +5081,19 @@ def render_graph(
collapsed_modules,
vis_nesting_depth,
show_buffer_layers,
+ vis_node_overrides,
+ vis_nested_node_overrides,
+ vis_edge_overrides,
+ vis_gradient_edge_overrides
)
# Finally, set up the subgraphs.
- self._set_up_subgraphs(dot, vis_opt, module_cluster_dict)
+ self._set_up_subgraphs(dot, vis_opt, module_cluster_dict, vis_module_overrides)
if in_notebook() and not save_only:
display(dot)
dot.render(vis_outpath, view=(not save_only))
- os.remove(vis_outpath)
def _add_node_to_graphviz(
self,
@@ -5027,6 +5105,10 @@ def _add_node_to_graphviz(
collapsed_modules: Set,
vis_nesting_depth: int = 1000,
show_buffer_layers: bool = False,
+ vis_node_overrides: Dict = None,
+ vis_collapsed_node_overrides: Dict = None,
+ vis_edge_overrides: Dict = None,
+ vis_gradient_edge_overrides: Dict = None
):
"""Addes a node and its relevant edges to the graphviz figure.
@@ -5043,12 +5125,12 @@ def _add_node_to_graphviz(
if is_collapsed_module:
self._construct_collapsed_module_node(
- node, graphviz_graph, collapsed_modules, vis_opt, vis_nesting_depth
+ node, graphviz_graph, collapsed_modules, vis_opt, vis_nesting_depth, vis_collapsed_node_overrides
)
node_color = "black"
else:
node_color = self._construct_layer_node(
- node, graphviz_graph, show_buffer_layers, vis_opt
+ node, graphviz_graph, show_buffer_layers, vis_opt, vis_node_overrides
)
self._add_edges_for_node(
@@ -5061,6 +5143,8 @@ def _add_node_to_graphviz(
graphviz_graph,
vis_opt,
show_buffer_layers,
+ vis_edge_overrides,
+ vis_gradient_edge_overrides
)
@staticmethod
@@ -5074,7 +5158,7 @@ def _check_if_collapsed_module(node, vis_nesting_depth):
else:
return False
- def _construct_layer_node(self, node, graphviz_graph, show_buffer_layers, vis_opt):
+ def _construct_layer_node(self, node, graphviz_graph, show_buffer_layers, vis_opt, vis_node_overrides):
# Get the address, shape, color, and line style:
node_address, node_shape, node_color = self._get_node_address_shape_color(
@@ -5090,16 +5174,23 @@ def _construct_layer_node(self, node, graphviz_graph, show_buffer_layers, vis_op
# Get the text for the node label:
node_label = self._make_node_label(node, node_address, vis_opt)
- graphviz_graph.node(
- name=node.layer_label.replace(":", "pass"),
- label=node_label,
- fontcolor=node_color,
- color=node_color,
- style=f"filled,{line_style}",
- fillcolor=node_bg_color,
- shape=node_shape,
- ordering="out",
- )
+
+ node_args = {'name': node.layer_label.replace(":", "pass"),
+ 'label': node_label,
+ 'fontcolor': node_color,
+ 'color': node_color,
+ 'style': f'filled,{line_style}',
+ 'fillcolor': node_bg_color,
+ 'shape': node_shape,
+ 'ordering': 'out'
+ }
+ for arg_name, arg_val in vis_node_overrides.items():
+ if callable(arg_val):
+ node_args[arg_name] = str(arg_val(self, node))
+ else:
+ node_args[arg_name] = str(arg_val)
+
+ graphviz_graph.node(**node_args)
if node.is_last_output_layer:
with graphviz_graph.subgraph() as s:
@@ -5109,7 +5200,7 @@ def _construct_layer_node(self, node, graphviz_graph, show_buffer_layers, vis_op
return node_color
def _construct_collapsed_module_node(
- self, node, graphviz_graph, collapsed_modules, vis_opt, vis_nesting_depth
+ self, node, graphviz_graph, collapsed_modules, vis_opt, vis_nesting_depth, vis_collapsed_node_overrides
):
module_address_w_pass = node.containing_modules_origin_nested[
vis_nesting_depth - 1
@@ -5177,16 +5268,23 @@ def _construct_collapsed_module_node(
f"{module_nparams} parameters>"
)
- graphviz_graph.node(
- name=node_name,
- label=node_label,
- fontcolor="black",
- color="black",
- style=f"filled,{line_style}",
- fillcolor=bg_color,
- shape="box3d",
- ordering="out",
- )
+ node_args = {'name': node_name,
+ 'label': node_label,
+ 'fontcolor': 'black',
+ 'color': 'black',
+ 'style': f'filled,{line_style}',
+ 'fillcolor': bg_color,
+ 'shape': 'box3d',
+ 'ordering': 'out'
+ }
+
+ for arg_name, arg_val in vis_collapsed_node_overrides.items():
+ if callable(arg_val):
+ node_args[arg_name] = str(arg_val(self, node))
+ else:
+ node_args[arg_name] = str(arg_val)
+
+ graphviz_graph.node(**node_args)
def _get_node_address_shape_color(
self,
@@ -5380,6 +5478,8 @@ def _add_edges_for_node(
graphviz_graph,
vis_opt: str = "unrolled",
show_buffer_layers: bool = False,
+ vis_edge_overrides: Dict = None,
+ vis_gradient_edge_overrides: Dict = None
):
"""Add the rolled-up edges for a node, marking for the edge which passes it happened for.
@@ -5491,6 +5591,12 @@ def _add_edges_for_node(
if vis_opt == "rolled":
self._label_rolled_pass_nums(child_node, parent_node, edge_dict)
+ for arg_name, arg_val in vis_edge_overrides.items():
+ if callable(arg_val):
+ edge_dict[arg_name] = str(arg_val(self, parent_node, child_node))
+ else:
+ edge_dict[arg_name] = str(arg_val)
+
# Add it to the appropriate module cluster (most nested one containing both nodes)
containing_module = self._get_lowest_containing_module_for_two_nodes(
parent_node, child_node, both_nodes_collapsed_modules, vis_nesting_depth
@@ -5519,6 +5625,7 @@ def _add_edges_for_node(
containing_module,
module_edge_dict,
graphviz_graph,
+ vis_gradient_edge_overrides
)
def _label_node_arguments_if_needed(
@@ -5707,6 +5814,7 @@ def _add_gradient_edge(
containing_module,
module_edge_dict,
graphviz_graph,
+ vis_gradient_edge_overrides
):
"""Adds a backwards edge if both layers have saved gradients, showing the backward pass."""
if parent_layer.has_saved_grad and child_layer.has_saved_grad:
@@ -5719,13 +5827,19 @@ def _add_gradient_edge(
"arrowsize": ".7",
"labelfontsize": "8",
}
+ for arg_name, arg_val in vis_gradient_edge_overrides.items():
+ if callable(arg_val):
+ edge_dict[arg_name] = str(arg_val(self, parent_layer, child_layer))
+ else:
+ edge_dict[arg_name] = str(arg_val)
+
if containing_module != -1:
module_edge_dict[containing_module]["edges"].append(edge_dict)
else:
graphviz_graph.edge(**edge_dict)
def _set_up_subgraphs(
- self, graphviz_graph, vis_opt: str, module_edge_dict: Dict[str, List]
+ self, graphviz_graph, vis_opt: str, module_edge_dict: Dict[str, List], vis_module_overrides: Dict = None
):
"""Given a dictionary specifying the edges in each cluster and the graphviz graph object,
set up the nested subgraphs and the nodes that should go inside each of them. There will be some tricky
@@ -5763,6 +5877,7 @@ def _set_up_subgraphs(
nesting_depth,
max_nesting_depth,
vis_opt,
+ vis_module_overrides
)
def _setup_subgraphs_recurse(
@@ -5775,6 +5890,7 @@ def _setup_subgraphs_recurse(
nesting_depth,
max_nesting_depth,
vis_opt,
+ vis_module_overrides
):
"""Utility function to crawl down several layers deep into nested subgraphs.
@@ -5821,6 +5937,7 @@ def _setup_subgraphs_recurse(
nesting_depth + 1,
max_nesting_depth,
vis_opt,
+ vis_module_overrides
)
else: # we made it, make the subgraph and add all edges.
@@ -5835,13 +5952,20 @@ def _setup_subgraphs_recurse(
line_style = "solid"
else:
line_style = "dashed"
- s.attr(
- label=f"<@{subgraph_title}
({module_type})
>",
- labelloc="b",
- style=f"filled,{line_style}",
- fillcolor="white",
- penwidth=str(pen_width),
- )
+
+ module_args = {
+ 'label': f"<@{subgraph_title}
({module_type})
>",
+ 'labelloc': 'b',
+ 'style': f'filled,{line_style}',
+ 'fillcolor': 'white',
+ 'penwidth': str(pen_width)}
+
+ for arg_name, arg_val in vis_module_overrides.items():
+ if callable(arg_val):
+ module_args[arg_name] = str(arg_val(self, subgraph_name))
+ else:
+ module_args[arg_name] = str(arg_val)
+ s.attr(**module_args)
subgraph_edges = module_edge_dict[subgraph_name]["edges"]
for edge_dict in subgraph_edges:
s.edge(**edge_dict)