Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Graph-Aware Bayesian Optimization #179

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
65 changes: 65 additions & 0 deletions neps/optimizers/models/graphs/context_managers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations

from collections.abc import Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING

from botorch.models import SingleTaskGP

from neps.optimizers.models.graphs.kernels import BoTorchWLKernel, compute_kernel

if TYPE_CHECKING:
import networkx as nx
from botorch.models.gp_regression_mixed import Kernel


@contextmanager
def set_graph_lookup(
kernel_or_gp: Kernel | SingleTaskGP,
new_graphs: list[nx.Graph],
*,
append: bool = True,
) -> Iterator[None]:
"""Context manager to temporarily set the graph lookup for a kernel or GP model.

Args:
kernel_or_gp (Kernel | SingleTaskGP): The kernel or GP model whose graph lookup is
to be set.
new_graphs (list[nx.Graph]): The new graphs to set in the graph lookup.
append (bool, optional): Whether to append the new graphs to the existing graph
lookup. Defaults to True.
"""
kernel_prev_graphs: list[tuple[Kernel, list[nx.Graph]]] = []

# Determine the modules to update based on the input type
if isinstance(kernel_or_gp, SingleTaskGP):
modules = [
k
for k in kernel_or_gp.covar_module.sub_kernels()
if isinstance(k, BoTorchWLKernel)
]
elif isinstance(kernel_or_gp, BoTorchWLKernel):
modules = [kernel_or_gp]
else:
assert hasattr(kernel_or_gp, "sub_kernels"), (
"Kernel module must have sub_kernels method."
)
modules = [
k for k in kernel_or_gp.sub_kernels() if isinstance(k, BoTorchWLKernel)
]

# Save the current graph lookup and set the new graph lookup
for kern in modules:
compute_kernel.cache_clear()

kernel_prev_graphs.append((kern, kern.graph_lookup))
if append:
kern.set_graph_lookup([*kern.graph_lookup, *new_graphs])
else:
kern.set_graph_lookup(new_graphs)

yield

# Restore the original graph lookup after the context manager exits
for kern, prev_graphs in kernel_prev_graphs:
kern.set_graph_lookup(prev_graphs)
304 changes: 304 additions & 0 deletions neps/optimizers/models/graphs/kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
from __future__ import annotations

from functools import lru_cache
from typing import TYPE_CHECKING, Any

import torch
from botorch.models.gp_regression_mixed import Kernel
from torch import Tensor
from torch.nn import Module

from neps.optimizers.models.graphs.utils import graphs_to_tensors

if TYPE_CHECKING:
import networkx as nx


@lru_cache(maxsize=128)
def compute_kernel(
adjacency_cache: tuple[Tensor, ...],
label_cache: tuple[Tensor, ...],
indices1: tuple[int, ...],
indices2: tuple[int, ...],
n_iter: int,
*,
diag: bool,
normalize: bool,
) -> Tensor:
"""Compute the kernel matrix.

This function is defined outside the class to leverage the `lru_cache` decorator,
which caches the results of expensive function calls and reuses them when the same
inputs occur again.

Args:
adjacency_cache: Tuple of adjacency matrices for the graphs.
label_cache: Tuple of initial node labels for the graphs.
indices1: Tuple of indices for the first set of graphs.
indices2: Tuple of indices for the second set of graphs.
n_iter: Number of WL iterations.
diag: Whether to return only the diagonal of the kernel matrix.
normalize: Whether to normalize the kernel matrix.

Returns:
A Tensor representing the kernel matrix.
"""
all_graphs = list(set(indices1).union(indices2))
adj_matrices = [adjacency_cache[i] for i in all_graphs]
label_tensors = [label_cache[i] for i in all_graphs]

# Compute full kernel matrix
_kernel = TorchWLKernel(n_iter=n_iter, normalize=normalize)
K_full = _kernel(adj_matrices, label_tensors)

# Map indices to their positions in all_graphs
idx1 = [all_graphs.index(i) for i in indices1]
idx2 = [all_graphs.index(i) for i in indices2]

# Extract the relevant submatrix
K = K_full[idx1][:, idx2]

# Return the diagonal if requested
if diag:
return torch.diag(K)

