Skip to content

Commit

Permalink
Merge pull request #129 from dlt-hub/rfix/fixes-duckdb
Browse files Browse the repository at this point in the history
fixes file rotation on schema changes + bumps duckdb to 0.7
  • Loading branch information
rudolfix authored Feb 15, 2023
2 parents 566093c + 1d9ad09 commit 77c1291
Show file tree
Hide file tree
Showing 14 changed files with 678 additions and 225 deletions.
20 changes: 13 additions & 7 deletions dlt/common/data_writers/buffered.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import List, IO, Any

from copy import deepcopy
from typing import List, IO, Any, Type

from dlt.common.utils import uniq_id
from dlt.common.typing import TDataItem, TDataItems
Expand Down Expand Up @@ -32,7 +30,7 @@ def __init__(
self._caps = _caps
# validate if template has correct placeholders
self.file_name_template = file_name_template
self.all_files: List[str] = []
self.closed_files: List[str] = [] # all fully processed files
# buffered items must be less than max items in file
self.buffer_max_items = min(buffer_max_items, file_max_items or buffer_max_items)
self.file_max_bytes = file_max_bytes
Expand All @@ -54,9 +52,11 @@ def write_data_item(self, item: TDataItems, columns: TTableSchemaColumns) -> Non
# rotate file if columns changed and writer does not allow for that
# as the only allowed change is to add new column (no updates/deletes), we detect the change by comparing lengths
if self._writer and not self._writer.data_format().supports_schema_changes and len(columns) != len(self._current_columns):
assert len(columns) > len(self._current_columns)
self._rotate_file()
# until the first chunk is written we can change the columns schema freely
self._current_columns = deepcopy(columns)
if columns is not None:
self._current_columns = dict(columns)
if isinstance(item, List):
# items coming in single list will be written together, not matter how many are there
self._buffered_items.extend(item)
Expand All @@ -74,7 +74,7 @@ def write_data_item(self, item: TDataItems, columns: TTableSchemaColumns) -> Non
if self.file_max_items and self._writer.items_count >= self.file_max_items:
self._rotate_file()

def close_writer(self) -> None:
def close(self) -> None:
self._ensure_open()
self._flush_and_close_file()
self._closed = True
Expand All @@ -83,6 +83,12 @@ def close_writer(self) -> None:
def closed(self) -> bool:
return self._closed

def __enter__(self) -> "BufferedDataWriter":
return self

def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: Any) -> None:
self.close()

def _rotate_file(self) -> None:
self._flush_and_close_file()
self._file_name = self.file_name_template % uniq_id(5) + "." + self._file_format_spec.file_extension
Expand Down Expand Up @@ -111,7 +117,7 @@ def _flush_and_close_file(self) -> None:
self._writer.write_footer()
self._file.close()
# add file written to the list so we can commit all the files later
self.all_files.append(self._file_name)
self.closed_files.append(self._file_name)
self._writer = None
self._file = None

Expand Down
19 changes: 18 additions & 1 deletion dlt/common/destination.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from importlib import import_module
from types import TracebackType, ModuleType
from typing import Any, Callable, ClassVar, Final, List, Optional, Literal, Type, Protocol, Union, TYPE_CHECKING, cast
from dlt.common.exceptions import IdentifierTooLongException, InvalidDestinationReference, UnknownDestinationModule

from dlt.common.configuration.utils import serialize_value
from dlt.common.exceptions import IdentifierTooLongException, InvalidDestinationReference, UnknownDestinationModule
from dlt.common.schema import Schema
from dlt.common.schema.typing import TTableSchema
from dlt.common.configuration import configspec
Expand Down Expand Up @@ -37,6 +38,22 @@ class DestinationCapabilitiesContext(ContainerInjectableContext):
can_create_default: ClassVar[bool] = False


