Skip to content

Commit

Permalink
dataset factory (#1945)
Browse files Browse the repository at this point in the history
* create first version of dataset factory

* update all destination implementations for getting the newest schema, fixed linter errors, made dataset aware of config types

* test retrieval of schema for all destinations (except custom destination)

* add simple tests for schema selection in dataset tests

* unify filesystem schema behavior with other destinations

* fix gcs delta tests

* try to fix ci errors

* allow athena in a kind of "read only" mode

* fix delta table tests?

* mark dataset factory as private

* change signature and behavior of get_stored_schema

* fix weaviate schema retrieval

* switch back to  properties
  • Loading branch information
sh-rp authored Oct 15, 2024
1 parent 55e1c3c commit bc13448
Show file tree
Hide file tree
Showing 20 changed files with 365 additions and 93 deletions.
2 changes: 2 additions & 0 deletions dlt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from dlt.pipeline import progress
from dlt import destinations
from dlt.destinations.dataset import dataset as _dataset

pipeline = _pipeline
current = _current
Expand Down Expand Up @@ -79,6 +80,7 @@
"TCredentials",
"sources",
"destinations",
"_dataset",
]

# verify that no injection context was created
Expand Down
9 changes: 7 additions & 2 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
TDestinationConfig = TypeVar("TDestinationConfig", bound="DestinationClientConfiguration")
TDestinationClient = TypeVar("TDestinationClient", bound="JobClientBase")
TDestinationDwhClient = TypeVar("TDestinationDwhClient", bound="DestinationClientDwhConfiguration")
TDatasetType = Literal["dbapi", "ibis"]


DEFAULT_FILE_LAYOUT = "{table_name}/{load_id}.{file_id}.{ext}"

