diff --git a/kedro/io/core.py b/kedro/io/core.py index 4d05a14ff6..1ba0a2f2b4 100644 --- a/kedro/io/core.py +++ b/kedro/io/core.py @@ -20,6 +20,7 @@ from cachetools import Cache, cachedmethod from cachetools.keys import hashkey +from typing_extensions import Self from kedro.utils import load_obj @@ -178,9 +179,9 @@ def _logger(self) -> logging.Logger: return logging.getLogger(__name__) @classmethod - def _load_wrapper(cls, load_func: Callable[[], _DO]) -> Callable[[], _DO]: + def _load_wrapper(cls, load_func: Callable[[Self], _DO]) -> Callable[[Self], _DO]: @wraps(load_func) - def load(self) -> _DO: + def load(self: Self) -> _DO: self._logger.debug("Loading %s", str(self)) try: @@ -200,9 +201,11 @@ def load(self) -> _DO: return load @classmethod - def _save_wrapper(cls, save_func: Callable[[_DI], None]) -> Callable[[_DI], None]: + def _save_wrapper( + cls, save_func: Callable[[Self, _DI], None] + ) -> Callable[[Self, _DI], None]: @wraps(save_func) - def save(self, data: _DI) -> None: + def save(self: Self, data: _DI) -> None: if data is None: raise DatasetError("Saving 'None' to a 'Dataset' is not allowed") @@ -226,14 +229,14 @@ def __init_subclass__(cls, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) if hasattr(cls, "load") and not cls.load.__qualname__.startswith("Abstract"): - cls.load = cls._load_wrapper( # type: ignore[method-assign] + cls.load = cls._load_wrapper( # type: ignore[assignment] cls.load if not getattr(cls.load, "__loadwrapped__", False) else cls.load.__wrapped__ # type: ignore[attr-defined] ) if hasattr(cls, "save") and not cls.save.__qualname__.startswith("Abstract"): - cls.save = cls._save_wrapper( # type: ignore[method-assign] + cls.save = cls._save_wrapper( # type: ignore[assignment] cls.save if not getattr(cls.save, "__savewrapped__", False) else cls.save.__wrapped__ # type: ignore[attr-defined] @@ -678,9 +681,11 @@ def load(self) -> _DO: return super().load() @classmethod - def _save_wrapper(cls, save_func: Callable[[_DI], None]) -> Callable[[_DI], None]: + def _save_wrapper( + cls, save_func: Callable[[Self, _DI], None] + ) -> Callable[[Self, _DI], None]: @wraps(save_func) - def save(self, data: _DI) -> None: + def save(self: Self, data: _DI) -> None: self._version_cache.clear() save_version = ( self.resolve_save_version() diff --git a/pyproject.toml b/pyproject.toml index 6ece4851cd..e7d80de0ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "rich>=12.0,<14.0", "rope>=0.21,<2.0", # subject to LGPLv3 license "toml>=0.10.0", + "typing_extensions>=4.0", "graphlib_backport>=1.0.0; python_version < '3.9'", ] keywords = [