Skip to content

Commit

Permalink
Merge branch 'main' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
Kedro committed Nov 29, 2024
2 parents 704c66d + b482aaa commit 1e92b50
Show file tree
Hide file tree
Showing 10 changed files with 375 additions and 11 deletions.
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# Upcoming Release

## Major features and improvements
* Implemented `KedroDataCatalog.to_config()` method that converts the catalog instance into a configuration format suitable for serialization.

## Bug fixes and other changes
* Added validation to ensure dataset versions consistency across catalog.

## Breaking changes to the API
## Documentation changes
## Community contributions
Expand Down
1 change: 1 addition & 0 deletions kedro/framework/cli/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def _create_session(package_name: str, **kwargs: Any) -> KedroSession:


def is_parameter(dataset_name: str) -> bool:
# TODO: when breaking change move it to kedro/io/core.py
"""Check if dataset is a parameter."""
return dataset_name.startswith("params:") or dataset_name == "parameters"

Expand Down
53 changes: 48 additions & 5 deletions kedro/io/catalog_config_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
self._dataset_patterns, self._default_pattern = self._extract_patterns(
config, credentials
)
self._resolved_configs = self._resolve_config_credentials(config, credentials)
self._resolved_configs = self.resolve_credentials(config, credentials)

@property
def config(self) -> dict[str, dict[str, Any]]:
Expand Down Expand Up @@ -237,8 +237,9 @@ def _extract_patterns(

return sorted_patterns, user_default

def _resolve_config_credentials(
self,
@classmethod
def resolve_credentials(
cls,
config: dict[str, dict[str, Any]] | None,
credentials: dict[str, dict[str, Any]] | None,
) -> dict[str, dict[str, Any]]:
Expand All @@ -254,13 +255,55 @@ def _resolve_config_credentials(
"\nHint: If this catalog entry is intended for variable interpolation, "
"make sure that the key is preceded by an underscore."
)
if not self.is_pattern(ds_name):
resolved_configs[ds_name] = self._resolve_credentials(
if not cls.is_pattern(ds_name):
resolved_configs[ds_name] = cls._resolve_credentials(
ds_config, credentials
)

return resolved_configs

@staticmethod
def unresolve_credentials(
cred_name: str, ds_config: dict[str, dict[str, Any]] | None
) -> tuple[dict[str, dict[str, Any]], dict[str, dict[str, Any]]]:
"""Extracts and replaces credentials in a dataset configuration with
references, ensuring separation of credentials from the dataset configuration.
Credentials are searched for recursively in the dataset configuration.
The first occurrence of the `CREDENTIALS_KEY` is replaced with a generated
reference key.
Args:
cred_name: A unique identifier for the credentials being unresolved.
This is used to generate a reference key for the credentials.
ds_config: The dataset configuration containing potential credentials
under the key `CREDENTIALS_KEY`.
Returns:
A tuple containing:
ds_config_copy : A deep copy of the original dataset
configuration with credentials replaced by reference keys.
credentials: A dictionary mapping generated reference keys to the original credentials.
"""
ds_config_copy = copy.deepcopy(ds_config) or {}
credentials: dict[str, Any] = {}
credentials_ref = f"{cred_name}_{CREDENTIALS_KEY}"

def unresolve(config: Any) -> None:
# We don't expect credentials key appears more than once within the same dataset config,
# So once we found the key first time we unresolve it and stop iterating after
for key, val in config.items():
if key == CREDENTIALS_KEY and config[key]:
credentials[credentials_ref] = config[key]
config[key] = credentials_ref
return
if isinstance(val, dict):
unresolve(val)

unresolve(ds_config_copy)

return ds_config_copy, credentials

def resolve_pattern(self, ds_name: str) -> dict[str, Any]:
"""Resolve dataset patterns and return resolved configurations based on the existing patterns."""
matched_pattern = self.match_pattern(ds_name)
Expand Down
84 changes: 80 additions & 4 deletions kedro/io/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from datetime import datetime, timezone
from functools import partial, wraps
from glob import iglob
from inspect import getcallargs
from operator import attrgetter
from pathlib import Path, PurePath, PurePosixPath
from typing import (
Expand Down Expand Up @@ -57,6 +58,7 @@
"s3a",
"s3n",
)
TYPE_KEY = "type"


class DatasetError(Exception):
Expand Down Expand Up @@ -210,6 +212,59 @@ def from_config(
) from err
return dataset

def to_config(self) -> dict[str, Any]:
"""Converts the dataset instance into a dictionary-based configuration for
serialization. Ensures that any subclass-specific details are handled, with
additional logic for versioning and caching implemented for `CachedDataset`.
Adds a key for the dataset's type using its module and class name and
includes the initialization arguments.
For `CachedDataset` it extracts the underlying dataset's configuration,
handles the `versioned` flag and removes unnecessary metadata. It also
ensures the embedded dataset's configuration is appropriately flattened
or transformed.
If the dataset has a version key, it sets the `versioned` flag in the
configuration.
Removes the `metadata` key from the configuration if present.
Returns:
A dictionary containing the dataset's type and initialization arguments.
"""
return_config: dict[str, Any] = {
f"{TYPE_KEY}": f"{type(self).__module__}.{type(self).__name__}"
}

if self._init_args: # type: ignore[attr-defined]
self._init_args.pop("self", None) # type: ignore[attr-defined]
return_config.update(self._init_args) # type: ignore[attr-defined]

if type(self).__name__ == "CachedDataset":
cached_ds = return_config.pop("dataset")
cached_ds_return_config: dict[str, Any] = {}
if isinstance(cached_ds, dict):
cached_ds_return_config = cached_ds
elif isinstance(cached_ds, AbstractDataset):
cached_ds_return_config = cached_ds.to_config()
if VERSIONED_FLAG_KEY in cached_ds_return_config:
return_config[VERSIONED_FLAG_KEY] = cached_ds_return_config.pop(
VERSIONED_FLAG_KEY
)
# Pop metadata from configuration
cached_ds_return_config.pop("metadata", None)
return_config["dataset"] = cached_ds_return_config

# Set `versioned` key if version present in the dataset
if return_config.pop(VERSION_KEY, None):
return_config[VERSIONED_FLAG_KEY] = True

# Pop metadata from configuration
return_config.pop("metadata", None)

return return_config

@property
def _logger(self) -> logging.Logger:
return logging.getLogger(__name__)
Expand Down Expand Up @@ -290,11 +345,32 @@ def save(self: Self, data: _DI) -> None:
return save

def __init_subclass__(cls, **kwargs: Any) -> None:
"""Decorate the `load` and `save` methods provided by the class.
"""Customizes the behavior of subclasses of AbstractDataset during
their creation. This method is automatically invoked when a subclass
of AbstractDataset is defined.
Decorates the `load` and `save` methods provided by the class.
If `_load` or `_save` are defined, alias them as a prerequisite.
"""

# Save the original __init__ method of the subclass
init_func: Callable = cls.__init__

@wraps(init_func)
def new_init(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
"""Executes the original __init__, then save the arguments used
to initialize the instance.
"""
# Call the original __init__ method
init_func(self, *args, **kwargs)
# Capture and save the arguments passed to the original __init__
self._init_args = getcallargs(init_func, self, *args, **kwargs)

# Replace the subclass's __init__ with the new_init
# A hook for subclasses to capture initialization arguments and save them
# in the AbstractDataset._init_args field
cls.__init__ = new_init # type: ignore[method-assign]

super().__init_subclass__(**kwargs)

if hasattr(cls, "_load") and not cls._load.__qualname__.startswith("Abstract"):
Expand Down Expand Up @@ -493,14 +569,14 @@ def parse_dataset_definition(
config = copy.deepcopy(config)

# TODO: remove when removing old catalog as moved to KedroDataCatalog
if "type" not in config:
if TYPE_KEY not in config:
raise DatasetError(
"'type' is missing from dataset catalog configuration."
"\nHint: If this catalog entry is intended for variable interpolation, "
"make sure that the top level key is preceded by an underscore."
)

dataset_type = config.pop("type")
dataset_type = config.pop(TYPE_KEY)
class_obj = None
if isinstance(dataset_type, str):
if len(dataset_type.strip(".")) != len(dataset_type):
Expand Down
73 changes: 71 additions & 2 deletions kedro/io/kedro_data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from kedro.io.catalog_config_resolver import CatalogConfigResolver, Patterns
from kedro.io.core import (
TYPE_KEY,
AbstractDataset,
AbstractVersionedDataset,
CatalogProtocol,
Expand All @@ -28,7 +29,7 @@
_validate_versions,
generate_timestamp,
)
from kedro.io.memory_dataset import MemoryDataset
from kedro.io.memory_dataset import MemoryDataset, _is_memory_dataset
from kedro.utils import _format_rich, _has_rich_handler


Expand Down Expand Up @@ -97,7 +98,7 @@ def __init__(
>>> catalog = KedroDataCatalog(datasets={"cars": cars})
"""
self._config_resolver = config_resolver or CatalogConfigResolver()
self._datasets = datasets or {}
self._datasets: dict[str, AbstractDataset] = datasets or {}
self._lazy_datasets: dict[str, _LazyDataset] = {}
self._load_versions, self._save_version = _validate_versions(
datasets, load_versions or {}, save_version
Expand Down Expand Up @@ -369,6 +370,74 @@ class to be loaded is specified with the key ``type`` and their
config_resolver=config_resolver,
)

def to_config(
self,
) -> tuple[
dict[str, dict[str, Any]],
dict[str, dict[str, Any]],
dict[str, str | None],
str | None,
]:
"""Converts the `KedroDataCatalog` instance into a configuration format suitable for
serialization. This includes datasets, credentials, and versioning information.
This method is only applicable to catalogs that contain datasets initialized with static, primitive
parameters. For example, it will work fine if one passes credentials as dictionary to
`GBQQueryDataset` but not as `google.auth.credentials.Credentials` object. See
https://github.com/kedro-org/kedro-plugins/issues/950 for the details.
Returns:
A tuple containing:
catalog: A dictionary mapping dataset names to their unresolved configurations,
excluding in-memory datasets.
credentials: A dictionary of unresolved credentials extracted from dataset configurations.
load_versions: A dictionary mapping dataset names to specific versions to be loaded,
or `None` if no version is set.
save_version: A global version identifier for saving datasets, or `None` if not specified.
Example:
::
>>> from kedro.io import KedroDataCatalog
>>> from kedro_datasets.pandas import CSVDataset
>>>
>>> cars = CSVDataset(
>>> filepath="cars.csv",
>>> load_args=None,
>>> save_args={"index": False}
>>> )
>>> catalog = KedroDataCatalog(datasets={'cars': cars})
>>>
>>> config, credentials, load_versions, save_version = catalog.to_config()
>>>
>>> new_catalog = KedroDataCatalog.from_config(config, credentials, load_versions, save_version)
"""
catalog: dict[str, dict[str, Any]] = {}
credentials: dict[str, dict[str, Any]] = {}
load_versions: dict[str, str | None] = {}

for ds_name, ds in self._lazy_datasets.items():
if _is_memory_dataset(ds.config.get(TYPE_KEY, "")):
continue
unresolved_config, unresolved_credentials = (
self._config_resolver.unresolve_credentials(ds_name, ds.config)
)
catalog[ds_name] = unresolved_config
credentials.update(unresolved_credentials)
load_versions[ds_name] = self._load_versions.get(ds_name, None)

for ds_name, ds in self._datasets.items(): # type: ignore[assignment]
if _is_memory_dataset(ds): # type: ignore[arg-type]
continue
resolved_config = ds.to_config() # type: ignore[attr-defined]
unresolved_config, unresolved_credentials = (
self._config_resolver.unresolve_credentials(ds_name, resolved_config)
)
catalog[ds_name] = unresolved_config
credentials.update(unresolved_credentials)
load_versions[ds_name] = self._load_versions.get(ds_name, None)

return catalog, credentials, load_versions, self._save_version

@staticmethod
def _validate_dataset_config(ds_name: str, ds_config: Any) -> None:
if not isinstance(ds_config, dict):
Expand Down
10 changes: 10 additions & 0 deletions kedro/io/memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,13 @@ def _copy_with_mode(data: Any, copy_mode: str) -> Any:
)

return copied_data


def _is_memory_dataset(ds_or_type: AbstractDataset | str) -> bool:
"""Check if dataset or str type provided is a MemoryDataset."""
if isinstance(ds_or_type, MemoryDataset):
return True
if isinstance(ds_or_type, str):
return ds_or_type in {"MemoryDataset", "kedro.io.memory_dataset.MemoryDataset"}

return False
46 changes: 46 additions & 0 deletions tests/io/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,52 @@ def correct_config(filepath):
}


@pytest.fixture
def correct_config_versioned(filepath):
return {
"catalog": {
"boats": {
"type": "pandas.CSVDataset",
"filepath": filepath,
"versioned": True,
},
"cars": {
"type": "pandas.CSVDataset",
"filepath": "s3://test_bucket/test_file.csv",
"credentials": "cars_credentials",
},
"cars_ibis": {
"type": "ibis.FileDataset",
"filepath": "cars_ibis.csv",
"file_format": "csv",
"table_name": "cars",
"connection": {"backend": "duckdb", "database": "company.db"},
"load_args": {"sep": ",", "nullstr": "#NA"},
"save_args": {"sep": ",", "nullstr": "#NA"},
},
"cached_ds": {
"type": "kedro.io.cached_dataset.CachedDataset",
"versioned": True,
"dataset": {
"type": "pandas.CSVDataset",
"filepath": "cached_ds.csv",
"credentials": "cached_ds_credentials",
},
"copy_mode": None,
},
"parameters": {
"type": "kedro.io.memory_dataset.MemoryDataset",
"data": [4, 5, 6],
"copy_mode": None,
},
},
"credentials": {
"cars_credentials": {"key": "FAKE_ACCESS_KEY", "secret": "FAKE_SECRET_KEY"},
"cached_ds_credentials": {"key": "KEY", "secret": "SECRET"},
},
}


@pytest.fixture
def correct_config_with_nested_creds(correct_config):
correct_config["catalog"]["cars"]["credentials"] = {
Expand Down
Loading

0 comments on commit 1e92b50

Please sign in to comment.