Skip to content

Commit

Permalink
[ENH] refactor _clone to a plugin structure (#381)
Browse files Browse the repository at this point in the history
This PR refactors the current `_clone` logic, which had a lot of if/else
case distinctions (as in the `sklearn` native code) to a list of plugins
that can be extended.

Above the refactor, it also adds a `sklearn` specific plugin that
dispatches to `sklearn` clone, ensuring that configs get cloned as well.

An approach to sktime/sktime#7333, which would
be solved directly after update.

An alternative to #380, with
advantages:

* `sklearn` compatibility is automatic, for any dependency of `skbase`
* the plugins can later be extended easily, for instance in a case like
the missing `dict` support

A test for sktime/sktime#7333, namely that
output configs are retained in `sklearn` objects, has been added.
  • Loading branch information
fkiraly authored Nov 13, 2024
1 parent 49cc6ca commit badd7d4
Show file tree
Hide file tree
Showing 6 changed files with 410 additions and 111 deletions.
107 changes: 2 additions & 105 deletions skbase/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class name: BaseEstimator
from typing import List

from skbase._exceptions import NotFittedError
from skbase.base._clone_base import _check_clone, _clone
from skbase.base._pretty_printing._object_html_repr import _object_html_repr
from skbase.base._tagmanager import _FlagManager

Expand Down Expand Up @@ -175,7 +176,7 @@ def clone(self):
------
RuntimeError if the clone is non-conforming, due to faulty ``__init__``.
"""
self_clone = _clone(self)
self_clone = _clone(self, base_cls=BaseObject)
if self.get_config()["check_clone"]:
_check_clone(original=self, clone=self_clone)
return self_clone
Expand Down Expand Up @@ -1653,107 +1654,3 @@ def _get_fitted_params(self):
fitted parameters, keyed by names of fitted parameter
"""
return self._get_fitted_params_default()


# Adapted from sklearn's `_clone_parametrized()`
def _clone(estimator, *, safe=True):
"""Construct a new unfitted estimator with the same parameters.
Clone does a deep copy of the model in an estimator
without actually copying attached data. It returns a new estimator
with the same parameters that has not been fitted on any data.
Parameters
----------
estimator : {list, tuple, set} of estimator instance or a single \
estimator instance
The estimator or group of estimators to be cloned.
safe : bool, default=True
If safe is False, clone will fall back to a deep copy on objects
that are not estimators.
Returns
-------
estimator : object
The deep copy of the input, an estimator if input is an estimator.
Notes
-----
If the estimator's `random_state` parameter is an integer (or if the
estimator doesn't have a `random_state` parameter), an *exact clone* is
returned: the clone and the original estimator will give the exact same
results. Otherwise, *statistical clone* is returned: the clone might
return different results from the original estimator. More details can be
found in :ref:`randomness`.
"""
estimator_type = type(estimator)
if estimator_type is dict:
return {k: _clone(v, safe=safe) for k, v in estimator.items()}
if estimator_type in (list, tuple, set, frozenset):
return estimator_type([_clone(e, safe=safe) for e in estimator])
elif not hasattr(estimator, "get_params") or isinstance(estimator, type):
if not safe:
return deepcopy(estimator)
else:
if isinstance(estimator, type):
raise TypeError(
"Cannot clone object. "
+ "You should provide an instance of "
+ "scikit-learn estimator instead of a class."
)
else:
raise TypeError(
"Cannot clone object '%s' (type %s): "
"it does not seem to be a scikit-learn "
"estimator as it does not implement a "
"'get_params' method." % (repr(estimator), type(estimator))
)

klass = estimator.__class__
new_object_params = estimator.get_params(deep=False)
for name, param in new_object_params.items():
new_object_params[name] = _clone(param, safe=False)
new_object = klass(**new_object_params)
params_set = new_object.get_params(deep=False)

# quick sanity check of the parameters of the clone
for name in new_object_params:
param1 = new_object_params[name]
param2 = params_set[name]
if param1 is not param2:
raise RuntimeError(
"Cannot clone object %s, as the constructor "
"either does not set or modifies parameter %s" % (estimator, name)
)

# This is an extension to the original sklearn implementation
if isinstance(estimator, BaseObject) and estimator.get_config()["clone_config"]:
new_object.set_config(**estimator.get_config())

return new_object


def _check_clone(original, clone):
from skbase.utils.deep_equals import deep_equals

self_params = original.get_params(deep=False)

# check that all attributes are written to the clone
for attrname in self_params.keys():
if not hasattr(clone, attrname):
raise RuntimeError(
f"error in {original}.clone, __init__ must write all arguments "
f"to self and not mutate them, but {attrname} was not found. "
f"Please check __init__ of {original}."
)

clone_attrs = {attr: getattr(clone, attr) for attr in self_params.keys()}

# check equality of parameters post-clone and pre-clone
clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True)
if not clone_attrs_valid:
raise RuntimeError(
f"error in {original}.clone, __init__ must write all arguments "
f"to self and not mutate them, but this is not the case. "
f"Error on equality check of arguments (x) vs parameters (y): {msg}"
)
129 changes: 129 additions & 0 deletions skbase/base/_clone_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# -*- coding: utf-8 -*-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
# Elements of BaseObject reuse code developed in scikit-learn. These elements
# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
"""Logic and plugins for cloning objects.
This module contains logic for cloning objects:
_clone(estimator, *, safe=True, plugins=None) - central entry point for cloning
_check_clone(original, clone) - validation utility to check clones
Default plugins for _clone are stored in _clone_plugins:
DEFAULT_CLONE_PLUGINS - list with default plugins for cloning
Each element of DEFAULT_CLONE_PLUGINS inherits from BaseCloner, with methods:
* check(obj) -> boolean - fast checker whether plugin applies
* clone(obj) -> type(obj) - method to clone obj
"""
__all__ = ["_clone", "_check_clone"]

