Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dataset factory #1945

Merged
merged 13 commits into from
Oct 15, 2024
2 changes: 2 additions & 0 deletions dlt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from dlt.pipeline import progress
from dlt import destinations
from dlt.destinations.dataset import dataset

pipeline = _pipeline
current = _current
Expand Down Expand Up @@ -79,6 +80,7 @@
"TCredentials",
"sources",
"destinations",
"dataset",
rudolfix marked this conversation as resolved.
Show resolved Hide resolved
rudolfix marked this conversation as resolved.
Show resolved Hide resolved
]

# verify that no injection context was created
Expand Down
6 changes: 4 additions & 2 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
TDestinationConfig = TypeVar("TDestinationConfig", bound="DestinationClientConfiguration")
TDestinationClient = TypeVar("TDestinationClient", bound="JobClientBase")
TDestinationDwhClient = TypeVar("TDestinationDwhClient", bound="DestinationClientDwhConfiguration")
TDatasetType = Literal["dbapi", "ibis"]


DEFAULT_FILE_LAYOUT = "{table_name}/{load_id}.{file_id}.{ext}"

Expand Down Expand Up @@ -657,8 +659,8 @@ def __exit__(

class WithStateSync(ABC):
@abstractmethod
def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage"""
def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather add a new method but it really do not fit here. this interface assumes that there's a known name of a schema.

my take would be to change signature to

get_stored_schema(self, schema_name: str = None)

if None is specified, we load the newest schema, if name is provided we load the newest schema with given name

Copy link
Collaborator Author

@sh-rp sh-rp Oct 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I had the same idea but thought it might not be good to change the default behavior of this method. I have changed it now and updated all the places in the code and tests where it is used.

"""Retrieves newest schema from destination storage, setting any_schema_name to true will return the newest schema regardless of the schema name"""
pass

@abstractmethod
Expand Down
101 changes: 94 additions & 7 deletions dlt/destinations/dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from typing import Any, Generator, AnyStr, Optional
from typing import Any, Generator, Optional, Union
from dlt.common.json import json

from contextlib import contextmanager
from dlt.common.destination.reference import (
SupportsReadableRelation,
SupportsReadableDataset,
TDatasetType,
TDestinationReferenceArg,
Destination,
JobClientBase,
WithStateSync,
DestinationClientDwhConfiguration,
)

from dlt.common.schema.typing import TTableSchemaColumns
from dlt.destinations.sql_client import SqlClientBase
from dlt.destinations.sql_client import SqlClientBase, WithSqlClient
from dlt.common.schema import Schema


Expand Down Expand Up @@ -71,29 +78,109 @@ def _wrap(*args: Any, **kwargs: Any) -> Any:
class ReadableDBAPIDataset(SupportsReadableDataset):
"""Access to dataframes and arrowtables in the destination dataset via dbapi"""

def __init__(self, client: SqlClientBase[Any], schema: Optional[Schema]) -> None:
self.client = client
self.schema = schema
def __init__(
self,
destination: TDestinationReferenceArg,
dataset_name: str,
schema: Union[Schema, str, None] = None,
) -> None:
self._destination = Destination.from_reference(destination)
self._schema = schema
self._resolved_schema: Schema = None
self._dataset_name = dataset_name
self._sql_client: SqlClientBase[Any] = None

def _destination_client(self, schema: Schema) -> JobClientBase:
client_spec = self._destination.spec()
if isinstance(client_spec, DestinationClientDwhConfiguration):
client_spec._bind_dataset_name(
dataset_name=self._dataset_name, default_schema_name=schema.name
)
return self._destination.client(schema, client_spec)

def _ensure_client_and_schema(self) -> None:
"""Lazy load schema and client"""
# full schema given, nothing to do
if not self._resolved_schema and isinstance(self._schema, Schema):
self._resolved_schema = self._schema

# schema name given, resolve it from destination by name
elif not self._resolved_schema and isinstance(self._schema, str):
with self._destination_client(Schema(self._schema)) as client:
if isinstance(client, WithStateSync):
stored_schema = client.get_stored_schema()
if stored_schema:
self._resolved_schema = Schema.from_stored_schema(
json.loads(stored_schema.schema)
)

# no schema name given, load newest schema from destination
elif not self._resolved_schema:
with self._destination_client(Schema(self._dataset_name)) as client:
if isinstance(client, WithStateSync):
stored_schema = client.get_stored_schema(any_schema_name=True)
if stored_schema:
self._resolved_schema = Schema.from_stored_schema(
json.loads(stored_schema.schema)
)

# default to empty schema with dataset name if nothing found
if not self._resolved_schema:
self._resolved_schema = Schema(self._dataset_name)

# here we create the client bound to the resolved schema
if not self._sql_client:
destination_client = self._destination_client(self._resolved_schema)
if isinstance(destination_client, WithSqlClient):
self._sql_client = destination_client.sql_client
else:
raise Exception(
f"Destination {destination_client.config.destination_type} does not support"
" SqlClient."
)

def __call__(
self, query: Any, schema_columns: TTableSchemaColumns = None
) -> ReadableDBAPIRelation:
schema_columns = schema_columns or {}
return ReadableDBAPIRelation(client=self.client, query=query, schema_columns=schema_columns) # type: ignore[abstract]
return ReadableDBAPIRelation(client=self.sql_client, query=query, schema_columns=schema_columns) # type: ignore[abstract]

def table(self, table_name: str) -> SupportsReadableRelation:
# prepare query for table relation
schema_columns = (
self.schema.tables.get(table_name, {}).get("columns", {}) if self.schema else {}
)
table_name = self.client.make_qualified_table_name(table_name)
table_name = self.sql_client.make_qualified_table_name(table_name)
query = f"SELECT * FROM {table_name}"
return self(query, schema_columns)

@property
def schema(self) -> Schema:
"""Lazy load schema from destination"""
self._ensure_client_and_schema()
return self._resolved_schema

@property
def sql_client(self) -> SqlClientBase[Any]:
"""Lazy instantiate client"""
self._ensure_client_and_schema()
return self._sql_client

def __getitem__(self, table_name: str) -> SupportsReadableRelation:
"""access of table via dict notation"""
return self.table(table_name)

def __getattr__(self, table_name: str) -> SupportsReadableRelation:
"""access of table via property notation"""
return self.table(table_name)


def dataset(
destination: TDestinationReferenceArg,
dataset_name: str,
schema: Union[Schema, str, None] = None,
dataset_type: TDatasetType = "dbapi",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we allow a given schema or alternatively a schema name which will be loaded from the destination or no schema name which will do the autodiscovery as discussed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK cool. but as discussed we'll need to implement dataset compatible with pipeline dataset (many schemas, different database layout: we support schema separation but it is rarely used)

) -> SupportsReadableDataset:
if dataset_type == "dbapi":
return ReadableDBAPIDataset(destination, dataset_name, schema)
raise NotImplementedError(f"Dataset of type {dataset_type} not implemented")
44 changes: 24 additions & 20 deletions dlt/destinations/impl/filesystem/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,29 +650,33 @@ def _iter_stored_schema_files(self) -> Iterator[Tuple[str, List[str]]]:
yield filepath, fileparts

def _get_stored_schema_by_hash_or_newest(
self, version_hash: str = None
self, version_hash: str = None, any_schema_name: bool = False
) -> Optional[StorageSchemaInfo]:
"""Get the schema by supplied hash, falls back to getting the newest version matching the existing schema name"""
version_hash = self._to_path_safe_string(version_hash)
# find newest schema for pipeline or by version hash
selected_path = None
newest_load_id = "0"
for filepath, fileparts in self._iter_stored_schema_files():
if (
not version_hash
and fileparts[0] == self.schema.name
and fileparts[1] > newest_load_id
):
newest_load_id = fileparts[1]
selected_path = filepath
elif fileparts[2] == version_hash:
selected_path = filepath
break
try:
selected_path = None
newest_load_id = "0"
for filepath, fileparts in self._iter_stored_schema_files():
if (
not version_hash
and (fileparts[0] == self.schema.name or any_schema_name)
and fileparts[1] > newest_load_id
):
newest_load_id = fileparts[1]
selected_path = filepath
elif fileparts[2] == version_hash:
selected_path = filepath
break

if selected_path:
return StorageSchemaInfo(
**json.loads(self.fs_client.read_text(selected_path, encoding="utf-8"))
)
if selected_path:
return StorageSchemaInfo(
**json.loads(self.fs_client.read_text(selected_path, encoding="utf-8"))
)
except DestinationUndefinedEntity:
# ignore missing table
pass

return None

Expand All @@ -699,9 +703,9 @@ def _store_current_schema(self) -> None:
# we always keep tabs on what the current schema is
self._write_to_json_file(filepath, version_info)

def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage"""
return self._get_stored_schema_by_hash_or_newest()
return self._get_stored_schema_by_hash_or_newest(any_schema_name=any_schema_name)

def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]:
return self._get_stored_schema_by_hash_or_newest(version_hash)
Expand Down
11 changes: 5 additions & 6 deletions dlt/destinations/impl/lancedb/lancedb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI
return None

