From cfcb5a10fb98bb8f39d05ce53f6b7d53e17b3f3c Mon Sep 17 00:00:00 2001 From: Dave Date: Thu, 19 Dec 2024 13:02:26 +0100 Subject: [PATCH 1/2] first version of ibis dataset base transformations --- dlt/__init__.py | 4 + dlt/common/destination/reference.py | 2 + dlt/destinations/transformations/__init__.py | 94 ++++++++++++++++++++ dlt/pipeline/pipeline.py | 9 ++ tests/load/test_transformation.py | 39 ++++++++ 5 files changed, 148 insertions(+) create mode 100644 dlt/destinations/transformations/__init__.py create mode 100644 tests/load/test_transformation.py diff --git a/dlt/__init__.py b/dlt/__init__.py index 328817efd2..c1c49150f9 100644 --- a/dlt/__init__.py +++ b/dlt/__init__.py @@ -43,6 +43,8 @@ from dlt.pipeline import progress from dlt import destinations +from dlt.destinations.transformations import transformation, transformation_group + pipeline = _pipeline current = _current mark = _mark @@ -79,6 +81,8 @@ "TCredentials", "sources", "destinations", + "transformation", + "transformation_group", ] # verify that no injection context was created diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 827034ddca..df293c65d4 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -15,6 +15,7 @@ Union, List, ContextManager, + runtime_checkable, Dict, Any, TypeVar, @@ -483,6 +484,7 @@ def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJobRe return [] +@runtime_checkable class SupportsReadableRelation(Protocol): """A readable relation retrieved from a destination that supports it""" diff --git a/dlt/destinations/transformations/__init__.py b/dlt/destinations/transformations/__init__.py new file mode 100644 index 0000000000..a09cdef2d8 --- /dev/null +++ b/dlt/destinations/transformations/__init__.py @@ -0,0 +1,94 @@ +from typing import Callable, Literal, Union, Any, Generator, List, TYPE_CHECKING, Iterable + +from dataclasses import dataclass +from functools import wraps + +from dlt.common.destination.reference import SupportsReadableDataset, SupportsReadableRelation + + +TTransformationMaterialization = Literal["table", "view"] +TTransformationWriteDisposition = Literal["replace", "append"] + +TTransformationFunc = Callable[[SupportsReadableDataset], SupportsReadableRelation] + +TTransformationGroupFunc = Callable[[], List[TTransformationFunc]] + + +def transformation( + table_name: str, + materialization: TTransformationMaterialization = "table", + write_disposition: TTransformationWriteDisposition = "replace", +) -> Callable[[TTransformationFunc], TTransformationFunc]: + def decorator(func: TTransformationFunc) -> TTransformationFunc: + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> SupportsReadableRelation: + return func(*args, **kwargs) + + # save the arguments to the function + wrapper.__transformation_args__ = { # type: ignore + "table_name": table_name, + "materialization": materialization, + "write_disposition": write_disposition, + } + + return wrapper + + return decorator + + +def transformation_group( + name: str, +) -> Callable[[TTransformationGroupFunc], TTransformationGroupFunc]: + def decorator(func: TTransformationGroupFunc) -> TTransformationGroupFunc: + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> List[TTransformationFunc]: + return func(*args, **kwargs) + + func.__transformation_group_args__ = { # type: ignore + "name": name, + } + return wrapper + + return decorator + + +def run_transformations( + dataset: SupportsReadableDataset, + transformations: Union[TTransformationFunc, List[TTransformationFunc]], +) -> None: + if not isinstance(transformations, Iterable): + transformations = [transformations] + + # TODO: fix typing + with dataset.sql_client as client: # type: ignore + for transformation in transformations: + # get transformation settings + table_name = transformation.__transformation_args__["table_name"] # type: ignore + materialization = transformation.__transformation_args__["materialization"] # type: ignore + write_disposition = transformation.__transformation_args__["write_disposition"] # type: ignore + table_name = client.make_qualified_table_name(table_name) + + # get relation from transformation + relation = transformation(dataset) + if not isinstance(relation, SupportsReadableRelation): + raise ValueError( + f"Transformation {transformation.__name__} did not return a ReadableRelation" + ) + + # materialize result + select_clause = relation.query + + if write_disposition == "replace": + client.execute( + f"CREATE OR REPLACE {materialization} {table_name} AS {select_clause}" + ) + elif write_disposition == "append" and materialization == "table": + try: + client.execute(f"INSERT INTO {table_name} {select_clause}") + except Exception: + client.execute(f"CREATE TABLE {table_name} AS {select_clause}") + else: + raise ValueError( + f"Write disposition {write_disposition} is not supported for " + f"materialization {materialization}" + ) diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 74466a09e4..c403c68e37 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -152,6 +152,8 @@ from dlt.common.storages.load_package import TLoadPackageState from dlt.pipeline.helpers import refresh_source +from dlt.destinations.transformations import TTransformationFunc + def with_state_sync(may_extract_state: bool = False) -> Callable[[TFun], TFun]: def decorator(f: TFun) -> TFun: @@ -1770,3 +1772,10 @@ def dataset( schema=schema, dataset_type=dataset_type, ) + + def transform( + self, transformations: Union[TTransformationFunc, List[TTransformationFunc]] + ) -> None: + from dlt.destinations.transformations import run_transformations + + run_transformations(self.dataset(), transformations) diff --git a/tests/load/test_transformation.py b/tests/load/test_transformation.py new file mode 100644 index 0000000000..9cbee73110 --- /dev/null +++ b/tests/load/test_transformation.py @@ -0,0 +1,39 @@ +import dlt + +from dlt.common.destination.reference import SupportsReadableDataset, SupportsReadableRelation + +from functools import reduce + + +def test_simple_transformation() -> None: + # load some stuff into items table + + @dlt.resource(table_name="items") + def items_resource(): + for i in range(10): + yield {"id": i, "value": i * 2} + + p = dlt.pipeline("test_pipeline", destination="duckdb", dataset_name="test_dataset") + p.run(items_resource) + + print(p.dataset().items.df()) + + @dlt.transformation(table_name="quadrupled_items") + def simple_transformation(dataset: SupportsReadableDataset) -> SupportsReadableRelation: + items_table = dataset.items + return items_table.mutate(quadruple_id=items_table.id * 4).select("id", "quadruple_id") + + @dlt.transformation(table_name="aggregated_items") + def aggregate_transformation(dataset: SupportsReadableDataset) -> SupportsReadableRelation: + items_table = dataset.items + return items_table.aggregate(sum_id=items_table.id.sum(), value_sum=items_table.value.sum()) + + # we run two transformations + p.transform([simple_transformation, aggregate_transformation]) + + # check table with quadrupled ids + assert list(p.dataset().quadrupled_items.df()["quadruple_id"]) == [i * 4 for i in range(10)] + + # check aggregated table for both fields + assert p.dataset().aggregated_items.fetchone()[0] == reduce(lambda a, b: a + b, range(10)) + assert p.dataset().aggregated_items.fetchone()[1] == (reduce(lambda a, b: a + b, range(10)) * 2) From 3e352aa4372e29c2fffe4a2fa3617248bd5a13fd Mon Sep 17 00:00:00 2001 From: Dave Date: Thu, 19 Dec 2024 13:50:56 +0100 Subject: [PATCH 2/2] add direct execution from readable relation --- dlt/destinations/transformations/__init__.py | 73 ++++++++++++++------ dlt/pipeline/pipeline.py | 15 +++- tests/load/test_transformation.py | 6 ++ 3 files changed, 72 insertions(+), 22 deletions(-) diff --git a/dlt/destinations/transformations/__init__.py b/dlt/destinations/transformations/__init__.py index a09cdef2d8..a09000eece 100644 --- a/dlt/destinations/transformations/__init__.py +++ b/dlt/destinations/transformations/__init__.py @@ -1,4 +1,15 @@ -from typing import Callable, Literal, Union, Any, Generator, List, TYPE_CHECKING, Iterable +from typing import ( + Callable, + Literal, + Union, + Any, + Generator, + List, + TYPE_CHECKING, + Iterable, + Optional, + Any, +) from dataclasses import dataclass from functools import wraps @@ -52,15 +63,48 @@ def wrapper(*args: Any, **kwargs: Any) -> List[TTransformationFunc]: return decorator +def execute_transformation( + select_clause: str, + client: Any, + table_name: Optional[str] = None, + write_disposition: Optional[str] = None, + materialization: Optional[str] = None, +) -> None: + if write_disposition == "replace": + client.execute(f"CREATE OR REPLACE {materialization} {table_name} AS {select_clause}") + elif write_disposition == "append" and materialization == "table": + try: + client.execute(f"INSERT INTO {table_name} {select_clause}") + except Exception: + client.execute(f"CREATE TABLE {table_name} AS {select_clause}") + else: + raise ValueError( + f"Write disposition {write_disposition} is not supported for " + f"materialization {materialization}" + ) + + def run_transformations( dataset: SupportsReadableDataset, - transformations: Union[TTransformationFunc, List[TTransformationFunc]], + transformations: Union[ + TTransformationFunc, List[TTransformationFunc], SupportsReadableRelation + ], + *, + table_name: Optional[str] = None, + write_disposition: Optional[str] = None, + materialization: Optional[str] = None, ) -> None: - if not isinstance(transformations, Iterable): - transformations = [transformations] - - # TODO: fix typing with dataset.sql_client as client: # type: ignore + if isinstance(transformations, SupportsReadableRelation): + execute_transformation( + transformations.query, client, table_name, write_disposition, materialization + ) + return + + if not isinstance(transformations, Iterable): + transformations = [transformations] + + # TODO: fix typing for transformation in transformations: # get transformation settings table_name = transformation.__transformation_args__["table_name"] # type: ignore @@ -78,17 +122,6 @@ def run_transformations( # materialize result select_clause = relation.query - if write_disposition == "replace": - client.execute( - f"CREATE OR REPLACE {materialization} {table_name} AS {select_clause}" - ) - elif write_disposition == "append" and materialization == "table": - try: - client.execute(f"INSERT INTO {table_name} {select_clause}") - except Exception: - client.execute(f"CREATE TABLE {table_name} AS {select_clause}") - else: - raise ValueError( - f"Write disposition {write_disposition} is not supported for " - f"materialization {materialization}" - ) + execute_transformation( + select_clause, client, table_name, write_disposition, materialization + ) diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index c403c68e37..28ba4e4425 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -1774,8 +1774,19 @@ def dataset( ) def transform( - self, transformations: Union[TTransformationFunc, List[TTransformationFunc]] + self, + transformations: Union[TTransformationFunc, List[TTransformationFunc]], + *, + table_name: Optional[str] = None, + write_disposition: Optional[str] = "replace", + materialization: Optional[str] = "table", ) -> None: from dlt.destinations.transformations import run_transformations - run_transformations(self.dataset(), transformations) + run_transformations( + self.dataset(), + transformations, + table_name=table_name, + write_disposition=write_disposition, + materialization=materialization, + ) diff --git a/tests/load/test_transformation.py b/tests/load/test_transformation.py index 9cbee73110..bff8d1de50 100644 --- a/tests/load/test_transformation.py +++ b/tests/load/test_transformation.py @@ -37,3 +37,9 @@ def aggregate_transformation(dataset: SupportsReadableDataset) -> SupportsReadab # check aggregated table for both fields assert p.dataset().aggregated_items.fetchone()[0] == reduce(lambda a, b: a + b, range(10)) assert p.dataset().aggregated_items.fetchone()[1] == (reduce(lambda a, b: a + b, range(10)) * 2) + + # check simple transformation function + items_table = p.dataset().items + p.transform(items_table.mutate(new_col=items_table.id), table_name="direct") + + print(p.dataset().direct.df())