diff --git a/dlt/__init__.py b/dlt/__init__.py index 328817efd2..c1c49150f9 100644 --- a/dlt/__init__.py +++ b/dlt/__init__.py @@ -43,6 +43,8 @@ from dlt.pipeline import progress from dlt import destinations +from dlt.destinations.transformations import transformation, transformation_group + pipeline = _pipeline current = _current mark = _mark @@ -79,6 +81,8 @@ "TCredentials", "sources", "destinations", + "transformation", + "transformation_group", ] # verify that no injection context was created diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 827034ddca..df293c65d4 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -15,6 +15,7 @@ Union, List, ContextManager, + runtime_checkable, Dict, Any, TypeVar, @@ -483,6 +484,7 @@ def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJobRe return [] +@runtime_checkable class SupportsReadableRelation(Protocol): """A readable relation retrieved from a destination that supports it""" diff --git a/dlt/destinations/transformations/__init__.py b/dlt/destinations/transformations/__init__.py new file mode 100644 index 0000000000..a09000eece --- /dev/null +++ b/dlt/destinations/transformations/__init__.py @@ -0,0 +1,127 @@ +from typing import ( + Callable, + Literal, + Union, + Any, + Generator, + List, + TYPE_CHECKING, + Iterable, + Optional, + Any, +) + +from dataclasses import dataclass +from functools import wraps + +from dlt.common.destination.reference import SupportsReadableDataset, SupportsReadableRelation + + +TTransformationMaterialization = Literal["table", "view"] +TTransformationWriteDisposition = Literal["replace", "append"] + +TTransformationFunc = Callable[[SupportsReadableDataset], SupportsReadableRelation] + +TTransformationGroupFunc = Callable[[], List[TTransformationFunc]] + + +def transformation( + table_name: str, + materialization: TTransformationMaterialization = "table", + write_disposition: TTransformationWriteDisposition = "replace", +) -> Callable[[TTransformationFunc], TTransformationFunc]: + def decorator(func: TTransformationFunc) -> TTransformationFunc: + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> SupportsReadableRelation: + return func(*args, **kwargs) + + # save the arguments to the function + wrapper.__transformation_args__ = { # type: ignore + "table_name": table_name, + "materialization": materialization, + "write_disposition": write_disposition, + } + + return wrapper + + return decorator + + +def transformation_group( + name: str, +) -> Callable[[TTransformationGroupFunc], TTransformationGroupFunc]: + def decorator(func: TTransformationGroupFunc) -> TTransformationGroupFunc: + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> List[TTransformationFunc]: + return func(*args, **kwargs) + + func.__transformation_group_args__ = { # type: ignore + "name": name, + } + return wrapper + + return decorator + + +def execute_transformation( + select_clause: str, + client: Any, + table_name: Optional[str] = None, + write_disposition: Optional[str] = None, + materialization: Optional[str] = None, +) -> None: + if write_disposition == "replace": + client.execute(f"CREATE OR REPLACE {materialization} {table_name} AS {select_clause}") + elif write_disposition == "append" and materialization == "table": + try: + client.execute(f"INSERT INTO {table_name} {select_clause}") + except Exception: + client.execute(f"CREATE TABLE {table_name} AS {select_clause}") + else: + raise ValueError( + f"Write disposition {write_disposition} is not supported for " + f"materialization {materialization}" + ) + + +def run_transformations( + dataset: SupportsReadableDataset, + transformations: Union[ + TTransformationFunc, List[TTransformationFunc], SupportsReadableRelation + ], + *, + table_name: Optional[str] = None, + write_disposition: Optional[str] = None, + materialization: Optional[str] = None, +) -> None: + with dataset.sql_client as client: # type: ignore + if isinstance(transformations, SupportsReadableRelation): + execute_transformation( + transformations.query, client, table_name, write_disposition, materialization + ) + return + + if not isinstance(transformations, Iterable): + transformations = [transformations] + + # TODO: fix typing + for transformation in transformations: + # get transformation settings + table_name = transformation.__transformation_args__["table_name"] # type: ignore + materialization = transformation.__transformation_args__["materialization"] # type: ignore + write_disposition = transformation.__transformation_args__["write_disposition"] # type: ignore + table_name = client.make_qualified_table_name(table_name) + + # get relation from transformation + relation = transformation(dataset) + if not isinstance(relation, SupportsReadableRelation): + raise ValueError( + f"Transformation {transformation.__name__} did not return a ReadableRelation" + ) + + # materialize result + select_clause = relation.query + + execute_transformation( + select_clause, client, table_name, write_disposition, materialization + ) diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 74466a09e4..28ba4e4425 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -152,6 +152,8 @@ from dlt.common.storages.load_package import TLoadPackageState from dlt.pipeline.helpers import refresh_source +from dlt.destinations.transformations import TTransformationFunc + def with_state_sync(may_extract_state: bool = False) -> Callable[[TFun], TFun]: def decorator(f: TFun) -> TFun: @@ -1770,3 +1772,21 @@ def dataset( schema=schema, dataset_type=dataset_type, ) + + def transform( + self, + transformations: Union[TTransformationFunc, List[TTransformationFunc]], + *, + table_name: Optional[str] = None, + write_disposition: Optional[str] = "replace", + materialization: Optional[str] = "table", + ) -> None: + from dlt.destinations.transformations import run_transformations + + run_transformations( + self.dataset(), + transformations, + table_name=table_name, + write_disposition=write_disposition, + materialization=materialization, + ) diff --git a/tests/load/test_transformation.py b/tests/load/test_transformation.py new file mode 100644 index 0000000000..bff8d1de50 --- /dev/null +++ b/tests/load/test_transformation.py @@ -0,0 +1,45 @@ +import dlt + +from dlt.common.destination.reference import SupportsReadableDataset, SupportsReadableRelation + +from functools import reduce + + +def test_simple_transformation() -> None: + # load some stuff into items table + + @dlt.resource(table_name="items") + def items_resource(): + for i in range(10): + yield {"id": i, "value": i * 2} + + p = dlt.pipeline("test_pipeline", destination="duckdb", dataset_name="test_dataset") + p.run(items_resource) + + print(p.dataset().items.df()) + + @dlt.transformation(table_name="quadrupled_items") + def simple_transformation(dataset: SupportsReadableDataset) -> SupportsReadableRelation: + items_table = dataset.items + return items_table.mutate(quadruple_id=items_table.id * 4).select("id", "quadruple_id") + + @dlt.transformation(table_name="aggregated_items") + def aggregate_transformation(dataset: SupportsReadableDataset) -> SupportsReadableRelation: + items_table = dataset.items + return items_table.aggregate(sum_id=items_table.id.sum(), value_sum=items_table.value.sum()) + + # we run two transformations + p.transform([simple_transformation, aggregate_transformation]) + + # check table with quadrupled ids + assert list(p.dataset().quadrupled_items.df()["quadruple_id"]) == [i * 4 for i in range(10)] + + # check aggregated table for both fields + assert p.dataset().aggregated_items.fetchone()[0] == reduce(lambda a, b: a + b, range(10)) + assert p.dataset().aggregated_items.fetchone()[1] == (reduce(lambda a, b: a + b, range(10)) * 2) + + # check simple transformation function + items_table = p.dataset().items + p.transform(items_table.mutate(new_col=items_table.id), table_name="direct") + + print(p.dataset().direct.df())