diff --git a/ax/modelbridge/transforms/deprecated_transform_mixin.py b/ax/modelbridge/transforms/deprecated_transform_mixin.py new file mode 100644 index 00000000000..ab49456cb60 --- /dev/null +++ b/ax/modelbridge/transforms/deprecated_transform_mixin.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from logging import Logger +from typing import Any + +from ax.utils.common.logger import get_logger + +logger: Logger = get_logger(__name__) + + +class DeprecatedTransformMixin: + """ + Mixin class for deprecated transforms. + + This class is used to log warnings when a deprecated transform is used, + and will construct the new transform that should be used instead. + + The deprecated transform should inherit as follows: + + class DeprecatedTransform(DeprecatedTransformMixin, NewTransform): + ... + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """ + Log a warning that the transform is deprecated, and construct the + new transform. + """ + warning_msg = self.warn_deprecated_message( + self.__class__.__name__, type(self).__bases__[1].__name__ + ) + logger.warning(warning_msg) + + super().__init__(*args, **kwargs) + + @staticmethod + def warn_deprecated_message( + deprecated_transform_name: str, new_transform_name: str + ) -> str: + """ + Constructs the warning message. + """ + return ( + f"`{deprecated_transform_name}` transform has been deprecated " + "and will be removed in a future release. " + f"Using `{new_transform_name}` instead." + ) diff --git a/ax/modelbridge/transforms/tests/test_deprecated_transform_mixin.py b/ax/modelbridge/transforms/tests/test_deprecated_transform_mixin.py new file mode 100644 index 00000000000..3bf08b96eea --- /dev/null +++ b/ax/modelbridge/transforms/tests/test_deprecated_transform_mixin.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging + +from typing import Any +from unittest.mock import MagicMock + +from ax.modelbridge.transforms.base import Transform +from ax.modelbridge.transforms.deprecated_transform_mixin import ( + DeprecatedTransformMixin, +) +from ax.utils.common.testutils import TestCase + + +class DeprecatedTransformTest(TestCase): + class DeprecatedTransform(DeprecatedTransformMixin, Transform): + def __init__(self, *args: Any) -> None: + super().__init__(*args) + + class DummyTransform(Transform): + def __init__(self, *args: Any) -> None: + super().__init__(*args) + + class DeprecatedDummyTransform(DeprecatedTransformMixin, DummyTransform): + def __init__(self, *args: Any) -> None: + super().__init__(*args) + + def setUp(self) -> None: + self.deprecated_t = self.DeprecatedTransform(MagicMock(), MagicMock()) + self.t = Transform(MagicMock(), MagicMock()) + + def test_isinstance(self) -> None: + self.assertTrue(isinstance(self.deprecated_t, type(self.t))) + self.assertTrue(isinstance(self.deprecated_t, Transform)) + self.assertTrue(isinstance(self.deprecated_t, self.DeprecatedTransform)) + self.assertTrue(isinstance(self.deprecated_t, DeprecatedTransformMixin)) + + def test_deprecated_transform_equality(self) -> None: + class DeprecatedTransform(DeprecatedTransformMixin, Transform): + def __init__(self, *args): + super().__init__(*args) + + t = Transform(MagicMock(), MagicMock()) + t2 = Transform(MagicMock(), MagicMock()) + self.assertEqual(t.__dict__, t2.__dict__) + + dt = DeprecatedTransform(MagicMock(), MagicMock()) + self.assertEqual(t.__dict__, dt.__dict__) + + def test_logging(self) -> None: + with self.assertLogs( + "ax.modelbridge.transforms.deprecated_transform_mixin", + level=logging.WARNING, + ) as logger: + _ = self.DeprecatedTransform(MagicMock(), MagicMock()) + message = DeprecatedTransformMixin.warn_deprecated_message( + self.DeprecatedTransform.__name__, + Transform.__name__, + ) + self.assertTrue(any(message in s for s in logger.output))