Skip to content

Commit

Permalink
Added compatibility for complex numbers (enables quantum neural netwo…
Browse files Browse the repository at this point in the history
…rks), and allows functions called as tensor properties to be tracked (real, imag, T, mT, H).
  • Loading branch information
JohnMark Taylor committed Feb 11, 2025
1 parent 94e3d2b commit 42e93e1
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 19 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

setup(
name="torchlens",
version="0.1.29",
version="0.1.30",
description="A package for extracting activations from PyTorch models",
long_description="A package for extracting activations from PyTorch models. Contains functionality for "
"extracting model activations, visualizing a model's computational graph, and "
Expand Down
19 changes: 19 additions & 0 deletions tests/example_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,3 +1240,22 @@ def forward(self, x):
w1 = x ** 3
x = torch.sum(torch.stack([y4, z3, w1]))
return x


class PropertyModel(nn.Module):
def __init__(self):
"""Conv, relu, pool, fc, output."""
super().__init__()

def forward(self, x):
r = x.real
i = x.imag
t = torch.rand(4, 4)
t = t * 3
t = t.data
t2 = t.T
m = torch.rand(4, 4, 4)
m = m ** 2
m2 = m.mT.mean()
out = r * i / m2 + t2.mean()
return out
51 changes: 51 additions & 0 deletions tests/test_validation_and_visuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def input_2d():
return torch.rand(5, 5)


@pytest.fixture
def input_complex():
return torch.complex(torch.rand(3, 3), torch.rand(3, 3)),


# Test different operations


Expand Down Expand Up @@ -1187,6 +1192,18 @@ def test_module_looping_clash3(default_input1):
)


def test_propertymodel(input_complex):
model = example_models.PropertyModel()
assert validate_saved_activations(model, input_complex)
show_model_graph(
model,
input_complex,
save_only=True,
vis_opt="unrolled",
vis_outpath=opj("visualization_outputs", "toy-networks", "propertymodel"),
)


def test_ubermodel1(input_2d):
model = example_models.UberModel1()
assert validate_saved_activations(model, [[input_2d, input_2d * 2, input_2d * 3]])
Expand Down Expand Up @@ -3044,6 +3061,40 @@ def test_dimenet():
assert validate_saved_activations(model, model_inputs)


# Quantum machine-learning model

def test_qml():
import pennylane as qml
n_qubits = 2
dev = qml.device("default.qubit", wires=n_qubits)

@qml.qnode(dev, diff_method="backprop")
def qnode(inputs, weights):
# print(inputs)
qml.RX(inputs[0][0], wires=0)
qml.RY(weights, wires=0)
return [qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1))]

weight_shapes = {"weights": 1}
qlayer = qml.qnn.TorchLayer(qnode, weight_shapes)

clayer_1 = torch.nn.Linear(2, 2)
clayer_2 = torch.nn.Linear(2, 2)
softmax = torch.nn.Softmax(dim=1)
layers = [clayer_1, qlayer, clayer_2, softmax]
model = torch.nn.Sequential(*layers)
model_inputs = torch.rand(1, 2, requires_grad=False)

show_model_graph(
model,
model_inputs,
save_only=True,
vis_opt="unrolled",
vis_outpath=opj("visualization_outputs", "quantum", "qml"),
)
assert validate_saved_activations(model, model_inputs)


# Lightning modules

def test_lightning():
Expand Down
5 changes: 5 additions & 0 deletions torchlens/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,11 @@
("torch.Tensor", "_make_subclass"),
("torch.Tensor", "solve"),
("torch.Tensor", "unflatten"),
("torch.Tensor", "real"),
("torch.Tensor", "imag"),
("torch.Tensor", "T"),
("torch.Tensor", "mT"),
("torch.Tensor", "H")
]


Expand Down
56 changes: 43 additions & 13 deletions torchlens/decorate_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@
print_funcs = ["__repr__", "__str__", "_str"]