return K


class BoTorchWLKernel(Kernel):
"""A custom kernel for Gaussian Processes using the Weisfeiler-Lehman (WL) algorithm.

This kernel computes similarities between graphs based on their structural properties
using the WL algorithm. It is designed to be used with BoTorch and GPyTorch for
Gaussian Process regression.

Args:
graph_lookup (list[nx.Graph]): List of NetworkX graphs.
n_iter (int, optional): Number of WL iterations to perform. Default is 5.
normalize (bool, optional): Whether to normalize the kernel matrix.
Default is True.
active_dims (tuple[int, ...]): Dimensions of the input to consider.
Not used in this kernel but included for compatibility with the base Kernel class.
**kwargs (Any): Additional arguments for the base Kernel class.

Attributes:
graph_lookup (list[nx.Graph]): List of graphs used for kernel computation.
n_iter (int): Number of WL iterations.
normalize (bool): Whether to normalize the kernel matrix.
adjacency_cache (list[Tensor]): Cached adjacency matrices of the graphs.
label_cache (list[Tensor]): Cached initial node labels of the graphs.
"""

has_lengthscale = False

def __init__(
self,
graph_lookup: list[nx.Graph],
n_iter: int = 5,
*,
normalize: bool = True,
active_dims: tuple[int, ...],
**kwargs: Any,
) -> None:
super().__init__(active_dims=active_dims, **kwargs)
self.graph_lookup = graph_lookup
self.n_iter = n_iter
self.normalize = normalize
self._precompute_graph_data()

def _precompute_graph_data(self) -> None:
"""Precompute and cache adjacency matrices and initial node labels."""
self.adjacency_cache, self.label_cache = graphs_to_tensors(
self.graph_lookup, device=self.device
)

def set_graph_lookup(self, graph_lookup: list[nx.Graph]) -> None:
"""Update the graph lookup and refresh the cached data."""
self.graph_lookup = graph_lookup
self._precompute_graph_data()

def forward(
self,
x1: Tensor,
x2: Tensor,
*,
diag: bool = False,
last_dim_is_batch: bool = False,
**params: Any,
) -> Tensor:
"""Compute kernel matrix containing pairwise similarities between graphs."""
if last_dim_is_batch:
raise NotImplementedError("Batch dimension handling is not implemented.")

if x1.ndim == 3:
return self._handle_batched_input(x1=x1, x2=x2, diag=diag)

indices1, indices2 = self._prepare_indices(x1, x2)

return compute_kernel(
adjacency_cache=tuple(self.adjacency_cache),
label_cache=tuple(self.label_cache),
indices1=tuple(indices1),
indices2=tuple(indices2),
n_iter=self.n_iter,
diag=diag,
normalize=self.normalize,
)

def _handle_batched_input(self, x1: Tensor, x2: Tensor, *, diag: bool) -> Tensor:
"""Handle computation for batched input tensors."""
q_dim_size = x1.shape[0]
assert x2.shape[0] == q_dim_size

out = torch.empty((q_dim_size, x1.shape[1], x2.shape[1]), device=x1.device)
for q in range(q_dim_size):
out[q] = self.forward(x1[q], x2[q], diag=diag)
return out

def _prepare_indices(self, x1: Tensor, x2: Tensor) -> tuple[list[int], list[int]]:
"""Convert tensor indices to integer lists."""
indices1 = x1.flatten().to(torch.int64).tolist()
indices2 = x2.flatten().to(torch.int64).tolist()

# Check for missing graph indices (-1) and handle them
# Explanation: The index `-1` is used as a placeholder for "missing" or "invalid"
# graphs. This can occur when a graph feature is missing or undefined, such as
# during the exploration of new candidates where no corresponding graph is
# available in the `graph_lookup`. The kernel expects non-negative indices, so we
# need to convert `-1` to the index of the last graph in the lookup.

# Use the last graph in the lookup as a placeholder
last_graph_idx = len(self.graph_lookup) - 1

if -1 in indices1:
# Replace any `-1` indices with the index of the last graph.
indices1 = [last_graph_idx if i == -1 else i for i in indices1]

