-
Notifications
You must be signed in to change notification settings - Fork 310
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Remove parameter constraints that can be trivially converted into an updated lower/upper bound Reviewed By: SebastianAment Differential Revision: D55718753
- Loading branch information
1 parent
05eb25f
commit 08be821
Showing
3 changed files
with
195 additions
and
0 deletions.
There are no files selected for viewing
69 changes: 69 additions & 0 deletions
69
ax/modelbridge/transforms/simplify_parameter_constraints.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
import math | ||
from typing import List, TYPE_CHECKING | ||
|
||
from ax.core.parameter import FixedParameter, ParameterType, RangeParameter | ||
from ax.core.parameter_constraint import ParameterConstraint | ||
from ax.core.search_space import SearchSpace | ||
from ax.modelbridge.transforms.base import Transform | ||
from ax.utils.common.typeutils import checked_cast | ||
|
||
if TYPE_CHECKING: | ||
# import as module to make sphinx-autodoc-typehints happy | ||
from ax import modelbridge as modelbridge_module # noqa F401 | ||
|
||
|
||
class SimplifyParameterConstraints(Transform): | ||
"""Convert parameter constraints on one parameter to an updated bound. | ||
This transform converts parameter constraints on only one parameter into an updated | ||
upper or lower bound. Note that this transform will convert parameters that can only | ||
take on one value into a `FixedParameter`. Make sure this transform is applied | ||
before `RemoveFixed` if you want to remove all fixed parameters. | ||
""" | ||
|
||
def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace: | ||
# keeps track of the constraints that cannot be converted to bounds | ||
nontrivial_constraints: List[ParameterConstraint] = [] | ||
for pc in search_space.parameter_constraints: | ||
if len(pc.constraint_dict) == 1: | ||
# This can be turned into an updated bound since only one variable is | ||
# involved in the constraint. | ||
[(p_name, weight)] = pc.constraint_dict.items() | ||
# NOTE: We only allow parameter constraints on range parameters | ||
p = checked_cast(RangeParameter, search_space.parameters[p_name]) | ||
lb, ub = p.lower, p.upper | ||
if weight == 0 and pc.bound < 0: # Cannot be satisfied | ||
raise ValueError( | ||
"Parameter constraint cannot be satisfied since the weight " | ||
"is zero and the bound is negative." | ||
) | ||
elif weight == 0: # Constraint is always satisfied | ||
continue | ||
elif weight > 0: # New upper bound | ||
ub = float(pc.bound) / float(weight) | ||
if p.parameter_type == ParameterType.INT: | ||
ub = math.floor(ub) # Round down | ||
else: # New lower bound | ||
lb = float(pc.bound) / float(weight) | ||
if p.parameter_type == ParameterType.INT: | ||
lb = math.ceil(lb) # Round up | ||
|
||
if lb == ub: # Need to turn this into a fixed parameter | ||
search_space.parameters[p_name] = FixedParameter( | ||
name=p_name, parameter_type=p.parameter_type, value=lb | ||
) | ||
elif weight > 0: | ||
p._upper = ub | ||
else: | ||
p._lower = lb | ||
else: | ||
nontrivial_constraints.append(pc) | ||
search_space.set_parameter_constraints(nontrivial_constraints) | ||
return search_space |
122 changes: 122 additions & 0 deletions
122
ax/modelbridge/transforms/tests/test_simplify_parameter_constraints.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
from copy import deepcopy | ||
from typing import List | ||
|
||
from ax.core.observation import ObservationFeatures | ||
from ax.core.parameter import ( | ||
ChoiceParameter, | ||
FixedParameter, | ||
Parameter, | ||
ParameterType, | ||
RangeParameter, | ||
) | ||
from ax.core.parameter_constraint import ParameterConstraint | ||
from ax.core.search_space import SearchSpace | ||
from ax.modelbridge.transforms.simplify_parameter_constraints import ( | ||
SimplifyParameterConstraints, | ||
) | ||
from ax.utils.common.testutils import TestCase | ||
|
||
|
||
class SimplifyParameterConstraintsTest(TestCase): | ||
def setUp(self) -> None: | ||
self.parameters: List[Parameter] = [ | ||
RangeParameter("x", lower=1, upper=3, parameter_type=ParameterType.FLOAT), | ||
RangeParameter("y", lower=2, upper=5, parameter_type=ParameterType.INT), | ||
ChoiceParameter( | ||
"z", parameter_type=ParameterType.STRING, values=["a", "b", "c"] | ||
), | ||
] | ||
self.observation_features = [ | ||
ObservationFeatures(parameters={"x": 2, "y": 2, "z": "b"}) | ||
] | ||
|
||
def test_transform_no_constraints(self) -> None: | ||
t = SimplifyParameterConstraints() | ||
ss = SearchSpace(parameters=self.parameters) | ||
ss_transformed = t.transform_search_space(search_space=ss) | ||
self.assertEqual(ss, ss_transformed) | ||
self.assertEqual( | ||
self.observation_features, | ||
t.transform_observation_features(self.observation_features), | ||
) | ||
|
||
def test_transform_weight_zero(self) -> None: | ||
t = SimplifyParameterConstraints() | ||
ss = SearchSpace( | ||
parameters=self.parameters, | ||
parameter_constraints=[ | ||
ParameterConstraint(constraint_dict={"x": 0}, bound=1) | ||
], | ||
) | ||
ss_transformed = t.transform_search_space(search_space=deepcopy(ss)) | ||
self.assertEqual(ss_transformed.parameter_constraints, []) | ||
self.assertEqual(ss.parameters, ss_transformed.parameters) | ||
ss_raises = SearchSpace( | ||
parameters=self.parameters, | ||
parameter_constraints=[ | ||
ParameterConstraint(constraint_dict={"x": 0}, bound=-1) | ||
], | ||
) | ||
with self.assertRaisesRegex( | ||
ValueError, "Parameter constraint cannot be satisfied since the weight" | ||
): | ||
ss_transformed = t.transform_search_space(search_space=deepcopy(ss_raises)) | ||
|
||
def test_transform_search_space(self) -> None: | ||
t = SimplifyParameterConstraints() | ||
ss = SearchSpace( | ||
parameters=self.parameters, | ||
parameter_constraints=[ | ||
ParameterConstraint(constraint_dict={"x": 1}, bound=2), # x <= 2 | ||
ParameterConstraint(constraint_dict={"y": -1}, bound=-4), # y => 4 | ||
], | ||
) | ||
ss_transformed = t.transform_search_space(search_space=deepcopy(ss)) | ||
self.assertEqual( | ||
{ | ||
**ss.parameters, | ||
"x": RangeParameter( | ||
"x", parameter_type=ParameterType.FLOAT, lower=1, upper=2 | ||
), | ||
"y": RangeParameter( | ||
"y", parameter_type=ParameterType.INT, lower=4, upper=5 | ||
), | ||
}, | ||
ss_transformed.parameters, | ||
) | ||
self.assertEqual(ss_transformed.parameter_constraints, []) | ||
self.assertEqual( # No-op | ||
self.observation_features, | ||
t.transform_observation_features(self.observation_features), | ||
) | ||
|
||
def test_transform_to_fixed(self) -> None: | ||
t = SimplifyParameterConstraints() | ||
ss = SearchSpace( | ||
parameters=self.parameters, | ||
parameter_constraints=[ | ||
ParameterConstraint(constraint_dict={"x": 1}, bound=1), # x == 1 | ||
ParameterConstraint(constraint_dict={"y": -1}, bound=-5), # y == 5 | ||
], | ||
) | ||
ss_transformed = t.transform_search_space(search_space=deepcopy(ss)) | ||
self.assertEqual( | ||
{ | ||
**ss.parameters, | ||
"x": FixedParameter("x", parameter_type=ParameterType.FLOAT, value=1), | ||
"y": FixedParameter("y", parameter_type=ParameterType.INT, value=5), | ||
}, | ||
ss_transformed.parameters, | ||
) | ||
self.assertEqual(ss_transformed.parameter_constraints, []) | ||
self.assertEqual( # No-op | ||
self.observation_features, | ||
t.transform_observation_features(self.observation_features), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters