Skip to content

Commit

Permalink
Utilize existing model from GS in BestPointMixin._get_hypervolume (#3285
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: #3285

This was previously re-constructing the model using `get_model_from_generator_run`, which is a helper that I want to deprecate. Since the GS is readily available, we can utilize the model from the GS rather than re-constructing & fitting it from scratch.

Reviewed By: esantorella

Differential Revision: D68836172

fbshipit-source-id: af8300f6a80f6802977d6bdb482eab2dc982421e
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Jan 29, 2025
1 parent 9be04d6 commit d059120
Showing 1 changed file with 6 additions and 26 deletions.
32 changes: 6 additions & 26 deletions ax/service/utils/best_point_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
predicted_hypervolume,
validate_and_apply_final_transform,
)
from ax.modelbridge.registry import get_model_from_generator_run, ModelRegistryBase
from ax.modelbridge.registry import ModelRegistryBase
from ax.modelbridge.torch import TorchModelBridge
from ax.modelbridge.transforms.derelativize import Derelativize
from ax.models.torch.botorch_moo_defaults import (
Expand Down Expand Up @@ -387,35 +387,15 @@ def _get_hypervolume(
)

if use_model_predictions:
current_model = generation_strategy._curr.model_spec_to_gen_from.model_enum
# Cover for the case where source of `self._curr.model` was not a `Models`
# enum but a factory function, in which case we cannot do
# `get_model_from_generator_run` (since we don't have model type and inputs
# recorded on the generator run.
models_enum = (
current_model.__class__
if isinstance(current_model, ModelRegistryBase)
else None
)

if models_enum is None:
raise ValueError(
f"Model {current_model} is not in the ModelRegistry, cannot "
"calculate predicted hypervolume."
)

model = get_model_from_generator_run(
generator_run=none_throws(generation_strategy.last_generator_run),
experiment=experiment,
data=experiment.fetch_data(trial_indices=trial_indices),
models_enum=models_enum,
)
# Make sure that the model is fitted. If model is fitted already,
# this should be a no-op.
generation_strategy._fit_current_model(data=None)
model = generation_strategy.model
if not isinstance(model, TorchModelBridge):
raise ValueError(
f"Model {current_model} is not of type TorchModelBridge, cannot "
f"Model {model} is not of type TorchModelBridge, cannot "
"calculate predicted hypervolume."
)

return predicted_hypervolume(
modelbridge=model, optimization_config=optimization_config
)
Expand Down

0 comments on commit d059120

Please sign in to comment.