diff --git a/setup.py b/setup.py index acfa65a..fe7c5ae 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ setup( name="torchlens", - version="0.1.29", + version="0.1.30", description="A package for extracting activations from PyTorch models", long_description="A package for extracting activations from PyTorch models. Contains functionality for " "extracting model activations, visualizing a model's computational graph, and " diff --git a/tests/example_models.py b/tests/example_models.py index 4c75574..9ab4a17 100644 --- a/tests/example_models.py +++ b/tests/example_models.py @@ -1240,3 +1240,22 @@ def forward(self, x): w1 = x ** 3 x = torch.sum(torch.stack([y4, z3, w1])) return x + + +class PropertyModel(nn.Module): + def __init__(self): + """Conv, relu, pool, fc, output.""" + super().__init__() + + def forward(self, x): + r = x.real + i = x.imag + t = torch.rand(4, 4) + t = t * 3 + t = t.data + t2 = t.T + m = torch.rand(4, 4, 4) + m = m ** 2 + m2 = m.mT.mean() + out = r * i / m2 + t2.mean() + return out diff --git a/tests/test_validation_and_visuals.py b/tests/test_validation_and_visuals.py index 878b8df..f360b68 100644 --- a/tests/test_validation_and_visuals.py +++ b/tests/test_validation_and_visuals.py @@ -93,6 +93,11 @@ def input_2d(): return torch.rand(5, 5) +@pytest.fixture +def input_complex(): + return torch.complex(torch.rand(3, 3), torch.rand(3, 3)), + + # Test different operations @@ -1187,6 +1192,18 @@ def test_module_looping_clash3(default_input1): ) +def test_propertymodel(input_complex): + model = example_models.PropertyModel() + assert validate_saved_activations(model, input_complex) + show_model_graph( + model, + input_complex, + save_only=True, + vis_opt="unrolled", + vis_outpath=opj("visualization_outputs", "toy-networks", "propertymodel"), + ) + + def test_ubermodel1(input_2d): model = example_models.UberModel1() assert validate_saved_activations(model, [[input_2d, input_2d * 2, input_2d * 3]]) @@ -3044,6 +3061,40 @@ def test_dimenet(): assert validate_saved_activations(model, model_inputs) +# Quantum machine-learning model + +def test_qml(): + import pennylane as qml + n_qubits = 2 + dev = qml.device("default.qubit", wires=n_qubits) + + @qml.qnode(dev, diff_method="backprop") + def qnode(inputs, weights): + # print(inputs) + qml.RX(inputs[0][0], wires=0) + qml.RY(weights, wires=0) + return [qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1))] + + weight_shapes = {"weights": 1} + qlayer = qml.qnn.TorchLayer(qnode, weight_shapes) + + clayer_1 = torch.nn.Linear(2, 2) + clayer_2 = torch.nn.Linear(2, 2) + softmax = torch.nn.Softmax(dim=1) + layers = [clayer_1, qlayer, clayer_2, softmax] + model = torch.nn.Sequential(*layers) + model_inputs = torch.rand(1, 2, requires_grad=False) + + show_model_graph( + model, + model_inputs, + save_only=True, + vis_opt="unrolled", + vis_outpath=opj("visualization_outputs", "quantum", "qml"), + ) + assert validate_saved_activations(model, model_inputs) + + # Lightning modules def test_lightning(): diff --git a/torchlens/constants.py b/torchlens/constants.py index 4a034a2..d752045 100644 --- a/torchlens/constants.py +++ b/torchlens/constants.py @@ -330,6 +330,11 @@ ("torch.Tensor", "_make_subclass"), ("torch.Tensor", "solve"), ("torch.Tensor", "unflatten"), + ("torch.Tensor", "real"), + ("torch.Tensor", "imag"), + ("torch.Tensor", "T"), + ("torch.Tensor", "mT"), + ("torch.Tensor", "H") ] diff --git a/torchlens/decorate_torch.py b/torchlens/decorate_torch.py index 132806b..0af8437 100644 --- a/torchlens/decorate_torch.py +++ b/torchlens/decorate_torch.py @@ -19,12 +19,11 @@ print_funcs = ["__repr__", "__str__", "_str"] -def torch_func_decorator(self, func: Callable): +def torch_func_decorator(self, func: Callable, func_name: str): @wraps(func) def wrapped_func(*args, **kwargs): # Initial bookkeeping; check if it's a special function, organize the arguments. self.current_function_call_barcode = 0 - func_name = func.__name__ if ( (func_name in funcs_not_to_log) or (not self._track_tensors) @@ -78,6 +77,7 @@ def wrapped_func(*args, **kwargs): log_function_output_tensors( self, func, + func_name, args, kwargs, arg_copies, @@ -118,6 +118,13 @@ def decorate_pytorch( collect_orig_func_defs(torch_module, orig_func_defs) decorated_func_mapper = {} + # Get references to the function classes. + function_class = type(lambda: 0) + builtin_class = type(torch.mean) + method_class = type(torch.Tensor.__add__) + wrapper_class = type(torch.Tensor.__getitem__) + getset_class = type(torch.Tensor.real) + for namespace_name, func_name in ORIG_TORCH_FUNCS: namespace_name_notorch = namespace_name.replace("torch.", "") local_func_namespace = nested_getattr(torch_module, namespace_name_notorch) @@ -128,19 +135,39 @@ def decorate_pytorch( get_func_argnames(self, orig_func, func_name) if getattr(orig_func, "__name__", False) == "wrapped_func": continue - new_func = torch_func_decorator(self, orig_func) - try: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - setattr(local_func_namespace, func_name, new_func) - except (AttributeError, TypeError) as _: - pass - new_func.tl_is_decorated_function = True - decorated_func_mapper[new_func] = orig_func - decorated_func_mapper[orig_func] = new_func + + if type(orig_func) in [function_class, builtin_class, method_class, wrapper_class]: + new_func = torch_func_decorator(self, orig_func, func_name) + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + setattr(local_func_namespace, func_name, new_func) + except (AttributeError, TypeError) as _: + pass + new_func.tl_is_decorated_function = True + decorated_func_mapper[new_func] = orig_func + decorated_func_mapper[orig_func] = new_func + + elif type(orig_func) == getset_class: + getter_orig, setter_orig, deleter_orig = orig_func.__get__, orig_func.__set__, orig_func.__delete__ + getter_dec, setter_dec, deleter_dec = (torch_func_decorator(self, getter_orig, func_name), + torch_func_decorator(self, setter_orig, func_name), + torch_func_decorator(self, deleter_orig, func_name)) + getter_dec.tl_is_decorated_function = True + setter_dec.tl_is_decorated_function = True + deleter_dec.tl_is_decorated_function = True + new_property = property(getter_dec, setter_dec, deleter_dec, doc=func_name) + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + setattr(local_func_namespace, func_name, new_property) + except (AttributeError, TypeError) as _: + pass + decorated_func_mapper[new_property] = orig_func + decorated_func_mapper[orig_func] = new_property # Bolt on the identity function - new_identity = torch_func_decorator(self, identity) + new_identity = torch_func_decorator(self, identity, 'identity') torch.identity = new_identity return decorated_func_mapper @@ -225,6 +252,9 @@ def get_func_argnames(self, orig_func: Callable, func_name: str): """Attempts to get the argument names for a function, first by checking the signature, then by checking the documentation. Adds these names to func_argnames if it can find them, doesn't do anything if it can't.""" + if func_name in ['real', 'imag', 'T', 'mT', 'data', 'H']: + return + try: argnames = list(inspect.signature(orig_func).parameters.keys()) argnames = tuple([arg.replace('*', '') for arg in argnames if arg not in ['cls', 'self']]) diff --git a/torchlens/helper_funcs.py b/torchlens/helper_funcs.py index ac153b4..1a1e520 100644 --- a/torchlens/helper_funcs.py +++ b/torchlens/helper_funcs.py @@ -403,7 +403,9 @@ def extend_search_stack_from_item( ) for attr_name in dir(item): - if (attr_name.startswith("__")) or (attr_name == 'T') or ('grad' in attr_name): + if ((attr_name.startswith("__")) or + (attr_name in ['T', 'mT', 'real', 'imag']) or + ('grad' in attr_name)): continue try: with warnings.catch_warnings(): diff --git a/torchlens/logging_funcs.py b/torchlens/logging_funcs.py index 5f7a227..7cabd63 100644 --- a/torchlens/logging_funcs.py +++ b/torchlens/logging_funcs.py @@ -322,6 +322,7 @@ def log_source_tensor_fast(self, t: torch.Tensor, source: str): def log_function_output_tensors( self, func: Callable, + func_name: str, args: Tuple[Any], kwargs: Dict[str, Any], arg_copies: Tuple[Any], @@ -335,6 +336,7 @@ def log_function_output_tensors( log_function_output_tensors_exhaustive( self, func, + func_name, args, kwargs, arg_copies, @@ -347,7 +349,7 @@ def log_function_output_tensors( elif self.logging_mode == "fast": log_function_output_tensors_fast( self, - func, + func_name, args, kwargs, arg_copies, @@ -362,6 +364,7 @@ def log_function_output_tensors( def log_function_output_tensors_exhaustive( self, func: Callable, + func_name: str, args: Tuple[Any], kwargs: Dict[str, Any], arg_copies: Tuple[Any], @@ -385,7 +388,6 @@ def log_function_output_tensors_exhaustive( is_bottom_level_func: whether the function is at the bottom-level of function nesting """ # Unpacking and reformatting: - func_name = func.__name__ layer_type = func_name.lower().replace("_", "") all_args = list(args) + list(kwargs.values()) @@ -645,7 +647,7 @@ def _get_parent_contents( def log_function_output_tensors_fast( self, - func: Callable, + func_name: str, args: Tuple[Any], kwargs: Dict[str, Any], arg_copies: Tuple[Any], @@ -656,7 +658,6 @@ def log_function_output_tensors_fast( is_bottom_level_func: bool, ): # Collect information. - func_name = func.__name__ layer_type = func_name.lower().replace("_", "") all_args = list(args) + list(kwargs.values()) non_tensor_args = [arg for arg in args if not _check_if_tensor_arg(arg)] @@ -770,6 +771,13 @@ def _output_should_be_logged(out: Any, is_bottom_level_func: bool) -> bool: return False +def _get_funcname(self, f): + if f in self._funcname_overrides: + return self._funcname_overrides[f] + else: + return f.__name__ + + def _add_backward_hook(self, t: torch.Tensor, tensor_label): """Adds a backward hook to the tensor that saves the gradients to ModelHistory if specified.