Skip to content

Commit

Permalink
Fix pending point extraction in gen_multi_from_multi (facebook#2914)
Browse files Browse the repository at this point in the history
Summary:

Earlier today danielcohenlive was having trouble using gs and was getting a weird modeling error. Lena did some investigation an found we were using the wrong method to grab pending points from the experiment. See her notebook in the test plan for more details. to quote "Basically in some settings it’s not safe to deem a point pending based on the trial status; generally no point for which we have data, we want to consider pending"

I just am putting up the fix with her logic here

Reviewed By: lena-kashtelyan, Cesar-Cardoso

Differential Revision: D64563496
  • Loading branch information
mgarrard authored and facebook-github-bot committed Oct 18, 2024
1 parent 202dc33 commit b8fc177
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 16 deletions.
10 changes: 2 additions & 8 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.generator_run import GeneratorRun
from ax.core.observation import ObservationFeatures
from ax.core.utils import (
extend_pending_observations,
get_pending_observation_features_based_on_trial_status,
)
from ax.core.utils import extend_pending_observations, extract_pending_observations
from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError
from ax.exceptions.generation_strategy import (
GenerationStrategyCompleted,
Expand Down Expand Up @@ -549,10 +546,7 @@ def gen_for_multiple_trials_with_multiple_models(
"""
trial_grs = []
pending_observations = (
get_pending_observation_features_based_on_trial_status(
experiment=experiment
)
or {}
extract_pending_observations(experiment=experiment) or {}
if pending_observations is None
else deepcopy(pending_observations)
)
Expand Down
13 changes: 5 additions & 8 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ax.core.objective import Objective
from ax.core.optimization_config import OptimizationConfig
from ax.core.runner import Runner
from ax.core.utils import get_pending_observation_features_based_on_trial_status
from ax.core.utils import extract_pending_observations
from ax.early_stopping.strategies import BaseEarlyStoppingStrategy
from ax.exceptions.core import OptimizationComplete, UnsupportedError, UserInputError
from ax.exceptions.generation_strategy import AxGenerationException
Expand Down Expand Up @@ -583,11 +583,8 @@ def test_run_all_trials_using_runner_and_metrics(self) -> None:
)
with patch(
# Record calls to function, but still execute it.
(
f"{self.PENDING_FEATURES_CALL_LOCATION}."
"get_pending_observation_features_based_on_trial_status"
),
side_effect=get_pending_observation_features_based_on_trial_status,
(f"{self.PENDING_FEATURES_CALL_LOCATION}." "extract_pending_observations"),
side_effect=extract_pending_observations,
) as mock_get_pending:
scheduler.run_all_trials()
# Check that we got pending feat. at least 8 times (1 for each new trial and
Expand Down Expand Up @@ -1783,9 +1780,9 @@ def test_batch_trial(self, status_quo_weight: float = 0.0) -> None:
with patch( # Record calls to functions, but still execute them.
(
f"{self.PENDING_FEATURES_CALL_LOCATION_BATCH}."
"get_pending_observation_features_based_on_trial_status"
"extract_pending_observations"
),
side_effect=get_pending_observation_features_based_on_trial_status,
side_effect=extract_pending_observations,
) as mock_get_pending, patch.object(
scheduler.generation_strategy,
"gen_for_multiple_trials_with_multiple_models",
Expand Down

0 comments on commit b8fc177

Please sign in to comment.