Skip to content

Commit

Permalink
Don't silently set a generation strategy on load (#2937)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2937

Reviewed By: Cesar-Cardoso

Differential Revision: D64786290
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Oct 23, 2024
1 parent 02c0ce5 commit d253afa
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
6 changes: 6 additions & 0 deletions ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
OptimizationShouldStop,
UnsupportedError,
UnsupportedPlotError,
UserInputError,
)
from ax.exceptions.generation_strategy import MaxParallelismReachedException
from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy
Expand Down Expand Up @@ -1118,6 +1119,11 @@ def load_experiment_from_database(
)
logger.info(f"Loaded {experiment}.")
if generation_strategy is None:
if choose_generation_strategy_kwargs is None:
raise UserInputError(
f"No generation strategy was found for {experiment}. Please "
"pass `choose_generation_strategy_kwargs` to load it with one."
)
self._set_generation_strategy(
choose_generation_strategy_kwargs=choose_generation_strategy_kwargs
)
Expand Down
26 changes: 25 additions & 1 deletion ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,17 @@
from ax.storage.sqa_store.db import init_test_engine_and_session_factory
from ax.storage.sqa_store.decoder import Decoder
from ax.storage.sqa_store.encoder import Encoder
from ax.storage.sqa_store.save import save_experiment
from ax.storage.sqa_store.sqa_config import SQAConfig
from ax.storage.sqa_store.structs import DBSettings
from ax.utils.common.random import with_rng_seed
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import checked_cast, not_none
from ax.utils.measurement.synthetic_functions import Branin
from ax.utils.testing.core_stubs import DummyEarlyStoppingStrategy
from ax.utils.testing.core_stubs import (
DummyEarlyStoppingStrategy,
get_branin_experiment,
)
from ax.utils.testing.mock import fast_botorch_optimize
from ax.utils.testing.modeling_stubs import get_observation1, get_observation1trans
from botorch.test_functions.multi_objective import BraninCurrin
Expand Down Expand Up @@ -625,6 +629,26 @@ def test_save_and_load_generation_strategy(self) -> None:
second_client.load_experiment_from_database("unique_test_experiment")
self.assertEqual(second_client.generation_strategy, generation_strategy)

def test_save_and_load_no_generation_strategy(self) -> None:
init_test_engine_and_session_factory(force_init=True)
config = SQAConfig()
encoder = Encoder(config=config)
decoder = Decoder(config=config)
db_settings = DBSettings(encoder=encoder, decoder=decoder)
experiment = get_branin_experiment(named=True)
save_experiment(experiment=experiment, config=config)
client = AxClient(db_settings=db_settings)
with self.assertRaisesRegex(
UserInputError, "choose_generation_strategy_kwargs"
):
client.load_experiment_from_database(experiment.name)

client = AxClient(db_settings=db_settings)
client.load_experiment_from_database(
experiment_name=experiment.name, choose_generation_strategy_kwargs={}
)
self.assertIsNotNone(client.generation_strategy)

@patch(
f"{AxClient.__module__}.AxClient._save_experiment_to_db_if_possible",
side_effect=Exception("patched db exception"),
Expand Down

0 comments on commit d253afa

Please sign in to comment.