From b8fc1774baae8c4101e473d3ac81ea6d95537240 Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Fri, 18 Oct 2024 08:13:35 -0700 Subject: [PATCH] Fix pending point extraction in gen_multi_from_multi (#2914) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Earlier today danielcohenlive was having trouble using gs and was getting a weird modeling error. Lena did some investigation an found we were using the wrong method to grab pending points from the experiment. See her notebook in the test plan for more details. to quote "Basically in some settings it’s not safe to deem a point pending based on the trial status; generally no point for which we have data, we want to consider pending" I just am putting up the fix with her logic here Reviewed By: lena-kashtelyan, Cesar-Cardoso Differential Revision: D64563496 --- ax/modelbridge/generation_strategy.py | 10 ++-------- ax/service/tests/scheduler_test_utils.py | 13 +++++-------- 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index 814f1ddc85e..357fc786cb8 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -21,10 +21,7 @@ from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.generator_run import GeneratorRun from ax.core.observation import ObservationFeatures -from ax.core.utils import ( - extend_pending_observations, - get_pending_observation_features_based_on_trial_status, -) +from ax.core.utils import extend_pending_observations, extract_pending_observations from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError from ax.exceptions.generation_strategy import ( GenerationStrategyCompleted, @@ -549,10 +546,7 @@ def gen_for_multiple_trials_with_multiple_models( """ trial_grs = [] pending_observations = ( - get_pending_observation_features_based_on_trial_status( - experiment=experiment - ) - or {} + extract_pending_observations(experiment=experiment) or {} if pending_observations is None else deepcopy(pending_observations) ) diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index 42466418b3c..fa9fe56b474 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -31,7 +31,7 @@ from ax.core.objective import Objective from ax.core.optimization_config import OptimizationConfig from ax.core.runner import Runner -from ax.core.utils import get_pending_observation_features_based_on_trial_status +from ax.core.utils import extract_pending_observations from ax.early_stopping.strategies import BaseEarlyStoppingStrategy from ax.exceptions.core import OptimizationComplete, UnsupportedError, UserInputError from ax.exceptions.generation_strategy import AxGenerationException @@ -583,11 +583,8 @@ def test_run_all_trials_using_runner_and_metrics(self) -> None: ) with patch( # Record calls to function, but still execute it. - ( - f"{self.PENDING_FEATURES_CALL_LOCATION}." - "get_pending_observation_features_based_on_trial_status" - ), - side_effect=get_pending_observation_features_based_on_trial_status, + (f"{self.PENDING_FEATURES_CALL_LOCATION}." "extract_pending_observations"), + side_effect=extract_pending_observations, ) as mock_get_pending: scheduler.run_all_trials() # Check that we got pending feat. at least 8 times (1 for each new trial and @@ -1783,9 +1780,9 @@ def test_batch_trial(self, status_quo_weight: float = 0.0) -> None: with patch( # Record calls to functions, but still execute them. ( f"{self.PENDING_FEATURES_CALL_LOCATION_BATCH}." - "get_pending_observation_features_based_on_trial_status" + "extract_pending_observations" ), - side_effect=get_pending_observation_features_based_on_trial_status, + side_effect=extract_pending_observations, ) as mock_get_pending, patch.object( scheduler.generation_strategy, "gen_for_multiple_trials_with_multiple_models",