Skip to content

Commit

Permalink
create first version of dataset factory
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Oct 10, 2024
1 parent 47633c6 commit eb43626
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 21 deletions.
2 changes: 2 additions & 0 deletions dlt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from dlt.pipeline import progress
from dlt import destinations
from dlt.destinations.dataset import dataset

pipeline = _pipeline
current = _current
Expand Down Expand Up @@ -79,6 +80,7 @@
"TCredentials",
"sources",
"destinations",
"dataset",
]

# verify that no injection context was created
Expand Down
6 changes: 4 additions & 2 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
TDestinationConfig = TypeVar("TDestinationConfig", bound="DestinationClientConfiguration")
TDestinationClient = TypeVar("TDestinationClient", bound="JobClientBase")
TDestinationDwhClient = TypeVar("TDestinationDwhClient", bound="DestinationClientDwhConfiguration")
TDatasetType = Literal["dbapi", "ibis"]


DEFAULT_FILE_LAYOUT = "{table_name}/{load_id}.{file_id}.{ext}"

Expand Down Expand Up @@ -657,8 +659,8 @@ def __exit__(

class WithStateSync(ABC):
@abstractmethod
def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage"""
def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage, setting any_schema_name to true will return the newest schema regardless of the schema name"""
pass

@abstractmethod
Expand Down
88 changes: 82 additions & 6 deletions dlt/destinations/dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from typing import Any, Generator, AnyStr, Optional
from typing import Any, Generator, Optional, Union
from dlt.common.json import json

from contextlib import contextmanager
from dlt.common.destination.reference import (
SupportsReadableRelation,
SupportsReadableDataset,
TDatasetType,
TDestinationReferenceArg,
Destination,
JobClientBase,
)

from dlt.common.schema.typing import TTableSchemaColumns
Expand Down Expand Up @@ -71,29 +76,100 @@ def _wrap(*args: Any, **kwargs: Any) -> Any:
class ReadableDBAPIDataset(SupportsReadableDataset):
"""Access to dataframes and arrowtables in the destination dataset via dbapi"""

def __init__(self, client: SqlClientBase[Any], schema: Optional[Schema]) -> None:
self.client = client
self.schema = schema
def __init__(
self,
destination: TDestinationReferenceArg,
dataset_name: str,
schema: Union[Schema, str, None] = None,
) -> None:
self._destination = Destination.from_reference(destination)
self._schema = schema
self._resolved_schema: Schema = None
self._dataset_name = dataset_name
self._sql_client: SqlClientBase[Any] = None

def _destination_client(self, schema: Schema) -> JobClientBase:
client_spec = self._destination.spec()
client_spec._bind_dataset_name(
dataset_name=self._dataset_name, default_schema_name=schema.name
)
return self._destination.client(schema, client_spec)

def _ensure_client_and_schema(self) -> None:
"""Lazy load schema and client"""
# full schema given, nothing to do
if not self._resolved_schema and isinstance(self._schema, Schema):
self._resolved_schema = self._schema

# schema name given, resolve it from destination by name
elif not self._resolved_schema and isinstance(self._schema, str):
with self._destination_client(Schema(self._schema)) as client:
stored_schema = client.get_stored_schema()
if stored_schema:
self._resolved_schema = Schema.from_stored_schema(
json.loads(stored_schema.schema)
)

# no schema name given, load newest schema from destination
elif not self._resolved_schema:
with self._destination_client(Schema(self._dataset_name)) as client:
stored_schema = client.get_stored_schema(any_schema_name=True)
if stored_schema:
self._resolved_schema = Schema.from_stored_schema(
json.loads(stored_schema.schema)
)

# default to empty schema with dataset name if nothing found
if not self._resolved_schema:
self._resolved_schema = Schema(self._dataset_name)

# here we create the client bound to the resolved schema
# TODO: ensure that this destination supports the sql_client. otherwise error
if not self._sql_client:
self._sql_client = self._destination_client(self._resolved_schema).sql_client

def __call__(
self, query: Any, schema_columns: TTableSchemaColumns = None
) -> ReadableDBAPIRelation:
schema_columns = schema_columns or {}
return ReadableDBAPIRelation(client=self.client, query=query, schema_columns=schema_columns) # type: ignore[abstract]
return ReadableDBAPIRelation(client=self.sql_client, query=query, schema_columns=schema_columns) # type: ignore[abstract]

def table(self, table_name: str) -> SupportsReadableRelation:
# prepare query for table relation
schema_columns = (
self.schema.tables.get(table_name, {}).get("columns", {}) if self.schema else {}
)
table_name = self.client.make_qualified_table_name(table_name)
table_name = self.sql_client.make_qualified_table_name(table_name)
query = f"SELECT * FROM {table_name}"
return self(query, schema_columns)

@property
def schema(self) -> Schema:
"""Lazy load schema from destination"""
self._ensure_client_and_schema()
return self._resolved_schema

@property
def sql_client(self) -> SqlClientBase[Any]:
"""Lazy instantiate client"""
self._ensure_client_and_schema()
return self._sql_client

def __getitem__(self, table_name: str) -> SupportsReadableRelation:
"""access of table via dict notation"""
return self.table(table_name)

def __getattr__(self, table_name: str) -> SupportsReadableRelation:
"""access of table via property notation"""
return self.table(table_name)


def dataset(
destination: TDestinationReferenceArg,
dataset_name: str,
schema: Union[Schema, str, None] = None,
dataset_type: TDatasetType = "dbapi",
) -> SupportsReadableDataset:
if dataset_type == "dbapi":
return ReadableDBAPIDataset(destination, dataset_name, schema)
raise NotImplementedError(f"Dataset of type {dataset_type} not implemented")
19 changes: 13 additions & 6 deletions dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,14 +397,21 @@ def _from_db_type(
) -> TColumnType:
pass

def get_stored_schema(self) -> StorageSchemaInfo:
def get_stored_schema(self, any_schema_name: bool = False) -> StorageSchemaInfo:
name = self.sql_client.make_qualified_table_name(self.schema.version_table_name)
c_schema_name, c_inserted_at = self._norm_and_escape_columns("schema_name", "inserted_at")
query = (
f"SELECT {self.version_table_schema_columns} FROM {name} WHERE {c_schema_name} = %s"
f" ORDER BY {c_inserted_at} DESC;"
)
return self._row_to_schema_info(query, self.schema.name)
if any_schema_name:
query = (
f"SELECT {self.version_table_schema_columns} FROM {name}"
f" ORDER BY {c_inserted_at} DESC;"
)
return self._row_to_schema_info(query)
else:
query = (
f"SELECT {self.version_table_schema_columns} FROM {name} WHERE {c_schema_name} = %s"
f" ORDER BY {c_inserted_at} DESC;"
)
return self._row_to_schema_info(query, self.schema.name)

def get_stored_state(self, pipeline_name: str) -> StateInfo:
state_table = self.sql_client.make_qualified_table_name(self.schema.state_table_name)
Expand Down
16 changes: 9 additions & 7 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
DestinationClientStagingConfiguration,
DestinationClientDwhWithStagingConfiguration,
SupportsReadableDataset,
TDatasetType,
)
from dlt.common.normalizers.naming import NamingConvention
from dlt.common.pipeline import (
Expand Down Expand Up @@ -113,7 +114,7 @@
from dlt.destinations.sql_client import SqlClientBase, WithSqlClient
from dlt.destinations.fs_client import FSClientBase
from dlt.destinations.job_client_impl import SqlJobClientBase
from dlt.destinations.dataset import ReadableDBAPIDataset
from dlt.destinations.dataset import dataset
from dlt.load.configuration import LoaderConfiguration
from dlt.load import Load

Expand Down Expand Up @@ -1717,10 +1718,11 @@ def __getstate__(self) -> Any:
# pickle only the SupportsPipeline protocol fields
return {"pipeline_name": self.pipeline_name}

def _dataset(self, dataset_type: Literal["dbapi", "ibis"] = "dbapi") -> SupportsReadableDataset:
def _dataset(self, dataset_type: TDatasetType = "dbapi") -> SupportsReadableDataset:
"""Access helper to dataset"""
if dataset_type == "dbapi":
return ReadableDBAPIDataset(
self.sql_client(), schema=self.default_schema if self.default_schema_name else None
)
raise NotImplementedError(f"Dataset of type {dataset_type} not implemented")
return dataset(
self.destination,
self.dataset_name,
schema=(self.default_schema if self.default_schema_name else None),
dataset_type=dataset_type,
)
41 changes: 41 additions & 0 deletions tests/load/test_read_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,47 @@ def double_items():
loads_table = pipeline._dataset()[pipeline.default_schema.loads_table_name]
loads_table.fetchall()

# check dataset factory
dataset = dlt.dataset(
destination=destination_config.destination_type, dataset_name=pipeline.dataset_name
)
table_relationship = dataset.items
table = table_relationship.fetchall()
assert len(table) == total_records

# check that schema is loaded by name
dataset = dlt.dataset(
destination=destination_config.destination_type,
dataset_name=pipeline.dataset_name,
schema=pipeline.default_schema_name,
)
assert dataset.schema.tables["items"]["write_disposition"] == "replace"

# check that schema is not loaded when wrong name given
dataset = dlt.dataset(
destination=destination_config.destination_type,
dataset_name=pipeline.dataset_name,
schema="wrong_schema_name",
)
assert "items" not in dataset.schema.tables
assert dataset.schema.name == pipeline.dataset_name

# check that schema is loaded if no schema name given
dataset = dlt.dataset(
destination=destination_config.destination_type,
dataset_name=pipeline.dataset_name,
)
assert dataset.schema.name == pipeline.default_schema_name
assert dataset.schema.tables["items"]["write_disposition"] == "replace"

# check that there is no error when creating dataset without schema table
dataset = dlt.dataset(
destination=destination_config.destination_type,
dataset_name="unknown_dataset",
)
assert dataset.schema.name == "unknown_dataset"
assert "items" not in dataset.schema.tables


@pytest.mark.essential
@pytest.mark.parametrize(
Expand Down

0 comments on commit eb43626

Please sign in to comment.