Skip to content

Commit

Permalink
Rename OrderedChoiceEncode => OrderedChoiceToIntegerRange (#2323)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2323

This change renames the OrderedChoiceEncode transform to one which reflects its behavior- see T182722751 for the overall task.

- Adds a new "OrderedChoiceToIntegerRange" class with the logic from the original OrderedChoiceEncode
- Updates OrderedChoiceEncode to inherit from DeprecatedTransformMixin and OrderedChoiceToIntegerRange
- Updates the registry to support the new transform.

Initially, the new classes will be decoded into the deprecated classes to maintain backwards compatibility. Once the new classes are landed, call sites will be updated to use the new class instead of the old.

Reviewed By: mpolson64

Differential Revision: D55754487

fbshipit-source-id: cb880209845f1fd45b552b74c586d2692e0e05ad
  • Loading branch information
mgrange1998 authored and facebook-github-bot committed Apr 5, 2024
1 parent 360bdcb commit 60330a8
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 17 deletions.
3 changes: 2 additions & 1 deletion ax/modelbridge/modelbridge_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ def extract_search_space_digest(
assumptions regarding the parameters being transformed.
For ChoiceParameters:
* The choices are assumed to be numerical. ChoiceEncode and OrderedChoiceEncode
* The choices are assumed to be numerical. ChoiceEncode
and OrderedChoiceToIntegerRange
transforms handle this.
* If is_task, its index is added to task_features.
* If ordered, its index is added to ordinal_features.
Expand Down
7 changes: 5 additions & 2 deletions ax/modelbridge/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
from ax.modelbridge.torch import TorchModelBridge
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.centered_unit_x import CenteredUnitX
from ax.modelbridge.transforms.choice_encode import ChoiceEncode, OrderedChoiceEncode
from ax.modelbridge.transforms.choice_encode import (
ChoiceEncode,
OrderedChoiceToIntegerRange,
)
from ax.modelbridge.transforms.convert_metric_names import ConvertMetricNames
from ax.modelbridge.transforms.derelativize import Derelativize
from ax.modelbridge.transforms.int_range_to_choice import IntRangeToChoice
Expand Down Expand Up @@ -87,7 +90,7 @@

Cont_X_trans: List[Type[Transform]] = [
RemoveFixed,
OrderedChoiceEncode,
OrderedChoiceToIntegerRange,
OneHot,
IntToFloat,
Log,
Expand Down
16 changes: 13 additions & 3 deletions ax/modelbridge/transforms/choice_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@

# pyre-strict

from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING

import numpy as np
from ax.core.observation import Observation, ObservationFeatures
from ax.core.parameter import ChoiceParameter, Parameter, ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.core.types import TParamValue
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.deprecated_transform_mixin import (
DeprecatedTransformMixin,
)
from ax.modelbridge.transforms.utils import (
ClosestLookupDict,
construct_new_search_space,
Expand All @@ -39,7 +42,7 @@ class ChoiceEncode(Transform):
This transform does not transform task parameters (use TaskEncode for this).
Note that this behavior is different from that of OrderedChoiceEncode, which
Note that this behavior is different from that of OrderedChoiceToIntegerRange, which
transforms (ordered) ChoiceParameters to integer RangeParameters (rather than
ChoiceParameters).
Expand Down Expand Up @@ -120,7 +123,7 @@ def untransform_observation_features(
return observation_features


class OrderedChoiceEncode(ChoiceEncode):
class OrderedChoiceToIntegerRange(ChoiceEncode):
"""Convert ordered ChoiceParameters to integer RangeParameters.
Parameters will be transformed to an integer RangeParameters, mapped from the
Expand Down Expand Up @@ -187,6 +190,13 @@ def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
)


class OrderedChoiceEncode(DeprecatedTransformMixin, OrderedChoiceToIntegerRange):
"""Deprecated alias for OrderedChoiceToIntegerRange."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)


def transform_choice_values(p: ChoiceParameter) -> Tuple[np.ndarray, ParameterType]:
"""Transforms the choice values and returns the new parameter type.
Expand Down
4 changes: 2 additions & 2 deletions ax/modelbridge/transforms/task_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ax.core.parameter import ChoiceParameter, Parameter, ParameterType
from ax.core.search_space import SearchSpace
from ax.core.types import TParamValue
from ax.modelbridge.transforms.choice_encode import OrderedChoiceEncode
from ax.modelbridge.transforms.choice_encode import OrderedChoiceToIntegerRange
from ax.modelbridge.transforms.utils import construct_new_search_space
from ax.models.types import TConfig

Expand All @@ -21,7 +21,7 @@
from ax import modelbridge as modelbridge_module # noqa F401


class TaskEncode(OrderedChoiceEncode):
class TaskEncode(OrderedChoiceToIntegerRange):
"""Convert task ChoiceParameters to integer-valued ChoiceParameters.
Parameters will be transformed to an integer ChoiceParameter with
Expand Down
32 changes: 27 additions & 5 deletions ax/modelbridge/transforms/tests/test_choice_encode_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
from ax.core.parameter_constraint import ParameterConstraint
from ax.core.search_space import RobustSearchSpace, SearchSpace
from ax.modelbridge.transforms.choice_encode import ChoiceEncode, OrderedChoiceEncode
from ax.modelbridge.transforms.choice_encode import (
ChoiceEncode,
OrderedChoiceEncode,
OrderedChoiceToIntegerRange,
)
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import checked_cast
from ax.utils.testing.core_stubs import get_robust_search_space
Expand Down Expand Up @@ -166,7 +170,7 @@ def test_TransformSearchSpace(self) -> None:
)
]
)
t = OrderedChoiceEncode(search_space=ss3, observations=[])
t = OrderedChoiceToIntegerRange(search_space=ss3, observations=[])
with self.assertRaises(ValueError):
t.transform_search_space(ss3)

Expand Down Expand Up @@ -208,8 +212,8 @@ def test_w_parameter_distributions(self) -> None:
self.assertEqual(rss_new.parameters.get("c").parameter_type, ParameterType.INT)


class OrderedChoiceEncodeTransformTest(ChoiceEncodeTransformTest):
t_class = OrderedChoiceEncode
class OrderedChoiceToIntegerRangeTransformTest(ChoiceEncodeTransformTest):
t_class = OrderedChoiceToIntegerRange

def setUp(self) -> None:
super().setUp()
Expand Down Expand Up @@ -258,10 +262,28 @@ def test_TransformSearchSpace(self) -> None:
)
]
)
t = OrderedChoiceEncode(search_space=ss3, observations=[])
t = OrderedChoiceToIntegerRange(search_space=ss3, observations=[])
with self.assertRaises(ValueError):
t.transform_search_space(ss3)

