Skip to content

Commit

Permalink
Always consider choice parameters with 2 values as ordered (#2464)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2464

See title

Reviewed By: saitcakmak

Differential Revision: D57303167

fbshipit-source-id: b47b29c4ebdcd07fadbf0512f1367dc212546754
  • Loading branch information
David Eriksson authored and facebook-github-bot committed May 16, 2024
1 parent 80974db commit 333e695
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 31 deletions.
39 changes: 33 additions & 6 deletions ax/core/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,16 +571,24 @@ def __init__(
)
values = list(dict_values)

if is_ordered is False and len(values) == 2:
is_ordered = True
warn(
f"Changing `is_ordered` to `True` for `ChoiceParameter` '{name}' since "
"there are only two possible values.",
AxParameterWarning,
stacklevel=3,
)
self._is_ordered: bool = (
is_ordered
if is_ordered is not None
else self._get_default_bool_and_warn(param_string="is_ordered")
else self._get_default_is_ordered_and_warn(num_choices=len(values))
)
# sort_values defaults to True if the parameter is not a string
self._sort_values: bool = (
sort_values
if sort_values is not None
else self._get_default_bool_and_warn(param_string="sort_values")
else self._get_default_sort_values_and_warn()
)
if self.sort_values:
values = cast(List[TParamValue], sorted([not_none(v) for v in values]))
Expand All @@ -597,14 +605,33 @@ def __init__(
# that is done in `HierarchicalSearchSpace` constructor.
self._dependents = dependents

def _get_default_bool_and_warn(self, param_string: str) -> bool:
def _get_default_is_ordered_and_warn(self, num_choices: int) -> bool:
default_bool = self._parameter_type != ParameterType.STRING or num_choices == 2
if self._parameter_type == ParameterType.STRING and num_choices > 2:
motivation = " since the parameter is a string with more than 2 choices."
elif num_choices == 2:
motivation = " since there are exactly two choices."
else:
motivation = " since the parameter is not of type string."
warn(
f'`is_ordered` is not specified for `ChoiceParameter` "{self._name}". '
f"Defaulting to `{default_bool}` {motivation}. To override this behavior "
f"(or avoid this warning), specify `is_ordered` during `ChoiceParameter` "
"construction. Note that choice parameters with exactly 2 choices are "
"always considered ordered and that the user-supplied `is_ordered` has no "
"effect in this particular case.",
AxParameterWarning,
stacklevel=3,
)
return default_bool

def _get_default_sort_values_and_warn(self) -> bool:
default_bool = self._parameter_type != ParameterType.STRING
warn(
f'`{param_string}` is not specified for `ChoiceParameter` "{self._name}". '
f'`sort_values` is not specified for `ChoiceParameter` "{self._name}". '
f"Defaulting to `{default_bool}` for parameters of `ParameterType` "
f"{self.parameter_type.name}. To override this behavior (or avoid this "
f"warning), specify `{param_string}` during `ChoiceParameter` "
"construction.",
f"warning), specify `sort_values` during `ChoiceParameter` construction.",
AxParameterWarning,
stacklevel=3,
)
Expand Down
48 changes: 45 additions & 3 deletions ax/core/tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
ParameterType,
RangeParameter,
)
from ax.exceptions.core import AxWarning, UserInputError
from ax.exceptions.core import AxParameterWarning, AxWarning, UserInputError
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import not_none

Expand Down Expand Up @@ -242,7 +242,7 @@ def setUp(self) -> None:
)
self.param3_repr = (
"ChoiceParameter(name='x', parameter_type=STRING, "
"values=['foo', 'bar'], is_fidelity=True, is_ordered=False, "
"values=['foo', 'bar'], is_fidelity=True, is_ordered=True, "
"sort_values=False, target_value='bar')"
)
self.param4 = ChoiceParameter(
Expand Down Expand Up @@ -485,7 +485,7 @@ def test_summary_dict(self) -> None:
"type": "Choice",
"domain": "values=['foo', 'bar']",
"parameter_type": "string",
"flags": "fidelity, unordered, unsorted",
"flags": "fidelity, ordered, unsorted",
"target_value": "bar",
},
)
Expand All @@ -509,6 +509,48 @@ def test_duplicate_values(self) -> None:
)
self.assertEqual(p.values, ["foo", "bar"])

def test_two_values_is_ordered(self) -> None:
parameter_types = (
ParameterType.INT,
ParameterType.FLOAT,
ParameterType.BOOL,
ParameterType.STRING,
)
parameter_values = ([0, 4], [0, 1.234], [False, True], ["foo", "bar"])
for parameter_type, values in zip(parameter_types, parameter_values):
p = ChoiceParameter(
name="x",
parameter_type=parameter_type,
values=values, # pyre-ignore
)
self.assertEqual(p._is_ordered, True)

# Change `is_ordered` to True and warn
with self.assertWarnsRegex(
AxParameterWarning,
"Changing `is_ordered` to `True` for `ChoiceParameter` 'x' since "
"there are only two possible values",
):
p = ChoiceParameter(
name="x",
parameter_type=parameter_type,
values=values, # pyre-ignore
is_ordered=False,
)
self.assertEqual(p._is_ordered, True)

# Set to True if `is_ordered` is not specified
with self.assertWarnsRegex(
AxParameterWarning, "since there are exactly two choices"
):
p = ChoiceParameter(
name="x",
parameter_type=parameter_type,
values=values, # pyre-ignore
sort_values=False,
)
self.assertEqual(p._is_ordered, True)