def torch_func_decorator(self, func: Callable):
def torch_func_decorator(self, func: Callable, func_name: str):
@wraps(func)
def wrapped_func(*args, **kwargs):
# Initial bookkeeping; check if it's a special function, organize the arguments.
self.current_function_call_barcode = 0
func_name = func.__name__
if (
(func_name in funcs_not_to_log)
or (not self._track_tensors)
Expand Down Expand Up @@ -78,6 +77,7 @@ def wrapped_func(*args, **kwargs):
log_function_output_tensors(
self,
func,
func_name,
args,
kwargs,
arg_copies,
Expand Down Expand Up @@ -118,6 +118,13 @@ def decorate_pytorch(
collect_orig_func_defs(torch_module, orig_func_defs)
decorated_func_mapper = {}

# Get references to the function classes.
function_class = type(lambda: 0)
builtin_class = type(torch.mean)
method_class = type(torch.Tensor.__add__)
wrapper_class = type(torch.Tensor.__getitem__)
getset_class = type(torch.Tensor.real)

for namespace_name, func_name in ORIG_TORCH_FUNCS:
namespace_name_notorch = namespace_name.replace("torch.", "")
local_func_namespace = nested_getattr(torch_module, namespace_name_notorch)
Expand All @@ -128,19 +135,39 @@ def decorate_pytorch(
get_func_argnames(self, orig_func, func_name)
if getattr(orig_func, "__name__", False) == "wrapped_func":
continue
new_func = torch_func_decorator(self, orig_func)
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
setattr(local_func_namespace, func_name, new_func)
except (AttributeError, TypeError) as _:
pass
new_func.tl_is_decorated_function = True
decorated_func_mapper[new_func] = orig_func
decorated_func_mapper[orig_func] = new_func

if type(orig_func) in [function_class, builtin_class, method_class, wrapper_class]:
new_func = torch_func_decorator(self, orig_func, func_name)
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
setattr(local_func_namespace, func_name, new_func)
except (AttributeError, TypeError) as _:
pass
new_func.tl_is_decorated_function = True
decorated_func_mapper[new_func] = orig_func
decorated_func_mapper[orig_func] = new_func

elif type(orig_func) == getset_class:
getter_orig, setter_orig, deleter_orig = orig_func.__get__, orig_func.__set__, orig_func.__delete__
getter_dec, setter_dec, deleter_dec = (torch_func_decorator(self, getter_orig, func_name),
torch_func_decorator(self, setter_orig, func_name),
torch_func_decorator(self, deleter_orig, func_name))
getter_dec.tl_is_decorated_function = True
setter_dec.tl_is_decorated_function = True
deleter_dec.tl_is_decorated_function = True
new_property = property(getter_dec, setter_dec, deleter_dec, doc=func_name)
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
setattr(local_func_namespace, func_name, new_property)
except (AttributeError, TypeError) as _:
pass
decorated_func_mapper[new_property] = orig_func
decorated_func_mapper[orig_func] = new_property

# Bolt on the identity function
new_identity = torch_func_decorator(self, identity)
new_identity = torch_func_decorator(self, identity, 'identity')
torch.identity = new_identity

return decorated_func_mapper
Expand Down Expand Up @@ -225,6 +252,9 @@ def get_func_argnames(self, orig_func: Callable, func_name: str):
"""Attempts to get the argument names for a function, first by checking the signature, then
by checking the documentation. Adds these names to func_argnames if it can find them,
doesn't do anything if it can't."""
if func_name in ['real', 'imag', 'T', 'mT', 'data', 'H']:
return

try:
argnames = list(inspect.signature(orig_func).parameters.keys())
argnames = tuple([arg.replace('*', '') for arg in argnames if arg not in ['cls', 'self']])
Expand Down
4 changes: 3 additions & 1 deletion torchlens/helper_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,9 @@ def extend_search_stack_from_item(
)

for attr_name in dir(item):
if (attr_name.startswith("__")) or (attr_name == 'T') or ('grad' in attr_name):
if ((attr_name.startswith("__")) or
(attr_name in ['T', 'mT', 'real', 'imag']) or
('grad' in attr_name)):
continue
try:
with warnings.catch_warnings():
Expand Down
16 changes: 12 additions & 4 deletions torchlens/logging_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def log_source_tensor_fast(self, t: torch.Tensor, source: str):
def log_function_output_tensors(
self,
func: Callable,
func_name: str,
args: Tuple[Any],
kwargs: Dict[str, Any],
arg_copies: Tuple[Any],
Expand All @@ -335,6 +336,7 @@ def log_function_output_tensors(
log_function_output_tensors_exhaustive(
self,
func,
func_name,
args,
kwargs,
arg_copies,
Expand All @@ -347,7 +349,7 @@ def log_function_output_tensors(
elif self.logging_mode == "fast":
log_function_output_tensors_fast(
self,
func,
func_name,
args,
kwargs,
arg_copies,
Expand All @@ -362,6 +364,7 @@ def log_function_output_tensors(
def log_function_output_tensors_exhaustive(
self,
func: Callable,
func_name: str,
args: Tuple[Any],
kwargs: Dict[str, Any],
arg_copies: Tuple[Any],
Expand All @@ -385,7 +388,6 @@ def log_function_output_tensors_exhaustive(
is_bottom_level_func: whether the function is at the bottom-level of function nesting
"""
# Unpacking and reformatting:
func_name = func.__name__
layer_type = func_name.lower().replace("_", "")
all_args = list(args) + list(kwargs.values())

Expand Down Expand Up @@ -645,7 +647,7 @@ def _get_parent_contents(

def log_function_output_tensors_fast(
self,
func: Callable,
func_name: str,
args: Tuple[Any],
kwargs: Dict[str, Any],
arg_copies: Tuple[Any],
Expand All @@ -656,7 +658,6 @@ def log_function_output_tensors_fast(
is_bottom_level_func: bool,
):
# Collect information.
func_name = func.__name__
layer_type = func_name.lower().replace("_", "")
all_args = list(args) + list(kwargs.values())
non_tensor_args = [arg for arg in args if not _check_if_tensor_arg(arg)]
Expand Down Expand Up @@ -770,6 +771,13 @@ def _output_should_be_logged(out: Any, is_bottom_level_func: bool) -> bool:
return False


def _get_funcname(self, f):
if f in self._funcname_overrides:
return self._funcname_overrides[f]
else:
return f.__name__


def _add_backward_hook(self, t: torch.Tensor, tensor_label):
"""Adds a backward hook to the tensor that saves the gradients to ModelHistory if specified.
Expand Down

0 comments on commit 42e93e1

Please sign in to comment.