Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix networkx adapter #160

Merged
merged 4 commits into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/molecule_search/mol_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def store_edges_params_in_nodes(graph: nx.DiGraph, opt_graph: OptGraph) -> OptGr
edges_params = {}
for predecessor in graph.predecessors(node):
edges_params.update({str(predecessor): graph.get_edge_data(predecessor, node)})
opt_graph.get_node_by_uid(str(node)).content.update({'edges_params': edges_params})
opt_graph.get_node_by_uid(str(node)).parameters.update({'edges_params': edges_params})
return opt_graph


Expand Down
4 changes: 3 additions & 1 deletion golem/core/adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def adapt(self, item: Union[DomainStructureType, Sequence[DomainStructureType]])
else:
return item

def restore(self, item: Union[Graph, Individual, PopulationT]) \
def restore(self, item: Union[Graph, Individual, PopulationT, Sequence[Graph]]) \
-> Union[DomainStructureType, Sequence[DomainStructureType]]:
"""Maps graphs from internal representation to domain graphs.
Performs mapping only if argument has a type of internal representation.
Expand All @@ -98,6 +98,8 @@ def restore(self, item: Union[Graph, Individual, PopulationT]) \
return self._restore(item.graph, item.metadata)
elif isinstance(item, Sequence) and isinstance(item[0], Individual):
return [self._restore(ind.graph, ind.metadata) for ind in item]
elif isinstance(item, Sequence) and isinstance(item[0], self.opt_graph_class):
return [self._restore(graph) for graph in item]
else:
return item

Expand Down
11 changes: 8 additions & 3 deletions golem/core/adapter/nx_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,20 @@ def __init__(self):
def _node_restore(self, node: GraphNode) -> Dict:
"""Transforms GraphNode to dict of NetworkX node attributes.
Override for custom behavior."""
if hasattr(node, 'content'):
return deepcopy(node.content)
if hasattr(node, 'parameters'):
parameters = node.parameters
if node.name:
parameters['name'] = node.name
return deepcopy(parameters)
else:
return {}

def _node_adapt(self, data: Dict) -> OptNode:
"""Transforms a dict of NetworkX node attributes to GraphNode.
Override for custom behavior."""
return OptNode(content=deepcopy(data))
data = deepcopy(data)
name = data.pop('name', None)
return OptNode(content={'name': name, 'params': data})

def _adapt(self, adaptee: nx.DiGraph) -> OptGraph:
mapped_nodes = {}
Expand Down
21 changes: 11 additions & 10 deletions golem/core/tuning/optuna_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, objective_evaluate: ObjectiveFunction,
n_jobs,
deviation)
self.objectives_number = objectives_number
self.study = None

def tune(self, graph: DomainGraphForTune, show_progress: bool = True) -> \
Union[DomainGraphForTune, Sequence[DomainGraphForTune]]:
Expand All @@ -42,7 +43,7 @@ def tune(self, graph: DomainGraphForTune, show_progress: bool = True) -> \

self.init_check(graph)

study = optuna.create_study(directions=['minimize'] * self.objectives_number)
self.study = optuna.create_study(directions=['minimize'] * self.objectives_number)

init_parameters, has_parameters_to_optimize = self._get_initial_point(graph)
if not has_parameters_to_optimize:
Expand All @@ -51,22 +52,22 @@ def tune(self, graph: DomainGraphForTune, show_progress: bool = True) -> \
else:
# Enqueue initial point to try
if init_parameters:
study.enqueue_trial(init_parameters)
self.study.enqueue_trial(init_parameters)

study.optimize(predefined_objective,
n_trials=self.iterations,
n_jobs=self.n_jobs,
timeout=self.timeout.seconds,
callbacks=[self.early_stopping_callback],
show_progress_bar=show_progress)
self.study.optimize(predefined_objective,
n_trials=self.iterations,
n_jobs=self.n_jobs,
timeout=self.timeout.seconds,
callbacks=[self.early_stopping_callback],
show_progress_bar=show_progress)

if not is_multi_objective:
best_parameters = study.best_trials[0].params
best_parameters = self.study.best_trials[0].params
tuned_graphs = self.set_arg_graph(graph, best_parameters)
self.was_tuned = True
else:
tuned_graphs = []
for best_trial in study.best_trials:
for best_trial in self.study.best_trials:
best_parameters = best_trial.params
tuned_graph = self.set_arg_graph(deepcopy(graph), best_parameters)
tuned_graphs.append(tuned_graph)
Expand Down
1 change: 1 addition & 0 deletions golem/core/tuning/tuner_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def _single_obj_final_check(self, tuned_graph: OptGraph):
f'worse than initial (+ {self.deviation}% deviation) {abs(init_metric):.3f}')
final_graph = self.init_graph
final_metric = self.init_metric
self.obtained_metric = final_metric
self.log.message(f'Final graph: {graph_structure(final_graph)}')
if final_metric is not None:
self.log.message(f'Final metric: {abs(final_metric):.3f}')
Expand Down
11 changes: 11 additions & 0 deletions test/unit/adapter/graph_data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import networkx as nx

from golem.core.adapter.nx_adapter import DumbNetworkxAdapter, BaseNetworkxAdapter
from test.unit.mocks.common_mocks import MockNode, MockDomainStructure, MockAdapter

Expand Down Expand Up @@ -50,6 +52,15 @@ def graph_with_custom_parameters(alpha_value):
return graph


def networkx_graph_with_parameters(alpha_value):
graph = nx.DiGraph()
graph.add_node('a')
graph.add_node('b')
graph.add_node('c', alpha=alpha_value)
graph.add_edges_from([('a', 'c'), ('b', 'c')])
return graph


def get_complex_graph():
node_a = MockNode('a')
node_b = MockNode('b', nodes_from=[node_a])
Expand Down
17 changes: 12 additions & 5 deletions test/unit/adapter/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,37 @@
import numpy as np
import pytest

from golem.core.adapter.nx_adapter import BaseNetworkxAdapter
from golem.core.dag.graph import Graph
from golem.core.dag.graph_node import GraphNode
from golem.core.dag.graph_verifier import GraphVerifier
from golem.core.dag.verification_rules import DEFAULT_DAG_RULES
from golem.core.optimisers.graph import OptNode
from test.unit.adapter.graph_data import get_graphs, graph_with_custom_parameters, get_complex_graph, get_adapters, \
get_optgraphs
get_optgraphs, networkx_graph_with_parameters
from test.unit.mocks.common_mocks import MockNode, MockAdapter
from test.unit.utils import find_first


def test_adapters_params_correct():
@pytest.mark.parametrize('adapter, graph_with_params', [(MockAdapter(), graph_with_custom_parameters),
(BaseNetworkxAdapter(), networkx_graph_with_parameters)])
def test_adapters_params_correct(adapter, graph_with_params):
""" Checking the correct conversion of hyperparameters in nodes when nodes
are passing through adapter
"""
init_alpha = 12.1
graph = graph_with_custom_parameters(init_alpha)
graph = graph_with_params(init_alpha)

# Convert into OptGraph object
adapter = MockAdapter()
opt_graph = adapter.adapt(graph)
assert np.isclose(init_alpha, opt_graph.root_node.parameters['alpha'])
# Get graph object back
restored_graph = adapter.restore(opt_graph)
# Get hyperparameter value after graph restoration
restored_alpha = restored_graph.root_node.content['params']['alpha']
if isinstance(graph, Graph):
restored_alpha = restored_graph.root_node.content['params']['alpha']
else:
restored_alpha = restored_graph.nodes['c']['alpha']
assert np.isclose(init_alpha, restored_alpha)


Expand Down
9 changes: 5 additions & 4 deletions test/unit/tuning/test_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_node_tuning(search_space, graph):


@pytest.mark.parametrize('tuner_cls', [OptunaTuner])
@pytest.mark.parametrize('graph, adapter, obj_eval',
@pytest.mark.parametrize('init_graph, adapter, obj_eval',
[(mock_graph_with_params(), MockAdapter(),
MockObjectiveEvaluate(Objective({'sum_metric': ParamsSumMetric.get_value,
'prod_metric': ParamsProductMetric.get_value},
Expand All @@ -119,11 +119,12 @@ def test_node_tuning(search_space, graph):
ObjectiveEvaluate(Objective({'sum_metric': ParamsSumMetric.get_value,
'prod_metric': ParamsProductMetric.get_value},
is_multi_objective=True)))])
def test_multi_objective_tuning(search_space, tuner_cls, graph, adapter, obj_eval):
init_metric = obj_eval.evaluate(graph)
def test_multi_objective_tuning(search_space, tuner_cls, init_graph, adapter, obj_eval):
init_metric = obj_eval.evaluate(init_graph)
tuner = tuner_cls(obj_eval, search_space, adapter, iterations=20, objectives_number=2)
tuned_graphs = tuner.tune(deepcopy(graph), show_progress=False)
tuned_graphs = tuner.tune(deepcopy(init_graph), show_progress=False)
for graph in tuned_graphs:
assert type(graph) == type(init_graph)
final_metric = obj_eval.evaluate(graph)
assert final_metric is not None
assert not init_metric.dominates(final_metric)
Loading