Skip to content

Commit

Permalink
Deprecate use of factory functions in generation strategies
Browse files Browse the repository at this point in the history
Differential Revision: D54696356
  • Loading branch information
Lena Kashtelyan authored and facebook-github-bot committed Mar 8, 2024
1 parent 6bb6935 commit a38c64b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 60 deletions.
72 changes: 30 additions & 42 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,6 @@ class GenerationStrategy(GenerationStrategyInterface):

_nodes: List[GenerationNode]
_curr: GenerationNode # Current node in the strategy.
# Whether all models in this GS are in Models registry enum.
_uses_registered_models: bool
# All generator runs created through this generation strategy, in chronological
# order.
_generator_runs: List[GeneratorRun]
Expand Down Expand Up @@ -142,21 +140,21 @@ def __init__(
+ "GenerationNode list."
)

# Log warning if the GS uses a non-registered (factory function) model.
self._uses_registered_models = not any(
isinstance(ms, FactoryFunctionModelSpec)
for node in self._nodes
for ms in node.model_specs
)
if not self._uses_registered_models:
logger.info(
"Using model via callable function, "
"so optimization is not resumable if interrupted."
)
# Log warning if the GS uses a non-registered (factory function) model
for node in self._nodes:
for model_spec in node.model_specs:
if isinstance(model_spec, FactoryFunctionModelSpec):
raise GenerationStrategyMisconfiguredException(
"Use of model factory functions is no longer supported in "
"Ax `GenerationStrategy`. "
f"Encountered: {model_spec.factory_function} in node {node}."
)

# Initialize required attributes
if name is None:
name = "+".join(n.model_spec_to_gen_from.model_key for n in self._nodes)
super().__init__(name=name)
self._generator_runs = []
# Set name to an explicit value ahead of time to avoid
# adding properties during equality checks
super().__init__(name=name or self._make_default_name())

