Skip to content
This repository has been archived by the owner on Mar 11, 2024. It is now read-only.

Commit

Permalink
🐛 DeltaLake - fix loading single partitions, implement loading all pa…
Browse files Browse the repository at this point in the history
…rtitions in one go
  • Loading branch information
danielgafni committed Jan 19, 2024
1 parent a902ebd commit 2e1516a
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 4 deletions.
36 changes: 36 additions & 0 deletions dagster_polars/io_managers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_check as check,
)
from dagster._annotations import experimental
from dagster._core.storage.upath_io_manager import is_dict_type
from pydantic.fields import Field, PrivateAttr

from dagster_polars.io_managers.utils import get_polars_metadata
Expand Down Expand Up @@ -187,6 +188,40 @@ def scan_df_from_path(
) -> Union[pl.LazyFrame, LazyFrameWithMetadata]:
...

# tmp fix until https://github.com/dagster-io/dagster/pull/19294 is merged
def load_input(self, context: InputContext) -> Union[Any, Dict[str, Any]]:
# If no asset key, we are dealing with an op output which is always non-partitioned
if not context.has_asset_key or not context.has_asset_partitions:
path = self._get_path(context)
return self._load_single_input(path, context)
else:
asset_partition_keys = context.asset_partition_keys
if len(asset_partition_keys) == 0:
return None
elif len(asset_partition_keys) == 1:
paths = self._get_paths_for_partitions(context)
check.invariant(len(paths) == 1, f"Expected 1 path, but got {len(paths)}")
path = next(iter(paths.values()))
backcompat_paths = self._get_multipartition_backcompat_paths(context)
backcompat_path = None if not backcompat_paths else next(iter(backcompat_paths.values()))

return self._load_partition_from_path(
context=context,
partition_key=asset_partition_keys[0],
path=path,
backcompat_path=backcompat_path,
)
else: # we are dealing with multiple partitions of an asset
type_annotation = context.dagster_type.typing_type
if type_annotation != Any and not is_dict_type(type_annotation):
check.failed(
"Loading an input that corresponds to multiple partitions, but the"
" type annotation on the op input is not a dict, Dict, Mapping, or"
f" Any: is '{type_annotation}'."
)

return self._load_multiple_inputs(context)

