Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible PyTorch implementation of WL kernel #153

Closed
wants to merge 62 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
36fc3bd
Add a PyTorch implementation of WL kernel
vladislavalerievich Oct 28, 2024
b0d3842
Fix imports
vladislavalerievich Oct 29, 2024
f87abd6
Remove redundant copy
vladislavalerievich Oct 29, 2024
358fbb7
Increase precision for allclose
vladislavalerievich Oct 29, 2024
de140b6
Fix calculation for graphs with reordered edges
vladislavalerievich Oct 29, 2024
08c7aea
Increase test coverage
vladislavalerievich Oct 29, 2024
6f07858
Improve readability of TorchWLKernel
vladislavalerievich Oct 30, 2024
896f461
Add additional comments to TorchWLKernel
vladislavalerievich Oct 30, 2024
383e924
Add MixedSingleTaskGP to process graphs
vladislavalerievich Nov 8, 2024
65666a3
Refactor WLKernelWrapper into a standalone WLKernel class.
vladislavalerievich Nov 20, 2024
7fa9432
Update tests
vladislavalerievich Nov 20, 2024
4227f22
Add a check for empty inputs
vladislavalerievich Nov 20, 2024
f194bd2
Improve and combine tests
vladislavalerievich Nov 20, 2024
a104840
Update WLKernel
vladislavalerievich Nov 21, 2024
246f9f6
Add acquisition function with graph sampling
vladislavalerievich Nov 21, 2024
770c626
Add a custom __call__ method to pass graphs during optimization
vladislavalerievich Nov 21, 2024
8bf7ea7
Update MixedSingleTaskGP
vladislavalerievich Dec 7, 2024
84d0104
Remove not used argument
vladislavalerievich Dec 7, 2024
d63239a
Update sample_graphs
vladislavalerievich Dec 7, 2024
3db3f89
Handle different batch dimensions
vladislavalerievich Dec 7, 2024
f69ddbe
Set num_restarts=10
vladislavalerievich Dec 7, 2024
1c4cc83
Add acquisition function
vladislavalerievich Dec 7, 2024
dab9a8c
Update WLKernel
vladislavalerievich Dec 7, 2024
2999582
Make train_inputs private
vladislavalerievich Dec 7, 2024
ad55030
Update tests
vladislavalerievich Dec 7, 2024
8093d31
fix: Implement graph acquisition
eddiebergman Dec 16, 2024
9f978d6
fix: Implement graph acquisition (#164)
vladislavalerievich Dec 24, 2024
a1a29a8
Delete unused MixedSingleTaskGP
vladislavalerievich Dec 24, 2024
046ad66
Add seed_all and min_max_scale
vladislavalerievich Dec 24, 2024
0a609f7
Refactor optimize.py
vladislavalerievich Dec 24, 2024
5486dcc
Speed up WL kernel computations
vladislavalerievich Dec 24, 2024
f140c56
Process wl iterations in batches
vladislavalerievich Dec 24, 2024
371b530
Use CSR
vladislavalerievich Dec 25, 2024
1478fd9
Implement caching
vladislavalerievich Jan 16, 2025
a4ffaaf
Clean up __init__ methods
vladislavalerievich Jan 16, 2025
2ec7d5b
Split _compute_kernel logic into smaller methods
vladislavalerievich Jan 16, 2025
8d6b63b
Rename kernel to BoTorchWLKernel
vladislavalerievich Jan 16, 2025
f18642b
Move GraphDataset class into utils.py
vladislavalerievich Jan 16, 2025
bb92de4
Delete GraphDataset
vladislavalerievich Jan 19, 2025
e409798
Update tests
vladislavalerievich Jan 20, 2025
51e6ae4
Simplify TorchWLKernel
vladislavalerievich Jan 20, 2025
bdd32db
Remove torch_wl_usage_example.py
vladislavalerievich Jan 21, 2025
7747e49
Update grakel_wl_usage_example.py
vladislavalerievich Jan 23, 2025
21b32c8
Update TestTorchWLKernel
vladislavalerievich Jan 23, 2025
dabf4f0
Create graphs_to_tensors function
vladislavalerievich Jan 23, 2025
22cf6d5
Add docstring to BoTorchWLKernel
vladislavalerievich Jan 23, 2025
52b3b14
Add tests for the BoTorchWLKernel
vladislavalerievich Jan 23, 2025
fe79d63
Move redundant files to examples directory
vladislavalerievich Jan 23, 2025
7729d2c
Combine set_graph_lookup context managers into one
vladislavalerievich Jan 23, 2025
1cecacc
Update comments for the optimize_acqf_graph function
vladislavalerievich Jan 23, 2025
ab730d3
Move sample_graphs into utils.py
vladislavalerievich Jan 23, 2025
3eb793d
Rename mixed_single_task_gp_usage_example.py
vladislavalerievich Jan 23, 2025
0cfae28
Add comments
vladislavalerievich Jan 23, 2025
88ddfe1
Move set_graph_lookup into its own file
vladislavalerievich Jan 23, 2025
6d9ea56
Update imports
vladislavalerievich Jan 23, 2025
be04ad2
Print results
vladislavalerievich Jan 23, 2025
458d420
Provide better file names
vladislavalerievich Jan 23, 2025
f7922db
Organize imports
vladislavalerievich Jan 23, 2025
4cc0b29
Use lru_cache instead of simple dict cache
vladislavalerievich Jan 23, 2025
4e8bdad
Improve tests
vladislavalerievich Jan 23, 2025
ea77e44
Fix ruff and mypy complaints
vladislavalerievich Jan 23, 2025
5e2a33b
Improve kernels
vladislavalerievich Jan 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added neps/graphs/__init__.py
Empty file.
58 changes: 58 additions & 0 deletions neps/graphs/context_managers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from __future__ import annotations

from collections.abc import Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING

from botorch.models import SingleTaskGP

from neps.graphs.kernels import BoTorchWLKernel

if TYPE_CHECKING:
import networkx as nx
from botorch.models.gp_regression_mixed import Kernel


@contextmanager
def set_graph_lookup(
kernel_or_gp: Kernel | SingleTaskGP,
new_graphs: list[nx.Graph],
*,
append: bool = True,
) -> Iterator[None]:
"""Context manager to temporarily set the graph lookup for a kernel or GP model.

Args:
kernel_or_gp (Kernel | SingleTaskGP): The kernel or GP model whose graph lookup is
to be set.
new_graphs (list[nx.Graph]): The new graphs to set in the graph lookup.
append (bool, optional): Whether to append the new graphs to the existing graph
lookup. Defaults to True.
"""
kernel_prev_graphs: list[tuple[Kernel, list[nx.Graph]]] = []

# Determine the modules to update based on the input type
if isinstance(kernel_or_gp, SingleTaskGP):
modules = [k for k in kernel_or_gp.covar_module.sub_kernels() if
isinstance(k, BoTorchWLKernel)]
elif isinstance(kernel_or_gp, BoTorchWLKernel):
modules = [kernel_or_gp]
else:
assert hasattr(kernel_or_gp,
"sub_kernels"), "Kernel module must have sub_kernels method."
modules = [k for k in kernel_or_gp.sub_kernels() if
isinstance(k, BoTorchWLKernel)]

# Save the current graph lookup and set the new graph lookup
for kern in modules:
kernel_prev_graphs.append((kern, kern.graph_lookup))
if append:
kern.set_graph_lookup([*kern.graph_lookup, *new_graphs])
else:
kern.set_graph_lookup(new_graphs)

yield

# Restore the original graph lookup after the context manager exits
for kern, prev_graphs in kernel_prev_graphs:
kern.set_graph_lookup(prev_graphs)
52 changes: 52 additions & 0 deletions neps/graphs/examples/grakel_wl_usage_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import annotations

import matplotlib.pyplot as plt
import networkx as nx
from grakel import WeisfeilerLehman, graph_from_networkx


def visualize_graph(G):
"""Visualize the NetworkX graph."""
pos = nx.spring_layout(G)
nx.draw(G, pos, with_labels=True, node_size=700, node_color="lightblue")
plt.show()


def add_labels(G):
"""Add labels to the nodes of the graph."""
for node in G.nodes():
G.nodes[node]["label"] = str(node)


# Create graphs
G1 = nx.Graph()
G1.add_edges_from([(0, 1), (1, 2), (1, 3), (1, 4), (2, 3)])
add_labels(G1)

G2 = nx.Graph()
G2.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4)])
add_labels(G2)