class FixedParameterTest(TestCase):
def setUp(self) -> None:
Expand Down
4 changes: 0 additions & 4 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,10 +850,6 @@ def test_hierarchical_search_space(self) -> None:
all_parameter_names = checked_cast(
HierarchicalSearchSpace, experiment.search_space
)._all_parameter_names.copy()
# One of the parameter names is modified by transforms (because it's
# one-hot encoded).
all_parameter_names.remove("model")
all_parameter_names.add("model_OH_PARAM_")
for obs in observations:
for p_name in all_parameter_names:
self.assertIn(p_name, obs.features.parameters)
Expand Down
24 changes: 8 additions & 16 deletions ax/modelbridge/transforms/tests/test_one_hot_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def setUp(self) -> None:
"c",
parameter_type=ParameterType.BOOL,
values=[True, False],
is_ordered=False,
),
ChoiceParameter(
"d",
Expand Down Expand Up @@ -66,8 +65,7 @@ def setUp(self) -> None:
"b" + OH_PARAM_INFIX + "_0": 0,
"b" + OH_PARAM_INFIX + "_1": 1,
"b" + OH_PARAM_INFIX + "_2": 0,
# Only two choices => one parameter.
"c" + OH_PARAM_INFIX: 0,
"c": False,
"d": 10.0,
}
)
Expand All @@ -76,8 +74,8 @@ def setUp(self) -> None:
)

def test_Init(self) -> None:
self.assertEqual(list(self.t.encoded_parameters.keys()), ["b", "c"])
self.assertEqual(list(self.t2.encoded_parameters.keys()), ["b", "c"])
self.assertEqual(list(self.t.encoded_parameters.keys()), ["b"])
self.assertEqual(list(self.t2.encoded_parameters.keys()), ["b"])

def test_TransformObservationFeatures(self) -> None:
observation_features = [self.observation_features]
Expand Down Expand Up @@ -126,18 +124,15 @@ def test_TransformSearchSpace(self) -> None:
ss2.parameters["b" + OH_PARAM_INFIX + "_1"].parameter_type,
ParameterType.FLOAT,
)
self.assertEqual(
ss2.parameters["c" + OH_PARAM_INFIX].parameter_type, ParameterType.FLOAT
)
self.assertEqual(ss2.parameters["c"].parameter_type, ParameterType.BOOL)
self.assertEqual(ss2.parameters["d"].parameter_type, ParameterType.FLOAT)

# Parameter range fixed to [0,1].
# pyre-fixme[16]: `Parameter` has no attribute `lower`.
self.assertEqual(ss2.parameters["b" + OH_PARAM_INFIX + "_0"].lower, 0.0)
# pyre-fixme[16]: `Parameter` has no attribute `upper`.
self.assertEqual(ss2.parameters["b" + OH_PARAM_INFIX + "_1"].upper, 1.0)
self.assertEqual(ss2.parameters["c" + OH_PARAM_INFIX].lower, 0.0)
self.assertEqual(ss2.parameters["c" + OH_PARAM_INFIX].upper, 1.0)
self.assertEqual(ss2.parameters["c"].parameter_type, ParameterType.BOOL)

# Ensure we error if we try to transform a fidelity parameter
ss3 = SearchSpace(
Expand All @@ -158,14 +153,11 @@ def test_TransformSearchSpace(self) -> None:
def test_w_parameter_distributions(self) -> None:
rss = get_robust_search_space()
# Transform a non-distributional parameter.
t = OneHot(
search_space=rss,
observations=[],
)
t = OneHot(search_space=rss, observations=[])
rss_new = t.transform_search_space(rss)
# Make sure that the return value is still a RobustSearchSpace.
self.assertIsInstance(rss_new, RobustSearchSpace)
self.assertEqual(len(rss_new.parameters.keys()), 4)
self.assertEqual(len(rss_new.parameters.keys()), 6)
# pyre-fixme[16]: `SearchSpace` has no attribute `parameter_distributions`.
self.assertEqual(rss.parameter_distributions, rss_new.parameter_distributions)
self.assertNotIn("c", rss_new.parameters)
Expand All @@ -183,7 +175,7 @@ def test_w_parameter_distributions(self) -> None:
)
rss_new = t.transform_search_space(rss)
self.assertIsInstance(rss_new, RobustSearchSpace)
self.assertEqual(len(rss_new.parameters.keys()), 4)
self.assertEqual(len(rss_new.parameters.keys()), 6)
self.assertEqual(rss.parameter_distributions, rss_new.parameter_distributions)
# pyre-fixme[16]: `SearchSpace` has no attribute `_environmental_variables`.
self.assertEqual(rss._environmental_variables, rss_new._environmental_variables)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def test_TransformSearchSpace(self) -> None:
self.assertEqual(p.parameter_type, ParameterType.STRING)
self.assertEqual(set(p.values), {"u1", "u2"})
self.assertTrue(p.is_task)
self.assertFalse(p.is_ordered)
self.assertTrue(p.is_ordered) # 2 choices so always ordered
self.assertEqual(p.target_value, "u1")
t = TrialAsTask(
search_space=self.search_space,
Expand Down
2 changes: 1 addition & 1 deletion ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ def get_robust_search_space(
RangeParameter("x", ParameterType.FLOAT, lb, ub),
RangeParameter("y", ParameterType.FLOAT, lb, ub),
RangeParameter("z", ParameterType.INT, lb, ub),
ChoiceParameter("c", ParameterType.STRING, ["red", "panda"]),
ChoiceParameter("c", ParameterType.STRING, ["red", "blue", "green"]),
]
if multivariate:
if use_discrete:
Expand Down

0 comments on commit 333e695

Please sign in to comment.