-
Notifications
You must be signed in to change notification settings - Fork 310
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding DeprecatedTransformMixin class
Summary: This task starts the backlog task "Rename Ax transforms in terms of what they transform from and to, when it isn't clear" It has a list of transforms to have their names updated to clearer values ``` OrderedChoiceEncode -> OrderedChoiceToIntegerRange ChoiceEncode -> ChoiceToNumericChoice TaskEncode -> TaskChoiceToIntTaskChoice Cast -> Map ``` This change - Adds a "DeprecatedTransformMixin", which classes can inherit from in order to print a logging message with the deprecated transform and the new transform to update to. Subsequent changes will add the new transform classes, and update the transform registry to point to the new classes instead of the old. ## Warning The warning is as follows: ``` [WARNING 04-04 09:58:45] ax.modelbridge.transforms.deprecated_transform_mixin: `DeprecatedTransform` transform has been deprecated and will be removed in a future release. Using `Transform` instead. ``` Reviewed By: mpolson64 Differential Revision: D55643016 fbshipit-source-id: dd05ffc1a350799ee46563023df7dfed316ccdc0
- Loading branch information
1 parent
9026a58
commit a1f8be6
Showing
2 changed files
with
119 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from logging import Logger | ||
from typing import Any | ||
|
||
from ax.utils.common.logger import get_logger | ||
|
||
logger: Logger = get_logger(__name__) | ||
|
||
|
||
class DeprecatedTransformMixin: | ||
""" | ||
Mixin class for deprecated transforms. | ||
This class is used to log warnings when a deprecated transform is used, | ||
and will construct the new transform that should be used instead. | ||
The deprecated transform should inherit as follows: | ||
class DeprecatedTransform(DeprecatedTransformMixin, NewTransform): | ||
... | ||
""" | ||
|
||
def __init__(self, *args: Any, **kwargs: Any) -> None: | ||
""" | ||
Log a warning that the transform is deprecated, and construct the | ||
new transform. | ||
""" | ||
warning_msg = self.warn_deprecated_message( | ||
self.__class__.__name__, type(self).__bases__[1].__name__ | ||
) | ||
logger.warning(warning_msg) | ||
|
||
super().__init__(*args, **kwargs) | ||
|
||
@staticmethod | ||
def warn_deprecated_message( | ||
deprecated_transform_name: str, new_transform_name: str | ||
) -> str: | ||
""" | ||
Constructs the warning message. | ||
""" | ||
return ( | ||
f"`{deprecated_transform_name}` transform has been deprecated " | ||
"and will be removed in a future release. " | ||
f"Using `{new_transform_name}` instead." | ||
) |
66 changes: 66 additions & 0 deletions
66
ax/modelbridge/transforms/tests/test_deprecated_transform_mixin.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
import logging | ||
|
||
from typing import Any | ||
from unittest.mock import MagicMock | ||
|
||
from ax.modelbridge.transforms.base import Transform | ||
from ax.modelbridge.transforms.deprecated_transform_mixin import ( | ||
DeprecatedTransformMixin, | ||
) | ||
from ax.utils.common.testutils import TestCase | ||
|
||
|
||
class DeprecatedTransformTest(TestCase): | ||
class DeprecatedTransform(DeprecatedTransformMixin, Transform): | ||
def __init__(self, *args: Any) -> None: | ||
super().__init__(*args) | ||
|
||
class DummyTransform(Transform): | ||
def __init__(self, *args: Any) -> None: | ||
super().__init__(*args) | ||
|
||
class DeprecatedDummyTransform(DeprecatedTransformMixin, DummyTransform): | ||
def __init__(self, *args: Any) -> None: | ||
super().__init__(*args) | ||
|
||
def setUp(self) -> None: | ||
self.deprecated_t = self.DeprecatedTransform(MagicMock(), MagicMock()) | ||
self.t = Transform(MagicMock(), MagicMock()) | ||
|
||
def test_isinstance(self) -> None: | ||
self.assertTrue(isinstance(self.deprecated_t, type(self.t))) | ||
self.assertTrue(isinstance(self.deprecated_t, Transform)) | ||
self.assertTrue(isinstance(self.deprecated_t, self.DeprecatedTransform)) | ||
self.assertTrue(isinstance(self.deprecated_t, DeprecatedTransformMixin)) | ||
|
||
def test_deprecated_transform_equality(self) -> None: | ||
class DeprecatedTransform(DeprecatedTransformMixin, Transform): | ||
def __init__(self, *args): | ||
super().__init__(*args) | ||
|
||
t = Transform(MagicMock(), MagicMock()) | ||
t2 = Transform(MagicMock(), MagicMock()) | ||
self.assertEqual(t.__dict__, t2.__dict__) | ||
|
||
dt = DeprecatedTransform(MagicMock(), MagicMock()) | ||
self.assertEqual(t.__dict__, dt.__dict__) | ||
|
||
def test_logging(self) -> None: | ||
with self.assertLogs( | ||
"ax.modelbridge.transforms.deprecated_transform_mixin", | ||
level=logging.WARNING, | ||
) as logger: | ||
_ = self.DeprecatedTransform(MagicMock(), MagicMock()) | ||
message = DeprecatedTransformMixin.warn_deprecated_message( | ||
self.DeprecatedTransform.__name__, | ||
Transform.__name__, | ||
) | ||
self.assertTrue(any(message in s for s in logger.output)) |