Skip to content

Commit

Permalink
Adding DeprecatedTransformMixin class
Browse files Browse the repository at this point in the history
Summary:
This task starts the backlog task "Rename Ax transforms in terms of what they transform from and to, when it isn't clear"

It has a list of transforms to have their names updated to clearer values
```
OrderedChoiceEncode -> OrderedChoiceToIntegerRange
ChoiceEncode -> ChoiceToNumericChoice
TaskEncode -> TaskChoiceToIntTaskChoice
Cast -> Map
```

This change
- Adds a "DeprecatedTransformMixin", which classes can inherit from in order to print a logging message with the deprecated transform and the new transform to update to.

Subsequent changes will add the new transform classes, and update the transform registry to point to the new classes instead of the old.

## Warning

The warning is as follows:
```
[WARNING 04-04 09:58:45] ax.modelbridge.transforms.deprecated_transform_mixin:
`DeprecatedTransform` transform has been deprecated
and will be removed in a future release. Using `Transform` instead.
```

Reviewed By: mpolson64

Differential Revision: D55643016

fbshipit-source-id: dd05ffc1a350799ee46563023df7dfed316ccdc0
  • Loading branch information
mgrange1998 authored and facebook-github-bot committed Apr 5, 2024
1 parent 9026a58 commit a1f8be6
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 0 deletions.
53 changes: 53 additions & 0 deletions ax/modelbridge/transforms/deprecated_transform_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from logging import Logger
from typing import Any

from ax.utils.common.logger import get_logger

logger: Logger = get_logger(__name__)


class DeprecatedTransformMixin:
"""
Mixin class for deprecated transforms.
This class is used to log warnings when a deprecated transform is used,
and will construct the new transform that should be used instead.
The deprecated transform should inherit as follows:
class DeprecatedTransform(DeprecatedTransformMixin, NewTransform):
...
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""
Log a warning that the transform is deprecated, and construct the
new transform.
"""
warning_msg = self.warn_deprecated_message(
self.__class__.__name__, type(self).__bases__[1].__name__
)
logger.warning(warning_msg)

super().__init__(*args, **kwargs)

@staticmethod
def warn_deprecated_message(
deprecated_transform_name: str, new_transform_name: str
) -> str:
"""
Constructs the warning message.
"""
return (
f"`{deprecated_transform_name}` transform has been deprecated "
"and will be removed in a future release. "
f"Using `{new_transform_name}` instead."
)
66 changes: 66 additions & 0 deletions ax/modelbridge/transforms/tests/test_deprecated_transform_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import logging

from typing import Any
from unittest.mock import MagicMock

from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.deprecated_transform_mixin import (
DeprecatedTransformMixin,
)
from ax.utils.common.testutils import TestCase


class DeprecatedTransformTest(TestCase):
class DeprecatedTransform(DeprecatedTransformMixin, Transform):
def __init__(self, *args: Any) -> None:
super().__init__(*args)

class DummyTransform(Transform):
def __init__(self, *args: Any) -> None:
super().__init__(*args)

class DeprecatedDummyTransform(DeprecatedTransformMixin, DummyTransform):
def __init__(self, *args: Any) -> None:
super().__init__(*args)

def setUp(self) -> None:
self.deprecated_t = self.DeprecatedTransform(MagicMock(), MagicMock())
self.t = Transform(MagicMock(), MagicMock())

def test_isinstance(self) -> None:
self.assertTrue(isinstance(self.deprecated_t, type(self.t)))
self.assertTrue(isinstance(self.deprecated_t, Transform))
self.assertTrue(isinstance(self.deprecated_t, self.DeprecatedTransform))
self.assertTrue(isinstance(self.deprecated_t, DeprecatedTransformMixin))

def test_deprecated_transform_equality(self) -> None:
class DeprecatedTransform(DeprecatedTransformMixin, Transform):
def __init__(self, *args):
super().__init__(*args)

t = Transform(MagicMock(), MagicMock())
t2 = Transform(MagicMock(), MagicMock())
self.assertEqual(t.__dict__, t2.__dict__)

dt = DeprecatedTransform(MagicMock(), MagicMock())
self.assertEqual(t.__dict__, dt.__dict__)

def test_logging(self) -> None:
with self.assertLogs(
"ax.modelbridge.transforms.deprecated_transform_mixin",
level=logging.WARNING,
) as logger:
_ = self.DeprecatedTransform(MagicMock(), MagicMock())
message = DeprecatedTransformMixin.warn_deprecated_message(
self.DeprecatedTransform.__name__,
Transform.__name__,
)
self.assertTrue(any(message in s for s in logger.output))

0 comments on commit a1f8be6

Please sign in to comment.