G3 = nx.Graph()
G3.add_edges_from([(0, 1), (1, 3), (3, 2)])
add_labels(G3)

# Visualize the graphs
visualize_graph(G1)
visualize_graph(G2)
visualize_graph(G3)

# Convert NetworkX graphs to Grakel format using graph_from_networkx
graph_list = list(
graph_from_networkx([G1, G2, G3], node_labels_tag="label", as_Graph=True)
)

# Initialize the Weisfeiler-Lehman kernel
wl_kernel = WeisfeilerLehman(n_iter=5, normalize=False)

# Compute the kernel matrix
K = wl_kernel.fit_transform(graph_list)

# Display the kernel matrix
print("Fit and Transform on Kernel matrix (pairwise similarities):")
print(K)
136 changes: 136 additions & 0 deletions neps/graphs/examples/single_task_gp_usage_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from __future__ import annotations

from itertools import product
from typing import TYPE_CHECKING

import torch
from botorch import fit_gpytorch_mll
from botorch.acquisition import LinearMCObjective, qLogNoisyExpectedImprovement
from botorch.models import SingleTaskGP
from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel
from botorch.optim import optimize_acqf_mixed
from gpytorch import ExactMarginalLogLikelihood
from gpytorch.kernels import AdditiveKernel, MaternKernel

