From e91dcab6ae12ad552557a44d243aad87b10aa93e Mon Sep 17 00:00:00 2001 From: Nathan Painchaud <23144457+nathanpainchaud@users.noreply.github.com> Date: Mon, 30 Oct 2023 16:46:36 +0100 Subject: [PATCH] Improve `LayersHistogramsLogger` to support inspecting multiple explicitly named submodules (#177) + Fix outdated references to previous implementation in the docstring --- vital/callbacks/debug.py | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/vital/callbacks/debug.py b/vital/callbacks/debug.py index b65e40d2..dc3d3cd0 100644 --- a/vital/callbacks/debug.py +++ b/vital/callbacks/debug.py @@ -62,7 +62,7 @@ class LayersHistogramsLogger(Callback): def __init__( self, layer_types: Sequence[Union[str, Type[nn.Module]]], - submodule: str = None, + submodules: Sequence[str] = None, log_every_n_steps: int = 50, include_weight: bool = True, include_grad: bool = True, @@ -71,17 +71,18 @@ def __init__( Args: layer_types: Types or classpaths of layers for which to log histograms. - submodule: Name of the module (e.g. 'encoder', 'classifier.', etc.) inside which to search for matching - layers. If none is provided, the Lightning module will be inspected starting from its root. + submodules: Name of the fields (e.g. 'encoder', 'classifier', etc.) corresponding to fields inside which to + search for matching layers. If none is provided, the Lightning module will be inspected starting from + its root. log_every_n_steps: Frequency at which to log the attention weights computed during the forward pass. - include_weight: Whether to log the layers' weights alongside the attention weights w.r.t the input tokens. - include_grad: Whether to log the layers' gradients alongside the attention weights w.r.t the input tokens. + include_weight: Whether to log the layers' weights. + include_grad: Whether to log the layers' gradients. """ self.layer_types = [ layer_type if not isinstance(layer_type, str) else import_from_module(layer_type) for layer_type in layer_types ] - self.submodule_name = submodule + self.submodule_names = submodules self.log_every_n_steps = log_every_n_steps self.train_step = 0 @@ -96,22 +97,27 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O pl_module: `LightningModule` used in the experiment. stage: Current stage (e.g. fit, test, etc.) of the experiment. """ - # Extract the requested submodule from the root module - module = pl_module - if self.submodule_name: - for submodule_name in self.submodule_name.split("."): - module = getattr(module, submodule_name) def _convert_camel_case_to_snake_case(string: str) -> str: return re.sub("(?!^)([A-Z]+)", r"_\1", string).lower() - # Get references to the specific layers to watch + # Extract the requested submodule from the root module + submodules_to_inspect = {"self": pl_module} + if self.submodule_names: + submodules_to_inspect = {} + for submodule_name in self.submodule_names: + # For each submodule, (recursively) follow the chain of attributes to get the actual submodule + module = pl_module + for submodule_name in submodule_name.split("."): + module = getattr(module, submodule_name) + submodules_to_inspect[submodule_name] = module + + # Get references to the specific layers to watch, prepending the submodule they're from to the layer name self.layers_to_log = { - f"{self.submodule_name}.{_convert_camel_case_to_snake_case(layer.__class__.__name__)}_{layer_idx}": layer + f"{submodule_name}.{_convert_camel_case_to_snake_case(layer.__class__.__name__)}_{layer_idx}": layer + for submodule_name, submodule in submodules_to_inspect.items() for layer_type in self.layer_types - for layer_idx, layer in enumerate( - submodule for submodule in module.modules() if isinstance(submodule, layer_type) - ) + for layer_idx, layer in enumerate(layer for layer in submodule.modules() if isinstance(layer, layer_type)) } def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: