Skip to content

Commit

Permalink
Improve LayersHistogramsLogger to support inspecting multiple expli…
Browse files Browse the repository at this point in the history
…citly named submodules (#177)

+ Fix outdated references to previous implementation in the docstring
  • Loading branch information
nathanpainchaud authored Oct 30, 2023
1 parent 58e47bd commit e91dcab
Showing 1 changed file with 22 additions and 16 deletions.
38 changes: 22 additions & 16 deletions vital/callbacks/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class LayersHistogramsLogger(Callback):
def __init__(
self,
layer_types: Sequence[Union[str, Type[nn.Module]]],
submodule: str = None,
submodules: Sequence[str] = None,
log_every_n_steps: int = 50,
include_weight: bool = True,
include_grad: bool = True,
Expand All @@ -71,17 +71,18 @@ def __init__(
Args:
layer_types: Types or classpaths of layers for which to log histograms.
submodule: Name of the module (e.g. 'encoder', 'classifier.', etc.) inside which to search for matching
layers. If none is provided, the Lightning module will be inspected starting from its root.
submodules: Name of the fields (e.g. 'encoder', 'classifier', etc.) corresponding to fields inside which to
search for matching layers. If none is provided, the Lightning module will be inspected starting from
its root.
log_every_n_steps: Frequency at which to log the attention weights computed during the forward pass.
include_weight: Whether to log the layers' weights alongside the attention weights w.r.t the input tokens.
include_grad: Whether to log the layers' gradients alongside the attention weights w.r.t the input tokens.
include_weight: Whether to log the layers' weights.
include_grad: Whether to log the layers' gradients.
"""
self.layer_types = [
layer_type if not isinstance(layer_type, str) else import_from_module(layer_type)
for layer_type in layer_types
]
self.submodule_name = submodule
self.submodule_names = submodules

self.log_every_n_steps = log_every_n_steps
self.train_step = 0
Expand All @@ -96,22 +97,27 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
pl_module: `LightningModule` used in the experiment.
stage: Current stage (e.g. fit, test, etc.) of the experiment.
"""
# Extract the requested submodule from the root module
module = pl_module
if self.submodule_name:
for submodule_name in self.submodule_name.split("."):
module = getattr(module, submodule_name)

def _convert_camel_case_to_snake_case(string: str) -> str:
return re.sub("(?!^)([A-Z]+)", r"_\1", string).lower()

# Get references to the specific layers to watch
# Extract the requested submodule from the root module
submodules_to_inspect = {"self": pl_module}
if self.submodule_names:
submodules_to_inspect = {}
for submodule_name in self.submodule_names:
# For each submodule, (recursively) follow the chain of attributes to get the actual submodule
module = pl_module
for submodule_name in submodule_name.split("."):
module = getattr(module, submodule_name)
submodules_to_inspect[submodule_name] = module

# Get references to the specific layers to watch, prepending the submodule they're from to the layer name
self.layers_to_log = {
f"{self.submodule_name}.{_convert_camel_case_to_snake_case(layer.__class__.__name__)}_{layer_idx}": layer
f"{submodule_name}.{_convert_camel_case_to_snake_case(layer.__class__.__name__)}_{layer_idx}": layer
for submodule_name, submodule in submodules_to_inspect.items()
for layer_type in self.layer_types
for layer_idx, layer in enumerate(
submodule for submodule in module.modules() if isinstance(submodule, layer_type)
)
for layer_idx, layer in enumerate(layer for layer in submodule.modules() if isinstance(layer, layer_type))
}

def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand Down

0 comments on commit e91dcab

Please sign in to comment.