diff --git a/skbase/base/_base.py b/skbase/base/_base.py index 24dea32b..daf9fdca 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -60,6 +60,7 @@ class name: BaseEstimator from typing import List from skbase._exceptions import NotFittedError +from skbase.base._clone_base import _check_clone, _clone from skbase.base._pretty_printing._object_html_repr import _object_html_repr from skbase.base._tagmanager import _FlagManager @@ -175,7 +176,7 @@ def clone(self): ------ RuntimeError if the clone is non-conforming, due to faulty ``__init__``. """ - self_clone = _clone(self) + self_clone = _clone(self, base_cls=BaseObject) if self.get_config()["check_clone"]: _check_clone(original=self, clone=self_clone) return self_clone @@ -1653,107 +1654,3 @@ def _get_fitted_params(self): fitted parameters, keyed by names of fitted parameter """ return self._get_fitted_params_default() - - -# Adapted from sklearn's `_clone_parametrized()` -def _clone(estimator, *, safe=True): - """Construct a new unfitted estimator with the same parameters. - - Clone does a deep copy of the model in an estimator - without actually copying attached data. It returns a new estimator - with the same parameters that has not been fitted on any data. - - Parameters - ---------- - estimator : {list, tuple, set} of estimator instance or a single \ - estimator instance - The estimator or group of estimators to be cloned. - safe : bool, default=True - If safe is False, clone will fall back to a deep copy on objects - that are not estimators. - - Returns - ------- - estimator : object - The deep copy of the input, an estimator if input is an estimator. - - Notes - ----- - If the estimator's `random_state` parameter is an integer (or if the - estimator doesn't have a `random_state` parameter), an *exact clone* is - returned: the clone and the original estimator will give the exact same - results. Otherwise, *statistical clone* is returned: the clone might - return different results from the original estimator. More details can be - found in :ref:`randomness`. - """ - estimator_type = type(estimator) - if estimator_type is dict: - return {k: _clone(v, safe=safe) for k, v in estimator.items()} - if estimator_type in (list, tuple, set, frozenset): - return estimator_type([_clone(e, safe=safe) for e in estimator]) - elif not hasattr(estimator, "get_params") or isinstance(estimator, type): - if not safe: - return deepcopy(estimator) - else: - if isinstance(estimator, type): - raise TypeError( - "Cannot clone object. " - + "You should provide an instance of " - + "scikit-learn estimator instead of a class." - ) - else: - raise TypeError( - "Cannot clone object '%s' (type %s): " - "it does not seem to be a scikit-learn " - "estimator as it does not implement a " - "'get_params' method." % (repr(estimator), type(estimator)) - ) - - klass = estimator.__class__ - new_object_params = estimator.get_params(deep=False) - for name, param in new_object_params.items(): - new_object_params[name] = _clone(param, safe=False) - new_object = klass(**new_object_params) - params_set = new_object.get_params(deep=False) - - # quick sanity check of the parameters of the clone - for name in new_object_params: - param1 = new_object_params[name] - param2 = params_set[name] - if param1 is not param2: - raise RuntimeError( - "Cannot clone object %s, as the constructor " - "either does not set or modifies parameter %s" % (estimator, name) - ) - - # This is an extension to the original sklearn implementation - if isinstance(estimator, BaseObject) and estimator.get_config()["clone_config"]: - new_object.set_config(**estimator.get_config()) - - return new_object - - -def _check_clone(original, clone): - from skbase.utils.deep_equals import deep_equals - - self_params = original.get_params(deep=False) - - # check that all attributes are written to the clone - for attrname in self_params.keys(): - if not hasattr(clone, attrname): - raise RuntimeError( - f"error in {original}.clone, __init__ must write all arguments " - f"to self and not mutate them, but {attrname} was not found. " - f"Please check __init__ of {original}." - ) - - clone_attrs = {attr: getattr(clone, attr) for attr in self_params.keys()} - - # check equality of parameters post-clone and pre-clone - clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True) - if not clone_attrs_valid: - raise RuntimeError( - f"error in {original}.clone, __init__ must write all arguments " - f"to self and not mutate them, but this is not the case. " - f"Error on equality check of arguments (x) vs parameters (y): {msg}" - ) diff --git a/skbase/base/_clone_base.py b/skbase/base/_clone_base.py new file mode 100644 index 00000000..d96f178a --- /dev/null +++ b/skbase/base/_clone_base.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +# Elements of BaseObject reuse code developed in scikit-learn. These elements +# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For +# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING +"""Logic and plugins for cloning objects. + +This module contains logic for cloning objects: + +_clone(estimator, *, safe=True, plugins=None) - central entry point for cloning +_check_clone(original, clone) - validation utility to check clones + +Default plugins for _clone are stored in _clone_plugins: + +DEFAULT_CLONE_PLUGINS - list with default plugins for cloning + +Each element of DEFAULT_CLONE_PLUGINS inherits from BaseCloner, with methods: + +* check(obj) -> boolean - fast checker whether plugin applies +* clone(obj) -> type(obj) - method to clone obj +""" +__all__ = ["_clone", "_check_clone"] + +from skbase.base._clone_plugins import DEFAULT_CLONE_PLUGINS + + +# Adapted from sklearn's `_clone_parametrized()` +def _clone(estimator, *, safe=True, clone_plugins=None, base_cls=None): + """Construct a new unfitted estimator with the same parameters. + + Clone does a deep copy of the model in an estimator + without actually copying attached data. It returns a new estimator + with the same parameters that has not been fitted on any data. + + Parameters + ---------- + estimator : {list, tuple, set} of estimator instance or a single estimator instance + The estimator or group of estimators to be cloned. + safe : bool, default=True + If ``safe`` is False, clone will fall back to a deep copy on objects + that are not estimators. + clone_plugins : list of BaseCloner clone plugins, concrete descendant classes. + Must implement ``_check`` and ``_clone`` method, see ``BaseCloner`` interface. + If passed, will work through clone plugins in ``clone_plugins`` + before working through ``DEFAULT_CLONE_PLUGINS``. To override + a cloner in ``DEAULT_CLONE_PLUGINS``, simply ensure a cloner with + the same ``_check`` logis is present in ``clone_plugins``. + base_cls : reference to BaseObject + Reference to the BaseObject class from skbase.base._base. + Present for easy reference, fast imports, and potential extensions. + + Returns + ------- + estimator : object + The deep copy of the input, an estimator if input is an estimator. + + Notes + ----- + If the estimator's `random_state` parameter is an integer (or if the + estimator doesn't have a `random_state` parameter), an *exact clone* is + returned: the clone and the original estimator will give the exact same + results. Otherwise, *statistical clone* is returned: the clone might + return different results from the original estimator. More details can be + found in :ref:`randomness`. + """ + # handle cloning plugins: + # if no plugins provided by user, work through the DEFAULT_CLONE_PLUGINS + # if provided by user, work through user provided plugins first, then defaults + if clone_plugins is not None: + all_plugins = clone_plugins.copy() + all_plugins.append(DEFAULT_CLONE_PLUGINS.copy()) + else: + all_plugins = DEFAULT_CLONE_PLUGINS + + for cloner_plugin in all_plugins: + cloner = cloner_plugin(safe=safe, clone_plugins=all_plugins, base_cls=base_cls) + # we clone with the first plugin in the list that: + # 1. claims it is applicable, via check + # 2. does not produce an Exception when cloning + if cloner.check(obj=estimator): + return cloner.clone(obj=estimator) + + raise RuntimeError( + "Error in skbase _clone, catch-all plugin did not catch all " + "remaining cases. This is likely due to custom modification of the module." + ) + + +def _check_clone(original, clone): + """Check that clone is a valid clone of original. + + Called from BaseObject.clone to validate the clone, if + the config flag check_clone is set to True. + + Parameters + ---------- + original : object + The original object. + clone : object + The cloned object. + + Raises + ------ + RuntimeError + If the clone is not a valid clone of the original. + """ + from skbase.utils.deep_equals import deep_equals + + self_params = original.get_params(deep=False) + + # check that all attributes are written to the clone + for attrname in self_params.keys(): + if not hasattr(clone, attrname): + raise RuntimeError( + f"error in {original}.clone, __init__ must write all arguments " + f"to self and not mutate them, but {attrname} was not found. " + f"Please check __init__ of {original}." + ) + + clone_attrs = {attr: getattr(clone, attr) for attr in self_params.keys()} + + # check equality of parameters post-clone and pre-clone + clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True) + if not clone_attrs_valid: + raise RuntimeError( + f"error in {original}.clone, __init__ must write all arguments " + f"to self and not mutate them, but this is not the case. " + f"Error on equality check of arguments (x) vs parameters (y): {msg}" + ) diff --git a/skbase/base/_clone_plugins.py b/skbase/base/_clone_plugins.py new file mode 100644 index 00000000..71a2be3f --- /dev/null +++ b/skbase/base/_clone_plugins.py @@ -0,0 +1,215 @@ +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +# Elements of BaseObject reuse code developed in scikit-learn. These elements +# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For +# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING +"""Logic and plugins for cloning objects - default plugins. + +This module contains default plugins for _clone, from _clone_base. + +DEFAULT_CLONE_PLUGINS - list with default plugins for cloning + +Each element of DEFAULT_CLONE_PLUGINS inherits from BaseCloner, with methods: + +* check(obj) -> boolean - fast checker whether plugin applies +* clone(obj) -> type(obj) - method to clone obj +""" +from functools import lru_cache +from inspect import isclass + + +# imports wrapped in functions to avoid exceptions on skbase init +# wrapped in _safe_import to avoid exceptions on skbase init +@lru_cache(maxsize=None) +def _is_sklearn_present(): + """Check whether scikit-learn is present.""" + from skbase.utils.dependencies import _check_soft_dependencies + + return _check_soft_dependencies("scikit-learn") + + +@lru_cache(maxsize=None) +def _get_sklearn_clone(): + """Get sklearn's clone function.""" + from skbase.utils.dependencies._import import _safe_import + + return _safe_import("sklearn.base:clone", condition=_is_sklearn_present()) + + +class BaseCloner: + """Base class for clone plugins. + + Concrete classes must inherit methods: + + * check(obj) -> boolean - fast checker whether plugin applies + * clone(obj) -> type(obj) - method to clone obj + """ + + def __init__(self, safe, clone_plugins=None, base_cls=None): + self.safe = safe + self.clone_plugins = clone_plugins + self.base_cls = base_cls + + def check(self, obj): + """Check whether the plugin applies to obj.""" + try: + return self._check(obj) + except Exception: + return False + + def clone(self, obj): + """Return a clone of obj.""" + return self._clone(obj) + + def recursive_clone(self, obj, **kwargs): + """Recursive call to _clone, for explicit code and to avoid circular imports.""" + from skbase.base._clone_base import _clone + + recursion_kwargs = { + "safe": self.safe, + "clone_plugins": self.clone_plugins, + "base_cls": self.base_cls, + } + recursion_kwargs.update(kwargs) + return _clone(obj, **recursion_kwargs) + + +class _CloneClass(BaseCloner): + """Clone plugin for classes. Returns the class.""" + + def _check(self, obj): + """Check whether the plugin applies to obj.""" + return isclass(obj) + + def _clone(self, obj): + """Return a clone of obj.""" + return obj + + +class _CloneDict(BaseCloner): + """Clone plugin for dicts. Performs recursive cloning.""" + + def _check(self, obj): + """Check whether the plugin applies to obj.""" + return isinstance(obj, dict) + + def _clone(self, obj): + """Return a clone of obj.""" + _clone = self.recursive_clone + return {k: _clone(v) for k, v in obj.items()} + + +class _CloneListTupleSet(BaseCloner): + """Clone plugin for lists, tuples, sets. Performs recursive cloning.""" + + def _check(self, obj): + """Check whether the plugin applies to obj.""" + return isinstance(obj, (list, tuple, set, frozenset)) + + def _clone(self, obj): + """Return a clone of obj.""" + _clone = self.recursive_clone + return type(obj)([_clone(e) for e in obj]) + + +def _default_clone(estimator, recursive_clone): + """Clone estimator. Default used in skbase native and generic get_params clone.""" + klass = estimator.__class__ + new_object_params = estimator.get_params(deep=False) + for name, param in new_object_params.items(): + new_object_params[name] = recursive_clone(param, safe=False) + new_object = klass(**new_object_params) + params_set = new_object.get_params(deep=False) + + # quick sanity check of the parameters of the clone + for name in new_object_params: + param1 = new_object_params[name] + param2 = params_set[name] + if param1 is not param2: + raise RuntimeError( + "Cannot clone object %s, as the constructor " + "either does not set or modifies parameter %s" % (estimator, name) + ) + + return new_object + + +class _CloneSkbase(BaseCloner): + """Clone plugin for scikit-base BaseObject descendants.""" + + def _check(self, obj): + """Check whether the plugin applies to obj.""" + return isinstance(obj, self.base_cls) + + def _clone(self, obj): + """Return a clone of obj.""" + new_object = _default_clone(estimator=obj, recursive_clone=self.recursive_clone) + + # Ensure that configs are retained in the new object + if obj.get_config()["clone_config"]: + new_object.set_config(**obj.get_config()) + + return new_object + + +class _CloneSklearn(BaseCloner): + """Clone plugin for scikit-learn BaseEstimator descendants.""" + + def _check(self, obj): + """Check whether the plugin applies to obj.""" + if not _is_sklearn_present(): + return False + + from sklearn.base import BaseEstimator + + return isinstance(obj, BaseEstimator) + + def _clone(self, obj): + """Return a clone of obj.""" + _sklearn_clone = _get_sklearn_clone() + return _sklearn_clone(obj) + + +class _CloneGetParams(BaseCloner): + """Clone plugin for objects that implement get_params but are not the above.""" + + def _check(self, obj): + """Check whether the plugin applies to obj.""" + return hasattr(obj, "get_params") + + def _clone(self, obj): + """Return a clone of obj.""" + return _default_clone(estimator=obj, recursive_clone=self.recursive_clone) + + +class _CloneCatchAll(BaseCloner): + """Catch-all plug-in to deal, catches all objects at the end of list.""" + + def _check(self, obj): + """Check whether the plugin applies to obj.""" + return True + + def _clone(self, obj): + """Return a clone of obj.""" + from copy import deepcopy + + if not self.safe: + return deepcopy(obj) + else: + raise TypeError( + "Cannot clone object '%s' (type %s): " + "it does not seem to be a scikit-base object or scikit-learn " + "estimator, as it does not implement a " + "'get_params' method." % (repr(obj), type(obj)) + ) + + +DEFAULT_CLONE_PLUGINS = [ + _CloneClass, + _CloneDict, + _CloneListTupleSet, + _CloneSkbase, + _CloneSklearn, + _CloneGetParams, + _CloneCatchAll, +] diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py index 019cb097..5c6e0bcf 100644 --- a/skbase/tests/conftest.py +++ b/skbase/tests/conftest.py @@ -22,6 +22,8 @@ "skbase._nopytest_tests", "skbase.base", "skbase.base._base", + "skbase.base._clone_base", + "skbase.base._clone_plugins", "skbase.base._meta", "skbase.base._pretty_printing", "skbase.base._pretty_printing._object_html_repr", @@ -53,6 +55,7 @@ "skbase.utils.deep_equals._deep_equals", "skbase.utils.dependencies", "skbase.utils.dependencies._dependencies", + "skbase.utils.dependencies._import", "skbase.utils.random_state", "skbase.utils.stderr_mute", "skbase.utils.stdout_mute", @@ -96,6 +99,7 @@ "BaseObject", ), "skbase.base._base": ("BaseEstimator", "BaseObject"), + "skbase.base._clone_plugins": ("BaseCloner",), "skbase.base._meta": ( "BaseMetaObject", "BaseMetaObjectMixin", @@ -116,6 +120,16 @@ SKBASE_CLASSES_BY_MODULE = SKBASE_PUBLIC_CLASSES_BY_MODULE.copy() SKBASE_CLASSES_BY_MODULE.update( { + "skbase.base._clone_plugins": ( + "BaseCloner", + "_CloneClass", + "_CloneSkbase", + "_CloneSklearn", + "_CloneDict", + "_CloneListTupleSet", + "_CloneGetParams", + "_CloneCatchAll", + ), "skbase.base._meta": ( "BaseMetaObject", "BaseMetaObjectMixin", @@ -184,10 +198,8 @@ SKBASE_FUNCTIONS_BY_MODULE = SKBASE_PUBLIC_FUNCTIONS_BY_MODULE.copy() SKBASE_FUNCTIONS_BY_MODULE.update( { - "skbase.base._base": ( - "_clone", - "_check_clone", - ), + "skbase.base._clone_base": {"_check_clone", "_clone"}, + "skbase.base._clone_plugins": ("_default_clone",), "skbase.base._pretty_printing._object_html_repr": ( "_get_visual_block", "_object_html_repr", @@ -218,6 +230,7 @@ "_check_python_version", "_check_estimator_deps", ), + "skbase.utils.dependencies._import": ("_safe_import",), "skbase.utils._iter": ( "_format_seq_to_str", "_remove_type_text", diff --git a/skbase/tests/test_base.py b/skbase/tests/test_base.py index 3f053ed2..d1789f64 100644 --- a/skbase/tests/test_base.py +++ b/skbase/tests/test_base.py @@ -1123,8 +1123,8 @@ def test_clone_class_rather_than_instance_raises_error( not _check_soft_dependencies("scikit-learn", severity="none"), reason="skip test if sklearn is not available", ) # sklearn is part of the dev dependency set, test should be executed with that -def test_clone_sklearn_composite(fixture_class_parent: Type[Parent]): - """Test clone with keyword parameter set to None.""" +def test_clone_sklearn_composite(): + """Test clone with a composite of sklearn and skbase.""" from sklearn.ensemble import GradientBoostingRegressor sklearn_obj = GradientBoostingRegressor(random_state=5, learning_rate=0.02) @@ -1134,6 +1134,23 @@ def test_clone_sklearn_composite(fixture_class_parent: Type[Parent]): assert composite_set.get_params()["a__random_state"] == 42 +@pytest.mark.skipif( + not _check_soft_dependencies("scikit-learn", severity="none"), + reason="skip test if sklearn is not available", +) # sklearn is part of the dev dependency set, test should be executed with that +def test_clone_sklearn_composite_retains_config(): + """Test that clone retains sklearn config if inside skbase composite.""" + from sklearn.preprocessing import StandardScaler + + sklearn_obj_w_config = StandardScaler().set_output(transform="pandas") + + composite = ResetTester(a=sklearn_obj_w_config) + composite_clone = composite.clone() + + assert hasattr(composite_clone.a, "_sklearn_output_config") + assert composite_clone.a._sklearn_output_config.get("transform", None) == "pandas" + + # Tests of BaseObject pretty printing representation inspired by sklearn def test_baseobject_repr( fixture_class_parent: Type[Parent], diff --git a/skbase/utils/dependencies/_import.py b/skbase/utils/dependencies/_import.py new file mode 100644 index 00000000..1a9faea3 --- /dev/null +++ b/skbase/utils/dependencies/_import.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +"""Utility for safe import.""" +import importlib + + +def _safe_import(path, condition=True): + """Safely imports an object from a module given its string location. + + Parameters + ---------- + path: str + A string representing the module and object. + In the form ``"module.submodule:object"``. + condition: bool, default=True + If False, the import will not be attempted. + + Returns + ------- + Any: The imported object, or None if it could not be imported. + """ + if not condition: + return None + try: + module_name, object_name = path.split(":") + module = importlib.import_module(module_name) + return getattr(module, object_name, None) + except (ImportError, AttributeError, ValueError): + return None