@property
def is_node_based(self) -> bool:
Expand Down Expand Up @@ -185,20 +183,17 @@ def model_transitions(self) -> List[int]:
another."""
# TODO @mgarrard to support GenerationNodes here, which is non-trival
# since nodes are dynamic and may only support past model_transitions
gen_changes = []
for node in self._nodes:
for criterion in node.transition_criteria:
if (
isinstance(criterion, TrialBasedCriterion)
and criterion.criterion_class == "MaxTrials"
):
gen_changes.append(criterion.threshold)

# if the last node has unlimited generation, do not remeove the last
# transition point in the list
if self._nodes[-1].gen_unlimited_trials:
return [sum(gen_changes[: i + 1]) for i in range(len(gen_changes))]
return [sum(gen_changes[: i + 1]) for i in range(len(gen_changes))][:-1]
# @no-commit NOTE from @drfreund: I think we can remove this method
# altogether. It's used exclusively to draw a vertical line in plots
# to show at what time the model changed. The same can be accomplished
# by just checking model keys on generator runs, so I think this is
# only added complexity.
# Query for where this is used:
# https://www.internalfb.com/code/search?q=repo%3Afbcode%20model_transitions
raise DeprecationWarning(
"`model_transitions` is no longer supported. Please refer to `model_key` "
"field on generator runs for similar information if needed."
)

@property
def current_step(self) -> GenerationStep:
Expand Down Expand Up @@ -250,7 +245,7 @@ def model(self) -> Optional[ModelBridge]:
def experiment(self) -> Experiment:
"""Experiment, currently set on this generation strategy."""
if self._experiment is None:
raise ValueError("No experiment set on generation strategy.")
raise AxError("No experiment set on generation strategy.")
return not_none(self._experiment)

@experiment.setter
Expand All @@ -263,7 +258,7 @@ def experiment(self, experiment: Experiment) -> None:
if self._experiment is None or experiment._name == self.experiment._name:
self._experiment = experiment
else:
raise ValueError(
raise AxError(
"This generation strategy has been used for experiment "
f"{self.experiment._name} so far; cannot reset experiment"
f" to {experiment._name}. If this is a new optimization, "
Expand All @@ -278,12 +273,6 @@ def last_generator_run(self) -> Optional[GeneratorRun]:
# Used to restore current model when decoding a serialized GS.
return self._generator_runs[-1] if self._generator_runs else None

@property
def uses_non_registered_models(self) -> bool:
"""Whether this generation strategy involves models that are not
registered and therefore cannot be stored."""
return not self._uses_registered_models

@property
def trials_as_df(self) -> Optional[pd.DataFrame]:
"""Puts information on individual trials into a data frame for easy
Expand All @@ -293,6 +282,8 @@ def trials_as_df(self) -> Optional[pd.DataFrame]:
Gen. Step | Model | Trial Index | Trial Status | Arm Parameterizations
0 | Sobol | 0 | RUNNING | {"0_0":{"x":9.17...}}
"""

# TODO: reap this method or replace it to return `exp_to_df`?
logger.info(
"Note that parameter values in dataframe are rounded to 2 decimal "
"points; the values in the dataframe are thus not the exact ones "
Expand Down Expand Up @@ -617,10 +608,7 @@ def _make_default_name(self) -> str:
"Cannot make a default name for a generation strategy with no nodes "
"set yet."
)
factory_names = (node.model_spec_to_gen_from.model_key for node in self._nodes)
# Trim the "get_" beginning of the factory function if it's there.
factory_names = (n[4:] if n[:4] == "get_" else n for n in factory_names)
return "+".join(factory_names)
return "+".join(node.model_spec_to_gen_from.model_key for node in self._nodes)

def __repr__(self) -> str:
"""String representation of this generation strategy."""
Expand Down
15 changes: 2 additions & 13 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,6 @@ def test_validation(self) -> None:
GenerationStep(model=Models.THOMPSON, num_trials=2),
]
)
self.assertTrue(factorial_thompson_generation_strategy._uses_registered_models)
self.assertFalse(
factorial_thompson_generation_strategy.uses_non_registered_models
)
with self.assertRaises(ValueError):
factorial_thompson_generation_strategy.gen(exp)
self.assertEqual(GenerationStep(model=sum, num_trials=1).model_name, "sum")
Expand All @@ -255,14 +251,8 @@ def test_validation(self) -> None:
)

def test_custom_callables_for_models(self) -> None:
exp = get_branin_experiment()
sobol_factory_generation_strategy = GenerationStrategy(
steps=[GenerationStep(model=get_sobol, num_trials=-1)]
)
self.assertFalse(sobol_factory_generation_strategy._uses_registered_models)
self.assertTrue(sobol_factory_generation_strategy.uses_non_registered_models)
gr = sobol_factory_generation_strategy.gen(experiment=exp, n=1)
self.assertEqual(len(gr.arms), 1)
with self.assertRaises(GenerationStrategyMisconfiguredException):
GenerationStrategy(steps=[GenerationStep(model=get_sobol, num_trials=-1)])

def test_string_representation(self) -> None:
gs1 = GenerationStrategy(
Expand Down Expand Up @@ -340,7 +330,6 @@ def test_min_observed(self) -> None:
GenerationStep(model=Models.GPEI, num_trials=1),
]
)
self.assertFalse(gs.uses_non_registered_models)
for _ in range(5):
exp.new_trial(gs.gen(exp))
with self.assertRaises(DataRequiredError):
Expand Down
5 changes: 0 additions & 5 deletions ax/storage/json_store/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,11 +502,6 @@ def generation_strategy_to_dict(
generation_strategy: GenerationStrategy,
) -> Dict[str, Any]:
"""Converts Ax generation strategy to a dictionary."""
if generation_strategy.uses_non_registered_models:
raise ValueError(
"Generation strategies that use custom models provided through "
"callables cannot be serialized and stored."
)
node_based_gs = generation_strategy.is_node_based
return {
"__type": generation_strategy.__class__.__name__,
Expand Down

0 comments on commit a38c64b

Please sign in to comment.