Skip to content

Commit

Permalink
Expose fit_out_of_design
Browse files Browse the repository at this point in the history
Reviewed By: bernardbeckerman

Differential Revision: D54006405
  • Loading branch information
David Eriksson authored and facebook-github-bot committed Mar 20, 2024
1 parent 2480064 commit 4216138
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
11 changes: 10 additions & 1 deletion ax/modelbridge/dispatch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def _make_botorch_step(
disable_progbar: Optional[bool] = None,
jit_compile: Optional[bool] = None,
derelativize_with_raw_status_quo: bool = False,
fit_out_of_design: bool = False,
) -> GenerationStep:
"""Shortcut for creating a BayesOpt generation step."""
model_kwargs = model_kwargs or {}
Expand All @@ -98,6 +99,7 @@ def _make_botorch_step(
model_kwargs["transform_configs"][
"Derelativize"
] = derelativization_transform_config
model_kwargs["fit_out_of_design"] = fit_out_of_design

if not no_winsorization:
_, default_bridge_kwargs = model.view_defaults()
Expand Down Expand Up @@ -312,6 +314,7 @@ def choose_generation_strategy(
jit_compile: Optional[bool] = None,
experiment: Optional[Experiment] = None,
suggested_model_override: Optional[ModelRegistryBase] = None,
fit_out_of_design: bool = False,
) -> GenerationStrategy:
"""Select an appropriate generation strategy based on the properties of
the search space and expected settings of the experiment, such as number of
Expand Down Expand Up @@ -412,6 +415,7 @@ def choose_generation_strategy(
provided as an arg to this function.
suggested_model_override: If specified, this model will be used for the GP
step and automatic selection will be skipped.
fit_out_of_design: Whether to include out-of-design points in the model.
"""
if experiment is not None and optimization_config is None:
optimization_config = experiment.optimization_config
Expand Down Expand Up @@ -507,6 +511,11 @@ def choose_generation_strategy(
)
jit_compile = None

model_kwargs: Dict[str, Any] = {
"torch_device": torch_device,
"fit_out_of_design": fit_out_of_design,
}

# Create `generation_strategy`, adding first Sobol step
# if `num_remaining_initialization_trials` is > 0.
if num_remaining_initialization_trials > 0:
Expand All @@ -527,7 +536,7 @@ def choose_generation_strategy(
derelativize_with_raw_status_quo=derelativize_with_raw_status_quo,
no_winsorization=no_winsorization,
max_parallelism=bo_parallelism,
model_kwargs={"torch_device": torch_device},
model_kwargs=model_kwargs,
should_deduplicate=should_deduplicate,
verbose=verbose,
disable_progbar=disable_progbar,
Expand Down
17 changes: 15 additions & 2 deletions ax/modelbridge/tests/test_dispatch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def test_choose_generation_strategy(self) -> None:
"torch_device": None,
"transforms": expected_transforms,
"transform_configs": expected_transform_configs,
"fit_out_of_design": False,
}
self.assertEqual(sobol_gpei._steps[1].model_kwargs, expected_model_kwargs)
device = torch.device("cpu")
Expand Down Expand Up @@ -121,7 +122,12 @@ def test_choose_generation_strategy(self) -> None:
model_kwargs = not_none(sobol_gpei._steps[1].model_kwargs)
self.assertEqual(
set(model_kwargs.keys()),
{"torch_device", "transforms", "transform_configs"},
{
"torch_device",
"transforms",
"transform_configs",
"fit_out_of_design",
},
)
self.assertGreater(len(model_kwargs["transforms"]), 0)
with self.subTest("Sobol (we can try every option)"):
Expand Down Expand Up @@ -198,6 +204,7 @@ def test_choose_generation_strategy(self) -> None:
"torch_device": None,
"transforms": [Winsorize] + Mixed_transforms + Y_trans,
"transform_configs": expected_transform_configs,
"fit_out_of_design": False,
}
self.assertEqual(bo_mixed._steps[1].model_kwargs, expected_model_kwargs)
with self.subTest("BO_MIXED (mixed search space)"):
Expand All @@ -212,6 +219,7 @@ def test_choose_generation_strategy(self) -> None:
"torch_device": None,
"transforms": [Winsorize] + Mixed_transforms + Y_trans,
"transform_configs": expected_transform_configs,
"fit_out_of_design": False,
}
self.assertEqual(bo_mixed._steps[1].model_kwargs, expected_model_kwargs)
with self.subTest("BO_MIXED (mixed multi-objective optimization)"):
Expand All @@ -229,7 +237,12 @@ def test_choose_generation_strategy(self) -> None:
model_kwargs = not_none(moo_mixed._steps[1].model_kwargs)
self.assertEqual(
set(model_kwargs.keys()),
{"torch_device", "transforms", "transform_configs"},
{
"torch_device",
"transforms",
"transform_configs",
"fit_out_of_design",
},
)
self.assertGreater(len(model_kwargs["transforms"]), 0)
with self.subTest("SAASBO"):
Expand Down

0 comments on commit 4216138

Please sign in to comment.