From b482aaa7c28233d3a01a432c9bc9cc9246a27719 Mon Sep 17 00:00:00 2001 From: ElenaKhaustova <157851531+ElenaKhaustova@users.noreply.github.com> Date: Fri, 29 Nov 2024 21:11:33 +0000 Subject: [PATCH] Catalog to config (#4323) * Captured init arguments Signed-off-by: Elena Khaustova * Implemented unresoloving credentials Signed-off-by: Elena Khaustova * Added some comments Signed-off-by: Elena Khaustova * Put type in first place for dataset config Signed-off-by: Elena Khaustova * Handled version key Signed-off-by: Elena Khaustova * Added lazy dataset to_config Signed-off-by: Elena Khaustova * Removed data key from MemoryDataset Signed-off-by: Elena Khaustova * Added TODOs Signed-off-by: Elena Khaustova * Saved call args Signed-off-by: Elena Khaustova * Saved only set credentials Signed-off-by: Elena Khaustova * Processed CachedDataset case Signed-off-by: Elena Khaustova * Updated TODOs Signed-off-by: Elena Khaustova * Tested with PartitionedDataset Signed-off-by: Elena Khaustova * Popped metadata Signed-off-by: Elena Khaustova * Fixed versioning when load Signed-off-by: Elena Khaustova * Fixed linter Signed-off-by: Elena Khaustova * Tested datasets factories Signed-off-by: Elena Khaustova * Tested transcoding Signed-off-by: Elena Khaustova * Removed TODOs Signed-off-by: Elena Khaustova * Removed debug output Signed-off-by: Elena Khaustova * Removed debug output Signed-off-by: Elena Khaustova * Added logic to set VERSIONED_FLAG_KEY Signed-off-by: Elena Khaustova * Updated version set up Signed-off-by: Elena Khaustova * Added TODO for versioning Signed-off-by: Elena Khaustova * Added tests for unresolve_config_credentials Signed-off-by: Elena Khaustova * Implemented test_to_config Signed-off-by: Elena Khaustova * Added test with MemoryDataset Signed-off-by: Elena Khaustova * Extended test examples Signed-off-by: Elena Khaustova * Materialized cached_ds Signed-off-by: Elena Khaustova * Exclude parameters Signed-off-by: Elena Khaustova * Fixed import Signed-off-by: Elena Khaustova * Added test with parameters Signed-off-by: Elena Khaustova * Moved tests for CatalogConfigResolver to a separate file Signed-off-by: Elena Khaustova * Made unresolve_config_credentials staticmethod Signed-off-by: Elena Khaustova * Updated comment to clarify meaning Signed-off-by: Elena Khaustova * Moved to_config anfter from_config Signed-off-by: Elena Khaustova * Returned is_parameter for catalog and added TODOs Signed-off-by: Elena Khaustova * Renamed catalog config resolver methods Signed-off-by: Elena Khaustova * Implemented _validate_versions method Signed-off-by: Elena Khaustova * Added _validate_versions calls Signed-off-by: Elena Khaustova * Updated error descriptions Signed-off-by: Elena Khaustova * Added validation to the old catalog Signed-off-by: Elena Khaustova * Fixed linter Signed-off-by: Elena Khaustova * Implemented unit tests for KedroDataCatalog Signed-off-by: Elena Khaustova * Removed odd comments Signed-off-by: Elena Khaustova * Implemented tests for DataCatalog Signed-off-by: Elena Khaustova * Added docstrings Signed-off-by: Elena Khaustova * Added release notes Signed-off-by: Elena Khaustova * Updated version logic Signed-off-by: Elena Khaustova * Added CachedDataset case Signed-off-by: Elena Khaustova * Updated release notes Signed-off-by: Elena Khaustova * Added tests for CachedDataset use case Signed-off-by: Elena Khaustova * Updated unit test after version validation is applied Signed-off-by: Elena Khaustova * Removed MemoryDatasets Signed-off-by: Elena Khaustova * Removed _is_parameter Signed-off-by: Elena Khaustova * Pop metadata from cached dataset configuration Signed-off-by: Elena Khaustova * Fixed lint Signed-off-by: Elena Khaustova * Fixed unit test Signed-off-by: Elena Khaustova * Added docstrings for AbstractDataset.to_config() Signed-off-by: Elena Khaustova * Updated docstrings Signed-off-by: Elena Khaustova * Fixed typos Signed-off-by: Elena Khaustova * Updated TODOs Signed-off-by: Elena Khaustova * Added docstring for KedroDataCatalog.to_config Signed-off-by: Elena Khaustova * Added docstrinbgs for unresolve_credentials Signed-off-by: Elena Khaustova * Updated release notes Signed-off-by: Elena Khaustova * Fixed indentation Signed-off-by: Elena Khaustova * Fixed to_config() example Signed-off-by: Elena Khaustova * Fixed indentation Signed-off-by: Elena Khaustova * Fixed indentation Signed-off-by: Elena Khaustova * Added a note about to_config() constraints Signed-off-by: Elena Khaustova * Fixed typo Signed-off-by: Elena Khaustova * Replace type string with the constant Signed-off-by: Elena Khaustova * Replace type string with the constant Signed-off-by: Elena Khaustova * Moved _is_memory_dataset Signed-off-by: Elena Khaustova * Simplified nested decorator Signed-off-by: Elena Khaustova * Fixed lint Signed-off-by: Elena Khaustova * Removed _init_args class attribute Signed-off-by: Elena Khaustova * Returned @wraps Signed-off-by: Elena Khaustova --------- Signed-off-by: Elena Khaustova --- RELEASE.md | 3 + kedro/framework/cli/catalog.py | 1 + kedro/io/catalog_config_resolver.py | 53 +++++++++++++-- kedro/io/core.py | 84 ++++++++++++++++++++++-- kedro/io/kedro_data_catalog.py | 73 +++++++++++++++++++- kedro/io/memory_dataset.py | 10 +++ tests/io/conftest.py | 46 +++++++++++++ tests/io/test_catalog_config_resolver.py | 38 +++++++++++ tests/io/test_kedro_data_catalog.py | 62 +++++++++++++++++ tests/io/test_memory_dataset.py | 16 +++++ 10 files changed, 375 insertions(+), 11 deletions(-) create mode 100644 tests/io/test_catalog_config_resolver.py diff --git a/RELEASE.md b/RELEASE.md index 94fa345843..0cc0fdf013 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,8 +1,11 @@ # Upcoming Release ## Major features and improvements +* Implemented `KedroDataCatalog.to_config()` method that converts the catalog instance into a configuration format suitable for serialization. + ## Bug fixes and other changes * Added validation to ensure dataset versions consistency across catalog. + ## Breaking changes to the API ## Documentation changes ## Community contributions diff --git a/kedro/framework/cli/catalog.py b/kedro/framework/cli/catalog.py index 25fad6083d..99350dc01c 100644 --- a/kedro/framework/cli/catalog.py +++ b/kedro/framework/cli/catalog.py @@ -28,6 +28,7 @@ def _create_session(package_name: str, **kwargs: Any) -> KedroSession: def is_parameter(dataset_name: str) -> bool: + # TODO: when breaking change move it to kedro/io/core.py """Check if dataset is a parameter.""" return dataset_name.startswith("params:") or dataset_name == "parameters" diff --git a/kedro/io/catalog_config_resolver.py b/kedro/io/catalog_config_resolver.py index f722bedb6e..d4582d8e25 100644 --- a/kedro/io/catalog_config_resolver.py +++ b/kedro/io/catalog_config_resolver.py @@ -30,7 +30,7 @@ def __init__( self._dataset_patterns, self._default_pattern = self._extract_patterns( config, credentials ) - self._resolved_configs = self._resolve_config_credentials(config, credentials) + self._resolved_configs = self.resolve_credentials(config, credentials) @property def config(self) -> dict[str, dict[str, Any]]: @@ -237,8 +237,9 @@ def _extract_patterns( return sorted_patterns, user_default - def _resolve_config_credentials( - self, + @classmethod + def resolve_credentials( + cls, config: dict[str, dict[str, Any]] | None, credentials: dict[str, dict[str, Any]] | None, ) -> dict[str, dict[str, Any]]: @@ -254,13 +255,55 @@ def _resolve_config_credentials( "\nHint: If this catalog entry is intended for variable interpolation, " "make sure that the key is preceded by an underscore." ) - if not self.is_pattern(ds_name): - resolved_configs[ds_name] = self._resolve_credentials( + if not cls.is_pattern(ds_name): + resolved_configs[ds_name] = cls._resolve_credentials( ds_config, credentials ) return resolved_configs + @staticmethod + def unresolve_credentials( + cred_name: str, ds_config: dict[str, dict[str, Any]] | None + ) -> tuple[dict[str, dict[str, Any]], dict[str, dict[str, Any]]]: + """Extracts and replaces credentials in a dataset configuration with + references, ensuring separation of credentials from the dataset configuration. + + Credentials are searched for recursively in the dataset configuration. + The first occurrence of the `CREDENTIALS_KEY` is replaced with a generated + reference key. + + Args: + cred_name: A unique identifier for the credentials being unresolved. + This is used to generate a reference key for the credentials. + ds_config: The dataset configuration containing potential credentials + under the key `CREDENTIALS_KEY`. + + Returns: + A tuple containing: + ds_config_copy : A deep copy of the original dataset + configuration with credentials replaced by reference keys. + credentials: A dictionary mapping generated reference keys to the original credentials. + """ + ds_config_copy = copy.deepcopy(ds_config) or {} + credentials: dict[str, Any] = {} + credentials_ref = f"{cred_name}_{CREDENTIALS_KEY}" + + def unresolve(config: Any) -> None: + # We don't expect credentials key appears more than once within the same dataset config, + # So once we found the key first time we unresolve it and stop iterating after + for key, val in config.items(): + if key == CREDENTIALS_KEY and config[key]: + credentials[credentials_ref] = config[key] + config[key] = credentials_ref + return + if isinstance(val, dict): + unresolve(val) + + unresolve(ds_config_copy) + + return ds_config_copy, credentials + def resolve_pattern(self, ds_name: str) -> dict[str, Any]: """Resolve dataset patterns and return resolved configurations based on the existing patterns.""" matched_pattern = self.match_pattern(ds_name) diff --git a/kedro/io/core.py b/kedro/io/core.py index c83e77c7a6..1e518d5c7a 100644 --- a/kedro/io/core.py +++ b/kedro/io/core.py @@ -15,6 +15,7 @@ from datetime import datetime, timezone from functools import partial, wraps from glob import iglob +from inspect import getcallargs from operator import attrgetter from pathlib import Path, PurePath, PurePosixPath from typing import ( @@ -57,6 +58,7 @@ "s3a", "s3n", ) +TYPE_KEY = "type" class DatasetError(Exception): @@ -210,6 +212,59 @@ def from_config( ) from err return dataset + def to_config(self) -> dict[str, Any]: + """Converts the dataset instance into a dictionary-based configuration for + serialization. Ensures that any subclass-specific details are handled, with + additional logic for versioning and caching implemented for `CachedDataset`. + + Adds a key for the dataset's type using its module and class name and + includes the initialization arguments. + + For `CachedDataset` it extracts the underlying dataset's configuration, + handles the `versioned` flag and removes unnecessary metadata. It also + ensures the embedded dataset's configuration is appropriately flattened + or transformed. + + If the dataset has a version key, it sets the `versioned` flag in the + configuration. + + Removes the `metadata` key from the configuration if present. + + Returns: + A dictionary containing the dataset's type and initialization arguments. + """ + return_config: dict[str, Any] = { + f"{TYPE_KEY}": f"{type(self).__module__}.{type(self).__name__}" + } + + if self._init_args: # type: ignore[attr-defined] + self._init_args.pop("self", None) # type: ignore[attr-defined] + return_config.update(self._init_args) # type: ignore[attr-defined] + + if type(self).__name__ == "CachedDataset": + cached_ds = return_config.pop("dataset") + cached_ds_return_config: dict[str, Any] = {} + if isinstance(cached_ds, dict): + cached_ds_return_config = cached_ds + elif isinstance(cached_ds, AbstractDataset): + cached_ds_return_config = cached_ds.to_config() + if VERSIONED_FLAG_KEY in cached_ds_return_config: + return_config[VERSIONED_FLAG_KEY] = cached_ds_return_config.pop( + VERSIONED_FLAG_KEY + ) + # Pop metadata from configuration + cached_ds_return_config.pop("metadata", None) + return_config["dataset"] = cached_ds_return_config + + # Set `versioned` key if version present in the dataset + if return_config.pop(VERSION_KEY, None): + return_config[VERSIONED_FLAG_KEY] = True + + # Pop metadata from configuration + return_config.pop("metadata", None) + + return return_config + @property def _logger(self) -> logging.Logger: return logging.getLogger(__name__) @@ -290,11 +345,32 @@ def save(self: Self, data: _DI) -> None: return save def __init_subclass__(cls, **kwargs: Any) -> None: - """Decorate the `load` and `save` methods provided by the class. + """Customizes the behavior of subclasses of AbstractDataset during + their creation. This method is automatically invoked when a subclass + of AbstractDataset is defined. + Decorates the `load` and `save` methods provided by the class. If `_load` or `_save` are defined, alias them as a prerequisite. - """ + + # Save the original __init__ method of the subclass + init_func: Callable = cls.__init__ + + @wraps(init_func) + def new_init(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + """Executes the original __init__, then save the arguments used + to initialize the instance. + """ + # Call the original __init__ method + init_func(self, *args, **kwargs) + # Capture and save the arguments passed to the original __init__ + self._init_args = getcallargs(init_func, self, *args, **kwargs) + + # Replace the subclass's __init__ with the new_init + # A hook for subclasses to capture initialization arguments and save them + # in the AbstractDataset._init_args field + cls.__init__ = new_init # type: ignore[method-assign] + super().__init_subclass__(**kwargs) if hasattr(cls, "_load") and not cls._load.__qualname__.startswith("Abstract"): @@ -493,14 +569,14 @@ def parse_dataset_definition( config = copy.deepcopy(config) # TODO: remove when removing old catalog as moved to KedroDataCatalog - if "type" not in config: + if TYPE_KEY not in config: raise DatasetError( "'type' is missing from dataset catalog configuration." "\nHint: If this catalog entry is intended for variable interpolation, " "make sure that the top level key is preceded by an underscore." ) - dataset_type = config.pop("type") + dataset_type = config.pop(TYPE_KEY) class_obj = None if isinstance(dataset_type, str): if len(dataset_type.strip(".")) != len(dataset_type): diff --git a/kedro/io/kedro_data_catalog.py b/kedro/io/kedro_data_catalog.py index 9555cf1f69..33128fd809 100644 --- a/kedro/io/kedro_data_catalog.py +++ b/kedro/io/kedro_data_catalog.py @@ -18,6 +18,7 @@ from kedro.io.catalog_config_resolver import CatalogConfigResolver, Patterns from kedro.io.core import ( + TYPE_KEY, AbstractDataset, AbstractVersionedDataset, CatalogProtocol, @@ -28,7 +29,7 @@ _validate_versions, generate_timestamp, ) -from kedro.io.memory_dataset import MemoryDataset +from kedro.io.memory_dataset import MemoryDataset, _is_memory_dataset from kedro.utils import _format_rich, _has_rich_handler @@ -97,7 +98,7 @@ def __init__( >>> catalog = KedroDataCatalog(datasets={"cars": cars}) """ self._config_resolver = config_resolver or CatalogConfigResolver() - self._datasets = datasets or {} + self._datasets: dict[str, AbstractDataset] = datasets or {} self._lazy_datasets: dict[str, _LazyDataset] = {} self._load_versions, self._save_version = _validate_versions( datasets, load_versions or {}, save_version @@ -369,6 +370,74 @@ class to be loaded is specified with the key ``type`` and their config_resolver=config_resolver, ) + def to_config( + self, + ) -> tuple[ + dict[str, dict[str, Any]], + dict[str, dict[str, Any]], + dict[str, str | None], + str | None, + ]: + """Converts the `KedroDataCatalog` instance into a configuration format suitable for + serialization. This includes datasets, credentials, and versioning information. + + This method is only applicable to catalogs that contain datasets initialized with static, primitive + parameters. For example, it will work fine if one passes credentials as dictionary to + `GBQQueryDataset` but not as `google.auth.credentials.Credentials` object. See + https://github.com/kedro-org/kedro-plugins/issues/950 for the details. + + Returns: + A tuple containing: + catalog: A dictionary mapping dataset names to their unresolved configurations, + excluding in-memory datasets. + credentials: A dictionary of unresolved credentials extracted from dataset configurations. + load_versions: A dictionary mapping dataset names to specific versions to be loaded, + or `None` if no version is set. + save_version: A global version identifier for saving datasets, or `None` if not specified. + Example: + :: + + >>> from kedro.io import KedroDataCatalog + >>> from kedro_datasets.pandas import CSVDataset + >>> + >>> cars = CSVDataset( + >>> filepath="cars.csv", + >>> load_args=None, + >>> save_args={"index": False} + >>> ) + >>> catalog = KedroDataCatalog(datasets={'cars': cars}) + >>> + >>> config, credentials, load_versions, save_version = catalog.to_config() + >>> + >>> new_catalog = KedroDataCatalog.from_config(config, credentials, load_versions, save_version) + """ + catalog: dict[str, dict[str, Any]] = {} + credentials: dict[str, dict[str, Any]] = {} + load_versions: dict[str, str | None] = {} + + for ds_name, ds in self._lazy_datasets.items(): + if _is_memory_dataset(ds.config.get(TYPE_KEY, "")): + continue + unresolved_config, unresolved_credentials = ( + self._config_resolver.unresolve_credentials(ds_name, ds.config) + ) + catalog[ds_name] = unresolved_config + credentials.update(unresolved_credentials) + load_versions[ds_name] = self._load_versions.get(ds_name, None) + + for ds_name, ds in self._datasets.items(): # type: ignore[assignment] + if _is_memory_dataset(ds): # type: ignore[arg-type] + continue + resolved_config = ds.to_config() # type: ignore[attr-defined] + unresolved_config, unresolved_credentials = ( + self._config_resolver.unresolve_credentials(ds_name, resolved_config) + ) + catalog[ds_name] = unresolved_config + credentials.update(unresolved_credentials) + load_versions[ds_name] = self._load_versions.get(ds_name, None) + + return catalog, credentials, load_versions, self._save_version + @staticmethod def _validate_dataset_config(ds_name: str, ds_config: Any) -> None: if not isinstance(ds_config, dict): diff --git a/kedro/io/memory_dataset.py b/kedro/io/memory_dataset.py index 1e8eef8452..2fdedf29b5 100644 --- a/kedro/io/memory_dataset.py +++ b/kedro/io/memory_dataset.py @@ -140,3 +140,13 @@ def _copy_with_mode(data: Any, copy_mode: str) -> Any: ) return copied_data + + +def _is_memory_dataset(ds_or_type: AbstractDataset | str) -> bool: + """Check if dataset or str type provided is a MemoryDataset.""" + if isinstance(ds_or_type, MemoryDataset): + return True + if isinstance(ds_or_type, str): + return ds_or_type in {"MemoryDataset", "kedro.io.memory_dataset.MemoryDataset"} + + return False diff --git a/tests/io/conftest.py b/tests/io/conftest.py index ce466469dd..1596620100 100644 --- a/tests/io/conftest.py +++ b/tests/io/conftest.py @@ -73,6 +73,52 @@ def correct_config(filepath): } +@pytest.fixture +def correct_config_versioned(filepath): + return { + "catalog": { + "boats": { + "type": "pandas.CSVDataset", + "filepath": filepath, + "versioned": True, + }, + "cars": { + "type": "pandas.CSVDataset", + "filepath": "s3://test_bucket/test_file.csv", + "credentials": "cars_credentials", + }, + "cars_ibis": { + "type": "ibis.FileDataset", + "filepath": "cars_ibis.csv", + "file_format": "csv", + "table_name": "cars", + "connection": {"backend": "duckdb", "database": "company.db"}, + "load_args": {"sep": ",", "nullstr": "#NA"}, + "save_args": {"sep": ",", "nullstr": "#NA"}, + }, + "cached_ds": { + "type": "kedro.io.cached_dataset.CachedDataset", + "versioned": True, + "dataset": { + "type": "pandas.CSVDataset", + "filepath": "cached_ds.csv", + "credentials": "cached_ds_credentials", + }, + "copy_mode": None, + }, + "parameters": { + "type": "kedro.io.memory_dataset.MemoryDataset", + "data": [4, 5, 6], + "copy_mode": None, + }, + }, + "credentials": { + "cars_credentials": {"key": "FAKE_ACCESS_KEY", "secret": "FAKE_SECRET_KEY"}, + "cached_ds_credentials": {"key": "KEY", "secret": "SECRET"}, + }, + } + + @pytest.fixture def correct_config_with_nested_creds(correct_config): correct_config["catalog"]["cars"]["credentials"] = { diff --git a/tests/io/test_catalog_config_resolver.py b/tests/io/test_catalog_config_resolver.py new file mode 100644 index 0000000000..db5ee6741c --- /dev/null +++ b/tests/io/test_catalog_config_resolver.py @@ -0,0 +1,38 @@ +from kedro.io import CatalogConfigResolver + + +class TestCatalogConfigResolver: + def test_unresolve_credentials(self, correct_config): + """Test unresolve dataset credentials to original format.""" + config = correct_config["catalog"] + credentials = correct_config["credentials"] + resolved_configs = CatalogConfigResolver.resolve_credentials( + config, credentials + ) + + unresolved_config, unresolved_credentials = ( + CatalogConfigResolver.unresolve_credentials( + cred_name="s3", ds_config=resolved_configs + ) + ) + assert config == unresolved_config + assert credentials == unresolved_credentials + + def test_unresolve_credentials_two_keys(self, correct_config): + """Test unresolve dataset credentials to original format when two credentials keys provided.""" + config = correct_config["catalog"] + credentials = correct_config["credentials"] + + resolved_configs = CatalogConfigResolver.resolve_credentials( + config, credentials + ) + resolved_configs["cars"]["metadata"] = {"credentials": {}} + + unresolved_config, unresolved_credentials = ( + CatalogConfigResolver.unresolve_credentials( + cred_name="s3", ds_config=resolved_configs + ) + ) + unresolved_config["cars"].pop("metadata") + assert config == unresolved_config + assert credentials == unresolved_credentials diff --git a/tests/io/test_kedro_data_catalog.py b/tests/io/test_kedro_data_catalog.py index efd5a8a68e..e6ffbf88aa 100644 --- a/tests/io/test_kedro_data_catalog.py +++ b/tests/io/test_kedro_data_catalog.py @@ -10,6 +10,7 @@ from pandas.testing import assert_frame_equal from kedro.io import ( + CachedDataset, DatasetAlreadyExistsError, DatasetError, DatasetNotFoundError, @@ -294,6 +295,67 @@ def test_release(self, data_catalog): """Test release is called without errors""" data_catalog.release("test") + class TestKedroDataCatalogToConfig: + def test_to_config(self, correct_config_versioned, dataset, filepath): + """Test dumping catalog config""" + config = correct_config_versioned["catalog"] + credentials = correct_config_versioned["credentials"] + catalog = KedroDataCatalog.from_config(config, credentials) + catalog["resolved_ds"] = dataset + catalog["memory_ds"] = [1, 2, 3] + catalog["params:a.b"] = {"abc": "def"} + # Materialize cached_ds + _ = catalog["cached_ds"] + + version = Version( + load="fake_load_version.csv", # load exact version + save=None, # save to exact version + ) + versioned_dataset = CSVDataset( + filepath="shuttles.csv", version=version, metadata=[1, 2, 3] + ) + cached_versioned_dataset = CachedDataset(dataset=versioned_dataset) + catalog["cached_versioned_dataset"] = cached_versioned_dataset + + catalog_config, catalog_credentials, load_version, save_version = ( + catalog.to_config() + ) + + expected_config = { + "resolved_ds": { + "type": "kedro_datasets.pandas.csv_dataset.CSVDataset", + "filepath": filepath, + "save_args": {"index": False}, + "load_args": None, + "credentials": None, + "fs_args": None, + }, + "cached_versioned_dataset": { + "type": "kedro.io.cached_dataset.CachedDataset", + "copy_mode": None, + "versioned": True, + "dataset": { + "type": "kedro_datasets.pandas.csv_dataset.CSVDataset", + "filepath": "shuttles.csv", + "load_args": None, + "save_args": None, + "credentials": None, + "fs_args": None, + }, + }, + } + expected_config.update(config) + expected_config.pop("parameters", None) + + assert catalog_config == expected_config + assert catalog_credentials == credentials + # Load version is set only for cached_versioned_dataset + assert catalog._load_versions == { + "cached_versioned_dataset": "fake_load_version.csv" + } + # Save version is not None and set to default + assert catalog._save_version + class TestKedroDataCatalogFromConfig: def test_from_correct_config(self, data_catalog_from_config, dummy_dataframe): """Test populating the data catalog from config""" diff --git a/tests/io/test_memory_dataset.py b/tests/io/test_memory_dataset.py index c2dbe56925..5a85400d66 100644 --- a/tests/io/test_memory_dataset.py +++ b/tests/io/test_memory_dataset.py @@ -3,11 +3,13 @@ import numpy as np import pandas as pd import pytest +from kedro_datasets.pandas import CSVDataset from kedro.io import DatasetError, MemoryDataset from kedro.io.memory_dataset import ( _copy_with_mode, _infer_copy_mode, + _is_memory_dataset, ) @@ -233,3 +235,17 @@ class DataFrame: data = DataFrame() copy_mode = _infer_copy_mode(data) assert copy_mode == "assign" + + +@pytest.mark.parametrize( + "ds_or_type,expected_result", + [ + ("MemoryDataset", True), + ("kedro.io.memory_dataset.MemoryDataset", True), + ("NotMemoryDataset", False), + (MemoryDataset(data=""), True), + (CSVDataset(filepath="abc.csv"), False), + ], +) +def test_is_memory_dataset(ds_or_type, expected_result): + assert _is_memory_dataset(ds_or_type) == expected_result