def dump_to_path(
self,
context: OutputContext,
Expand Down Expand Up @@ -252,6 +287,7 @@ def load_from_path(
# otherwise we would have been dealing with a multi-partition key
# which is not straightforward to filter by
if partition_by is not None and isinstance(partition_by, str):
context.log.debug(f"Filtering by {partition_by}=={partition_key}")
ldf = ldf.filter(pl.col(partition_by) == partition_key)

if context.dagster_type.typing_type in POLARS_EAGER_FRAME_ANNOTATIONS:
Expand Down
76 changes: 73 additions & 3 deletions dagster_polars/io_managers/delta.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import json
from enum import Enum
from pprint import pformat
from typing import TYPE_CHECKING, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

import dagster._check as check
import polars as pl
from dagster import InputContext, MetadataValue, OutputContext
from dagster._annotations import experimental
from dagster._core.storage.upath_io_manager import is_dict_type

from dagster_polars.types import LazyFrameWithMetadata, StorageMetadata
from dagster_polars.types import DataFrameWithMetadata, LazyFrameWithMetadata, StorageMetadata

try:
from deltalake import DeltaTable
Expand All @@ -22,6 +24,8 @@

DAGSTER_POLARS_STORAGE_METADATA_SUBDIR = ".dagster_polars_metadata"

SINGLE_LOADING_TYPES = (pl.DataFrame, pl.LazyFrame, LazyFrameWithMetadata, DataFrameWithMetadata)


class DeltaWriteMode(str, Enum):
error = "error"
Expand All @@ -37,7 +41,11 @@ class PolarsDeltaIOManager(BasePolarsUPathIOManager):
Features:
- All features provided by :py:class:`~dagster_polars.BasePolarsUPathIOManager`.
- All read/write options can be set via corresponding metadata or config parameters (metadata takes precedence).
- Supports native DeltaLake partitioning by storing different asset partitions in the same DeltaLake table. To enable this behavior, set the `partition_by` metadata value or config parameter (it's passed to `delta_write_options` of `pl.DataFrame.write_delta`). Automatically filters loaded partitions, unless `MultiPartitionsDefinition` are used. In this case you are responsible for filtering the partitions in the downstream asset, as it's non-trivial to do so in the IOManager.
- Supports native DeltaLake partitioning by storing different asset partitions in the same DeltaLake table.
To enable this behavior, set the `partition_by` metadata value or config parameter (it's passed to `delta_write_options` of `pl.DataFrame.write_delta`).
Automatically filters loaded partitions, unless `MultiPartitionsDefinition` are used.
In this case you are responsible for filtering the partitions in the downstream asset, as it's non-trivial to do so in the IOManager.
When loading all available asset partitions, the whole table can be loaded in one go by using type annotations like `pl.DataFrame` and `pl.LazyFrame`.
- Supports writing/reading custom metadata to/from `.dagster_polars_metadata/<version>.json` file in the DeltaLake table directory.
Install `dagster-polars[delta]` to use this IOManager.
Expand Down Expand Up @@ -110,6 +118,68 @@ def downstream(upstream: LazyFramePartitions) -> pl.DataFrame:
overwrite_schema: bool = False
version: Optional[int] = None

# tmp fix until UPathIOManager supports this: added special handling for loading all partitions of an asset

def load_input(self, context: InputContext) -> Union[Any, Dict[str, Any]]:
# If no asset key, we are dealing with an op output which is always non-partitioned
if not context.has_asset_key or not context.has_asset_partitions:
path = self._get_path(context)
return self._load_single_input(path, context)
else:
asset_partition_keys = context.asset_partition_keys
if len(asset_partition_keys) == 0:
return None
elif len(asset_partition_keys) == 1:
paths = self._get_paths_for_partitions(context)
check.invariant(len(paths) == 1, f"Expected 1 path, but got {len(paths)}")
path = next(iter(paths.values()))
backcompat_paths = self._get_multipartition_backcompat_paths(context)
backcompat_path = None if not backcompat_paths else next(iter(backcompat_paths.values()))

return self._load_partition_from_path(
context=context,
partition_key=asset_partition_keys[0],
path=path,
backcompat_path=backcompat_path,
)
else: # we are dealing with multiple partitions of an asset
type_annotation = context.dagster_type.typing_type
if type_annotation == Any or is_dict_type(type_annotation):
return self._load_multiple_inputs(context)

# special case of loading the whole DeltaLake table at once
# when using AllPartitionMappings and native DeltaLake partitioning
elif (
context.upstream_output is not None
and context.upstream_output.metadata is not None
and context.upstream_output.metadata.get("partition_by") is not None
and type_annotation in SINGLE_LOADING_TYPES
and context.upstream_output.asset_info is not None
and context.upstream_output.asset_info.partitions_def is not None
and set(asset_partition_keys)
== set(
context.upstream_output.asset_info.partitions_def.get_partition_keys(
dynamic_partitions_store=context.instance
)
)
):
# load all partitions at once
return self.load_from_path(
context=context,
path=self.get_path_for_partition(
context=context,
partition=asset_partition_keys[0], # 0 would work,
path=self._get_paths_for_partitions(context)[asset_partition_keys[0]], # 0 would work,
),
partition_key=None,
)
else:
check.failed(
"Loading an input that corresponds to multiple partitions, but the"
f" type annotation on the op input is not a dict, Dict, Mapping, one of {SINGLE_LOADING_TYPES},"
" or Any: is '{type_annotation}'."
)

def dump_df_to_path(
self,
context: OutputContext,
Expand Down
43 changes: 42 additions & 1 deletion tests/test_polars_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import polars.testing as pl_testing
import pytest
from dagster import (
AssetExecutionContext,
AssetIn,
Config,
DagsterInstance,
Expand Down Expand Up @@ -211,7 +212,7 @@ def downstream_load_multiple_partitions(upstream_partitioned: Dict[str, pl.LazyF

assert set(upstream_partitioned.keys()) == {"a", "b"}, upstream_partitioned.keys()

saved_path = None
saved_path = None # noqa

for partition_key in ["a", "b"]:
result = materialize(
Expand All @@ -235,6 +236,46 @@ def downstream_load_multiple_partitions(upstream_partitioned: Dict[str, pl.LazyF
],
)

@asset(io_manager_def=manager)
def downstream_load_multiple_partitions_as_single_df(upstream_partitioned: pl.DataFrame) -> None:
assert set(upstream_partitioned["partition"].unique()) == {"a", "b"}

materialize(
[
upstream_partitioned.to_source_asset(),
downstream_load_multiple_partitions_as_single_df,
],
)


def test_polars_delta_native_partitioning_loading_single_partition(
polars_delta_io_manager: PolarsDeltaIOManager, df_for_delta: pl.DataFrame
):
manager = polars_delta_io_manager
df = df_for_delta

partitions_def = StaticPartitionsDefinition(["a", "b"])

@asset(
io_manager_def=manager,
partitions_def=partitions_def,
metadata={"partition_by": "partition"},
)
def upstream_partitioned(context: OpExecutionContext) -> pl.DataFrame:
return df.with_columns(pl.lit(context.partition_key).alias("partition"))

@asset(io_manager_def=manager, partitions_def=partitions_def)
def downstream_partitioned(context: AssetExecutionContext, upstream_partitioned: pl.DataFrame) -> None:
partitions = upstream_partitioned["partition"].unique().to_list()
assert len(partitions) == 1
assert partitions[0] == context.partition_key

for partition_key in ["a", "b"]:
materialize(
[upstream_partitioned, downstream_partitioned],
partition_key=partition_key,
)


def test_polars_delta_time_travel(polars_delta_io_manager: PolarsDeltaIOManager, df_for_delta: pl.DataFrame):
manager = polars_delta_io_manager
Expand Down

0 comments on commit 2e1516a

Please sign in to comment.