def test_deprecated_OrderedChoiceEncode(self) -> None:
# Ensure we error if we try to transform a fidelity parameter
ss3 = SearchSpace(
parameters=[
ChoiceParameter(
"b",
parameter_type=ParameterType.FLOAT,
values=[1.0, 10.0, 100.0],
is_ordered=True,
is_fidelity=True,
target_value=100.0,
)
]
)
t = OrderedChoiceToIntegerRange(search_space=ss3, observations=[])
t_deprecated = OrderedChoiceEncode(search_space=ss3, observations=[])
self.assertEqual(t.__dict__, t_deprecated.__dict__)


def normalize_values(values: Sized) -> List[float]:
values = np.array(values, dtype=float)
Expand Down
22 changes: 18 additions & 4 deletions ax/storage/transform_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@

# pyre-strict

from typing import Dict, Type
from typing import Dict, List, Type

from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.cap_parameter import CapParameter
from ax.modelbridge.transforms.choice_encode import ChoiceEncode, OrderedChoiceEncode
from ax.modelbridge.transforms.choice_encode import (
ChoiceEncode,
OrderedChoiceEncode,
OrderedChoiceToIntegerRange,
)
from ax.modelbridge.transforms.convert_metric_names import ConvertMetricNames
from ax.modelbridge.transforms.derelativize import Derelativize
from ax.modelbridge.transforms.int_range_to_choice import IntRangeToChoice
Expand Down Expand Up @@ -59,7 +63,8 @@
IVW: 4,
Log: 5,
OneHot: 6,
OrderedChoiceEncode: 7,
OrderedChoiceEncode: 7, # TO BE DEPRECATED
OrderedChoiceToIntegerRange: 7,
# This transform was upstreamed into the base modelbridge.
# Old transforms serialized with this will have the OutOfDesign transform
# replaced with a no-op, the base transform.
Expand All @@ -85,7 +90,16 @@
MergeRepeatedMeasurements: 26,
}

"""
List transforms which are be deprecated.
These will be present in TRANSFORM_REGISTRY so that old call sites
can still store properly, but when loading back the new class will
be used.
"""
DEPRECATED_TRANSFORMS: List[Type[Transform]] = [
OrderedChoiceEncode # replaced by OrderedChoiceToIntegerRange
]

REVERSE_TRANSFORM_REGISTRY: Dict[int, Type[Transform]] = {
v: k for k, v in TRANSFORM_REGISTRY.items()
v: k for k, v in TRANSFORM_REGISTRY.items() if k not in DEPRECATED_TRANSFORMS
}

0 comments on commit 60330a8

Please sign in to comment.