@lancedb_error
def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage."""
fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name)

Expand All @@ -553,11 +553,10 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
p_schema = self.schema.naming.normalize_identifier("schema")

try:
schemas = (
version_table.search().where(
f'`{p_schema_name}` = "{self.schema.name}"', prefilter=True
)
).to_list()
query = version_table.search()
if not any_schema_name:
query = query.where(f'`{p_schema_name}` = "{self.schema.name}"', prefilter=True)
schemas = query.to_list()

# LanceDB's ORDER BY clause doesn't seem to work.
# See https://github.com/dlt-hub/dlt/pull/1375#issuecomment-2171909341
Expand Down
21 changes: 12 additions & 9 deletions dlt/destinations/impl/qdrant/qdrant_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,23 +377,26 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
raise DestinationUndefinedEntity(str(e)) from e
raise

def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage"""
try:
scroll_table_name = self._make_qualified_collection_name(self.schema.version_table_name)
p_schema_name = self.schema.naming.normalize_identifier("schema_name")
p_inserted_at = self.schema.naming.normalize_identifier("inserted_at")

name_filter = models.Filter(
must=[
models.FieldCondition(
key=p_schema_name,
match=models.MatchValue(value=self.schema.name),
)
]
)

response = self.db_client.scroll(
scroll_table_name,
with_payload=True,
scroll_filter=models.Filter(
must=[
models.FieldCondition(
key=p_schema_name,
match=models.MatchValue(value=self.schema.name),
)
]
),
scroll_filter=None if any_schema_name else name_filter,
limit=1,
order_by=models.OrderBy(
key=p_inserted_at,
Expand Down
13 changes: 9 additions & 4 deletions dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,10 @@ def _update_schema_in_storage(self, schema: Schema) -> None:
self.sql_client.execute_sql(table_obj.insert().values(schema_mapping))

def _get_stored_schema(
self, version_hash: Optional[str] = None, schema_name: Optional[str] = None
self,
version_hash: Optional[str] = None,
schema_name: Optional[str] = None,
any_schema_name: bool = False,
) -> Optional[StorageSchemaInfo]:
version_table = self.schema.tables[self.schema.version_table_name]
table_obj = self._to_table_object(version_table) # type: ignore[arg-type]
Expand All @@ -249,7 +252,7 @@ def _get_stored_schema(
if version_hash is not None:
version_hash_col = self.schema.naming.normalize_identifier("version_hash")
q = q.where(table_obj.c[version_hash_col] == version_hash)
if schema_name is not None:
if schema_name is not None and not any_schema_name:
schema_name_col = self.schema.naming.normalize_identifier("schema_name")
q = q.where(table_obj.c[schema_name_col] == schema_name)
inserted_at_col = self.schema.naming.normalize_identifier("inserted_at")
Expand All @@ -267,9 +270,11 @@ def _get_stored_schema(
def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]:
return self._get_stored_schema(version_hash)

def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]:
"""Get the latest stored schema"""
return self._get_stored_schema(schema_name=self.schema.name)
return self._get_stored_schema(
schema_name=self.schema.name, any_schema_name=any_schema_name
)

def get_stored_state(self, pipeline_name: str) -> StateInfo:
state_table = self.schema.tables.get(
Expand Down
15 changes: 9 additions & 6 deletions dlt/destinations/impl/weaviate/weaviate_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,19 +516,22 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
if len(load_records):
return StateInfo(**state)

def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
def get_stored_schema(self, any_schema_name: bool = False) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage"""
p_schema_name = self.schema.naming.normalize_identifier("schema_name")
p_inserted_at = self.schema.naming.normalize_identifier("inserted_at")

name_filter = {
"path": [p_schema_name],
"operator": "Equal",
"valueString": self.schema.name,
}

try:
record = self.get_records(
self.schema.version_table_name,
sort={"path": [p_inserted_at], "order": "desc"},
where={
"path": [p_schema_name],
"operator": "Equal",
"valueString": self.schema.name,
},
where=None if any_schema_name else name_filter,
limit=1,
)[0]
return StorageSchemaInfo(**record)
Expand Down
19 changes: 13 additions & 6 deletions dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,14 +397,21 @@ def _from_db_type(
) -> TColumnType:
pass

def get_stored_schema(self) -> StorageSchemaInfo:
def get_stored_schema(self, any_schema_name: bool = False) -> StorageSchemaInfo:
name = self.sql_client.make_qualified_table_name(self.schema.version_table_name)
c_schema_name, c_inserted_at = self._norm_and_escape_columns("schema_name", "inserted_at")
query = (
f"SELECT {self.version_table_schema_columns} FROM {name} WHERE {c_schema_name} = %s"
f" ORDER BY {c_inserted_at} DESC;"
)
return self._row_to_schema_info(query, self.schema.name)
if any_schema_name:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be compressed and also needs to be implemented for all destinations

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also: is this signature ok, or do we want to add a new function for this? I'm also not sure about this "any_schema_name":.. :)

query = (
f"SELECT {self.version_table_schema_columns} FROM {name}"
f" ORDER BY {c_inserted_at} DESC;"
)
return self._row_to_schema_info(query)
else:
query = (
f"SELECT {self.version_table_schema_columns} FROM {name} WHERE {c_schema_name} = %s"
f" ORDER BY {c_inserted_at} DESC;"
)
return self._row_to_schema_info(query, self.schema.name)

def get_stored_state(self, pipeline_name: str) -> StateInfo:
state_table = self.sql_client.make_qualified_table_name(self.schema.state_table_name)
Expand Down
Loading
Loading