Skip to content

Commit

Permalink
Add checks that only Range parameters can have ParameterConstraints i…
Browse files Browse the repository at this point in the history
…nstantiated (facebook#2936)

Summary:

As titled. This is a little janky but will do for now, a clearner validation scheme will be set up with via the Ax API workstream.

Differential Revision: D64784254
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Oct 23, 2024
1 parent 02c0ce5 commit 4d99a93
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 8 deletions.
135 changes: 128 additions & 7 deletions ax/service/tests/test_instantiation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,20 @@ def test_constraint_from_str(self) -> None:
"x1 + x2 <= not_numerical_bound",
# pyre-fixme[6]: For 2nd param expected `Dict[str, Parameter]` but
# got `Dict[str, None]`.
{"x1": None, "x2": None},
{
"x1": RangeParameter(
name="x1",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=2.0,
),
"x2": RangeParameter(
name="x2",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=2.0,
),
},
)
with self.assertRaisesRegex(ValueError, "Outcome constraint bound"):
InstantiationBase.outcome_constraint_from_str("m1 <= not_numerical_bound")
Expand Down Expand Up @@ -92,15 +105,29 @@ def test_constraint_from_str(self) -> None:
"x1 <= 0",
# pyre-fixme[6]: For 2nd param expected `Dict[str, Parameter]` but
# got `Dict[str, None]`.
{"x1": None, "x2": None},
{
"x1": RangeParameter(
name="x1", parameter_type=ParameterType.FLOAT, lower=0.1, upper=2.0
),
"x2": RangeParameter(
name="x2", parameter_type=ParameterType.FLOAT, lower=0.1, upper=2.0
),
},
)
self.assertEqual(one_val_constraint.bound, 0.0)
self.assertEqual(one_val_constraint.constraint_dict, {"x1": 1.0})
one_val_constraint = InstantiationBase.constraint_from_str(
"-0.5*x1 >= -0.1",
# pyre-fixme[6]: For 2nd param expected `Dict[str, Parameter]` but
# got `Dict[str, None]`.
{"x1": None, "x2": None},
{
"x1": RangeParameter(
name="x1", parameter_type=ParameterType.FLOAT, lower=0.1, upper=2.0
),
"x2": RangeParameter(
name="x2", parameter_type=ParameterType.FLOAT, lower=0.1, upper=2.0
),
},
)
self.assertEqual(one_val_constraint.bound, 0.1)
self.assertEqual(one_val_constraint.constraint_dict, {"x1": 0.5})
Expand Down Expand Up @@ -128,28 +155,122 @@ def test_constraint_from_str(self) -> None:
"x1 - e*x2 + x3 <= 3",
# pyre-fixme[6]: For 2nd param expected `Dict[str, Parameter]` but
# got `Dict[str, None]`.
{"x1": None, "x2": None, "x3": None},
{
"x1": RangeParameter(
name="x1",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
"x2": RangeParameter(
name="x2",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
"x3": RangeParameter(
name="x3",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
},
)
with self.assertRaisesRegex(ValueError, "A linear constraint should be"):
InstantiationBase.constraint_from_str(
"x1 - 2 *x2 + 3 *x3 <= 3",
# pyre-fixme[6]: For 2nd param expected `Dict[str, Parameter]` but
# got `Dict[str, None]`.
{"x1": None, "x2": None, "x3": None},
{
"x1": RangeParameter(
name="x1",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
"x2": RangeParameter(
name="x2",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
"x3": RangeParameter(
name="x3",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
},
)
with self.assertRaisesRegex(ValueError, "A linear constraint should be"):
InstantiationBase.constraint_from_str(
"x1 - 2* x2 + 3* x3 <= 3",
# pyre-fixme[6]: For 2nd param expected `Dict[str, Parameter]` but
# got `Dict[str, None]`.
{"x1": None, "x2": None, "x3": None},
{
"x1": RangeParameter(
name="x1",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
"x2": RangeParameter(
name="x2",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
"x3": RangeParameter(
name="x3",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
},
)
with self.assertRaisesRegex(ValueError, "A linear constraint should be"):
InstantiationBase.constraint_from_str(
"x1 - 2 * x2 + 3*x3 <= 3",
# pyre-fixme[6]: For 2nd param expected `Dict[str, Parameter]` but
# got `Dict[str, None]`.
{"x1": None, "x2": None, "x3": None},
{
"x1": RangeParameter(
name="x1",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
"x2": RangeParameter(
name="x2",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
"x3": RangeParameter(
name="x3",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=4.0,
),
},
)

with self.assertRaisesRegex(
ValueError, "Parameter constraints not supported for ChoiceParameter"
):
InstantiationBase.constraint_from_str(
"x1 + x2 <= 3",
{
"x1": RangeParameter(
name="x1",
parameter_type=ParameterType.FLOAT,
lower=0.1,
upper=2.0,
),
"x2": ChoiceParameter(
name="x2", parameter_type=ParameterType.FLOAT, values=[0, 1, 2]
),
},
)

def test_add_tracking_metrics(self) -> None:
Expand Down
12 changes: 11 additions & 1 deletion ax/service/utils/instantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
RangeParameter,
TParameterType,
)
from ax.core.parameter_constraint import OrderConstraint, ParameterConstraint
from ax.core.parameter_constraint import (
OrderConstraint,
ParameterConstraint,
validate_constraint_parameters,
)
from ax.core.search_space import HierarchicalSearchSpace, SearchSpace
from ax.core.types import ComparisonOp, TParameterization, TParamValue
from ax.exceptions.core import UnsupportedError
Expand Down Expand Up @@ -403,6 +407,9 @@ def constraint_from_str(
assert (
right in parameter_names
), f"Parameter {right} not in {parameter_names}."
validate_constraint_parameters(
parameters=[parameters[left], parameters[right]]
)
return (
OrderConstraint(
lower_parameter=parameters[left], upper_parameter=parameters[right]
Expand Down Expand Up @@ -451,9 +458,12 @@ def constraint_from_str(
multiplier = -1.0
else:
multiplier = 1.0

assert (
parameter in parameter_names
), f"Parameter {parameter} not in {parameter_names}."
validate_constraint_parameters(parameters=[parameters[parameter]])

parameter_weight[parameter] = operator_sign * multiplier
# for operators
else:
Expand Down

0 comments on commit 4d99a93

Please sign in to comment.