diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 9590ae86583..09be6bf0df7 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -50,6 +50,7 @@ OptimizationShouldStop, UnsupportedError, UnsupportedPlotError, + UserInputError, ) from ax.exceptions.generation_strategy import MaxParallelismReachedException from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy @@ -1118,6 +1119,11 @@ def load_experiment_from_database( ) logger.info(f"Loaded {experiment}.") if generation_strategy is None: + if choose_generation_strategy_kwargs is None: + raise UserInputError( + f"No generation strategy was found for {experiment}. Please " + "pass `choose_generation_strategy_kwargs` to load it with one." + ) self._set_generation_strategy( choose_generation_strategy_kwargs=choose_generation_strategy_kwargs ) diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index ae4a56c2e3c..3d97d6145cc 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -69,13 +69,17 @@ from ax.storage.sqa_store.db import init_test_engine_and_session_factory from ax.storage.sqa_store.decoder import Decoder from ax.storage.sqa_store.encoder import Encoder +from ax.storage.sqa_store.save import save_experiment from ax.storage.sqa_store.sqa_config import SQAConfig from ax.storage.sqa_store.structs import DBSettings from ax.utils.common.random import with_rng_seed from ax.utils.common.testutils import TestCase from ax.utils.common.typeutils import checked_cast, not_none from ax.utils.measurement.synthetic_functions import Branin -from ax.utils.testing.core_stubs import DummyEarlyStoppingStrategy +from ax.utils.testing.core_stubs import ( + DummyEarlyStoppingStrategy, + get_branin_experiment, +) from ax.utils.testing.mock import fast_botorch_optimize from ax.utils.testing.modeling_stubs import get_observation1, get_observation1trans from botorch.test_functions.multi_objective import BraninCurrin @@ -625,6 +629,26 @@ def test_save_and_load_generation_strategy(self) -> None: second_client.load_experiment_from_database("unique_test_experiment") self.assertEqual(second_client.generation_strategy, generation_strategy) + def test_save_and_load_no_generation_strategy(self) -> None: + init_test_engine_and_session_factory(force_init=True) + config = SQAConfig() + encoder = Encoder(config=config) + decoder = Decoder(config=config) + db_settings = DBSettings(encoder=encoder, decoder=decoder) + experiment = get_branin_experiment(named=True) + save_experiment(experiment=experiment, config=config) + client = AxClient(db_settings=db_settings) + with self.assertRaisesRegex( + UserInputError, "choose_generation_strategy_kwargs" + ): + client.load_experiment_from_database(experiment.name) + + client = AxClient(db_settings=db_settings) + client.load_experiment_from_database( + experiment_name=experiment.name, choose_generation_strategy_kwargs={} + ) + self.assertIsNotNone(client.generation_strategy) + @patch( f"{AxClient.__module__}.AxClient._save_experiment_to_db_if_possible", side_effect=Exception("patched db exception"),