def generic_destination_capabilities() ->DestinationCapabilitiesContext:
caps = DestinationCapabilitiesContext()
caps.preferred_loader_file_format=None
caps.supported_loader_file_formats = ["jsonl", "insert_values"]
caps.escape_identifier = lambda x: x
caps.escape_literal = lambda x: serialize_value(x)
caps.max_identifier_length = 65536
caps.max_column_identifier_length = 65536
caps.max_query_length = 32 * 1024 * 1024
caps.is_max_query_length_in_bytes = True
caps.max_text_data_type_length = 1024 * 1024 * 1024
caps.is_max_text_data_type_length_in_bytes = True
caps.supports_ddl_transactions = True
return caps


@configspec(init=True)
class DestinationClientConfiguration(BaseConfiguration):
destination_name: str = None # which destination to load data to
Expand Down
11 changes: 9 additions & 2 deletions dlt/common/storages/data_item_storage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Any
from typing import Dict, Any, List
from abc import ABC, abstractmethod

from dlt.common import logger
Expand Down Expand Up @@ -30,7 +30,14 @@ def close_writers(self, extract_id: str) -> None:
for name, writer in self.buffered_writers.items():
if name.startswith(extract_id):
logger.debug(f"Closing writer for {name} with file {writer._file} and actual name {writer._file_name}")
writer.close_writer()
writer.close()

def closed_files(self) -> List[str]:
files: List[str] = []
for writer in self.buffered_writers.values():
files.extend(writer.closed_files)

return files

@abstractmethod
def _get_data_item_path_template(self, load_id: str, schema_name: str, table_name: str) -> str:
Expand Down
6 changes: 3 additions & 3 deletions dlt/destinations/bigquery/bigquery.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any, ClassVar, Dict, List, Optional, Sequence, Tuple
from typing import Any, ClassVar, Dict, List, Optional, Sequence, Tuple, cast
from dlt.common.storages.file_storage import FileStorage
import google.cloud.bigquery as bigquery # noqa: I250
from google.cloud import exceptions as gcp_exceptions
Expand Down Expand Up @@ -98,7 +98,7 @@ def __init__(self, schema: Schema, config: BigQueryClientConfiguration) -> None:
)
super().__init__(schema, config, sql_client)
self.config: BigQueryClientConfiguration = config
self.sql_client: BigQuerySqlClient = sql_client
self.sql_client: BigQuerySqlClient = sql_client # type: ignore

def restore_file_load(self, file_path: str) -> LoadJob:
try:
Expand Down Expand Up @@ -209,7 +209,7 @@ def _create_load_job(self, table_name: str, write_disposition: TWriteDisposition

def _retrieve_load_job(self, file_path: str) -> bigquery.LoadJob:
job_id = BigQueryClient._get_job_id_from_file_path(file_path)
return self.sql_client.native_connection.get_job(job_id)
return cast(bigquery.LoadJob, self.sql_client.native_connection.get_job(job_id))

@staticmethod
def _get_job_id_from_file_path(file_path: str) -> str:
Expand Down
4 changes: 2 additions & 2 deletions dlt/destinations/bigquery/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, dataset_name: str, credentials: GcpClientCredentials) -> None

self._default_retry = bigquery.DEFAULT_RETRY.with_deadline(credentials.retry_deadline)
self._default_query = bigquery.QueryJobConfig(default_dataset=self.fully_qualified_dataset_name(escape=False))
self._session_query = None
self._session_query: bigquery.QueryJobConfig = None

@raise_open_connection_error
def open_connection(self) -> None:
Expand All @@ -59,7 +59,7 @@ def query_patch(
) -> Any:
return query_orig(query, retry=retry, timeout=timeout, **kwargs)

self._client.query = query_patch
self._client.query = query_patch # type: ignore

def close_connection(self) -> None:
if self._session_query:
Expand Down
7 changes: 5 additions & 2 deletions dlt/destinations/duckdb/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@ def __init__(self, dataset_name: str, credentials: DuckDbCredentials) -> None:
def open_connection(self) -> None:
self._conn = self.credentials.borrow_conn(read_only=self.credentials.read_only)
# TODO: apply config settings from credentials
config={"search_path": self.fully_qualified_dataset_name()}
config={
"search_path": self.fully_qualified_dataset_name(),
"TimeZone": "UTC"
}
if config:
for k, v in config.items():
try:
# TODO: serialize str and ints, dbapi args do not work here
# TODO: enable various extensions ie. parquet
self._conn.execute(f"SET {k} = '{v}'")
except duckdb.CatalogException:
except (duckdb.CatalogException, duckdb.BinderException):
pass