if TYPE_CHECKING:
from gpytorch.distributions.multivariate_normal import MultivariateNormal

TRAIN_CONFIGS = 10
TEST_CONFIGS = 10
TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS

N_NUMERICAL = 2
N_CATEGORICAL = 2
N_CATEGORICAL_VALUES_PER_CATEGORY = 3

kernels = []

# Create some random encoded hyperparameter configurations
X = torch.empty(size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL), dtype=torch.float64)
if N_NUMERICAL > 0:
X[:, :N_NUMERICAL] = torch.rand(
size=(TOTAL_CONFIGS, N_NUMERICAL),
dtype=torch.float64,
)

if N_CATEGORICAL > 0:
X[:, N_NUMERICAL:] = torch.randint(
0,
N_CATEGORICAL_VALUES_PER_CATEGORY,
size=(TOTAL_CONFIGS, N_CATEGORICAL),
dtype=torch.float64,
)

y = torch.rand(size=(TOTAL_CONFIGS,), dtype=torch.float64)

if N_NUMERICAL > 0:
matern = ScaleKernel(
MaternKernel(
nu=2.5,
ard_num_dims=N_NUMERICAL,
active_dims=tuple(range(N_NUMERICAL)),
),
)
kernels.append(matern)

if N_CATEGORICAL > 0:
hamming = ScaleKernel(
CategoricalKernel(
ard_num_dims=N_CATEGORICAL,
active_dims=tuple(range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL)),
),
)
kernels.append(hamming)

combined_num_cat_kernel = AdditiveKernel(*kernels)

train_x = X[:TRAIN_CONFIGS]
train_y = y[:TRAIN_CONFIGS]

test_x = X[TRAIN_CONFIGS:]
test_y = y[TRAIN_CONFIGS:]

K_matrix = combined_num_cat_kernel.forward(train_x, train_x)

train_y = train_y.unsqueeze(-1)
test_y = test_y.unsqueeze(-1)

gp = SingleTaskGP(
train_X=train_x,
train_Y=train_y,
covar_module=combined_num_cat_kernel,
)

multivariate_normal: MultivariateNormal = gp.forward(train_x)

# =============== Fitting the GP using botorch ===============

print("\nFitting the GP model using botorch...")

mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
fit_gpytorch_mll(mll)

acq_function = qLogNoisyExpectedImprovement(
model=gp,
X_baseline=train_x,
objective=LinearMCObjective(weights=torch.tensor([-1.0])),
prune_baseline=True,
)