from skbase.base._clone_plugins import DEFAULT_CLONE_PLUGINS


# Adapted from sklearn's `_clone_parametrized()`
def _clone(estimator, *, safe=True, clone_plugins=None, base_cls=None):
"""Construct a new unfitted estimator with the same parameters.
Clone does a deep copy of the model in an estimator
without actually copying attached data. It returns a new estimator
with the same parameters that has not been fitted on any data.
Parameters
----------
estimator : {list, tuple, set} of estimator instance or a single estimator instance
The estimator or group of estimators to be cloned.
safe : bool, default=True
If ``safe`` is False, clone will fall back to a deep copy on objects
that are not estimators.
clone_plugins : list of BaseCloner clone plugins, concrete descendant classes.
Must implement ``_check`` and ``_clone`` method, see ``BaseCloner`` interface.
If passed, will work through clone plugins in ``clone_plugins``
before working through ``DEFAULT_CLONE_PLUGINS``. To override
a cloner in ``DEAULT_CLONE_PLUGINS``, simply ensure a cloner with
the same ``_check`` logis is present in ``clone_plugins``.
base_cls : reference to BaseObject
Reference to the BaseObject class from skbase.base._base.
Present for easy reference, fast imports, and potential extensions.
Returns
-------
estimator : object
The deep copy of the input, an estimator if input is an estimator.
Notes
-----
If the estimator's `random_state` parameter is an integer (or if the
estimator doesn't have a `random_state` parameter), an *exact clone* is
returned: the clone and the original estimator will give the exact same
results. Otherwise, *statistical clone* is returned: the clone might
return different results from the original estimator. More details can be
found in :ref:`randomness`.
"""
# handle cloning plugins:
# if no plugins provided by user, work through the DEFAULT_CLONE_PLUGINS
# if provided by user, work through user provided plugins first, then defaults
if clone_plugins is not None:
all_plugins = clone_plugins.copy()
all_plugins.append(DEFAULT_CLONE_PLUGINS.copy())
else:
all_plugins = DEFAULT_CLONE_PLUGINS

for cloner_plugin in all_plugins:
cloner = cloner_plugin(safe=safe, clone_plugins=all_plugins, base_cls=base_cls)
# we clone with the first plugin in the list that:
# 1. claims it is applicable, via check
# 2. does not produce an Exception when cloning
if cloner.check(obj=estimator):
return cloner.clone(obj=estimator)

raise RuntimeError(
"Error in skbase _clone, catch-all plugin did not catch all "
"remaining cases. This is likely due to custom modification of the module."
)


def _check_clone(original, clone):
"""Check that clone is a valid clone of original.
Called from BaseObject.clone to validate the clone, if
the config flag check_clone is set to True.
Parameters
----------
original : object
The original object.
clone : object
The cloned object.
Raises
------
RuntimeError
If the clone is not a valid clone of the original.
"""
from skbase.utils.deep_equals import deep_equals

self_params = original.get_params(deep=False)

# check that all attributes are written to the clone
for attrname in self_params.keys():
if not hasattr(clone, attrname):
raise RuntimeError(
f"error in {original}.clone, __init__ must write all arguments "
f"to self and not mutate them, but {attrname} was not found. "
f"Please check __init__ of {original}."
)

clone_attrs = {attr: getattr(clone, attr) for attr in self_params.keys()}

# check equality of parameters post-clone and pre-clone
clone_attrs_valid, msg = deep_equals(self_params, clone_attrs, return_msg=True)
if not clone_attrs_valid:
raise RuntimeError(
f"error in {original}.clone, __init__ must write all arguments "
f"to self and not mutate them, but this is not the case. "
f"Error on equality check of arguments (x) vs parameters (y): {msg}"
)
Loading

0 comments on commit badd7d4

Please sign in to comment.