diff --git a/skbase/base/_base.py b/skbase/base/_base.py index 28512554..b2b273de 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -1384,7 +1384,8 @@ def _clone(estimator, *, safe=True): found in :ref:`randomness`. """ estimator_type = type(estimator) - # XXX: not handling dictionaries + if estimator_type is dict: + return {k: _clone(v, safe=safe) for k, v in estimator.items()} if estimator_type in (list, tuple, set, frozenset): return estimator_type([_clone(e, safe=safe) for e in estimator]) elif not hasattr(estimator, "get_params") or isinstance(estimator, type):