Skip to content

Commit

Permalink
Fix pending point extraction in gen_multi_from_multi (#2914)
Browse files Browse the repository at this point in the history
Summary:

Earlier today danielcohenlive was having trouble using gs and was getting a weird modeling error. Lena did some investigation an found we were using the wrong method to grab pending points from the experiment. See her notebook in the test plan for more details. to quote "Basically in some settings it’s not safe to deem a point pending based on the trial status; generally no point for which we have data, we want to consider pending"

I just am putting up the fix with her logic here

Differential Revision: D64563496
  • Loading branch information
mgarrard authored and facebook-github-bot committed Oct 17, 2024
1 parent 849a7ec commit ffb8b4b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 11 deletions.
10 changes: 2 additions & 8 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.generator_run import GeneratorRun
from ax.core.observation import ObservationFeatures
from ax.core.utils import (
extend_pending_observations,
get_pending_observation_features_based_on_trial_status,
)
from ax.core.utils import extend_pending_observations, extract_pending_observations
from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError
from ax.exceptions.generation_strategy import (
GenerationStrategyCompleted,
Expand Down Expand Up @@ -549,10 +546,7 @@ def gen_for_multiple_trials_with_multiple_models(
"""
trial_grs = []
pending_observations = (
get_pending_observation_features_based_on_trial_status(
experiment=experiment
)
or {}
extract_pending_observations(experiment=experiment) or {}
if pending_observations is None
else deepcopy(pending_observations)
)
Expand Down
9 changes: 6 additions & 3 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
from ax.core.objective import Objective
from ax.core.optimization_config import OptimizationConfig
from ax.core.runner import Runner
from ax.core.utils import get_pending_observation_features_based_on_trial_status
from ax.core.utils import (
extract_pending_observations,
get_pending_observation_features_based_on_trial_status,
)
from ax.early_stopping.strategies import BaseEarlyStoppingStrategy
from ax.exceptions.core import OptimizationComplete, UnsupportedError, UserInputError
from ax.exceptions.generation_strategy import AxGenerationException
Expand Down Expand Up @@ -1783,9 +1786,9 @@ def test_batch_trial(self, status_quo_weight: float = 0.0) -> None:
with patch( # Record calls to functions, but still execute them.
(
f"{self.PENDING_FEATURES_CALL_LOCATION_BATCH}."
"get_pending_observation_features_based_on_trial_status"
"extract_pending_observations"
),
side_effect=get_pending_observation_features_based_on_trial_status,
side_effect=extract_pending_observations,
) as mock_get_pending, patch.object(
scheduler.generation_strategy,
"gen_for_multiple_trials_with_multiple_models",
Expand Down

0 comments on commit ffb8b4b

Please sign in to comment.