Expand Down Expand Up @@ -657,8 +659,11 @@ def __exit__(

class WithStateSync(ABC):
@abstractmethod
def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage"""
def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]:
"""
Retrieves newest schema with given name from destination storage
If no name is provided, the newest schema found is retrieved.
"""
pass

@abstractmethod
Expand Down
95 changes: 88 additions & 7 deletions dlt/destinations/dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from typing import Any, Generator, AnyStr, Optional
from typing import Any, Generator, Optional, Union
from dlt.common.json import json

from contextlib import contextmanager
from dlt.common.destination.reference import (
SupportsReadableRelation,
SupportsReadableDataset,
TDatasetType,
TDestinationReferenceArg,
Destination,
JobClientBase,
WithStateSync,
DestinationClientDwhConfiguration,
)

from dlt.common.schema.typing import TTableSchemaColumns
from dlt.destinations.sql_client import SqlClientBase
from dlt.destinations.sql_client import SqlClientBase, WithSqlClient
from dlt.common.schema import Schema


Expand Down Expand Up @@ -71,22 +78,85 @@ def _wrap(*args: Any, **kwargs: Any) -> Any:
class ReadableDBAPIDataset(SupportsReadableDataset):
"""Access to dataframes and arrowtables in the destination dataset via dbapi"""

def __init__(self, client: SqlClientBase[Any], schema: Optional[Schema]) -> None:
self.client = client
self.schema = schema
def __init__(
self,
destination: TDestinationReferenceArg,
dataset_name: str,
schema: Union[Schema, str, None] = None,
) -> None:
self._destination = Destination.from_reference(destination)
self._provided_schema = schema
self._dataset_name = dataset_name
self._sql_client: SqlClientBase[Any] = None
self._schema: Schema = None

@property
def schema(self) -> Schema:
self._ensure_client_and_schema()
return self._schema

@property
def sql_client(self) -> SqlClientBase[Any]:
self._ensure_client_and_schema()
return self._sql_client

def _destination_client(self, schema: Schema) -> JobClientBase:
client_spec = self._destination.spec()
if isinstance(client_spec, DestinationClientDwhConfiguration):
client_spec._bind_dataset_name(
dataset_name=self._dataset_name, default_schema_name=schema.name
)
return self._destination.client(schema, client_spec)

def _ensure_client_and_schema(self) -> None:
"""Lazy load schema and client"""
# full schema given, nothing to do
if not self._schema and isinstance(self._provided_schema, Schema):
self._schema = self._provided_schema

# schema name given, resolve it from destination by name
elif not self._schema and isinstance(self._provided_schema, str):
with self._destination_client(Schema(self._provided_schema)) as client:
if isinstance(client, WithStateSync):
stored_schema = client.get_stored_schema(self._provided_schema)
if stored_schema:
self._schema = Schema.from_stored_schema(json.loads(stored_schema.schema))

# no schema name given, load newest schema from destination
elif not self._schema:
with self._destination_client(Schema(self._dataset_name)) as client:
if isinstance(client, WithStateSync):
stored_schema = client.get_stored_schema()
if stored_schema:
self._schema = Schema.from_stored_schema(json.loads(stored_schema.schema))

# default to empty schema with dataset name if nothing found
if not self._schema:
self._schema = Schema(self._dataset_name)

# here we create the client bound to the resolved schema
if not self._sql_client:
destination_client = self._destination_client(self._schema)
if isinstance(destination_client, WithSqlClient):
self._sql_client = destination_client.sql_client
else:
raise Exception(
f"Destination {destination_client.config.destination_type} does not support"
" SqlClient."
)

def __call__(
self, query: Any, schema_columns: TTableSchemaColumns = None
) -> ReadableDBAPIRelation:
schema_columns = schema_columns or {}
return ReadableDBAPIRelation(client=self.client, query=query, schema_columns=schema_columns) # type: ignore[abstract]
return ReadableDBAPIRelation(client=self.sql_client, query=query, schema_columns=schema_columns) # type: ignore[abstract]

def table(self, table_name: str) -> SupportsReadableRelation:
# prepare query for table relation
schema_columns = (
self.schema.tables.get(table_name, {}).get("columns", {}) if self.schema else {}
)
table_name = self.client.make_qualified_table_name(table_name)
table_name = self.sql_client.make_qualified_table_name(table_name)
query = f"SELECT * FROM {table_name}"
return self(query, schema_columns)

Expand All @@ -97,3 +167,14 @@ def __getitem__(self, table_name: str) -> SupportsReadableRelation:
def __getattr__(self, table_name: str) -> SupportsReadableRelation:
"""access of table via property notation"""
return self.table(table_name)


def dataset(
destination: TDestinationReferenceArg,
dataset_name: str,
schema: Union[Schema, str, None] = None,
dataset_type: TDatasetType = "dbapi",
) -> SupportsReadableDataset:
if dataset_type == "dbapi":
return ReadableDBAPIDataset(destination, dataset_name, schema)
raise NotImplementedError(f"Dataset of type {dataset_type} not implemented")
11 changes: 6 additions & 5 deletions dlt/destinations/impl/athena/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,12 @@ def __init__(
# verify if staging layout is valid for Athena
# this will raise if the table prefix is not properly defined
# we actually that {table_name} is first, no {schema_name} is allowed
self.table_prefix_layout = path_utils.get_table_prefix_layout(
config.staging_config.layout,
supported_prefix_placeholders=[],
table_needs_own_folder=True,
)
if config.staging_config:
self.table_prefix_layout = path_utils.get_table_prefix_layout(
config.staging_config.layout,
supported_prefix_placeholders=[],
table_needs_own_folder=True,
)

sql_client = AthenaSQLClient(
config.normalize_dataset_name(schema),
Expand Down
44 changes: 24 additions & 20 deletions dlt/destinations/impl/filesystem/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,29 +650,33 @@ def _iter_stored_schema_files(self) -> Iterator[Tuple[str, List[str]]]:
yield filepath, fileparts

def _get_stored_schema_by_hash_or_newest(
self, version_hash: str = None
self, version_hash: str = None, schema_name: str = None
) -> Optional[StorageSchemaInfo]:
"""Get the schema by supplied hash, falls back to getting the newest version matching the existing schema name"""
version_hash = self._to_path_safe_string(version_hash)
# find newest schema for pipeline or by version hash
selected_path = None
newest_load_id = "0"
for filepath, fileparts in self._iter_stored_schema_files():
if (
not version_hash
and fileparts[0] == self.schema.name
and fileparts[1] > newest_load_id
):
newest_load_id = fileparts[1]
selected_path = filepath
elif fileparts[2] == version_hash:
selected_path = filepath
break
try:
selected_path = None
newest_load_id = "0"
for filepath, fileparts in self._iter_stored_schema_files():
if (
not version_hash
and (fileparts[0] == schema_name or (not schema_name))
and fileparts[1] > newest_load_id
):
newest_load_id = fileparts[1]
selected_path = filepath
elif fileparts[2] == version_hash:
selected_path = filepath
break

if selected_path:
return StorageSchemaInfo(
**json.loads(self.fs_client.read_text(selected_path, encoding="utf-8"))
)
if selected_path:
return StorageSchemaInfo(
**json.loads(self.fs_client.read_text(selected_path, encoding="utf-8"))
)
except DestinationUndefinedEntity:
# ignore missing table
pass

return None

Expand All @@ -699,9 +703,9 @@ def _store_current_schema(self) -> None:
# we always keep tabs on what the current schema is
self._write_to_json_file(filepath, version_info)

def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage"""
return self._get_stored_schema_by_hash_or_newest()
return self._get_stored_schema_by_hash_or_newest(schema_name=schema_name)

def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]:
return self._get_stored_schema_by_hash_or_newest(version_hash)
Expand Down
11 changes: 5 additions & 6 deletions dlt/destinations/impl/lancedb/lancedb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI
return None

