Skip to content

Commit

Permalink
Use lru_cache instead of simple dict cache
Browse files Browse the repository at this point in the history
  • Loading branch information
vladislavalerievich committed Jan 23, 2025
1 parent f7922db commit 4cc0b29
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 50 deletions.
7 changes: 6 additions & 1 deletion grakel_replace/graph_aware_gp_optimization_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from gpytorch import ExactMarginalLogLikelihood
from gpytorch.kernels import AdditiveKernel, MaternKernel
from grakel_replace.context_managers import set_graph_lookup
from grakel_replace.kernels import BoTorchWLKernel
from grakel_replace.kernels import BoTorchWLKernel, TorchWLKernel
from grakel_replace.optimization import optimize_acqf_graph
from grakel_replace.utils import min_max_scale, seed_all

Expand Down Expand Up @@ -122,3 +122,8 @@
print(f"Best candidate: {best_candidate}")
print(f"Best score: {best_score}")
print(f"Elapsed time: {time.time() - start_time} seconds")

# Clear caches after optimization to avoid memory leaks or unexpected behavior
BoTorchWLKernel._compute_kernel.cache_clear()
TorchWLKernel._get_node_neighbors.cache_clear()
TorchWLKernel._wl_iteration.cache_clear()
59 changes: 10 additions & 49 deletions grakel_replace/kernels.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import lru_cache
from typing import TYPE_CHECKING, Any

import torch
Expand Down Expand Up @@ -32,7 +33,6 @@ class BoTorchWLKernel(Kernel):
graph_lookup (list[nx.Graph]): List of graphs used for kernel computation.
n_iter (int): Number of WL iterations.
normalize (bool): Whether to normalize the kernel matrix.
cache (dict[tuple, Tensor]): Cache for storing precomputed kernel matrices.
adjacency_cache (list[Tensor]): Cached adjacency matrices of the graphs.
label_cache (list[Tensor]): Cached initial node labels of the graphs.
"""
Expand All @@ -51,7 +51,6 @@ def __init__(
self.graph_lookup = graph_lookup
self.n_iter = n_iter
self.normalize = normalize
self.cache: dict[tuple, Tensor] = {}
self._precompute_graph_data()

def _precompute_graph_data(self) -> None:
Expand All @@ -75,34 +74,15 @@ def forward(
**params: Any,
) -> Tensor:
"""Compute kernel matrix containing pairwise similarities between graphs."""
# last_dim_is_batch is for compatibility with base Kernel class.
if last_dim_is_batch:
raise NotImplementedError("Batch dimension handling is not implemented.")

x1_is_x2 = torch.equal(x1, x2)
indices = tuple(x1.flatten().tolist()) if x1_is_x2 else (
tuple(x1.flatten().tolist()), tuple(x2.flatten().tolist()))

if indices in self.cache:
return self.cache[indices]

# Compute kernel matrix if not cached
K = self._compute_kernel(x1, x2, diag=diag)
self.cache[indices] = K
return K

def _compute_kernel(self, x1: Tensor, x2: Tensor, diag: bool) -> Tensor:
"""Compute the kernel matrix."""
if x1.ndim == 3:
return self._handle_batched_input(x1, x2, diag)

indices1, indices2 = self._prepare_indices(x1, x2)

# Check if we're computing self-similarity or cross-similarity
if torch.equal(x1, x2):
return self._compute_self_kernel(indices1, diag)
else:
return self._compute_cross_kernel(indices1, indices2, diag)
return self._compute_kernel(tuple(indices1), tuple(indices2), diag)

def _handle_batched_input(self, x1: Tensor, x2: Tensor, diag: bool) -> Tensor:
"""Handle computation for batched input tensors."""
Expand All @@ -111,7 +91,7 @@ def _handle_batched_input(self, x1: Tensor, x2: Tensor, diag: bool) -> Tensor:

out = torch.empty((q_dim_size, x1.shape[1], x2.shape[1]), device=x1.device)
for q in range(q_dim_size):
out[q] = self._compute_kernel(x1[q], x2[q], diag=diag)
out[q] = self.forward(x1[q], x2[q], diag=diag)
return out

def _prepare_indices(self, x1: Tensor, x2: Tensor) -> tuple[list[int], list[int]]:
Expand All @@ -127,34 +107,14 @@ def _prepare_indices(self, x1: Tensor, x2: Tensor) -> tuple[list[int], list[int]

return indices1, indices2

def _compute_self_kernel(self, indices: list[int], diag: bool) -> Tensor:
"""Compute kernel matrix for self-similarity case."""
indices_tuple = tuple(indices)
if indices_tuple in self.cache:
return self.cache[indices_tuple]

adj_matrices = [self.adjacency_cache[i] for i in indices]
label_tensors = [self.label_cache[i] for i in indices]

# Compute kernel matrix
K = self._compute_base_kernel(adj_matrices, label_tensors)
if diag:
K = torch.diag(K)

self.cache[indices_tuple] = K
return K

def _compute_cross_kernel(
@lru_cache(maxsize=128)
def _compute_kernel(
self,
indices1: list[int],
indices2: list[int],
indices1: tuple[int],
indices2: tuple[int],
diag: bool,
) -> Tensor:
"""Compute kernel matrix for cross-similarity case."""
cache_key = (tuple(indices1), tuple(indices2))
if cache_key in self.cache:
return self.cache[cache_key]

"""Compute the kernel matrix."""
all_graphs = list(set(indices1 + indices2))
adj_matrices = [self.adjacency_cache[i] for i in all_graphs]
label_tensors = [self.label_cache[i] for i in all_graphs]
Expand All @@ -168,7 +128,6 @@ def _compute_cross_kernel(
if diag:
K = torch.diag(K)

self.cache[cache_key] = K
return K

def _compute_base_kernel(
Expand Down Expand Up @@ -206,6 +165,7 @@ def __init__(self, n_iter: int = 5, *, normalize: bool = True) -> None:
self.label_dict = {}
self.label_counter = 0

@lru_cache(maxsize=128)
def _get_node_neighbors(self, adj: Tensor) -> list[list[int]]:
"""Extract neighborhood information from adjacency matrix."""
if adj.layout == torch.sparse_csr:
Expand All @@ -221,6 +181,7 @@ def _get_node_neighbors(self, adj: Tensor) -> list[list[int]]:

return neighbors

@lru_cache(maxsize=128)
def _wl_iteration(self, adj: Tensor, labels: Tensor) -> Tensor:
"""Perform one WL iteration."""
if not self.label_dict:
Expand Down

0 comments on commit 4cc0b29

Please sign in to comment.