-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactored codebase to break functionality into separate files.
- Loading branch information
JohnMark Taylor
committed
Jan 4, 2025
1 parent
bcc80f0
commit e07e784
Showing
10 changed files
with
7,073 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import warnings | ||
|
||
import torch | ||
|
||
from .constants import MODEL_HISTORY_FIELD_ORDER | ||
from .helper_funcs import remove_entry_from_list | ||
from .tensor_log import TensorLogEntry | ||
|
||
|
||
def cleanup(self): | ||
"""Deletes all log entries in the model.""" | ||
for tensor_log_entry in self: | ||
self._remove_log_entry(tensor_log_entry, remove_references=True) | ||
for attr in MODEL_HISTORY_FIELD_ORDER: | ||
delattr(self, attr) | ||
torch.cuda.empty_cache() | ||
|
||
|
||
def _remove_log_entry( | ||
self, log_entry: TensorLogEntry, remove_references: bool = True | ||
): | ||
"""Given a TensorLogEntry, destroys it and all references to it. | ||
Args: | ||
log_entry: Tensor log entry to remove. | ||
remove_references: Whether to also remove references to the log entry | ||
""" | ||
if self._pass_finished: | ||
tensor_label = log_entry.layer_label | ||
else: | ||
tensor_label = log_entry.tensor_label_raw | ||
for attr in dir(log_entry): | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore") | ||
if not attr.startswith("_") and not callable(getattr(log_entry, attr)): | ||
delattr(log_entry, attr) | ||
del log_entry | ||
if remove_references: | ||
_remove_log_entry_references(self, tensor_label) | ||
|
||
|
||
def _remove_log_entry_references(self, layer_to_remove: str): | ||
"""Removes all references to a given TensorLogEntry in the ModelHistory object. | ||
Args: | ||
layer_to_remove: The log entry to remove. | ||
""" | ||
# Clear any fields in ModelHistory referring to the entry. | ||
|
||
remove_entry_from_list(self.input_layers, layer_to_remove) | ||
remove_entry_from_list(self.output_layers, layer_to_remove) | ||
remove_entry_from_list(self.buffer_layers, layer_to_remove) | ||
remove_entry_from_list(self.internally_initialized_layers, layer_to_remove) | ||
remove_entry_from_list(self.internally_terminated_layers, layer_to_remove) | ||
remove_entry_from_list(self.internally_terminated_bool_layers, layer_to_remove) | ||
remove_entry_from_list(self.layers_with_saved_activations, layer_to_remove) | ||
remove_entry_from_list(self.layers_with_saved_gradients, layer_to_remove) | ||
remove_entry_from_list( | ||
self._layers_where_internal_branches_merge_with_input, layer_to_remove | ||
) | ||
|
||
self.conditional_branch_edges = [ | ||
tup for tup in self.conditional_branch_edges if layer_to_remove not in tup | ||
] | ||
|
||
# Now any nested fields. | ||
|
||
for group_label, group_tensors in self.layers_computed_with_params.items(): | ||
if layer_to_remove in group_tensors: | ||
group_tensors.remove(layer_to_remove) | ||
self.layers_computed_with_params = { | ||
k: v for k, v in self.layers_computed_with_params.items() if len(v) > 0 | ||
} | ||
|
||
for group_label, group_tensors in self.equivalent_operations.items(): | ||
if layer_to_remove in group_tensors: | ||
group_tensors.remove(layer_to_remove) | ||
self.equivalent_operations = { | ||
k: v for k, v in self.equivalent_operations.items() if len(v) > 0 | ||
} | ||
|
||
for group_label, group_tensors in self.same_layer_operations.items(): | ||
if layer_to_remove in group_tensors: | ||
group_tensors.remove(layer_to_remove) | ||
self.same_layer_operations = { | ||
k: v for k, v in self.same_layer_operations.items() if len(v) > 0 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
import inspect | ||
import time | ||
import types | ||
import warnings | ||
from functools import wraps | ||
from typing import Callable, Dict, List, TYPE_CHECKING, Tuple | ||
|
||
import torch | ||
|
||
from .constants import ORIG_TORCH_FUNCS | ||
from .helper_funcs import (clean_to, get_vars_of_type_from_obj, identity, log_current_rng_states, make_random_barcode, | ||
nested_getattr, print_override, safe_copy) | ||
from .logging_funcs import log_function_output_tensors, log_source_tensor | ||
|
||
if TYPE_CHECKING: | ||
from model_history import ModelHistory | ||
|
||
funcs_not_to_log = ["numpy", "__array__", "size", "dim"] | ||
print_funcs = ["__repr__", "__str__", "_str"] | ||
|
||
|
||
def torch_func_decorator(self, func: Callable): | ||
@wraps(func) | ||
def wrapped_func(*args, **kwargs): | ||
# Initial bookkeeping; check if it's a special function, organize the arguments. | ||
self.current_function_call_barcode = 0 | ||
func_name = func.__name__ | ||
if ( | ||
(func_name in funcs_not_to_log) | ||
or (not self._track_tensors) | ||
or self._pause_logging | ||
): | ||
out = func(*args, **kwargs) | ||
return out | ||
all_args = list(args) + list(kwargs.values()) | ||
arg_tensorlike = get_vars_of_type_from_obj(all_args, torch.Tensor) | ||
|
||
# Register any buffer tensors in the arguments. | ||
|
||
for t in arg_tensorlike: | ||
if hasattr(t, 'tl_buffer_address'): | ||
log_source_tensor(self, t, 'buffer', getattr(t, 'tl_buffer_address')) | ||
|
||
if (func_name in print_funcs) and (len(arg_tensorlike) > 0): | ||
out = print_override(args[0], func_name) | ||
return out | ||
|
||
# Copy the args and kwargs in case they change in-place: | ||
if self.save_function_args: | ||
arg_copies = tuple([safe_copy(arg) for arg in args]) | ||
kwarg_copies = {k: safe_copy(v) for k, v in kwargs.items()} | ||
else: | ||
arg_copies = args | ||
kwarg_copies = kwargs | ||
|
||
# Call the function, tracking the timing, rng states, and whether it's a nested function | ||
func_call_barcode = make_random_barcode() | ||
self.current_function_call_barcode = func_call_barcode | ||
start_time = time.time() | ||
func_rng_states = log_current_rng_states() | ||
out_orig = func(*args, **kwargs) | ||
func_time_elapsed = time.time() - start_time | ||
is_bottom_level_func = ( | ||
self.current_function_call_barcode == func_call_barcode | ||
) | ||
|
||
if func_name in ["__setitem__", "zero_", "__delitem__"]: | ||
out_orig = args[0] | ||
|
||
# Log all output tensors | ||
output_tensors = get_vars_of_type_from_obj( | ||
out_orig, | ||
which_type=torch.Tensor, | ||
subclass_exceptions=[torch.nn.Parameter], | ||
) | ||
|
||
if len(output_tensors) > 0: | ||
log_function_output_tensors( | ||
self, | ||
func, | ||
args, | ||
kwargs, | ||
arg_copies, | ||
kwarg_copies, | ||
out_orig, | ||
func_time_elapsed, | ||
func_rng_states, | ||
is_bottom_level_func, | ||
) | ||
|
||
return out_orig | ||
|
||
return wrapped_func | ||
|
||
|
||
def decorate_pytorch( | ||
self: "ModelHistory", torch_module: types.ModuleType, orig_func_defs: List[Tuple] | ||
) -> Dict[Callable, Callable]: | ||
"""Mutates all PyTorch functions (TEMPORARILY!) to save the outputs of any functions | ||
that return Tensors, along with marking them with metadata. Returns a list of tuples that | ||
save the current state of the functions, such that they can be restored when done. | ||
args: | ||
torch_module: The top-level torch module (i.e., from "import torch"). | ||
This is supplied as an argument on the off-chance that the user has imported torch | ||
and done their own monkey-patching. | ||
tensors_to_mutate: A list of tensors that will be mutated (since any tensors created | ||
before calling the torch mutation function will not be mutated). | ||
orig_func_defs: Supply a list from outside to guarantee it can be cleaned up properly. | ||
tensor_record: A list to which the outputs of the functions will be appended. | ||
returns: | ||
List of tuples consisting of [namespace, func_name, orig_func], sufficient | ||
to return torch to normal when finished, and also a dict mapping mutated functions to original functions. | ||
""" | ||
|
||
# Do a pass to save the original func defs. | ||
collect_orig_func_defs(torch_module, orig_func_defs) | ||
decorated_func_mapper = {} | ||
|
||
for namespace_name, func_name in ORIG_TORCH_FUNCS: | ||
namespace_name_notorch = namespace_name.replace("torch.", "") | ||
local_func_namespace = nested_getattr(torch_module, namespace_name_notorch) | ||
if not hasattr(local_func_namespace, func_name): | ||
continue | ||
orig_func = getattr(local_func_namespace, func_name) | ||
if func_name not in self.func_argnames: | ||
get_func_argnames(self, orig_func, func_name) | ||
if getattr(orig_func, "__name__", False) == "wrapped_func": | ||
continue | ||
new_func = torch_func_decorator(self, orig_func) | ||
try: | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore") | ||
setattr(local_func_namespace, func_name, new_func) | ||
except (AttributeError, TypeError) as _: | ||
pass | ||
new_func.tl_is_decorated_function = True | ||
decorated_func_mapper[new_func] = orig_func | ||
decorated_func_mapper[orig_func] = new_func | ||
|
||
# Bolt on the identity function | ||
new_identity = torch_func_decorator(self, identity) | ||
torch.identity = new_identity | ||
|
||
return decorated_func_mapper | ||
|
||
|
||
def undecorate_pytorch( | ||
torch_module, orig_func_defs: List[Tuple], input_tensors: List[torch.Tensor] | ||
): | ||
""" | ||
Returns all PyTorch functions back to the definitions they had when mutate_pytorch was called. | ||
This is done for the output tensors and history_dict too to avoid ugliness. Also deletes | ||
the mutant versions of the functions to remove any references to old ModelHistory object. | ||
args: | ||
torch_module: The torch module object. | ||
orig_func_defs: List of tuples consisting of [namespace_name, func_name, orig_func], sufficient | ||
to regenerate the original functions. | ||
input_tensors: List of input tensors whose fucntions will be undecorated. | ||
decorated_func_mapper: Maps the decorated function to the original function | ||
""" | ||
for namespace_name, func_name, orig_func in orig_func_defs: | ||
namespace_name_notorch = namespace_name.replace("torch.", "") | ||
local_func_namespace = nested_getattr(torch_module, namespace_name_notorch) | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore") | ||
decorated_func = getattr(local_func_namespace, func_name) | ||
del decorated_func | ||
try: | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore") | ||
setattr(local_func_namespace, func_name, orig_func) | ||
except (AttributeError, TypeError) as _: | ||
continue | ||
delattr(torch, "identity") | ||
for input_tensor in input_tensors: | ||
if hasattr(input_tensor, "tl_tensor_label_raw"): | ||
delattr(input_tensor, "tl_tensor_label_raw") | ||
|
||
|
||
def undecorate_tensor(t, device: str = "cpu"): | ||
"""Convenience function to replace the tensor with an unmutated version of itself, keeping the same data. | ||
Args: | ||
t: tensor or parameter object | ||
device: device to move the tensor to | ||
Returns: | ||
Unmutated tensor. | ||
""" | ||
if type(t) in [torch.Tensor, torch.nn.Parameter]: | ||
new_t = safe_copy(t) | ||
else: | ||
new_t = t | ||
del t | ||
for attr in dir(new_t): | ||
if attr.startswith("tl_"): | ||
delattr(new_t, attr) | ||
new_t = clean_to(new_t, device) | ||
return new_t | ||
|
||
|
||
def collect_orig_func_defs( | ||
torch_module: types.ModuleType, orig_func_defs: List[Tuple] | ||
): | ||
"""Collects the original torch function definitions, so they can be restored after the logging is done. | ||
Args: | ||
torch_module: The top-level torch module | ||
orig_func_defs: List of tuples keeping track of the original function definitions | ||
""" | ||
for namespace_name, func_name in ORIG_TORCH_FUNCS: | ||
namespace_name_notorch = namespace_name.replace("torch.", "") | ||
local_func_namespace = nested_getattr(torch_module, namespace_name_notorch) | ||
if not hasattr(local_func_namespace, func_name): | ||
continue | ||
orig_func = getattr(local_func_namespace, func_name) | ||
orig_func_defs.append((namespace_name, func_name, orig_func)) | ||
|
||
|
||
# TODO: hard-code some of the arg names; for example truediv, getitem, etc. Can crawl through and see what isn't working | ||
def get_func_argnames(self, orig_func: Callable, func_name: str): | ||
"""Attempts to get the argument names for a function, first by checking the signature, then | ||
by checking the documentation. Adds these names to func_argnames if it can find them, | ||
doesn't do anything if it can't.""" | ||
try: | ||
argnames = list(inspect.signature(orig_func).parameters.keys()) | ||
argnames = tuple([arg.replace('*', '') for arg in argnames if arg not in ['cls', 'self']]) | ||
self.func_argnames[func_name] = argnames | ||
return | ||
except ValueError: | ||
pass | ||
|
||
docstring = orig_func.__doc__ | ||
if (type(docstring) is not str) or (len(docstring) == 0): # if docstring missing, skip it | ||
return | ||
|
||
open_ind, close_ind = docstring.find('('), docstring.find(')') | ||
argstring = docstring[open_ind + 1: close_ind] | ||
arg_list = argstring.split(',') | ||
arg_list = [arg.strip(' ') for arg in arg_list] | ||
argnames = [] | ||
for arg in arg_list: | ||
argname = arg.split('=')[0] | ||
if argname in ['*', '/', '//', '']: | ||
continue | ||
argname = argname.replace('*', '') | ||
argnames.append(argname) | ||
argnames = tuple([arg for arg in argnames if arg not in ['self', 'cls']]) | ||
self.func_argnames[func_name] = argnames | ||
return |
Oops, something went wrong.