From a1f8be66f0ef37f0deacfb2caea21b6b10c2f0f2 Mon Sep 17 00:00:00 2001 From: Matthew Grange Date: Fri, 5 Apr 2024 10:57:36 -0700 Subject: [PATCH] Adding DeprecatedTransformMixin class Summary: This task starts the backlog task "Rename Ax transforms in terms of what they transform from and to, when it isn't clear" It has a list of transforms to have their names updated to clearer values ``` OrderedChoiceEncode -> OrderedChoiceToIntegerRange ChoiceEncode -> ChoiceToNumericChoice TaskEncode -> TaskChoiceToIntTaskChoice Cast -> Map ``` This change - Adds a "DeprecatedTransformMixin", which classes can inherit from in order to print a logging message with the deprecated transform and the new transform to update to. Subsequent changes will add the new transform classes, and update the transform registry to point to the new classes instead of the old. ## Warning The warning is as follows: ``` [WARNING 04-04 09:58:45] ax.modelbridge.transforms.deprecated_transform_mixin: `DeprecatedTransform` transform has been deprecated and will be removed in a future release. Using `Transform` instead. ``` Reviewed By: mpolson64 Differential Revision: D55643016 fbshipit-source-id: dd05ffc1a350799ee46563023df7dfed316ccdc0 --- .../transforms/deprecated_transform_mixin.py | 53 +++++++++++++++ .../tests/test_deprecated_transform_mixin.py | 66 +++++++++++++++++++ 2 files changed, 119 insertions(+) create mode 100644 ax/modelbridge/transforms/deprecated_transform_mixin.py create mode 100644 ax/modelbridge/transforms/tests/test_deprecated_transform_mixin.py diff --git a/ax/modelbridge/transforms/deprecated_transform_mixin.py b/ax/modelbridge/transforms/deprecated_transform_mixin.py new file mode 100644 index 00000000000..ab49456cb60 --- /dev/null +++ b/ax/modelbridge/transforms/deprecated_transform_mixin.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from logging import Logger +from typing import Any + +from ax.utils.common.logger import get_logger + +logger: Logger = get_logger(__name__) + + +class DeprecatedTransformMixin: + """ + Mixin class for deprecated transforms. + + This class is used to log warnings when a deprecated transform is used, + and will construct the new transform that should be used instead. + + The deprecated transform should inherit as follows: + + class DeprecatedTransform(DeprecatedTransformMixin, NewTransform): + ... + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """ + Log a warning that the transform is deprecated, and construct the + new transform. + """ + warning_msg = self.warn_deprecated_message( + self.__class__.__name__, type(self).__bases__[1].__name__ + ) + logger.warning(warning_msg) + + super().__init__(*args, **kwargs) + + @staticmethod + def warn_deprecated_message( + deprecated_transform_name: str, new_transform_name: str + ) -> str: + """ + Constructs the warning message. + """ + return ( + f"`{deprecated_transform_name}` transform has been deprecated " + "and will be removed in a future release. " + f"Using `{new_transform_name}` instead." + ) diff --git a/ax/modelbridge/transforms/tests/test_deprecated_transform_mixin.py b/ax/modelbridge/transforms/tests/test_deprecated_transform_mixin.py new file mode 100644 index 00000000000..3bf08b96eea --- /dev/null +++ b/ax/modelbridge/transforms/tests/test_deprecated_transform_mixin.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging + +from typing import Any +from unittest.mock import MagicMock + +from ax.modelbridge.transforms.base import Transform +from ax.modelbridge.transforms.deprecated_transform_mixin import ( + DeprecatedTransformMixin, +) +from ax.utils.common.testutils import TestCase + + +class DeprecatedTransformTest(TestCase): + class DeprecatedTransform(DeprecatedTransformMixin, Transform): + def __init__(self, *args: Any) -> None: + super().__init__(*args) + + class DummyTransform(Transform): + def __init__(self, *args: Any) -> None: + super().__init__(*args) + + class DeprecatedDummyTransform(DeprecatedTransformMixin, DummyTransform): + def __init__(self, *args: Any) -> None: + super().__init__(*args) + + def setUp(self) -> None: + self.deprecated_t = self.DeprecatedTransform(MagicMock(), MagicMock()) + self.t = Transform(MagicMock(), MagicMock()) + + def test_isinstance(self) -> None: + self.assertTrue(isinstance(self.deprecated_t, type(self.t))) + self.assertTrue(isinstance(self.deprecated_t, Transform)) + self.assertTrue(isinstance(self.deprecated_t, self.DeprecatedTransform)) + self.assertTrue(isinstance(self.deprecated_t, DeprecatedTransformMixin)) + + def test_deprecated_transform_equality(self) -> None: + class DeprecatedTransform(DeprecatedTransformMixin, Transform): + def __init__(self, *args): + super().__init__(*args) + + t = Transform(MagicMock(), MagicMock()) + t2 = Transform(MagicMock(), MagicMock()) + self.assertEqual(t.__dict__, t2.__dict__) + + dt = DeprecatedTransform(MagicMock(), MagicMock()) + self.assertEqual(t.__dict__, dt.__dict__) + + def test_logging(self) -> None: + with self.assertLogs( + "ax.modelbridge.transforms.deprecated_transform_mixin", + level=logging.WARNING, + ) as logger: + _ = self.DeprecatedTransform(MagicMock(), MagicMock()) + message = DeprecatedTransformMixin.warn_deprecated_message( + self.DeprecatedTransform.__name__, + Transform.__name__, + ) + self.assertTrue(any(message in s for s in logger.output))