def close_connection(self) -> None:
Expand Down
8 changes: 4 additions & 4 deletions dlt/helpers/dbt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from dlt.helpers.dbt.runner import create_runner, DBTPackageRunner

DEFAULT_DLT_VERSION = ">=1.1,<1.4"
DEFAULT_DBT_VERSION = ">=1.1,<1.5"


def _default_profile_name(credentials: DestinationClientDwhConfiguration) -> str:
Expand All @@ -24,7 +24,7 @@ def _default_profile_name(credentials: DestinationClientDwhConfiguration) -> str
return profile_name


def _create_dbt_deps(destination_names: List[str], dbt_version: str = DEFAULT_DLT_VERSION) -> List[str]:
def _create_dbt_deps(destination_names: List[str], dbt_version: str = DEFAULT_DBT_VERSION) -> List[str]:
if dbt_version:
# if parses as version use "==" operator
with contextlib.suppress(ValueError):
Expand All @@ -45,13 +45,13 @@ def _create_dbt_deps(destination_names: List[str], dbt_version: str = DEFAULT_DL
return all_packages + [dlt_requirement]


def restore_venv(venv_dir: str, destination_names: List[str], dbt_version: str = DEFAULT_DLT_VERSION) -> Venv:
def restore_venv(venv_dir: str, destination_names: List[str], dbt_version: str = DEFAULT_DBT_VERSION) -> Venv:
venv = Venv.restore(venv_dir)
venv.add_dependencies(_create_dbt_deps(destination_names, dbt_version))
return venv


def create_venv(venv_dir: str, destination_names: List[str], dbt_version: str = DEFAULT_DLT_VERSION) -> Venv:
def create_venv(venv_dir: str, destination_names: List[str], dbt_version: str = DEFAULT_DBT_VERSION) -> Venv:
return Venv.create(venv_dir, _create_dbt_deps(destination_names, dbt_version))


Expand Down
4 changes: 2 additions & 2 deletions dlt/helpers/dbt/dbt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
raise MissingDependencyException("DBT Core", ["dbt-core"])

try:
from dbt.exceptions import FailFastException
from dbt.exceptions import FailFastException # type: ignore
except ImportError:
from dbt.exceptions import FailFastError as FailFastException # type: ignore
from dbt.exceptions import FailFastError as FailFastException

_DBT_LOGGER_INITIALIZED = False

Expand Down
4 changes: 2 additions & 2 deletions dlt/pipeline/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from dlt.common.schema import Schema
from dlt.common.typing import TSecretValue

from dlt.helpers.dbt import create_venv as _create_venv, package_runner as _package_runner, DBTPackageRunner, DEFAULT_DLT_VERSION as _DEFAULT_DLT_VERSION, restore_venv as _restore_venv
from dlt.helpers.dbt import create_venv as _create_venv, package_runner as _package_runner, DBTPackageRunner, DEFAULT_DBT_VERSION as _DEFAULT_DBT_VERSION, restore_venv as _restore_venv
from dlt.pipeline.pipeline import Pipeline



def get_venv(pipeline: Pipeline, venv_path: str = "dbt", dbt_version: str = _DEFAULT_DLT_VERSION) -> Venv:
def get_venv(pipeline: Pipeline, venv_path: str = "dbt", dbt_version: str = _DEFAULT_DBT_VERSION) -> Venv:
"""Creates or restores a virtual environment in which the `dbt` packages are executed.
The recommended way to execute dbt package is to use a separate virtual environment where only the dbt-core
Expand Down
Loading

0 comments on commit 77c1291

Please sign in to comment.