Skip to content

Commit

Permalink
Debugged issues with saving and visualizing gradients: Issue #33, Issue
Browse files Browse the repository at this point in the history
#34, Issue #35
  • Loading branch information
JohnMark Taylor committed Jan 22, 2025
1 parent 6745a86 commit 2782bf4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 14 deletions.
39 changes: 27 additions & 12 deletions torchlens/logging_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

if TYPE_CHECKING:
from .model_history import ModelHistory
from .tensor_log import TensorLogEntry


def save_new_activations(
Expand Down Expand Up @@ -275,6 +276,11 @@ def log_source_tensor_exhaustive(
self.buffer_layers.append(tensor_label)
self.internally_initialized_layers.append(tensor_label)

# Make it track gradients if relevant

if self.save_gradients:
_add_backward_hook(self, t, t.tl_tensor_label_raw)


def log_source_tensor_fast(self, t: torch.Tensor, source: str):
"""NOTES TO SELF--fields to change are:
Expand Down Expand Up @@ -774,12 +780,11 @@ def _add_backward_hook(self, t: torch.Tensor, tensor_label):
Nothing; it changes the tensor in place.
"""

# Define the decorator
def log_grad_to_model_history(_, g_out):
self._log_tensor_grad(g_out, tensor_label)
def log_grad_to_model_history(grad):
_log_tensor_grad(self, grad, tensor_label)

if t.grad_fn is not None:
t.grad_fn.register_hook(log_grad_to_model_history)
if (t.grad_fn is not None) or t.requires_grad:
t.register_hook(log_grad_to_model_history)


def _log_info_specific_to_single_function_output_tensor(
Expand Down Expand Up @@ -948,14 +953,24 @@ def _log_tensor_grad(self, grad: torch.Tensor, tensor_label_raw: str):
"""
self.has_saved_gradients = True
tensor_label = self._raw_to_final_layer_labels[tensor_label_raw]
if tensor_label not in self.layers_with_saved_gradients:
self.layers_with_saved_gradients.append(tensor_label)
layer_order = {layer: i for i, layer in enumerate(self.layer_labels)}
self.layers_with_saved_gradients = sorted(
self.layers_with_saved_gradients, key=lambda x: layer_order[x]
)
tensor_log_entry = self[tensor_label]
tensor_log_entry.log_tensor_grad(grad)
layers_to_update = [tensor_label]
if tensor_log_entry.is_output_parent: # also update any linked outputs
for child_layer in tensor_log_entry.child_layers:
if self[child_layer].is_output_layer:
layers_to_update.append(child_layer)

for layer_label in layers_to_update:
layer = self[layer_label]
if layer_label not in self.layers_with_saved_gradients:
self.layers_with_saved_gradients.append(layer_label)

layer.log_tensor_grad(grad)

layer_order = {layer: i for i, layer in enumerate(self.layer_labels)}
self.layers_with_saved_gradients = sorted(
self.layers_with_saved_gradients, key=lambda x: layer_order[x]
)


def _check_if_tensor_arg(arg: Any) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions torchlens/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,8 +918,8 @@ def _add_gradient_edge(
edge_dict = {
"tail_name": child_layer.layer_label.replace(":", "pass"),
"head_name": parent_layer.layer_label.replace(":", "pass"),
"color": self.GRADIENT_ARROW_COLOR,
"fontcolor": self.GRADIENT_ARROW_COLOR,
"color": GRADIENT_ARROW_COLOR,
"fontcolor": GRADIENT_ARROW_COLOR,
"style": edge_style,
"arrowsize": ".7",
"labelfontsize": "8",
Expand Down

0 comments on commit 2782bf4

Please sign in to comment.