# Define bounds
bounds = torch.tensor(
[
[0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL,
[1.0] * N_NUMERICAL + [
float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL
]
)

# Setup categorical feature optimization
cats_per_column: dict[int, list[float]] = {
column_ix: [float(i) for i in range(N_CATEGORICAL_VALUES_PER_CATEGORY)]
for column_ix in range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL)
}

# Generate fixed categorical features
fixed_cats: list[dict[int, float]]
if len(cats_per_column) == 1:
col, choice_indices = next(iter(cats_per_column.items()))
fixed_cats = [{col: i} for i in choice_indices]
else:
fixed_cats = [
dict(zip(cats_per_column.keys(), combo, strict=False))
for combo in product(*cats_per_column.values())
]

best_candidate, best_score = optimize_acqf_mixed(
acq_function=acq_function,
bounds=bounds,
fixed_features_list=fixed_cats,
num_restarts=10,
raw_samples=10,
q=1,
)

print("Best candidate:", best_candidate)
print("Acquisition score:", best_score)
130 changes: 130 additions & 0 deletions neps/graphs/graph_aware_gp_optimization_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from __future__ import annotations

import time
from itertools import product
from typing import TYPE_CHECKING

import networkx as nx
import torch
from botorch import fit_gpytorch_mll, settings
from botorch.acquisition import LinearMCObjective, qLogNoisyExpectedImprovement
from botorch.models import SingleTaskGP
from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel
from gpytorch import ExactMarginalLogLikelihood
from gpytorch.kernels import AdditiveKernel, MaternKernel

from neps.graphs.context_managers import set_graph_lookup
from neps.graphs.kernels import BoTorchWLKernel, TorchWLKernel
from neps.graphs.optimization import optimize_acqf_graph
from neps.graphs.utils import min_max_scale, seed_all

if TYPE_CHECKING:
from gpytorch.distributions.multivariate_normal import MultivariateNormal

start_time = time.time()
settings.debug._set_state(True)
seed_all()

TRAIN_CONFIGS = 50
TEST_CONFIGS = 10
TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS

N_NUMERICAL = 2
N_CATEGORICAL = 1
N_CATEGORICAL_VALUES_PER_CATEGORY = 2
N_GRAPH = 1

assert N_GRAPH == 1, "This example only supports a single graph feature"

# Generate random data
X = torch.cat([
torch.rand((TOTAL_CONFIGS, N_NUMERICAL), dtype=torch.float64),
torch.randint(0, N_CATEGORICAL_VALUES_PER_CATEGORY, (TOTAL_CONFIGS, N_CATEGORICAL),
dtype=torch.float64),
torch.arange(TOTAL_CONFIGS, dtype=torch.float64).unsqueeze(1)
], dim=1)

# Generate random graphs
graphs = [nx.erdos_renyi_graph(5, 0.5) for _ in range(TOTAL_CONFIGS)]

# Generate random target values
y = torch.rand(TOTAL_CONFIGS, dtype=torch.float64) + 0.5

# Split into train and test sets
train_x, test_x = X[:TRAIN_CONFIGS], X[TRAIN_CONFIGS:]
train_graphs, test_graphs = graphs[:TRAIN_CONFIGS], graphs[TRAIN_CONFIGS:]
train_y, test_y = y[:TRAIN_CONFIGS].unsqueeze(-1), y[TRAIN_CONFIGS:].unsqueeze(-1)

train_x, test_x = min_max_scale(train_x), min_max_scale(test_x)

kernels = [
ScaleKernel(
MaternKernel(nu=2.5, ard_num_dims=N_NUMERICAL, active_dims=range(N_NUMERICAL))),
ScaleKernel(CategoricalKernel(
ard_num_dims=N_CATEGORICAL,
active_dims=range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL))),
ScaleKernel(BoTorchWLKernel(
graph_lookup=train_graphs, n_iter=5, normalize=True,
active_dims=(X.shape[1] - 1,)))
]

# Create the Gaussian Process model
gp = SingleTaskGP(train_X=train_x, train_Y=train_y, covar_module=AdditiveKernel(*kernels))

# Compute the posterior distribution
multivariate_normal: MultivariateNormal = gp.forward(train_x)

# Making predictions on test data
with torch.no_grad(), set_graph_lookup(gp, train_graphs + test_graphs, append=False):
posterior = gp.forward(test_x)
predictions = posterior.mean
uncertainties = posterior.variance.sqrt()
covar = posterior.covariance_matrix

# Fit the GP model
mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
fit_gpytorch_mll(mll)

# Define the acquisition function
acq_function = qLogNoisyExpectedImprovement(
model=gp,
X_baseline=train_x,
objective=LinearMCObjective(weights=torch.tensor([-1.0])),
prune_baseline=True,
)

# Define the bounds for optimization
bounds = torch.tensor([
[0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL + [-1.0] * N_GRAPH,
[1.0] * N_NUMERICAL + [
float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL + [
len(X) - 1] * N_GRAPH,
])

# Define fixed categorical features
cats_per_column = {i: list(range(N_CATEGORICAL_VALUES_PER_CATEGORY)) for i in
range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL)}
fixed_cats = [dict(zip(cats_per_column.keys(), combo, strict=False)) for combo in
product(*cats_per_column.values())]

# Optimize the acquisition function with graph sampling
best_candidate, best_score = optimize_acqf_graph(
acq_function=acq_function,
bounds=bounds,
fixed_features_list=fixed_cats,
train_graphs=train_graphs,
num_graph_samples=2,
num_restarts=2,
raw_samples=16,
q=1,
)

# Print the results
print(f"Best candidate: {best_candidate}")
print(f"Best score: {best_score}")
print(f"Elapsed time: {time.time() - start_time} seconds")

# Clear caches after optimization to avoid memory leaks or unexpected behavior
BoTorchWLKernel._compute_kernel.cache_clear()
TorchWLKernel._get_node_neighbors.cache_clear()
TorchWLKernel._wl_iteration.cache_clear()
Loading
Loading