diff --git a/ax/core/base_trial.py b/ax/core/base_trial.py index d3739ac152d..1377d0b9e52 100644 --- a/ax/core/base_trial.py +++ b/ax/core/base_trial.py @@ -534,8 +534,10 @@ def _check_existing_and_name_arm(self, arm: Arm) -> None: experiment, uses the existing arm name. """ proposed_name = self._get_default_name() + + # Arm could already be in experiment, replacement is okay. self.experiment._name_and_store_arm_if_not_exists( - arm=arm, proposed_name=proposed_name + arm=arm, proposed_name=proposed_name, replace=True ) # If arm was named using given name, incremement the count if arm.name == proposed_name: diff --git a/ax/core/batch_trial.py b/ax/core/batch_trial.py index ab37b7d800a..6475e43283b 100644 --- a/ax/core/batch_trial.py +++ b/ax/core/batch_trial.py @@ -366,7 +366,9 @@ def set_status_quo_with_weight( status_quo.parameters, raise_error=True ) self.experiment._name_and_store_arm_if_not_exists( - arm=status_quo, proposed_name="status_quo_" + str(self.index) + arm=status_quo, + proposed_name="status_quo_" + str(self.index), + replace=True, ) self._status_quo = status_quo.clone() if status_quo is not None else None self._status_quo_weight_override = weight diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 0c48c6e5c94..86bca8f8cf5 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -37,7 +37,7 @@ from ax.core.search_space import HierarchicalSearchSpace, SearchSpace from ax.core.trial import Trial from ax.core.types import ComparisonOp, TParameterization -from ax.exceptions.core import UnsupportedError, UserInputError +from ax.exceptions.core import AxError, UnsupportedError, UserInputError from ax.utils.common.base import Base from ax.utils.common.constants import EXPERIMENT_IS_TEST_WARNING, Keys from ax.utils.common.docutils import copy_doc @@ -1350,7 +1350,9 @@ def warm_start_from_old_experiment( return copied_trials - def _name_and_store_arm_if_not_exists(self, arm: Arm, proposed_name: str) -> None: + def _name_and_store_arm_if_not_exists( + self, arm: Arm, proposed_name: str, replace: bool = False + ) -> None: """Tries to lookup arm with same signature, otherwise names and stores it. - Looks up if arm already exists on experiment @@ -1360,6 +1362,8 @@ def _name_and_store_arm_if_not_exists(self, arm: Arm, proposed_name: str) -> Non Args: arm: The arm object to name. proposed_name: The name to assign if it doesn't have one already. + replace: If true, override arm w/ same name and different signature. + If false, raise an error if this conflict occurs. """ # If arm is identical to an existing arm, return that @@ -1377,6 +1381,22 @@ def _name_and_store_arm_if_not_exists(self, arm: Arm, proposed_name: str) -> Non else: if not arm.has_name: arm.name = proposed_name + + # Check for signature conflict by arm name/proposed name + if ( + arm.name in self.arms_by_name + and arm.signature != self.arms_by_name[arm.name].signature + ): + error_msg = ( + f"Arm with name {arm.name} already exists on experiment " + + "with different signature." + ) + if replace: + logger.warning(f"{error_msg} Replacing the existing arm. ") + else: + raise AxError(error_msg) + + # Add the new arm self._register_arm(arm) def _register_arm(self, arm: Arm) -> None: diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index cd285fe9260..d245e2d99b2 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -34,7 +34,7 @@ ) from ax.core.search_space import SearchSpace from ax.core.types import ComparisonOp -from ax.exceptions.core import UnsupportedError +from ax.exceptions.core import AxError, UnsupportedError from ax.metrics.branin import BraninMetric from ax.modelbridge.registry import Models from ax.runners.synthetic import SyntheticRunner @@ -1541,3 +1541,43 @@ class TestAuxiliaryExperimentPurpose(AuxiliaryExperimentPurpose): ], }, ) + + def test_name_and_store_arm_if_not_exists_same_name_different_signature( + self, + ) -> None: + experiment = self.experiment + shared_name = "shared_name" + + arm_1 = Arm({"x1": -1.0, "x2": 1.0}, name=shared_name) + arm_2 = Arm({"x1": -1.7, "x2": 0.2, "x3": 1}) + self.assertNotEqual(arm_1.signature, arm_2.signature) + + experiment._register_arm(arm=arm_1) + with self.assertRaisesRegex( + AxError, + f"Arm with name {shared_name} already exists on experiment " + f"with different signature.", + ): + experiment._name_and_store_arm_if_not_exists( + arm=arm_2, proposed_name=shared_name + ) + + def test_name_and_store_arm_if_not_exists_same_proposed_name_different_signature( + self, + ) -> None: + experiment = self.experiment + shared_name = "shared_name" + + arm_1 = Arm({"x1": -1.0, "x2": 1.0}, name=shared_name) + arm_2 = Arm({"x1": -1.7, "x2": 0.2, "x3": 1}, name=shared_name) + self.assertNotEqual(arm_1.signature, arm_2.signature) + + experiment._register_arm(arm=arm_1) + with self.assertRaisesRegex( + AxError, + f"Arm with name {shared_name} already exists on experiment " + f"with different signature.", + ): + experiment._name_and_store_arm_if_not_exists( + arm=arm_2, proposed_name="different proposed name" + )