Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow candidate_generator_function to stop interactive_optimize #2950

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion ax/service/interactive_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ def interactive_optimize(
for _i in range(num_trials):
candidate_item = candidate_queue.get()

if candidate_item is None:
# if candidate_item is None,
# it means the candidate generator has failed and stopped
optimization_completed = False
break

response = elicitation_function(
candidate_item, **(elicitation_function_kwargs or {})
)
Expand All @@ -111,7 +117,8 @@ def interactive_optimize(
if response is not None:
data_queue.put(response)
else:
# if resopnse is None, abort the optimization
# if resopnse is None, it means the user has stopped
# abort the optimization
optimization_completed = False
break

Expand Down
126 changes: 70 additions & 56 deletions ax/service/tests/test_interactive_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,43 +10,32 @@
import functools
import time
from logging import WARN
from queue import Queue
from threading import Event, Lock

import numpy as np
from ax.core.types import TEvaluationOutcome

from ax.core.types import TEvaluationOutcome, TParameterization
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.service.ax_client import AxClient, TParameterization
from ax.service.interactive_loop import interactive_optimize_with_client
from ax.service.ax_client import AxClient
from ax.service.interactive_loop import (
ax_client_data_attacher,
interactive_optimize,
interactive_optimize_with_client,
)
from ax.service.utils.instantiation import ObjectiveProperties
from ax.utils.common.testutils import TestCase
from ax.utils.measurement.synthetic_functions import hartmann6
from ax.utils.testing.mock import fast_botorch_optimize


class TestInteractiveLoop(TestCase):
@fast_botorch_optimize
def test_interactive_loop(self) -> None:
def _elicit(
parameterization_with_trial_index: tuple[TParameterization, int],
) -> tuple[int, TEvaluationOutcome] | None:
parameterization, trial_index = parameterization_with_trial_index
x = np.array([parameterization.get(f"x{i+1}") for i in range(6)])

return (
trial_index,
{
"hartmann6": (hartmann6(x), 0.0),
"l2norm": (np.sqrt((x**2).sum()), 0.0),
},
)

def _aborted_elicit(
parameterization_with_trial_index: tuple[TParameterization, int],
) -> tuple[int, TEvaluationOutcome] | None:
return None

ax_client = AxClient()
ax_client.create_experiment(
def setUp(self) -> None:
generation_strategy = GenerationStrategy(
steps=[GenerationStep(model=Models.SOBOL, max_parallelism=1, num_trials=-1)]
)
self.ax_client = AxClient(generation_strategy=generation_strategy)
self.ax_client.create_experiment(
name="hartmann_test_experiment",
# pyre-fixme[6]
parameters=[
Expand All @@ -61,35 +50,80 @@ def _aborted_elicit(
tracking_metric_names=["l2norm"],
)

def _elicit(
self,
parameterization_with_trial_index: tuple[TParameterization, int],
) -> tuple[int, TEvaluationOutcome] | None:
parameterization, trial_index = parameterization_with_trial_index
x = np.array([parameterization.get(f"x{i + 1}") for i in range(6)])

return (
trial_index,
{
"hartmann6": (hartmann6(x), 0.0),
"l2norm": (np.sqrt((x**2).sum()), 0.0),
},
)

def test_interactive_loop(self) -> None:
optimization_completed = interactive_optimize_with_client(
ax_client=ax_client,
ax_client=self.ax_client,
num_trials=15,
candidate_queue_maxsize=3,
# pyre-fixme[6]
elicitation_function=_elicit,
elicitation_function=self._elicit,
)

self.assertTrue(optimization_completed)
self.assertEqual(len(ax_client.experiment.trials), 15)
self.assertEqual(len(self.ax_client.experiment.trials), 15)

def test_interactive_loop_aborted(self) -> None:
# Abort from elicitation function
def _aborted_elicit(
parameterization_with_trial_index: tuple[TParameterization, int],
) -> tuple[int, TEvaluationOutcome] | None:
return None

# test failed experiment
optimization_completed = interactive_optimize_with_client(
ax_client=ax_client,
ax_client=self.ax_client,
num_trials=15,
candidate_queue_maxsize=3,
# pyre-fixme[6]
elicitation_function=_aborted_elicit,
)
self.assertFalse(optimization_completed)

# Abort from candidate_generator
def ax_client_candidate_generator(
queue: Queue[tuple[TParameterization, int] | None],
stop_event: Event,
num_trials: int,
lock: Lock,
) -> None:
with lock:
queue.put(None)
stop_event.set()

ax_client_lock = Lock()
optimization_completed = interactive_optimize(
num_trials=15,
candidate_queue_maxsize=3,
candidate_generator_function=ax_client_candidate_generator,
candidate_generator_kwargs={"lock": ax_client_lock},
data_attacher_function=ax_client_data_attacher,
data_attacher_kwargs={"ax_client": self.ax_client, "lock": ax_client_lock},
elicitation_function=self._elicit,
)
self.assertFalse(optimization_completed)

def test_candidate_pregeneration_errors_raised(self) -> None:
def _elicit(
def _sleep_elicit(
parameterization_with_trial_index: tuple[TParameterization, int],
) -> tuple[int, TEvaluationOutcome]:
parameterization, trial_index = parameterization_with_trial_index
time.sleep(0.15) # Sleep to induce MaxParallelismException in loop

x = np.array([parameterization.get(f"x{i+1}") for i in range(6)])
x = np.array([parameterization.get(f"x{i + 1}") for i in range(6)])

return (
trial_index,
Expand All @@ -99,33 +133,13 @@ def _elicit(
},
)

# GS with low max parallelismm to induce MaxParallelismException:
generation_strategy = GenerationStrategy(
steps=[GenerationStep(model=Models.SOBOL, max_parallelism=1, num_trials=-1)]
)
ax_client = AxClient(generation_strategy=generation_strategy)
ax_client.create_experiment(
name="hartmann_test_experiment",
# pyre-fixme[6]
parameters=[
{
"name": f"x{i}",
"type": "range",
"bounds": [0.0, 1.0],
}
for i in range(1, 7)
],
objectives={"hartmann6": ObjectiveProperties(minimize=True)},
tracking_metric_names=["l2norm"],
)

with self.assertLogs(logger="ax", level=WARN) as logger:
interactive_optimize_with_client(
ax_client=ax_client,
ax_client=self.ax_client,
num_trials=3,
candidate_queue_maxsize=3,
# pyre-fixme[6]
elicitation_function=_elicit,
elicitation_function=_sleep_elicit,
)

# Assert sleep and retry warning is somewhere in the logs
Expand Down