@lancedb_error
def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage."""
fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name)

Expand All @@ -553,11 +553,10 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
p_schema = self.schema.naming.normalize_identifier("schema")

try:
schemas = (
version_table.search().where(
f'`{p_schema_name}` = "{self.schema.name}"', prefilter=True
)
).to_list()
query = version_table.search()
if schema_name:
query = query.where(f'`{p_schema_name}` = "{schema_name}"', prefilter=True)
schemas = query.to_list()

# LanceDB's ORDER BY clause doesn't seem to work.
# See https://github.com/dlt-hub/dlt/pull/1375#issuecomment-2171909341
Expand Down
21 changes: 14 additions & 7 deletions dlt/destinations/impl/qdrant/qdrant_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,23 +377,30 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
raise DestinationUndefinedEntity(str(e)) from e
raise

def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage"""
try:
scroll_table_name = self._make_qualified_collection_name(self.schema.version_table_name)
p_schema_name = self.schema.naming.normalize_identifier("schema_name")
p_inserted_at = self.schema.naming.normalize_identifier("inserted_at")
response = self.db_client.scroll(
scroll_table_name,
with_payload=True,
scroll_filter=models.Filter(

name_filter = (
models.Filter(
must=[
models.FieldCondition(
key=p_schema_name,
match=models.MatchValue(value=self.schema.name),
match=models.MatchValue(value=schema_name),
)
]
),
)
if schema_name
else None
)

response = self.db_client.scroll(
scroll_table_name,
with_payload=True,
scroll_filter=name_filter,
limit=1,
order_by=models.OrderBy(
key=p_inserted_at,
Expand Down
8 changes: 5 additions & 3 deletions dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,9 @@ def _update_schema_in_storage(self, schema: Schema) -> None:
self.sql_client.execute_sql(table_obj.insert().values(schema_mapping))

def _get_stored_schema(
self, version_hash: Optional[str] = None, schema_name: Optional[str] = None
self,
version_hash: Optional[str] = None,
schema_name: Optional[str] = None,
) -> Optional[StorageSchemaInfo]:
version_table = self.schema.tables[self.schema.version_table_name]
table_obj = self._to_table_object(version_table) # type: ignore[arg-type]
Expand All @@ -267,9 +269,9 @@ def _get_stored_schema(
def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]:
return self._get_stored_schema(version_hash)

def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]:
"""Get the latest stored schema"""
return self._get_stored_schema(schema_name=self.schema.name)
return self._get_stored_schema(schema_name=schema_name)

def get_stored_state(self, pipeline_name: str) -> StateInfo:
state_table = self.schema.tables.get(
Expand Down
19 changes: 13 additions & 6 deletions dlt/destinations/impl/weaviate/weaviate_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,19 +516,26 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
if len(load_records):
return StateInfo(**state)

def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage"""
p_schema_name = self.schema.naming.normalize_identifier("schema_name")
p_inserted_at = self.schema.naming.normalize_identifier("inserted_at")

name_filter = (
{
"path": [p_schema_name],
"operator": "Equal",
"valueString": schema_name,
}
if schema_name
else None
)

try:
record = self.get_records(
self.schema.version_table_name,
sort={"path": [p_inserted_at], "order": "desc"},
where={
"path": [p_schema_name],
"operator": "Equal",
"valueString": self.schema.name,
},
where=name_filter,
limit=1,
)[0]
return StorageSchemaInfo(**record)
Expand Down
19 changes: 13 additions & 6 deletions dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,14 +397,21 @@ def _from_db_type(
) -> TColumnType:
pass

def get_stored_schema(self) -> StorageSchemaInfo:
def get_stored_schema(self, schema_name: str = None) -> StorageSchemaInfo:
name = self.sql_client.make_qualified_table_name(self.schema.version_table_name)
c_schema_name, c_inserted_at = self._norm_and_escape_columns("schema_name", "inserted_at")
query = (
f"SELECT {self.version_table_schema_columns} FROM {name} WHERE {c_schema_name} = %s"
f" ORDER BY {c_inserted_at} DESC;"
)
return self._row_to_schema_info(query, self.schema.name)
if not schema_name:
query = (
f"SELECT {self.version_table_schema_columns} FROM {name}"
f" ORDER BY {c_inserted_at} DESC;"
)
return self._row_to_schema_info(query)
else:
query = (
f"SELECT {self.version_table_schema_columns} FROM {name} WHERE {c_schema_name} = %s"
f" ORDER BY {c_inserted_at} DESC;"
)
return self._row_to_schema_info(query, schema_name)

def get_stored_state(self, pipeline_name: str) -> StateInfo:
state_table = self.sql_client.make_qualified_table_name(self.schema.state_table_name)
Expand Down
Loading

0 comments on commit bc13448

Please sign in to comment.