if -1 in indices2:
# Replace any `-1` indices with the index of the last graph.
indices2 = [last_graph_idx if i == -1 else i for i in indices2]

return indices1, indices2


class TorchWLKernel(Module):
"""A custom implementation of Weisfeiler-Lehman (WL) Kernel in PyTorch.

The WL Kernel is a graph kernel that measures similarity between graphs based on
their structural properties. It works by iteratively updating node labels based on
their neighborhoods and computing feature vectors from label distributions.

Args:
n_iter: Number of WL iterations to perform
normalize: bool, optional. Whether to normalize the kernel matrix

Attributes:
device: torch.device for computation (CPU/GPU)
label_dict: Mapping from node labels to numerical indices
label_counter: Counter for generating new label indices
"""

def __init__(self, n_iter: int = 5, *, normalize: bool = True) -> None:
super().__init__()
self.n_iter = n_iter
self.normalize = normalize
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Keep track of labels across iterations
self.label_dict: dict[str, int] = {}
self.label_counter: int = 0

def _get_node_neighbors(self, adj: Tensor) -> list[list[int]]:
"""Extract neighborhood information from adjacency matrix."""
if adj.layout == torch.sparse_csr:
adj = adj.to_sparse_coo()

adj = adj.coalesce()
rows, cols = adj.indices()
num_nodes = adj.size(0)

neighbors: list[list[int]] = [[] for _ in range(num_nodes)]
for row, col in zip(rows.tolist(), cols.tolist(), strict=False):
neighbors[row].append(col)

return neighbors

def _wl_iteration(self, adj: Tensor, labels: Tensor) -> Tensor:
"""Perform one WL iteration."""
if not self.label_dict:
# Start new labels after initial ones
self.label_counter = int(labels.max().item()) + 1

num_nodes = labels.size(0)
new_labels: list[int] = []
neighbors = self._get_node_neighbors(adj)

for node_idx in range(num_nodes):
# Get current node label
node_label = int(labels[node_idx].item())
neighbor_labels = sorted([int(labels[n].item()) for n in neighbors[node_idx]])

credential = f"{node_label},{neighbor_labels}"

# Update label dictionary
new_labels.append(
self.label_dict.setdefault(credential, len(self.label_dict))
)

return torch.tensor(new_labels, dtype=torch.int64, device=self.device)

def _compute_feature_vector(self, all_labels: list[list[Tensor]]) -> Tensor:
"""Compute the histogram feature vector for all graphs."""
batch_size = len(all_labels[0])
features: list[Tensor] = []

for iteration_labels in all_labels:
# Find maximum label value across all graphs in this iteration
max_label = int(max(label.max().item() for label in iteration_labels)) + 1

iter_features = torch.zeros((batch_size, max_label), device=self.device)

# Compute label frequencies
for graph_idx, labels in enumerate(iteration_labels):
counts = torch.bincount(labels, minlength=max_label)
iter_features[graph_idx] = counts

features.append(iter_features)

return torch.cat(features, dim=1)

def forward(self, adj_matrices: list[Tensor], label_tensors: list[Tensor]) -> Tensor:
"""Compute WL kernel matrix for a list of graphs.

Args:
adj_matrices: Precomputed sparse adjacency matrices for graphs.
label_tensors: Precomputed node label tensors for graphs.

Returns:
Kernel matrix containing pairwise graph similarities.
"""
if len(adj_matrices) != len(label_tensors):
raise ValueError("Mismatch between adjacency matrices and label tensors.")

# Reset label dictionary for new computation
self.label_dict = {}
# Store all label iterations
all_labels: list[list[Tensor]] = [label_tensors]

# Perform WL iterations
for _ in range(self.n_iter):
new_labels = [
self._wl_iteration(adj, labels)
for adj, labels in zip(adj_matrices, all_labels[-1], strict=False)
]
all_labels.append(new_labels)

# Compute feature vectors and kernel matrix (similarity matrix)
final_features = self._compute_feature_vector(all_labels)
kernel_matrix = torch.mm(final_features, final_features.t())

if self.normalize:
diag = torch.sqrt(torch.diag(kernel_matrix))
kernel_matrix /= torch.outer(diag, diag)

return kernel_matrix
Loading
Loading