diff --git a/kedro-airflow/kedro_airflow/grouping.py b/kedro-airflow/kedro_airflow/grouping.py index 7ac1a8339..913d6d817 100644 --- a/kedro-airflow/kedro_airflow/grouping.py +++ b/kedro-airflow/kedro_airflow/grouping.py @@ -1,9 +1,16 @@ from __future__ import annotations -from kedro.io import CatalogProtocol +from typing import Any + +from kedro.io import DataCatalog from kedro.pipeline.node import Node from kedro.pipeline.pipeline import Pipeline +try: + from kedro.io import CatalogProtocol +except ImportError: # pragma: no cover + pass + def _is_memory_dataset(catalog, dataset_name: str) -> bool: if dataset_name not in catalog: @@ -11,7 +18,9 @@ def _is_memory_dataset(catalog, dataset_name: str) -> bool: return False -def get_memory_datasets(catalog: CatalogProtocol, pipeline: Pipeline) -> set[str]: +def get_memory_datasets( + catalog: CatalogProtocol[Any] | DataCatalog, pipeline: Pipeline +) -> set[str]: """Gather all datasets in the pipeline that are of type MemoryDataset, excluding 'parameters'.""" return { dataset_name @@ -21,7 +30,7 @@ def get_memory_datasets(catalog: CatalogProtocol, pipeline: Pipeline) -> set[str def create_adjacency_list( - catalog: CatalogProtocol, pipeline: Pipeline + catalog: CatalogProtocol[Any] | DataCatalog, pipeline: Pipeline ) -> tuple[dict[str, set], dict[str, set]]: """ Builds adjacency list (adj_list) to search connected components - undirected graph, @@ -48,7 +57,7 @@ def create_adjacency_list( def group_memory_nodes( - catalog: CatalogProtocol, pipeline: Pipeline + catalog: CatalogProtocol[Any] | DataCatalog, pipeline: Pipeline ) -> tuple[dict[str, list[Node]], dict[str, list[str]]]: """ Nodes that are connected through MemoryDatasets cannot be distributed across