diff --git a/.github/workflows/test_common.yml b/.github/workflows/test_common.yml index 6ec1212d1a..2160025ea0 100644 --- a/.github/workflows/test_common.yml +++ b/.github/workflows/test_common.yml @@ -83,11 +83,11 @@ jobs: run: poetry install --no-interaction -E duckdb --with sentry-sdk - run: | - poetry run pytest tests/pipeline/test_pipeline.py + poetry run pytest tests/pipeline/test_pipeline.py tests/pipeline/test_import_export_schema.py if: runner.os != 'Windows' name: Run pipeline smoke tests with minimum deps Linux/MAC - run: | - poetry run pytest tests/pipeline/test_pipeline.py + poetry run pytest tests/pipeline/test_pipeline.py tests/pipeline/test_import_export_schema.py if: runner.os == 'Windows' name: Run smoke tests with minimum deps Windows shell: cmd diff --git a/.github/workflows/test_destination_athena_iceberg.yml b/.github/workflows/test_destination_athena_iceberg.yml index fa45b1b49b..d77e35f088 100644 --- a/.github/workflows/test_destination_athena_iceberg.yml +++ b/.github/workflows/test_destination_athena_iceberg.yml @@ -65,7 +65,7 @@ jobs: - name: Install dependencies # if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction -E --with sentry-sdk --with pipeline + run: poetry install --no-interaction -E athena --with sentry-sdk --with pipeline - name: create secrets.toml run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml diff --git a/.github/workflows/test_doc_snippets.yml b/.github/workflows/test_doc_snippets.yml index 80b77ce6c9..d73c109894 100644 --- a/.github/workflows/test_doc_snippets.yml +++ b/.github/workflows/test_doc_snippets.yml @@ -17,7 +17,8 @@ env: # Slack hook for chess in production example RUNTIME__SLACK_INCOMING_HOOK: ${{ secrets.RUNTIME__SLACK_INCOMING_HOOK }} - + # detect if the workflow is executed in a repo fork + IS_FORK: ${{ github.event.pull_request.head.repo.fork }} jobs: run_lint: diff --git a/Makefile b/Makefile index 0680f463ec..5aa2b2786c 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ help: @echo " test" @echo " tests all the components including destinations" @echo " test-load-local" - @echo " tests all components unsing local destinations: duckdb and postgres" + @echo " tests all components using local destinations: duckdb and postgres" @echo " test-common" @echo " tests common components" @echo " test-and-lint-snippets" diff --git a/dlt/__main__.py b/dlt/__main__.py new file mode 100644 index 0000000000..4399f0b557 --- /dev/null +++ b/dlt/__main__.py @@ -0,0 +1,4 @@ +from dlt.cli._dlt import main + +if __name__ == "__main__": + main() diff --git a/dlt/common/configuration/specs/connection_string_credentials.py b/dlt/common/configuration/specs/connection_string_credentials.py index e7b0e5f900..9dd6f00942 100644 --- a/dlt/common/configuration/specs/connection_string_credentials.py +++ b/dlt/common/configuration/specs/connection_string_credentials.py @@ -1,5 +1,5 @@ from typing import Any, ClassVar, Dict, List, Optional -from sqlalchemy.engine import URL, make_url +from dlt.common.libs.sql_alchemy import URL, make_url from dlt.common.configuration.specs.exceptions import InvalidConnectionString from dlt.common.typing import TSecretValue @@ -26,6 +26,7 @@ def parse_native_representation(self, native_value: Any) -> None: # update only values that are not None self.update({k: v for k, v in url._asdict().items() if v is not None}) if self.query is not None: + # query may be immutable so make it mutable self.query = dict(self.query) except Exception: raise InvalidConnectionString(self.__class__, native_value, self.drivername) diff --git a/dlt/common/data_types/type_helpers.py b/dlt/common/data_types/type_helpers.py index 6ce961d72c..659b4951df 100644 --- a/dlt/common/data_types/type_helpers.py +++ b/dlt/common/data_types/type_helpers.py @@ -13,7 +13,6 @@ from dlt.common.data_types.typing import TDataType from dlt.common.time import ( ensure_pendulum_datetime, - parse_iso_like_datetime, ensure_pendulum_date, ensure_pendulum_time, ) diff --git a/dlt/common/libs/numpy.py b/dlt/common/libs/numpy.py new file mode 100644 index 0000000000..ccf255c6a8 --- /dev/null +++ b/dlt/common/libs/numpy.py @@ -0,0 +1,6 @@ +from dlt.common.exceptions import MissingDependencyException + +try: + import numpy +except ModuleNotFoundError: + raise MissingDependencyException("DLT Numpy Helpers", ["numpy"]) diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index 31423665f7..183c27954b 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -1,4 +1,5 @@ from datetime import datetime, date # noqa: I251 +from pendulum.tz import UTC from typing import Any, Tuple, Optional, Union, Callable, Iterable, Iterator, Sequence, Tuple from dlt import version @@ -14,6 +15,7 @@ try: import pyarrow import pyarrow.parquet + import pyarrow.compute except ModuleNotFoundError: raise MissingDependencyException( "dlt parquet Helpers", [f"{version.DLT_PKG_NAME}[parquet]"], "dlt Helpers for for parquet." @@ -314,13 +316,13 @@ def is_arrow_item(item: Any) -> bool: return isinstance(item, (pyarrow.Table, pyarrow.RecordBatch)) -def to_arrow_compute_input(value: Any, arrow_type: pyarrow.DataType) -> Any: +def to_arrow_scalar(value: Any, arrow_type: pyarrow.DataType) -> Any: """Converts python value to an arrow compute friendly version""" return pyarrow.scalar(value, type=arrow_type) -def from_arrow_compute_output(arrow_value: pyarrow.Scalar) -> Any: - """Converts arrow scalar into Python type. Currently adds "UTC" to naive date times.""" +def from_arrow_scalar(arrow_value: pyarrow.Scalar) -> Any: + """Converts arrow scalar into Python type. Currently adds "UTC" to naive date times and converts all others to UTC""" row_value = arrow_value.as_py() # dates are not represented as datetimes but I see connector-x represents # datetimes as dates and keeping the exact time inside. probably a bug @@ -328,7 +330,7 @@ def from_arrow_compute_output(arrow_value: pyarrow.Scalar) -> Any: if isinstance(row_value, date) and not isinstance(row_value, datetime): row_value = pendulum.from_timestamp(arrow_value.cast(pyarrow.int64()).as_py() / 1000) elif isinstance(row_value, datetime): - row_value = pendulum.instance(row_value) + row_value = pendulum.instance(row_value).in_tz("UTC") return row_value diff --git a/dlt/common/libs/sql_alchemy.py b/dlt/common/libs/sql_alchemy.py new file mode 100644 index 0000000000..2f3b51ec0d --- /dev/null +++ b/dlt/common/libs/sql_alchemy.py @@ -0,0 +1,446 @@ +""" +Ports fragments of URL class from Sql Alchemy to use them when dependency is not available. +""" + +from typing import cast + + +try: + import sqlalchemy +except ImportError: + # port basic functionality without the whole Sql Alchemy + + import re + from typing import ( + Any, + Dict, + Iterable, + List, + Mapping, + NamedTuple, + Optional, + Sequence, + Tuple, + TypeVar, + Union, + overload, + ) + import collections.abc as collections_abc + from urllib.parse import ( + quote_plus, + parse_qsl, + quote, + unquote, + ) + + _KT = TypeVar("_KT", bound=Any) + _VT = TypeVar("_VT", bound=Any) + + class ImmutableDict(Dict[_KT, _VT]): + """Not a real immutable dict""" + + def __setitem__(self, __key: _KT, __value: _VT) -> None: + raise NotImplementedError("Cannot modify immutable dict") + + def __delitem__(self, _KT: Any) -> None: + raise NotImplementedError("Cannot modify immutable dict") + + def update(self, *arg: Any, **kw: Any) -> None: + raise NotImplementedError("Cannot modify immutable dict") + + EMPTY_DICT: ImmutableDict[Any, Any] = ImmutableDict() + + def to_list(value: Any, default: Optional[List[Any]] = None) -> List[Any]: + if value is None: + return default + if not isinstance(value, collections_abc.Iterable) or isinstance(value, str): + return [value] + elif isinstance(value, list): + return value + else: + return list(value) + + class URL(NamedTuple): + """ + Represent the components of a URL used to connect to a database. + + Based on SqlAlchemy URL class with copyright as below: + + # engine/url.py + # Copyright (C) 2005-2023 the SQLAlchemy authors and contributors + # + # This module is part of SQLAlchemy and is released under + # the MIT License: https://www.opensource.org/licenses/mit-license.php + """ + + drivername: str + """database backend and driver name, such as `postgresql+psycopg2`""" + username: Optional[str] + "username string" + password: Optional[str] + """password, which is normally a string but may also be any object that has a `__str__()` method.""" + host: Optional[str] + """hostname or IP number. May also be a data source name for some drivers.""" + port: Optional[int] + """integer port number""" + database: Optional[str] + """database name""" + query: ImmutableDict[str, Union[Tuple[str, ...], str]] + """an immutable mapping representing the query string. contains strings + for keys and either strings or tuples of strings for values""" + + @classmethod + def create( + cls, + drivername: str, + username: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + database: Optional[str] = None, + query: Mapping[str, Union[Sequence[str], str]] = None, + ) -> "URL": + """Create a new `URL` object.""" + return cls( + cls._assert_str(drivername, "drivername"), + cls._assert_none_str(username, "username"), + password, + cls._assert_none_str(host, "host"), + cls._assert_port(port), + cls._assert_none_str(database, "database"), + cls._str_dict(query or EMPTY_DICT), + ) + + @classmethod + def _assert_port(cls, port: Optional[int]) -> Optional[int]: + if port is None: + return None + try: + return int(port) + except TypeError: + raise TypeError("Port argument must be an integer or None") + + @classmethod + def _assert_str(cls, v: str, paramname: str) -> str: + if not isinstance(v, str): + raise TypeError("%s must be a string" % paramname) + return v + + @classmethod + def _assert_none_str(cls, v: Optional[str], paramname: str) -> Optional[str]: + if v is None: + return v + + return cls._assert_str(v, paramname) + + @classmethod + def _str_dict( + cls, + dict_: Optional[ + Union[ + Sequence[Tuple[str, Union[Sequence[str], str]]], + Mapping[str, Union[Sequence[str], str]], + ] + ], + ) -> ImmutableDict[str, Union[Tuple[str, ...], str]]: + if dict_ is None: + return EMPTY_DICT + + @overload + def _assert_value( + val: str, + ) -> str: ... + + @overload + def _assert_value( + val: Sequence[str], + ) -> Union[str, Tuple[str, ...]]: ... + + def _assert_value( + val: Union[str, Sequence[str]], + ) -> Union[str, Tuple[str, ...]]: + if isinstance(val, str): + return val + elif isinstance(val, collections_abc.Sequence): + return tuple(_assert_value(elem) for elem in val) + else: + raise TypeError( + "Query dictionary values must be strings or sequences of strings" + ) + + def _assert_str(v: str) -> str: + if not isinstance(v, str): + raise TypeError("Query dictionary keys must be strings") + return v + + dict_items: Iterable[Tuple[str, Union[Sequence[str], str]]] + if isinstance(dict_, collections_abc.Sequence): + dict_items = dict_ + else: + dict_items = dict_.items() + + return ImmutableDict( + { + _assert_str(key): _assert_value( + value, + ) + for key, value in dict_items + } + ) + + def set( # noqa + self, + drivername: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + database: Optional[str] = None, + query: Optional[Mapping[str, Union[Sequence[str], str]]] = None, + ) -> "URL": + """return a new `URL` object with modifications.""" + + kw: Dict[str, Any] = {} + if drivername is not None: + kw["drivername"] = drivername + if username is not None: + kw["username"] = username + if password is not None: + kw["password"] = password + if host is not None: + kw["host"] = host + if port is not None: + kw["port"] = port + if database is not None: + kw["database"] = database + if query is not None: + kw["query"] = query + + return self._assert_replace(**kw) + + def _assert_replace(self, **kw: Any) -> "URL": + """argument checks before calling _replace()""" + + if "drivername" in kw: + self._assert_str(kw["drivername"], "drivername") + for name in "username", "host", "database": + if name in kw: + self._assert_none_str(kw[name], name) + if "port" in kw: + self._assert_port(kw["port"]) + if "query" in kw: + kw["query"] = self._str_dict(kw["query"]) + + return self._replace(**kw) + + def update_query_string(self, query_string: str, append: bool = False) -> "URL": + return self.update_query_pairs(parse_qsl(query_string), append=append) + + def update_query_pairs( + self, + key_value_pairs: Iterable[Tuple[str, Union[str, List[str]]]], + append: bool = False, + ) -> "URL": + """Return a new `URL` object with the `query` parameter dictionary updated by the given sequence of key/value pairs""" + existing_query = self.query + new_keys: Dict[str, Union[str, List[str]]] = {} + + for key, value in key_value_pairs: + if key in new_keys: + new_keys[key] = to_list(new_keys[key]) + cast("List[str]", new_keys[key]).append(cast(str, value)) + else: + new_keys[key] = to_list(value) if isinstance(value, (list, tuple)) else value + + new_query: Mapping[str, Union[str, Sequence[str]]] + if append: + new_query = {} + + for k in new_keys: + if k in existing_query: + new_query[k] = tuple(to_list(existing_query[k]) + to_list(new_keys[k])) + else: + new_query[k] = new_keys[k] + + new_query.update( + {k: existing_query[k] for k in set(existing_query).difference(new_keys)} + ) + else: + new_query = ImmutableDict( + { + **self.query, + **{k: tuple(v) if isinstance(v, list) else v for k, v in new_keys.items()}, + } + ) + return self.set(query=new_query) + + def update_query_dict( + self, + query_parameters: Mapping[str, Union[str, List[str]]], + append: bool = False, + ) -> "URL": + return self.update_query_pairs(query_parameters.items(), append=append) + + def render_as_string(self, hide_password: bool = True) -> str: + """Render this `URL` object as a string.""" + s = self.drivername + "://" + if self.username is not None: + s += quote(self.username, safe=" +") + if self.password is not None: + s += ":" + ("***" if hide_password else quote(str(self.password), safe=" +")) + s += "@" + if self.host is not None: + if ":" in self.host: + s += f"[{self.host}]" + else: + s += self.host + if self.port is not None: + s += ":" + str(self.port) + if self.database is not None: + s += "/" + self.database + if self.query: + keys = to_list(self.query) + keys.sort() + s += "?" + "&".join( + f"{quote_plus(k)}={quote_plus(element)}" + for k in keys + for element in to_list(self.query[k]) + ) + return s + + def __repr__(self) -> str: + return self.render_as_string() + + def __copy__(self) -> "URL": + return self.__class__.create( + self.drivername, + self.username, + self.password, + self.host, + self.port, + self.database, + self.query.copy(), + ) + + def __deepcopy__(self, memo: Any) -> "URL": + return self.__copy__() + + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, URL) + and self.drivername == other.drivername + and self.username == other.username + and self.password == other.password + and self.host == other.host + and self.database == other.database + and self.query == other.query + and self.port == other.port + ) + + def __ne__(self, other: Any) -> bool: + return not self == other + + def get_backend_name(self) -> str: + """Return the backend name. + + This is the name that corresponds to the database backend in + use, and is the portion of the `drivername` + that is to the left of the plus sign. + + """ + if "+" not in self.drivername: + return self.drivername + else: + return self.drivername.split("+")[0] + + def get_driver_name(self) -> str: + """Return the backend name. + + This is the name that corresponds to the DBAPI driver in + use, and is the portion of the `drivername` + that is to the right of the plus sign. + """ + + if "+" not in self.drivername: + return self.drivername + else: + return self.drivername.split("+")[1] + + def make_url(name_or_url: Union[str, URL]) -> URL: + """Given a string, produce a new URL instance. + + The format of the URL generally follows `RFC-1738`, with some exceptions, including + that underscores, and not dashes or periods, are accepted within the + "scheme" portion. + + If a `URL` object is passed, it is returned as is.""" + + if isinstance(name_or_url, str): + return _parse_url(name_or_url) + elif not isinstance(name_or_url, URL): + raise ValueError(f"Expected string or URL object, got {name_or_url!r}") + else: + return name_or_url + + def _parse_url(name: str) -> URL: + pattern = re.compile( + r""" + (?P[\w\+]+):// + (?: + (?P[^:/]*) + (?::(?P[^@]*))? + @)? + (?: + (?: + \[(?P[^/\?]+)\] | + (?P[^/:\?]+) + )? + (?::(?P[^/\?]*))? + )? + (?:/(?P[^\?]*))? + (?:\?(?P.*))? + """, + re.X, + ) + + m = pattern.match(name) + if m is not None: + components = m.groupdict() + query: Optional[Dict[str, Union[str, List[str]]]] + if components["query"] is not None: + query = {} + + for key, value in parse_qsl(components["query"]): + if key in query: + query[key] = to_list(query[key]) + cast("List[str]", query[key]).append(value) + else: + query[key] = value + else: + query = None + + components["query"] = query + if components["username"] is not None: + components["username"] = unquote(components["username"]) + + if components["password"] is not None: + components["password"] = unquote(components["password"]) + + ipv4host = components.pop("ipv4host") + ipv6host = components.pop("ipv6host") + components["host"] = ipv4host or ipv6host + name = components.pop("name") + + if components["port"]: + components["port"] = int(components["port"]) + + return URL.create(name, **components) # type: ignore + + else: + raise ValueError("Could not parse SQLAlchemy URL from string '%s'" % name) + +else: + from sqlalchemy.engine import URL, make_url # type: ignore[assignment] diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index 6b7b308b44..df221ec703 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -102,13 +102,19 @@ def finished_at(self) -> datetime.datetime: def asdict(self) -> DictStrAny: # to be mixed with NamedTuple - d: DictStrAny = self._asdict() # type: ignore - d["pipeline"] = {"pipeline_name": self.pipeline.pipeline_name} - d["load_packages"] = [package.asdict() for package in self.load_packages] + step_info: DictStrAny = self._asdict() # type: ignore + step_info["pipeline"] = {"pipeline_name": self.pipeline.pipeline_name} + step_info["load_packages"] = [package.asdict() for package in self.load_packages] if self.metrics: - d["started_at"] = self.started_at - d["finished_at"] = self.finished_at - return d + step_info["started_at"] = self.started_at + step_info["finished_at"] = self.finished_at + all_metrics = [] + for load_id, metrics in step_info["metrics"].items(): + for metric in metrics: + all_metrics.append({**dict(metric), "load_id": load_id}) + + step_info["metrics"] = all_metrics + return step_info def __str__(self) -> str: return self.asstr(verbosity=0) diff --git a/dlt/common/runtime/signals.py b/dlt/common/runtime/signals.py index 2a5cc75135..8e64c8ba64 100644 --- a/dlt/common/runtime/signals.py +++ b/dlt/common/runtime/signals.py @@ -64,5 +64,9 @@ def delayed_signals() -> Iterator[None]: signal.signal(signal.SIGINT, original_sigint_handler) signal.signal(signal.SIGTERM, original_sigterm_handler) else: - print("Running in daemon thread, signals not enabled") + if not TYPE_CHECKING: + from dlt.common.runtime import logger + else: + logger: Any = None + logger.info("Running in daemon thread, signals not enabled") yield diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index b73e45d489..4c81c8af72 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -124,10 +124,24 @@ def from_stored_schema(cls, stored_schema: TStoredSchema) -> "Schema": self._from_stored_schema(stored_schema) return self - def replace_schema_content(self, schema: "Schema") -> None: - self._reset_schema(schema.name, schema._normalizers_config) + def replace_schema_content( + self, schema: "Schema", link_to_replaced_schema: bool = False + ) -> None: + """Replaces content of the current schema with `schema` content. Does not compute new schema hash and + does not increase the numeric version. Optionally will link the replaced schema to incoming schema + by keeping its hash in prev hashes and setting stored hash to replaced schema hash. + """ # do not bump version so hash from `schema` is preserved - self._from_stored_schema(schema.to_dict(bump_version=False)) + stored_schema = schema.to_dict(bump_version=False) + if link_to_replaced_schema: + replaced_version_hash = self.stored_version_hash + assert replaced_version_hash is not None + # do not store hash if the replaced schema is identical + if stored_schema["version_hash"] != replaced_version_hash: + utils.store_prev_hash(stored_schema, replaced_version_hash) + stored_schema["version_hash"] = replaced_version_hash + self._reset_schema(schema.name, schema._normalizers_config) + self._from_stored_schema(stored_schema) def to_dict(self, remove_defaults: bool = False, bump_version: bool = True) -> TStoredSchema: stored_schema: TStoredSchema = { diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index 9cbd7266f2..ec60e4c365 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -18,6 +18,7 @@ from dlt.common.data_types import TDataType from dlt.common.normalizers.typing import TNormalizersConfig +from dlt.common.typing import TSortOrder try: from pydantic import BaseModel as _PydanticBaseModel @@ -71,7 +72,6 @@ TTypeDetectionFunc = Callable[[Type[Any], Any], Optional[TDataType]] TColumnNames = Union[str, Sequence[str]] """A string representing a column name or a list of""" -TSortOrder = Literal["asc", "desc"] COLUMN_PROPS: Set[TColumnProp] = set(get_args(TColumnProp)) COLUMN_HINTS: Set[TColumnHint] = set( diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index 835fe4279e..4f2a4aa22d 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -169,15 +169,21 @@ def bump_version_if_modified(stored_schema: TStoredSchema) -> Tuple[int, str, st pass elif hash_ != previous_hash: stored_schema["version"] += 1 - # unshift previous hash to previous_hashes and limit array to 10 entries - if previous_hash not in stored_schema["previous_hashes"]: - stored_schema["previous_hashes"].insert(0, previous_hash) - stored_schema["previous_hashes"] = stored_schema["previous_hashes"][:10] + store_prev_hash(stored_schema, previous_hash) stored_schema["version_hash"] = hash_ return stored_schema["version"], hash_, previous_hash, stored_schema["previous_hashes"] +def store_prev_hash( + stored_schema: TStoredSchema, previous_hash: str, max_history_len: int = 10 +) -> None: + # unshift previous hash to previous_hashes and limit array to 10 entries + if previous_hash not in stored_schema["previous_hashes"]: + stored_schema["previous_hashes"].insert(0, previous_hash) + stored_schema["previous_hashes"] = stored_schema["previous_hashes"][:max_history_len] + + def generate_version_hash(stored_schema: TStoredSchema) -> str: # generates hash out of stored schema content, excluding the hash itself and version schema_copy = copy(stored_schema) diff --git a/dlt/common/storages/live_schema_storage.py b/dlt/common/storages/live_schema_storage.py index e3fd07cf72..d3d5f14fe5 100644 --- a/dlt/common/storages/live_schema_storage.py +++ b/dlt/common/storages/live_schema_storage.py @@ -1,7 +1,8 @@ -from typing import Dict, List +from typing import Dict, List, cast from dlt.common.schema.schema import Schema from dlt.common.configuration.accessors import config +from dlt.common.storages.exceptions import SchemaNotFoundError from dlt.common.storages.schema_storage import SchemaStorage from dlt.common.storages.configuration import SchemaStorageConfiguration @@ -23,10 +24,10 @@ def __getitem__(self, name: str) -> Schema: return schema - def load_schema(self, name: str) -> Schema: - self.commit_live_schema(name) - # now live schema is saved so we can load it with the changes - return super().load_schema(name) + # def load_schema(self, name: str) -> Schema: + # self.commit_live_schema(name) + # # now live schema is saved so we can load it with the changes + # return super().load_schema(name) def save_schema(self, schema: Schema) -> str: rv = super().save_schema(schema) @@ -55,6 +56,17 @@ def commit_live_schema(self, name: str) -> Schema: self._save_schema(live_schema) return live_schema + def is_live_schema_committed(self, name: str) -> bool: + """Checks if live schema is present in storage and have same hash""" + live_schema = self.live_schemas.get(name) + if live_schema is None: + raise SchemaNotFoundError(name, f"live-schema://{name}") + try: + stored_schema_json = self._load_schema_json(name) + return live_schema.version_hash == cast(str, stored_schema_json.get("version_hash")) + except FileNotFoundError: + return False + def update_live_schema(self, schema: Schema, can_create_new: bool = True) -> None: """Will update live schema content without writing to storage. Optionally allows to create a new live schema""" live_schema = self.live_schemas.get(schema.name) @@ -62,7 +74,7 @@ def update_live_schema(self, schema: Schema, can_create_new: bool = True) -> Non if id(live_schema) != id(schema): # replace content without replacing instance # print(f"live schema {live_schema} updated in place") - live_schema.replace_schema_content(schema) + live_schema.replace_schema_content(schema, link_to_replaced_schema=True) elif can_create_new: # print(f"live schema {schema.name} created from schema") self.live_schemas[schema.name] = schema diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index 01f3923455..63409aa878 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -354,9 +354,11 @@ def remove_completed_jobs(self, load_id: str) -> None: recursively=True, ) - def delete_package(self, load_id: str) -> None: + def delete_package(self, load_id: str, not_exists_ok: bool = False) -> None: package_path = self.get_package_path(load_id) if not self.storage.has_folder(package_path): + if not_exists_ok: + return raise LoadPackageNotFound(load_id) self.storage.delete_folder(package_path, recursively=True) diff --git a/dlt/common/storages/schema_storage.py b/dlt/common/storages/schema_storage.py index a43b8a1f9b..4745d50dcc 100644 --- a/dlt/common/storages/schema_storage.py +++ b/dlt/common/storages/schema_storage.py @@ -1,5 +1,5 @@ import yaml -from typing import Iterator, List, Mapping, Tuple +from typing import Iterator, List, Mapping, Tuple, cast from dlt.common import json, logger from dlt.common.configuration import with_config @@ -31,12 +31,15 @@ def __init__( self.config = config self.storage = FileStorage(config.schema_volume_path, makedirs=makedirs) + def _load_schema_json(self, name: str) -> DictStrAny: + schema_file = self._file_name_in_store(name, "json") + return cast(DictStrAny, json.loads(self.storage.load(schema_file))) + def load_schema(self, name: str) -> Schema: # loads a schema from a store holding many schemas - schema_file = self._file_name_in_store(name, "json") storage_schema: DictStrAny = None try: - storage_schema = json.loads(self.storage.load(schema_file)) + storage_schema = self._load_schema_json(name) # prevent external modifications of schemas kept in storage if not verify_schema_hash(storage_schema, verifies_if_not_migrated=True): raise InStorageSchemaModified(name, self.config.schema_volume_path) diff --git a/dlt/common/time.py b/dlt/common/time.py index c06e2e2581..d3c8f9746c 100644 --- a/dlt/common/time.py +++ b/dlt/common/time.py @@ -44,9 +44,12 @@ def timestamp_before(timestamp: float, max_inclusive: Optional[float]) -> bool: def parse_iso_like_datetime(value: Any) -> Union[pendulum.DateTime, pendulum.Date, pendulum.Time]: - # we use internal pendulum parse function. the generic function, for example, parses string "now" as now() - # it also tries to parse ISO intervals but the code is very low quality + """Parses ISO8601 string into pendulum datetime, date or time. Preserves timezone info. + Note: naive datetimes will generated from string without timezone + we use internal pendulum parse function. the generic function, for example, parses string "now" as now() + it also tries to parse ISO intervals but the code is very low quality + """ # only iso dates are allowed dtv = None with contextlib.suppress(ValueError): @@ -57,7 +60,7 @@ def parse_iso_like_datetime(value: Any) -> Union[pendulum.DateTime, pendulum.Dat if isinstance(dtv, datetime.time): return pendulum.time(dtv.hour, dtv.minute, dtv.second, dtv.microsecond) if isinstance(dtv, datetime.datetime): - return pendulum.instance(dtv) + return pendulum.instance(dtv, tz=dtv.tzinfo) if isinstance(dtv, pendulum.Duration): raise ValueError("Interval ISO 8601 not supported: " + value) return pendulum.date(dtv.year, dtv.month, dtv.day) # type: ignore[union-attr] diff --git a/dlt/common/typing.py b/dlt/common/typing.py index 9a776bd51d..05720fe7d9 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -24,6 +24,8 @@ Union, runtime_checkable, IO, + Iterator, + Generator, ) from typing_extensions import ( @@ -69,6 +71,9 @@ AnyFun: TypeAlias = Callable[..., Any] TFun = TypeVar("TFun", bound=AnyFun) # any function TAny = TypeVar("TAny", bound=Any) +TAnyFunOrGenerator = TypeVar( + "TAnyFunOrGenerator", AnyFun, Generator[Any, Optional[Any], Optional[Any]] +) TAnyClass = TypeVar("TAnyClass", bound=object) TimedeltaSeconds = Union[int, float, timedelta] # represent secret value ie. coming from Kubernetes/Docker secrets or other providers @@ -88,6 +93,7 @@ TVariantRV = Tuple[str, Any] VARIANT_FIELD_FORMAT = "v_%s" TFileOrPath = Union[str, os.PathLike, IO[Any]] +TSortOrder = Literal["asc", "desc"] @runtime_checkable diff --git a/dlt/common/utils.py b/dlt/common/utils.py index 49a425780b..4ddde87758 100644 --- a/dlt/common/utils.py +++ b/dlt/common/utils.py @@ -43,9 +43,18 @@ RowCounts = Dict[str, int] -def chunks(seq: Sequence[T], n: int) -> Iterator[Sequence[T]]: - for i in range(0, len(seq), n): - yield seq[i : i + n] +def chunks(iterable: Iterable[T], n: int) -> Iterator[Sequence[T]]: + it = iter(iterable) + while True: + chunk = list() + try: + for _ in range(n): + chunk.append(next(it)) + except StopIteration: + if chunk: + yield chunk + break + yield chunk def uniq_id(len_: int = 16) -> str: @@ -272,8 +281,10 @@ def update_dict_with_prune(dest: DictStrAny, update: StrAny) -> None: del dest[k] -def update_dict_nested(dst: TDict, src: StrAny) -> TDict: - """Merges `src` into `dst` key wise. Does not recur into lists. Values in `src` overwrite `dst` if both keys exit.""" +def update_dict_nested(dst: TDict, src: StrAny, keep_dst_values: bool = False) -> TDict: + """Merges `src` into `dst` key wise. Does not recur into lists. Values in `src` overwrite `dst` if both keys exit. + Optionally (`keep_dst_values`) you can keep the `dst` value on conflict + """ # based on https://github.com/clarketm/mergedeep/blob/master/mergedeep/mergedeep.py def _is_recursive_merge(a: StrAny, b: StrAny) -> bool: @@ -290,7 +301,9 @@ def _is_recursive_merge(a: StrAny, b: StrAny) -> bool: # If a key exists in both objects and the values are `same`, the value from the `dst` object will be used. pass else: - dst[key] = src[key] + if not keep_dst_values: + # if not keep then overwrite + dst[key] = src[key] else: # If the key exists only in `src`, the value from the `src` object will be used. dst[key] = src[key] diff --git a/dlt/common/validation.py b/dlt/common/validation.py index 65b488bf8b..6bf1356aeb 100644 --- a/dlt/common/validation.py +++ b/dlt/common/validation.py @@ -90,7 +90,11 @@ def verify_prop(pk: str, pv: Any, t: Any) -> None: has_passed = True if not has_passed: type_names = [ - str(get_args(ut)) if is_literal_type(ut) else ut.__name__ + ( + str(get_args(ut)) + if is_literal_type(ut) + else getattr(ut, "__name__", str(ut)) + ) for ut in union_types ] raise DictValidationException( @@ -160,8 +164,11 @@ def verify_prop(pk: str, pv: Any, t: Any) -> None: pass else: if not validator_f(path, pk, pv, t): + # TODO: when Python 3.9 and earlier support is + # dropped, just __name__ can be used + type_name = getattr(t, "__name__", str(t)) raise DictValidationException( - f"In {path}: field {pk} has expected type {t.__name__} which lacks validator", + f"In {path}: field {pk} has expected type {type_name} which lacks validator", path, pk, ) diff --git a/dlt/destinations/impl/bigquery/bigquery_adapter.py b/dlt/destinations/impl/bigquery/bigquery_adapter.py index 26ca4a3883..1d630e9802 100644 --- a/dlt/destinations/impl/bigquery/bigquery_adapter.py +++ b/dlt/destinations/impl/bigquery/bigquery_adapter.py @@ -9,7 +9,7 @@ ) from dlt.destinations.utils import ensure_resource from dlt.extract import DltResource -from dlt.extract.typing import TTableHintTemplate +from dlt.extract.items import TTableHintTemplate PARTITION_HINT: Literal["x-bigquery-partition"] = "x-bigquery-partition" diff --git a/dlt/destinations/impl/mssql/configuration.py b/dlt/destinations/impl/mssql/configuration.py index f00998cfb2..45c448fab7 100644 --- a/dlt/destinations/impl/mssql/configuration.py +++ b/dlt/destinations/impl/mssql/configuration.py @@ -1,5 +1,5 @@ from typing import Final, ClassVar, Any, List, Dict, Optional, TYPE_CHECKING -from sqlalchemy.engine import URL +from dlt.common.libs.sql_alchemy import URL from dlt.common.configuration import configspec from dlt.common.configuration.specs import ConnectionStringCredentials diff --git a/dlt/destinations/impl/postgres/configuration.py b/dlt/destinations/impl/postgres/configuration.py index f1d30a7342..109d422650 100644 --- a/dlt/destinations/impl/postgres/configuration.py +++ b/dlt/destinations/impl/postgres/configuration.py @@ -1,5 +1,5 @@ from typing import Final, ClassVar, Any, List, TYPE_CHECKING -from sqlalchemy.engine import URL +from dlt.common.libs.sql_alchemy import URL from dlt.common.configuration import configspec from dlt.common.configuration.specs import ConnectionStringCredentials diff --git a/dlt/destinations/impl/snowflake/configuration.py b/dlt/destinations/impl/snowflake/configuration.py index 01f5ca6e03..4f97f08700 100644 --- a/dlt/destinations/impl/snowflake/configuration.py +++ b/dlt/destinations/impl/snowflake/configuration.py @@ -3,7 +3,7 @@ from typing import Final, Optional, Any, Dict, ClassVar, List, TYPE_CHECKING -from sqlalchemy.engine import URL +from dlt.common.libs.sql_alchemy import URL from dlt import version from dlt.common.exceptions import MissingDependencyException diff --git a/dlt/destinations/impl/synapse/synapse_adapter.py b/dlt/destinations/impl/synapse/synapse_adapter.py index 24932736f9..8b262f3621 100644 --- a/dlt/destinations/impl/synapse/synapse_adapter.py +++ b/dlt/destinations/impl/synapse/synapse_adapter.py @@ -1,7 +1,7 @@ from typing import Any, Literal, Set, get_args, Final, Dict from dlt.extract import DltResource, resource as make_resource -from dlt.extract.typing import TTableHintTemplate +from dlt.extract.items import TTableHintTemplate from dlt.extract.hints import TResourceHints from dlt.destinations.utils import ensure_resource diff --git a/dlt/extract/concurrency.py b/dlt/extract/concurrency.py new file mode 100644 index 0000000000..6a330b2645 --- /dev/null +++ b/dlt/extract/concurrency.py @@ -0,0 +1,237 @@ +import asyncio +from concurrent.futures import ( + ThreadPoolExecutor, + as_completed, + wait as wait_for_futures, +) +from threading import Thread +from typing import Awaitable, Dict, Optional + +from dlt.common.exceptions import PipelineException +from dlt.common.configuration.container import Container +from dlt.common.runtime.signals import sleep +from dlt.extract.items import DataItemWithMeta, TItemFuture, ResolvablePipeItem, FuturePipeItem + +from dlt.extract.exceptions import ( + DltSourceException, + ExtractorException, + PipeException, + ResourceExtractionError, +) + + +class FuturesPool: + """Worker pool for pipe items that can be resolved asynchronously. + + Items can be either asyncio coroutines or regular callables which will be executed in a thread pool. + """ + + def __init__( + self, workers: int = 5, poll_interval: float = 0.01, max_parallel_items: int = 20 + ) -> None: + self.futures: Dict[TItemFuture, FuturePipeItem] = {} + self._thread_pool: ThreadPoolExecutor = None + self._async_pool: asyncio.AbstractEventLoop = None + self._async_pool_thread: Thread = None + self.workers = workers + self.poll_interval = poll_interval + self.max_parallel_items = max_parallel_items + self.used_slots: int = 0 + + def __len__(self) -> int: + return len(self.futures) + + @property + def free_slots(self) -> int: + # Done futures don't count as slots, so we can still add futures + return self.max_parallel_items - self.used_slots + + @property + def empty(self) -> bool: + return len(self.futures) == 0 + + def _ensure_thread_pool(self) -> ThreadPoolExecutor: + # lazily start or return thread pool + if self._thread_pool: + return self._thread_pool + + self._thread_pool = ThreadPoolExecutor( + self.workers, thread_name_prefix=Container.thread_pool_prefix() + "threads" + ) + return self._thread_pool + + def _ensure_async_pool(self) -> asyncio.AbstractEventLoop: + # lazily create async pool is separate thread + if self._async_pool: + return self._async_pool + + def start_background_loop(loop: asyncio.AbstractEventLoop) -> None: + asyncio.set_event_loop(loop) + loop.run_forever() + + self._async_pool = asyncio.new_event_loop() + self._async_pool_thread = Thread( + target=start_background_loop, + args=(self._async_pool,), + daemon=True, + name=Container.thread_pool_prefix() + "futures", + ) + self._async_pool_thread.start() + + # start or return async pool + return self._async_pool + + def _vacate_slot(self, _: TItemFuture) -> None: + # Used as callback to free up slot when future is done + self.used_slots -= 1 + + def submit(self, pipe_item: ResolvablePipeItem) -> TItemFuture: + """Submit an item to the pool. + + Args: + pipe_item: The pipe item to submit. `pipe_item.item` must be either an asyncio coroutine or a callable. + + Returns: + The resulting future object + """ + + # Sanity check, negative free slots means there's a bug somewhere + assert self.free_slots >= 0, "Worker pool has negative free slots, this should never happen" + + if self.free_slots == 0: + # Wait until some future is completed to ensure there's a free slot + # Note: This is probably not thread safe. If ever multiple threads will be submitting + # jobs to the pool, we ned to change this whole method to be inside a `threading.Lock` + self._wait_for_free_slot() + + future: Optional[TItemFuture] = None + + # submit to thread pool or async pool + item = pipe_item.item + if isinstance(item, Awaitable): + future = asyncio.run_coroutine_threadsafe(item, self._ensure_async_pool()) + elif callable(item): + future = self._ensure_thread_pool().submit(item) + else: + raise ValueError(f"Unsupported item type: {type(item)}") + + # Future is not removed from self.futures until it's been consumed by the + # pipe iterator. But we always want to vacate a slot so new jobs can be submitted + future.add_done_callback(self._vacate_slot) + self.used_slots += 1 + + self.futures[future] = FuturePipeItem( + future, pipe_item.step, pipe_item.pipe, pipe_item.meta + ) + return future + + def sleep(self) -> None: + sleep(self.poll_interval) + + def _resolve_future(self, future: TItemFuture) -> Optional[ResolvablePipeItem]: + future, step, pipe, meta = self.futures.pop(future) + + if ex := future.exception(): + if isinstance(ex, StopAsyncIteration): + return None + # Raise if any future fails + if isinstance( + ex, (PipelineException, ExtractorException, DltSourceException, PipeException) + ): + raise ex + raise ResourceExtractionError(pipe.name, future, str(ex), "future") from ex + + item = future.result() + + if item is None: + return None + elif isinstance(item, DataItemWithMeta): + return ResolvablePipeItem(item.data, step, pipe, item.meta) + else: + return ResolvablePipeItem(item, step, pipe, meta) + + def _next_done_future(self) -> Optional[TItemFuture]: + """Get the done future in the pool (if any). This does not block.""" + return next((fut for fut in self.futures if fut.done() and not fut.cancelled()), None) + + def resolve_next_future( + self, use_configured_timeout: bool = False + ) -> Optional[ResolvablePipeItem]: + """Block until the next future is done and return the result. Returns None if no futures done. + + Args: + use_configured_timeout: If True, use the value of `self.poll_interval` as the max wait time, + raises `concurrent.futures.TimeoutError` if no future is done within that time. + + Returns: + The resolved future item or None if no future is done. + """ + if not self.futures: + return None + + if (future := self._next_done_future()) is not None: + # When there are multiple already done futures from the same pipe we return results in insertion order + return self._resolve_future(future) + for future in as_completed( + self.futures, timeout=self.poll_interval if use_configured_timeout else None + ): + if future.cancelled(): + # Get the next not-cancelled future + continue + + return self._resolve_future(future) + + return None + + def resolve_next_future_no_wait(self) -> Optional[ResolvablePipeItem]: + """Resolve the first done future in the pool. + This does not block and returns None if no future is done. + """ + # Get next done future + future = self._next_done_future() + if not future: + return None + + return self._resolve_future(future) + + def _wait_for_free_slot(self) -> None: + """Wait until any future in the pool is completed to ensure there's a free slot.""" + if self.free_slots >= 1: + return + + for future in as_completed(self.futures): + if future.cancelled(): + # Get the next not-cancelled future + continue + if self.free_slots == 0: + # Future was already completed so slot was not freed + continue + return # Return when first future completes + + def close(self) -> None: + # Cancel all futures + for f in self.futures: + if not f.done(): + f.cancel() + + def stop_background_loop(loop: asyncio.AbstractEventLoop) -> None: + loop.stop() + + if self._async_pool: + # wait for all async generators to be closed + future = asyncio.run_coroutine_threadsafe( + self._async_pool.shutdown_asyncgens(), self._ensure_async_pool() + ) + + wait_for_futures([future]) + self._async_pool.call_soon_threadsafe(stop_background_loop, self._async_pool) + + self._async_pool_thread.join() + self._async_pool = None + self._async_pool_thread = None + + if self._thread_pool: + self._thread_pool.shutdown(wait=True) + self._thread_pool = None + + self.futures.clear() diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index af8bb69c42..6e916ff6e1 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -62,7 +62,7 @@ ) from dlt.extract.incremental import IncrementalResourceWrapper -from dlt.extract.typing import TTableHintTemplate +from dlt.extract.items import TTableHintTemplate from dlt.extract.source import DltSource from dlt.extract.resource import DltResource, TUnboundDltResource @@ -301,6 +301,7 @@ def resource( table_format: TTableHintTemplate[TTableFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, + parallelized: bool = False, ) -> DltResource: ... @@ -318,6 +319,7 @@ def resource( table_format: TTableHintTemplate[TTableFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, + parallelized: bool = False, ) -> Callable[[Callable[TResourceFunParams, Any]], DltResource]: ... @@ -335,6 +337,7 @@ def resource( table_format: TTableHintTemplate[TTableFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, + parallelized: bool = False, standalone: Literal[True] = True, ) -> Callable[[Callable[TResourceFunParams, Any]], Callable[TResourceFunParams, DltResource]]: ... @@ -353,6 +356,7 @@ def resource( table_format: TTableHintTemplate[TTableFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, + parallelized: bool = False, ) -> DltResource: ... @@ -369,6 +373,7 @@ def resource( table_format: TTableHintTemplate[TTableFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, + parallelized: bool = False, standalone: bool = False, data_from: TUnboundDltResource = None, ) -> Any: @@ -427,6 +432,8 @@ def resource( data_from (TUnboundDltResource, optional): Allows to pipe data from one resource to another to build multi-step pipelines. + parallelized (bool, optional): If `True`, the resource generator will be extracted in parallel with other resources. Defaults to `False`. + Raises: ResourceNameMissing: indicates that name of the resource cannot be inferred from the `data` being passed. InvalidResourceDataType: indicates that the `data` argument cannot be converted into `dlt resource` @@ -447,7 +454,7 @@ def make_resource( schema_contract=schema_contract, table_format=table_format, ) - return DltResource.from_data( + resource = DltResource.from_data( _data, _name, _section, @@ -456,6 +463,9 @@ def make_resource( cast(DltResource, data_from), incremental=incremental, ) + if parallelized: + return resource.parallelize() + return resource def decorator( f: Callable[TResourceFunParams, Any] @@ -565,6 +575,7 @@ def transformer( merge_key: TTableHintTemplate[TColumnNames] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, + parallelized: bool = False, ) -> Callable[[Callable[Concatenate[TDataItem, TResourceFunParams], Any]], DltResource]: ... @@ -581,6 +592,7 @@ def transformer( merge_key: TTableHintTemplate[TColumnNames] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, + parallelized: bool = False, standalone: Literal[True] = True, ) -> Callable[ [Callable[Concatenate[TDataItem, TResourceFunParams], Any]], @@ -601,6 +613,7 @@ def transformer( merge_key: TTableHintTemplate[TColumnNames] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, + parallelized: bool = False, ) -> DltResource: ... @@ -617,6 +630,7 @@ def transformer( merge_key: TTableHintTemplate[TColumnNames] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, + parallelized: bool = False, standalone: Literal[True] = True, ) -> Callable[TResourceFunParams, DltResource]: ... @@ -633,6 +647,7 @@ def transformer( merge_key: TTableHintTemplate[TColumnNames] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, + parallelized: bool = False, standalone: bool = False, ) -> Any: """A form of `dlt resource` that takes input from other resources via `data_from` argument in order to enrich or transform the data. @@ -707,6 +722,7 @@ def transformer( spec=spec, standalone=standalone, data_from=data_from, + parallelized=parallelized, ) diff --git a/dlt/extract/exceptions.py b/dlt/extract/exceptions.py index d24b6f5250..c3a20e72e5 100644 --- a/dlt/extract/exceptions.py +++ b/dlt/extract/exceptions.py @@ -1,9 +1,9 @@ -from inspect import Signature, isgenerator +from inspect import Signature, isgenerator, isgeneratorfunction, unwrap from typing import Any, Set, Type from dlt.common.exceptions import DltException from dlt.common.utils import get_callable_name -from dlt.extract.typing import ValidateItem, TDataItems +from dlt.extract.items import ValidateItem, TDataItems class ExtractorException(DltException): @@ -101,6 +101,17 @@ def __init__(self, pipe_name: str, gen: Any) -> None: super().__init__(pipe_name, msg) +class UnclosablePipe(PipeException): + def __init__(self, pipe_name: str, gen: Any) -> None: + type_name = str(type(gen)) + if gen_name := getattr(gen, "__name__", None): + type_name = f"{type_name} ({gen_name})" + msg = f"Pipe with gen of type {type_name} cannot be closed." + if callable(gen) and isgeneratorfunction(unwrap(gen)): + msg += " Closing of partially evaluated transformers is not yet supported." + super().__init__(pipe_name, msg) + + class ResourceNameMissing(DltResourceException): def __init__(self) -> None: super().__init__( @@ -144,15 +155,14 @@ def __init__(self, resource_name: str, item: Any, _typ: Type[Any], msg: str) -> ) -class InvalidResourceDataTypeAsync(InvalidResourceDataType): +class InvalidParallelResourceDataType(InvalidResourceDataType): def __init__(self, resource_name: str, item: Any, _typ: Type[Any]) -> None: super().__init__( resource_name, item, _typ, - "Async iterators and generators are not valid resources. Please use standard iterators" - " and generators that yield Awaitables instead (for example by yielding from async" - " function without await", + "Parallel resource data must be a generator or a generator function. The provided" + f" data type for resource '{resource_name}' was {_typ.__name__}.", ) diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index c1ff5da80b..2ff813a2de 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -1,6 +1,6 @@ import contextlib from collections.abc import Sequence as C_Sequence -from datetime import datetime # noqa: 251 +from copy import copy import itertools from typing import List, Set, Dict, Optional, Set, Any import yaml @@ -33,7 +33,8 @@ from dlt.extract.decorators import SourceInjectableContext, SourceSchemaInjectableContext from dlt.extract.exceptions import DataItemRequiredForDynamicTableHints -from dlt.extract.pipe import PipeIterator +from dlt.extract.incremental import IncrementalResourceWrapper +from dlt.extract.pipe_iterator import PipeIterator from dlt.extract.source import DltSource from dlt.extract.resource import DltResource from dlt.extract.storage import ExtractStorage @@ -75,6 +76,8 @@ def choose_schema() -> Schema: """Except of explicitly passed schema, use a clone that will get discarded if extraction fails""" if schema: schema_ = schema + # TODO: We should start with a new schema of the same name here ideally, but many tests fail + # because of this. So some investigation is needed. elif pipeline.default_schema_name: schema_ = pipeline.schemas[pipeline.default_schema_name].clone() else: @@ -200,14 +203,20 @@ def _compute_metrics(self, load_id: str, source: DltSource) -> ExtractMetrics: for resource in source.selected_resources.values(): # cleanup the hints hints = clean_hints[resource.name] = {} - resource_hints = resource._hints or resource.compute_table_schema() + resource_hints = copy(resource._hints) or resource.compute_table_schema() + if resource.incremental and "incremental" not in resource_hints: + resource_hints["incremental"] = resource.incremental # type: ignore for name, hint in resource_hints.items(): if hint is None or name in ["validator"]: continue if name == "incremental": # represent incremental as dictionary (it derives from BaseConfiguration) - hints[name] = dict(hint) # type: ignore[call-overload] + if isinstance(hint, IncrementalResourceWrapper): + hint = hint._incremental + # sometimes internal incremental is not bound + if hint: + hints[name] = dict(hint) # type: ignore[call-overload] continue if name == "original_columns": # this is original type of the columns ie. Pydantic model diff --git a/dlt/extract/extractors.py b/dlt/extract/extractors.py index f6c3fde5d4..84abb4f3a8 100644 --- a/dlt/extract/extractors.py +++ b/dlt/extract/extractors.py @@ -1,6 +1,7 @@ from copy import copy from typing import Set, Dict, Any, Optional, Set +from dlt.common import logger from dlt.common.configuration.inject import with_config from dlt.common.configuration.specs import BaseConfiguration, configspec from dlt.common.destination.capabilities import DestinationCapabilitiesContext @@ -20,7 +21,7 @@ ) from dlt.extract.hints import HintsMeta from dlt.extract.resource import DltResource -from dlt.extract.typing import TableNameMeta +from dlt.extract.items import TableNameMeta from dlt.extract.storage import ExtractStorage, ExtractorItemStorage try: @@ -30,9 +31,9 @@ pyarrow = None try: - import pandas as pd -except ModuleNotFoundError: - pd = None + from dlt.common.libs.pandas import pandas +except MissingDependencyException: + pandas = None class Extractor: @@ -78,7 +79,9 @@ def item_format(items: TDataItems) -> Optional[TLoaderFileFormat]: """ for item in items if isinstance(items, list) else [items]: # Assume all items in list are the same type - if (pyarrow and pyarrow.is_arrow_item(item)) or (pd and isinstance(item, pd.DataFrame)): + if (pyarrow and pyarrow.is_arrow_item(item)) or ( + pandas and isinstance(item, pandas.DataFrame) + ): return "arrow" return "puae-jsonl" return None # Empty list is unknown format @@ -222,7 +225,7 @@ def write_items(self, resource: DltResource, items: TDataItems, meta: Any) -> No ( # 1. Convert pandas frame(s) to arrow Table pa.Table.from_pandas(item) - if (pd and isinstance(item, pd.DataFrame)) + if (pandas and isinstance(item, pandas.DataFrame)) else item ) for item in (items if isinstance(items, list) else [items]) @@ -289,9 +292,22 @@ def _compute_table(self, resource: DltResource, items: TDataItems) -> TPartialTa arrow_table["columns"] = pyarrow.py_arrow_to_table_schema_columns(items.schema) # normalize arrow table before merging arrow_table = self.schema.normalize_table_identifiers(arrow_table) + # issue warnings when overriding computed with arrow + for col_name, column in arrow_table["columns"].items(): + if src_column := computed_table["columns"].get(col_name): + print(src_column) + for hint_name, hint in column.items(): + if (src_hint := src_column.get(hint_name)) is not None: + if src_hint != hint: + logger.warning( + f"In resource: {resource.name}, when merging arrow schema on column" + f" {col_name}. The hint {hint_name} value {src_hint} defined in" + f" resource is overwritten from arrow with value {hint}." + ) + # we must override the columns to preserve the order in arrow table arrow_table["columns"] = update_dict_nested( - arrow_table["columns"], computed_table["columns"] + arrow_table["columns"], computed_table["columns"], keep_dst_values=True ) return arrow_table diff --git a/dlt/extract/hints.py b/dlt/extract/hints.py index 7f4f54389f..f298e414a1 100644 --- a/dlt/extract/hints.py +++ b/dlt/extract/hints.py @@ -21,7 +21,7 @@ InconsistentTableTemplate, ) from dlt.extract.incremental import Incremental -from dlt.extract.typing import TFunHintTemplate, TTableHintTemplate, ValidateItem +from dlt.extract.items import TFunHintTemplate, TTableHintTemplate, ValidateItem from dlt.extract.utils import ensure_table_schema_columns, ensure_table_schema_columns_hint from dlt.extract.validation import create_item_validator diff --git a/dlt/extract/incremental/__init__.py b/dlt/extract/incremental/__init__.py index 955aa12efd..54e8b3d447 100644 --- a/dlt/extract/incremental/__init__.py +++ b/dlt/extract/incremental/__init__.py @@ -1,14 +1,11 @@ import os +from datetime import datetime # noqa: I251 from typing import Generic, ClassVar, Any, Optional, Type, Dict from typing_extensions import get_origin, get_args + import inspect from functools import wraps -try: - import pandas as pd -except ModuleNotFoundError: - pd = None - import dlt from dlt.common.exceptions import MissingDependencyException from dlt.common import pendulum, logger @@ -17,6 +14,7 @@ TDataItem, TDataItems, TFun, + TSortOrder, extract_inner_type, get_generic_type_argument_from_instance, is_optional_type, @@ -38,7 +36,7 @@ ) from dlt.extract.incremental.typing import IncrementalColumnState, TCursorValue, LastValueFunc from dlt.extract.pipe import Pipe -from dlt.extract.typing import SupportsPipe, TTableHintTemplate, ItemTransform +from dlt.extract.items import SupportsPipe, TTableHintTemplate, ItemTransform from dlt.extract.incremental.transform import ( JsonIncremental, ArrowIncremental, @@ -50,6 +48,11 @@ except MissingDependencyException: is_arrow_item = lambda item: False +try: + from dlt.common.libs.pandas import pandas +except MissingDependencyException: + pandas = None + @configspec class Incremental(ItemTransform[TDataItem], BaseConfiguration, Generic[TCursorValue]): @@ -84,6 +87,9 @@ class Incremental(ItemTransform[TDataItem], BaseConfiguration, Generic[TCursorVa end_value: Optional value used to load a limited range of records between `initial_value` and `end_value`. Use in conjunction with `initial_value`, e.g. load records from given month `incremental(initial_value="2022-01-01T00:00:00Z", end_value="2022-02-01T00:00:00Z")` Note, when this is set the incremental filtering is stateless and `initial_value` always supersedes any previous incremental value in state. + row_order: Declares that data source returns rows in descending (desc) or ascending (asc) order as defined by `last_value_func`. If row order is know, Incremental class + is able to stop requesting new rows by closing pipe generator. This prevents getting more data from the source. Defaults to None, which means that + row order is not known. allow_external_schedulers: If set to True, allows dlt to look for external schedulers from which it will take "initial_value" and "end_value" resulting in loading only specified range of data. Currently Airflow scheduler is detected: "data_interval_start" and "data_interval_end" are taken from the context and passed Incremental class. The values passed explicitly to Incremental will be ignored. @@ -95,6 +101,8 @@ class Incremental(ItemTransform[TDataItem], BaseConfiguration, Generic[TCursorVa # TODO: Support typevar here initial_value: Optional[Any] = None end_value: Optional[Any] = None + row_order: Optional[TSortOrder] = None + allow_external_schedulers: bool = False # incremental acting as empty EMPTY: ClassVar["Incremental[Any]"] = None @@ -106,6 +114,7 @@ def __init__( last_value_func: Optional[LastValueFunc[TCursorValue]] = max, primary_key: Optional[TTableHintTemplate[TColumnNames]] = None, end_value: Optional[TCursorValue] = None, + row_order: Optional[TSortOrder] = None, allow_external_schedulers: bool = False, ) -> None: # make sure that path is valid @@ -120,6 +129,7 @@ def __init__( """Value of last_value at the beginning of current pipeline run""" self.resource_name: Optional[str] = None self._primary_key: Optional[TTableHintTemplate[TColumnNames]] = primary_key + self.row_order = row_order self.allow_external_schedulers = allow_external_schedulers self._cached_state: IncrementalColumnState = None @@ -132,6 +142,8 @@ def __init__( """Becomes true on the first item that is out of range of `start_value`. I.e. when using `max` this is a value that is lower than `start_value`""" self._transformers: Dict[str, IncrementalTransform] = {} + self._bound_pipe: SupportsPipe = None + """Bound pipe""" @property def primary_key(self) -> Optional[TTableHintTemplate[TColumnNames]]: @@ -168,18 +180,6 @@ def from_existing_state( i.resource_name = resource_name return i - def copy(self) -> "Incremental[TCursorValue]": - # preserve Generic param information - constructor = self.__orig_class__ if hasattr(self, "__orig_class__") else self.__class__ - return constructor( # type: ignore - self.cursor_path, - initial_value=self.initial_value, - last_value_func=self.last_value_func, - primary_key=self._primary_key, - end_value=self.end_value, - allow_external_schedulers=self.allow_external_schedulers, - ) - def merge(self, other: "Incremental[TCursorValue]") -> "Incremental[TCursorValue]": """Create a new incremental instance which merges the two instances. Only properties which are not `None` from `other` override the current instance properties. @@ -190,6 +190,7 @@ def merge(self, other: "Incremental[TCursorValue]") -> "Incremental[TCursorValue >>> >>> my_resource(updated=incremental(initial_value='2023-01-01', end_value='2023-02-01')) """ + # func, resource name and primary key are not part of the dict kwargs = dict(self, last_value_func=self.last_value_func, primary_key=self._primary_key) for key, value in dict( other, last_value_func=other.last_value_func, primary_key=other.primary_key @@ -204,7 +205,15 @@ def merge(self, other: "Incremental[TCursorValue]") -> "Incremental[TCursorValue other.__orig_class__ if hasattr(other, "__orig_class__") else other.__class__ ) constructor = extract_inner_type(constructor) - return constructor(**kwargs) # type: ignore + merged = constructor(**kwargs) + merged.resource_name = self.resource_name + if other.resource_name: + merged.resource_name = other.resource_name + return merged # type: ignore + + def copy(self) -> "Incremental[TCursorValue]": + # merge creates a copy + return self.merge(self) def on_resolved(self) -> None: compile_path(self.cursor_path) @@ -213,6 +222,8 @@ def on_resolved(self) -> None: "Incremental 'end_value' was specified without 'initial_value'. 'initial_value' is" " required when using 'end_value'." ) + self._cursor_datetime_check(self.initial_value, "initial_value") + self._cursor_datetime_check(self.initial_value, "end_value") # Ensure end value is "higher" than initial value if ( self.end_value is not None @@ -241,7 +252,10 @@ def parse_native_representation(self, native_value: Any) -> None: self.initial_value = native_value.initial_value self.last_value_func = native_value.last_value_func self.end_value = native_value.end_value - self.resource_name = self.resource_name + self.resource_name = native_value.resource_name + self._primary_key = native_value._primary_key + self.allow_external_schedulers = native_value.allow_external_schedulers + self.row_order = native_value.row_order else: # TODO: Maybe check if callable(getattr(native_value, '__lt__', None)) # Passing bare value `incremental=44` gets parsed as initial_value self.initial_value = native_value @@ -281,6 +295,16 @@ def _get_state(resource_name: str, cursor_path: str) -> IncrementalColumnState: # if state params is empty return state + @staticmethod + def _cursor_datetime_check(value: Any, arg_name: str) -> None: + if value and isinstance(value, datetime) and value.tzinfo is None: + logger.warning( + f"The {arg_name} argument {value} is a datetime without timezone. This may result" + " in an error when such values are compared by Incremental class. Note that `dlt`" + " stores datetimes in timezone-aware types so the UTC timezone will be added by" + " the destination" + ) + @property def last_value(self) -> Optional[TCursorValue]: s = self.get_state() @@ -289,9 +313,15 @@ def last_value(self) -> Optional[TCursorValue]: def _transform_item( self, transformer: IncrementalTransform, row: TDataItem ) -> Optional[TDataItem]: - row, start_out_of_range, end_out_of_range = transformer(row) - self.start_out_of_range = start_out_of_range - self.end_out_of_range = end_out_of_range + row, self.start_out_of_range, self.end_out_of_range = transformer(row) + # if we know that rows are ordered we can close the generator automatically + # mind that closing pipe will not immediately close processing. it only closes the + # generator so this page will be fully processed + # TODO: we cannot close partially evaluated transformer gen. to implement that + # we'd need to pass the source gen along with each yielded item and close this particular gen + # NOTE: with that implemented we could implement add_limit as a regular transform having access to gen + if self.can_close() and not self._bound_pipe.has_parent: + self._bound_pipe.close() return row def get_incremental_value_type(self) -> Type[Any]: @@ -372,6 +402,7 @@ def bind(self, pipe: SupportsPipe) -> "Incremental[TCursorValue]": if self.is_partial(): raise IncrementalCursorPathMissing(pipe.name, None, None) self.resource_name = pipe.name + self._bound_pipe = pipe # try to join external scheduler if self.allow_external_schedulers: self._join_external_scheduler() @@ -386,6 +417,21 @@ def bind(self, pipe: SupportsPipe) -> "Incremental[TCursorValue]": self._make_transforms() return self + def can_close(self) -> bool: + """Checks if incremental is out of range and can be closed. + + Returns true only when `row_order` was set and + 1. results are ordered ascending and are above upper bound (end_value) + 2. results are ordered descending and are below or equal lower bound (start_value) + """ + # ordered ascending, check if we cross upper bound + return ( + self.row_order == "asc" + and self.end_out_of_range + or self.row_order == "desc" + and self.start_out_of_range + ) + def __str__(self) -> str: return ( f"Incremental at {id(self)} for resource {self.resource_name} with cursor path:" @@ -397,7 +443,7 @@ def _get_transformer(self, items: TDataItems) -> IncrementalTransform: for item in items if isinstance(items, list) else [items]: if is_arrow_item(item): return self._transformers["arrow"] - elif pd is not None and isinstance(item, pd.DataFrame): + elif pandas is not None and isinstance(item, pandas.DataFrame): return self._transformers["arrow"] return self._transformers["json"] return self._transformers["json"] @@ -438,6 +484,7 @@ def __init__(self, primary_key: Optional[TTableHintTemplate[TColumnNames]] = Non self.primary_key = primary_key self.incremental_state: IncrementalColumnState = None self._allow_external_schedulers: bool = None + self._bound_pipe: SupportsPipe = None @staticmethod def should_wrap(sig: inspect.Signature) -> bool: @@ -511,7 +558,9 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: self._incremental.resolve() # in case of transformers the bind will be called before this wrapper is set: because transformer is called for a first time late in the pipe if self._resource_name: - self._incremental.bind(Pipe(self._resource_name)) + # rebind internal _incremental from wrapper that already holds + # instance of a Pipe + self.bind(None) bound_args.arguments[p.name] = self._incremental return func(*bound_args.args, **bound_args.kwargs) @@ -531,6 +580,9 @@ def allow_external_schedulers(self, value: bool) -> None: self._incremental.allow_external_schedulers = value def bind(self, pipe: SupportsPipe) -> "IncrementalResourceWrapper": + # if pipe is None we are re-binding internal incremental + pipe = pipe or self._bound_pipe + self._bound_pipe = pipe self._resource_name = pipe.name if self._incremental: if self._allow_external_schedulers is not None: diff --git a/dlt/extract/incremental/transform.py b/dlt/extract/incremental/transform.py index 2fc78fe4ee..e20617cf63 100644 --- a/dlt/extract/incremental/transform.py +++ b/dlt/extract/incremental/transform.py @@ -1,16 +1,6 @@ from datetime import datetime, date # noqa: I251 from typing import Any, Optional, Tuple, List -try: - import pandas as pd -except ModuleNotFoundError: - pd = None - -try: - import numpy as np -except ModuleNotFoundError: - np = None - from dlt.common.exceptions import MissingDependencyException from dlt.common.utils import digest128 from dlt.common.json import json @@ -23,16 +13,20 @@ ) from dlt.extract.incremental.typing import IncrementalColumnState, TCursorValue, LastValueFunc from dlt.extract.utils import resolve_column_value -from dlt.extract.typing import TTableHintTemplate +from dlt.extract.items import TTableHintTemplate from dlt.common.schema.typing import TColumnNames try: from dlt.common.libs import pyarrow + from dlt.common.libs.pandas import pandas + from dlt.common.libs.numpy import numpy from dlt.common.libs.pyarrow import pyarrow as pa, TAnyArrowItem - from dlt.common.libs.pyarrow import from_arrow_compute_output, to_arrow_compute_input + from dlt.common.libs.pyarrow import from_arrow_scalar, to_arrow_scalar except MissingDependencyException: pa = None pyarrow = None + numpy = None + pandas = None class IncrementalTransform: @@ -115,18 +109,21 @@ def __call__( Returns: Tuple (row, start_out_of_range, end_out_of_range) where row is either the data item or `None` if it is completely filtered out """ - start_out_of_range = end_out_of_range = False if row is None: - return row, start_out_of_range, end_out_of_range + return row, False, False row_value = self.find_cursor_value(row) + last_value = self.incremental_state["last_value"] # For datetime cursor, ensure the value is a timezone aware datetime. # The object saved in state will always be a tz aware pendulum datetime so this ensures values are comparable - if isinstance(row_value, datetime): - row_value = pendulum.instance(row_value) - - last_value = self.incremental_state["last_value"] + if ( + isinstance(row_value, datetime) + and row_value.tzinfo is None + and isinstance(last_value, datetime) + and last_value.tzinfo is not None + ): + row_value = pendulum.instance(row_value).in_tz("UTC") # Check whether end_value has been reached # Filter end value ranges exclusively, so in case of "max" function we remove values >= end_value @@ -134,8 +131,7 @@ def __call__( self.last_value_func((row_value, self.end_value)) != self.end_value or self.last_value_func((row_value,)) == self.end_value ): - end_out_of_range = True - return None, start_out_of_range, end_out_of_range + return None, False, True check_values = (row_value,) + ((last_value,) if last_value is not None else ()) new_value = self.last_value_func(check_values) @@ -148,10 +144,10 @@ def __call__( # if unique value exists then use it to deduplicate if unique_value: if unique_value in self.incremental_state["unique_hashes"]: - return None, start_out_of_range, end_out_of_range + return None, False, False # add new hash only if the record row id is same as current last value self.incremental_state["unique_hashes"].append(unique_value) - return row, start_out_of_range, end_out_of_range + return row, False, False # skip the record that is not a last_value or new_value: that record was already processed check_values = (row_value,) + ( (self.start_value,) if self.start_value is not None else () @@ -159,17 +155,16 @@ def __call__( new_value = self.last_value_func(check_values) # Include rows == start_value but exclude "lower" if new_value == self.start_value and processed_row_value != self.start_value: - start_out_of_range = True - return None, start_out_of_range, end_out_of_range + return None, True, False else: - return row, start_out_of_range, end_out_of_range + return row, False, False else: self.incremental_state["last_value"] = new_value unique_value = self.unique_value(row, self.primary_key, self.resource_name) if unique_value: self.incremental_state["unique_hashes"] = [unique_value] - return row, start_out_of_range, end_out_of_range + return row, False, False class ArrowIncremental(IncrementalTransform): @@ -193,14 +188,14 @@ def _deduplicate( """Creates unique index if necessary.""" # create unique index if necessary if self._dlt_index not in tbl.schema.names: - tbl = pyarrow.append_column(tbl, self._dlt_index, pa.array(np.arange(tbl.num_rows))) + tbl = pyarrow.append_column(tbl, self._dlt_index, pa.array(numpy.arange(tbl.num_rows))) return tbl def __call__( self, tbl: "TAnyArrowItem", ) -> Tuple[TDataItem, bool, bool]: - is_pandas = pd is not None and isinstance(tbl, pd.DataFrame) + is_pandas = pandas is not None and isinstance(tbl, pandas.DataFrame) if is_pandas: tbl = pa.Table.from_pandas(tbl) @@ -250,9 +245,10 @@ def __call__( cursor_path = self.cursor_path # The new max/min value try: - row_value = from_arrow_compute_output(compute(tbl[cursor_path])) + # NOTE: datetimes are always pendulum in UTC + row_value = from_arrow_scalar(compute(tbl[cursor_path])) cursor_data_type = tbl.schema.field(cursor_path).type - row_value_scalar = to_arrow_compute_input(row_value, cursor_data_type) + row_value_scalar = to_arrow_scalar(row_value, cursor_data_type) except KeyError as e: raise IncrementalCursorPathMissing( self.resource_name, @@ -265,7 +261,7 @@ def __call__( # If end_value is provided, filter to include table rows that are "less" than end_value if self.end_value is not None: - end_value_scalar = to_arrow_compute_input(self.end_value, cursor_data_type) + end_value_scalar = to_arrow_scalar(self.end_value, cursor_data_type) tbl = tbl.filter(end_compare(tbl[cursor_path], end_value_scalar)) # Is max row value higher than end value? # NOTE: pyarrow bool *always* evaluates to python True. `as_py()` is necessary @@ -275,13 +271,13 @@ def __call__( if self.start_value is not None: # Remove rows lower than the last start value keep_filter = last_value_compare( - tbl[cursor_path], to_arrow_compute_input(self.start_value, cursor_data_type) + tbl[cursor_path], to_arrow_scalar(self.start_value, cursor_data_type) ) start_out_of_range = bool(pa.compute.any(pa.compute.invert(keep_filter)).as_py()) tbl = tbl.filter(keep_filter) # Deduplicate after filtering old values - last_value_scalar = to_arrow_compute_input(last_value, cursor_data_type) + last_value_scalar = to_arrow_scalar(last_value, cursor_data_type) tbl = self._deduplicate(tbl, unique_columns, aggregate, cursor_path) # Remove already processed rows where the cursor is equal to the last value eq_rows = tbl.filter(pa.compute.equal(tbl[cursor_path], last_value_scalar)) diff --git a/dlt/extract/typing.py b/dlt/extract/items.py similarity index 74% rename from dlt/extract/typing.py rename to dlt/extract/items.py index e0096a255f..c6e1f0a4b8 100644 --- a/dlt/extract/typing.py +++ b/dlt/extract/items.py @@ -5,13 +5,18 @@ Callable, Generic, Iterator, + Iterable, Literal, Optional, Protocol, TypeVar, Union, Awaitable, + TYPE_CHECKING, + NamedTuple, + Generator, ) +from concurrent.futures import Future from dlt.common.typing import TAny, TDataItem, TDataItems @@ -25,6 +30,55 @@ TFunHintTemplate = Callable[[TDataItem], TDynHintType] TTableHintTemplate = Union[TDynHintType, TFunHintTemplate[TDynHintType]] +if TYPE_CHECKING: + TItemFuture = Future[TPipedDataItems] +else: + TItemFuture = Future + + +class PipeItem(NamedTuple): + item: TDataItems + step: int + pipe: "SupportsPipe" + meta: Any + + +class ResolvablePipeItem(NamedTuple): + # mypy unable to handle recursive types, ResolvablePipeItem should take itself in "item" + item: Union[TPipedDataItems, Iterator[TPipedDataItems]] + step: int + pipe: "SupportsPipe" + meta: Any + + +class FuturePipeItem(NamedTuple): + item: TItemFuture + step: int + pipe: "SupportsPipe" + meta: Any + + +class SourcePipeItem(NamedTuple): + item: Union[Iterator[TPipedDataItems], Iterator[ResolvablePipeItem]] + step: int + pipe: "SupportsPipe" + meta: Any + + +# pipeline step may be iterator of data items or mapping function that returns data item or another iterator +TPipeStep = Union[ + Iterable[TPipedDataItems], + Iterator[TPipedDataItems], + # Callable with meta + Callable[[TDataItems, Optional[Any]], TPipedDataItems], + Callable[[TDataItems, Optional[Any]], Iterator[TPipedDataItems]], + Callable[[TDataItems, Optional[Any]], Iterator[ResolvablePipeItem]], + # Callable without meta + Callable[[TDataItems], TPipedDataItems], + Callable[[TDataItems], Iterator[TPipedDataItems]], + Callable[[TDataItems], Iterator[ResolvablePipeItem]], +] + class DataItemWithMeta: __slots__ = "meta", "data" @@ -54,11 +108,28 @@ class SupportsPipe(Protocol): parent: "SupportsPipe" """A parent of the current pipe""" + @property + def gen(self) -> TPipeStep: + """A data generating step""" + ... + + def __getitem__(self, i: int) -> TPipeStep: + """Get pipe step at index""" + ... + + def __len__(self) -> int: + """Length of a pipe""" + ... + @property def has_parent(self) -> bool: """Checks if pipe is connected to parent pipe from which it takes data items. Connected pipes are created from transformer resources""" ... + def close(self) -> None: + """Closes pipe generator""" + ... + ItemTransformFunctionWithMeta = Callable[[TDataItem, str], TAny] ItemTransformFunctionNoMeta = Callable[[TDataItem], TAny] diff --git a/dlt/extract/pipe.py b/dlt/extract/pipe.py index 3062ed083d..6517273db5 100644 --- a/dlt/extract/pipe.py +++ b/dlt/extract/pipe.py @@ -1,55 +1,27 @@ import inspect -import types -import asyncio import makefun -from asyncio import Future -from concurrent.futures import ThreadPoolExecutor from copy import copy -from threading import Thread -from typing import ( - Any, - AsyncIterator, - Dict, - Optional, - Sequence, - Union, - Callable, - Iterable, - Iterator, - List, - NamedTuple, - Awaitable, - Tuple, - Type, - TYPE_CHECKING, - Literal, -) +from typing import Any, AsyncIterator, Optional, Union, Callable, Iterable, Iterator, List, Tuple -from dlt.common import sleep -from dlt.common.configuration import configspec -from dlt.common.configuration.inject import with_config -from dlt.common.configuration.specs import BaseConfiguration, ContainerInjectableContext -from dlt.common.configuration.container import Container -from dlt.common.exceptions import PipelineException -from dlt.common.source import unset_current_pipe_name, set_current_pipe_name from dlt.common.typing import AnyFun, AnyType, TDataItems from dlt.common.utils import get_callable_name from dlt.extract.exceptions import ( CreatePipeException, - DltSourceException, - ExtractorException, InvalidStepFunctionArguments, InvalidResourceDataTypeFunctionNotAGenerator, InvalidTransformerGeneratorFunction, ParametrizedResourceUnbound, - PipeException, - PipeGenInvalid, - PipeItemProcessingError, PipeNotBoundToData, - ResourceExtractionError, + UnclosablePipe, +) +from dlt.extract.items import ( + ItemTransform, + ResolvablePipeItem, + SupportsPipe, + TPipeStep, + TPipedDataItems, ) -from dlt.extract.typing import DataItemWithMeta, ItemTransform, SupportsPipe, TPipedDataItems from dlt.extract.utils import ( check_compat_transformer, simulate_func_call, @@ -58,59 +30,6 @@ wrap_async_iterator, ) -if TYPE_CHECKING: - TItemFuture = Future[Union[TDataItems, DataItemWithMeta]] -else: - TItemFuture = Future - - -class PipeItem(NamedTuple): - item: TDataItems - step: int - pipe: "Pipe" - meta: Any - - -class ResolvablePipeItem(NamedTuple): - # mypy unable to handle recursive types, ResolvablePipeItem should take itself in "item" - item: Union[TPipedDataItems, Iterator[TPipedDataItems]] - step: int - pipe: "Pipe" - meta: Any - - -class FuturePipeItem(NamedTuple): - item: TItemFuture - step: int - pipe: "Pipe" - meta: Any - - -class SourcePipeItem(NamedTuple): - item: Union[Iterator[TPipedDataItems], Iterator[ResolvablePipeItem]] - step: int - pipe: "Pipe" - meta: Any - - -# pipeline step may be iterator of data items or mapping function that returns data item or another iterator -from dlt.common.typing import TDataItem - -TPipeStep = Union[ - Iterable[TPipedDataItems], - Iterator[TPipedDataItems], - # Callable with meta - Callable[[TDataItems, Optional[Any]], TPipedDataItems], - Callable[[TDataItems, Optional[Any]], Iterator[TPipedDataItems]], - Callable[[TDataItems, Optional[Any]], Iterator[ResolvablePipeItem]], - # Callable without meta - Callable[[TDataItems], TPipedDataItems], - Callable[[TDataItems], Iterator[TPipedDataItems]], - Callable[[TDataItems], Iterator[ResolvablePipeItem]], -] - -TPipeNextItemMode = Literal["fifo", "round_robin"] - class ForkPipe: def __init__(self, pipe: "Pipe", step: int = -1, copy_on_fork: bool = False) -> None: @@ -257,6 +176,15 @@ def replace_gen(self, gen: TPipeStep) -> None: assert not self.is_empty self._steps[self._gen_idx] = gen + def close(self) -> None: + """Closes pipe generator""" + gen = self.gen + # NOTE: async generator are wrapped in generators + if inspect.isgenerator(gen): + gen.close() + else: + raise UnclosablePipe(self.name, gen) + def full_pipe(self) -> "Pipe": """Creates a pipe that from the current and all the parent pipes.""" # prevent creating full pipe with unbound heads @@ -486,444 +414,3 @@ def __repr__(self) -> str: else: bound_str = "" return f"Pipe {self.name} [steps: {len(self._steps)}] at {id(self)}{bound_str}" - - -class PipeIterator(Iterator[PipeItem]): - @configspec - class PipeIteratorConfiguration(BaseConfiguration): - max_parallel_items: int = 20 - workers: int = 5 - futures_poll_interval: float = 0.01 - copy_on_fork: bool = False - next_item_mode: str = "fifo" - - __section__ = "extract" - - def __init__( - self, - max_parallel_items: int, - workers: int, - futures_poll_interval: float, - sources: List[SourcePipeItem], - next_item_mode: TPipeNextItemMode, - ) -> None: - self.max_parallel_items = max_parallel_items - self.workers = workers - self.futures_poll_interval = futures_poll_interval - self._async_pool: asyncio.AbstractEventLoop = None - self._async_pool_thread: Thread = None - self._thread_pool: ThreadPoolExecutor = None - self._sources = sources - self._futures: List[FuturePipeItem] = [] - self._next_item_mode: TPipeNextItemMode = next_item_mode - self._initial_sources_count = len(sources) - self._current_source_index: int = 0 - - @classmethod - @with_config(spec=PipeIteratorConfiguration) - def from_pipe( - cls, - pipe: Pipe, - *, - max_parallel_items: int = 20, - workers: int = 5, - futures_poll_interval: float = 0.01, - next_item_mode: TPipeNextItemMode = "fifo", - ) -> "PipeIterator": - # join all dependent pipes - if pipe.parent: - pipe = pipe.full_pipe() - # clone pipe to allow multiple iterations on pipe based on iterables/callables - pipe = pipe._clone() - # head must be iterator - pipe.evaluate_gen() - if not isinstance(pipe.gen, Iterator): - raise PipeGenInvalid(pipe.name, pipe.gen) - - # create extractor - sources = [SourcePipeItem(pipe.gen, 0, pipe, None)] - return cls(max_parallel_items, workers, futures_poll_interval, sources, next_item_mode) - - @classmethod - @with_config(spec=PipeIteratorConfiguration) - def from_pipes( - cls, - pipes: Sequence[Pipe], - yield_parents: bool = True, - *, - max_parallel_items: int = 20, - workers: int = 5, - futures_poll_interval: float = 0.01, - copy_on_fork: bool = False, - next_item_mode: TPipeNextItemMode = "fifo", - ) -> "PipeIterator": - # print(f"max_parallel_items: {max_parallel_items} workers: {workers}") - sources: List[SourcePipeItem] = [] - - # clone all pipes before iterating (recursively) as we will fork them (this add steps) and evaluate gens - pipes, _ = PipeIterator.clone_pipes(pipes) - - def _fork_pipeline(pipe: Pipe) -> None: - if pipe.parent: - # fork the parent pipe - pipe.evaluate_gen() - pipe.parent.fork(pipe, copy_on_fork=copy_on_fork) - # make the parent yield by sending a clone of item to itself with position at the end - if yield_parents and pipe.parent in pipes: - # fork is last step of the pipe so it will yield - pipe.parent.fork(pipe.parent, len(pipe.parent) - 1, copy_on_fork=copy_on_fork) - _fork_pipeline(pipe.parent) - else: - # head of independent pipe must be iterator - pipe.evaluate_gen() - if not isinstance(pipe.gen, Iterator): - raise PipeGenInvalid(pipe.name, pipe.gen) - # add every head as source only once - if not any(i.pipe == pipe for i in sources): - sources.append(SourcePipeItem(pipe.gen, 0, pipe, None)) - - # reverse pipes for current mode, as we start processing from the back - pipes.reverse() - for pipe in pipes: - _fork_pipeline(pipe) - - # create extractor - return cls(max_parallel_items, workers, futures_poll_interval, sources, next_item_mode) - - def __next__(self) -> PipeItem: - pipe_item: Union[ResolvablePipeItem, SourcePipeItem] = None - # __next__ should call itself to remove the `while` loop and continue clauses but that may lead to stack overflows: there's no tail recursion opt in python - # https://stackoverflow.com/questions/13591970/does-python-optimize-tail-recursion (see Y combinator on how it could be emulated) - while True: - # do we need new item? - if pipe_item is None: - # process element from the futures - if len(self._futures) > 0: - pipe_item = self._resolve_futures() - # if none then take element from the newest source - if pipe_item is None: - pipe_item = self._get_source_item() - - if pipe_item is None: - if len(self._futures) == 0 and len(self._sources) == 0: - # no more elements in futures or sources - raise StopIteration() - else: - sleep(self.futures_poll_interval) - continue - - item = pipe_item.item - # if item is iterator, then add it as a new source - if isinstance(item, Iterator): - # print(f"adding iterable {item}") - self._sources.append( - SourcePipeItem(item, pipe_item.step, pipe_item.pipe, pipe_item.meta) - ) - pipe_item = None - continue - - # handle async iterator items as new source - if isinstance(item, AsyncIterator): - self._sources.append( - SourcePipeItem( - wrap_async_iterator(item), pipe_item.step, pipe_item.pipe, pipe_item.meta - ), - ) - pipe_item = None - continue - - if isinstance(item, Awaitable) or callable(item): - # do we have a free slot or one of the slots is done? - if len(self._futures) < self.max_parallel_items or self._next_future() >= 0: - # check if Awaitable first - awaitable can also be a callable - if isinstance(item, Awaitable): - future = asyncio.run_coroutine_threadsafe(item, self._ensure_async_pool()) - elif callable(item): - future = self._ensure_thread_pool().submit(item) - # print(future) - self._futures.append(FuturePipeItem(future, pipe_item.step, pipe_item.pipe, pipe_item.meta)) # type: ignore - # pipe item consumed for now, request a new one - pipe_item = None - continue - else: - # print("maximum futures exceeded, waiting") - sleep(self.futures_poll_interval) - # try same item later - continue - - # if we are at the end of the pipe then yield element - if pipe_item.step == len(pipe_item.pipe) - 1: - # must be resolved - if isinstance(item, (Iterator, Awaitable, AsyncIterator)) or callable(item): - raise PipeItemProcessingError( - pipe_item.pipe.name, - f"Pipe item at step {pipe_item.step} was not fully evaluated and is of type" - f" {type(pipe_item.item).__name__}. This is internal error or you are" - " yielding something weird from resources ie. functions or awaitables.", - ) - # mypy not able to figure out that item was resolved - return pipe_item # type: ignore - - # advance to next step - step = pipe_item.pipe[pipe_item.step + 1] - try: - set_current_pipe_name(pipe_item.pipe.name) - next_meta = pipe_item.meta - next_item = step(item, meta=pipe_item.meta) # type: ignore - if isinstance(next_item, DataItemWithMeta): - next_meta = next_item.meta - next_item = next_item.data - except TypeError as ty_ex: - assert callable(step) - raise InvalidStepFunctionArguments( - pipe_item.pipe.name, - get_callable_name(step), - inspect.signature(step), - str(ty_ex), - ) - except (PipelineException, ExtractorException, DltSourceException, PipeException): - raise - except Exception as ex: - raise ResourceExtractionError( - pipe_item.pipe.name, step, str(ex), "transform" - ) from ex - # create next pipe item if a value was returned. A None means that item was consumed/filtered out and should not be further processed - if next_item is not None: - pipe_item = ResolvablePipeItem( - next_item, pipe_item.step + 1, pipe_item.pipe, next_meta - ) - else: - pipe_item = None - - def close(self) -> None: - # unregister the pipe name right after execution of gen stopped - unset_current_pipe_name() - - def stop_background_loop(loop: asyncio.AbstractEventLoop) -> None: - loop.stop() - - # close all generators - for gen, _, _, _ in self._sources: - if inspect.isgenerator(gen): - gen.close() - self._sources.clear() - - # stop all futures - for f, _, _, _ in self._futures: - if not f.done(): - f.cancel() - - # let tasks cancel - if self._async_pool: - # wait for all async generators to be closed - future = asyncio.run_coroutine_threadsafe( - self._async_pool.shutdown_asyncgens(), self._ensure_async_pool() - ) - while not future.done(): - sleep(self.futures_poll_interval) - self._async_pool.call_soon_threadsafe(stop_background_loop, self._async_pool) - # print("joining thread") - self._async_pool_thread.join() - self._async_pool = None - self._async_pool_thread = None - if self._thread_pool: - self._thread_pool.shutdown(wait=True) - self._thread_pool = None - - self._futures.clear() - - def _ensure_async_pool(self) -> asyncio.AbstractEventLoop: - # lazily create async pool is separate thread - if self._async_pool: - return self._async_pool - - def start_background_loop(loop: asyncio.AbstractEventLoop) -> None: - asyncio.set_event_loop(loop) - loop.run_forever() - - self._async_pool = asyncio.new_event_loop() - self._async_pool_thread = Thread( - target=start_background_loop, - args=(self._async_pool,), - daemon=True, - name=Container.thread_pool_prefix() + "futures", - ) - self._async_pool_thread.start() - - # start or return async pool - return self._async_pool - - def _ensure_thread_pool(self) -> ThreadPoolExecutor: - # lazily start or return thread pool - if self._thread_pool: - return self._thread_pool - - self._thread_pool = ThreadPoolExecutor( - self.workers, thread_name_prefix=Container.thread_pool_prefix() + "threads" - ) - return self._thread_pool - - def __enter__(self) -> "PipeIterator": - return self - - def __exit__( - self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: types.TracebackType - ) -> None: - self.close() - - def _next_future(self) -> int: - return next((i for i, val in enumerate(self._futures) if val.item.done()), -1) - - def _resolve_futures(self) -> ResolvablePipeItem: - # no futures at all - if len(self._futures) == 0: - return None - - # anything done? - idx = self._next_future() - if idx == -1: - # nothing done - return None - - future, step, pipe, meta = self._futures.pop(idx) - - if future.cancelled(): - # get next future - return self._resolve_futures() - - if future.exception(): - ex = future.exception() - if isinstance(ex, StopAsyncIteration): - return None - if isinstance( - ex, (PipelineException, ExtractorException, DltSourceException, PipeException) - ): - raise ex - raise ResourceExtractionError(pipe.name, future, str(ex), "future") from ex - - item = future.result() - - # we also interpret future items that are None to not be value to be consumed - if item is None: - return None - elif isinstance(item, DataItemWithMeta): - return ResolvablePipeItem(item.data, step, pipe, item.meta) - else: - return ResolvablePipeItem(item, step, pipe, meta) - - def _get_source_item(self) -> ResolvablePipeItem: - sources_count = len(self._sources) - # no more sources to iterate - if sources_count == 0: - return None - try: - first_evaluated_index: int = None - # always reset to end of list for fifo mode, also take into account that new sources can be added - # if too many new sources is added we switch to fifo not to exhaust them - if ( - self._next_item_mode == "fifo" - or (sources_count - self._initial_sources_count) >= self.max_parallel_items - ): - self._current_source_index = sources_count - 1 - else: - self._current_source_index = (self._current_source_index - 1) % sources_count - while True: - # if we have checked all sources once and all returned None, then we can sleep a bit - if self._current_source_index == first_evaluated_index: - sleep(self.futures_poll_interval) - # get next item from the current source - gen, step, pipe, meta = self._sources[self._current_source_index] - set_current_pipe_name(pipe.name) - if (item := next(gen)) is not None: - # full pipe item may be returned, this is used by ForkPipe step - # to redirect execution of an item to another pipe - if isinstance(item, ResolvablePipeItem): - return item - else: - # keep the item assigned step and pipe when creating resolvable item - if isinstance(item, DataItemWithMeta): - return ResolvablePipeItem(item.data, step, pipe, item.meta) - else: - return ResolvablePipeItem(item, step, pipe, meta) - # remember the first evaluated index - if first_evaluated_index is None: - first_evaluated_index = self._current_source_index - # always go round robin if None was returned - self._current_source_index = (self._current_source_index - 1) % sources_count - except StopIteration: - # remove empty iterator and try another source - self._sources.pop(self._current_source_index) - # decrease initial source count if we popped an initial source - if self._current_source_index < self._initial_sources_count: - self._initial_sources_count -= 1 - return self._get_source_item() - except (PipelineException, ExtractorException, DltSourceException, PipeException): - raise - except Exception as ex: - raise ResourceExtractionError(pipe.name, gen, str(ex), "generator") from ex - - @staticmethod - def clone_pipes( - pipes: Sequence[Pipe], existing_cloned_pairs: Dict[int, Pipe] = None - ) -> Tuple[List[Pipe], Dict[int, Pipe]]: - """This will clone pipes and fix the parent/dependent references""" - cloned_pipes = [p._clone() for p in pipes if id(p) not in (existing_cloned_pairs or {})] - cloned_pairs = {id(p): c for p, c in zip(pipes, cloned_pipes)} - if existing_cloned_pairs: - cloned_pairs.update(existing_cloned_pairs) - - for clone in cloned_pipes: - while True: - if not clone.parent: - break - # if already a clone - if clone.parent in cloned_pairs.values(): - break - # clone if parent pipe not yet cloned - parent_id = id(clone.parent) - if parent_id not in cloned_pairs: - # print("cloning:" + clone.parent.name) - cloned_pairs[parent_id] = clone.parent._clone() - # replace with clone - # print(f"replace depends on {clone.name} to {clone.parent.name}") - clone.parent = cloned_pairs[parent_id] - # recur with clone - clone = clone.parent - - return cloned_pipes, cloned_pairs - - -class ManagedPipeIterator(PipeIterator): - """A version of the pipe iterator that gets closed automatically on an exception in _next_""" - - _ctx: List[ContainerInjectableContext] = None - _container: Container = None - - def set_context(self, ctx: List[ContainerInjectableContext]) -> None: - """Sets list of injectable contexts that will be injected into Container for each call to __next__""" - self._ctx = ctx - self._container = Container() - - def __next__(self) -> PipeItem: - if self._ctx: - managers = [self._container.injectable_context(ctx) for ctx in self._ctx] - for manager in managers: - manager.__enter__() - try: - item = super().__next__() - except Exception as ex: - # release context manager - if self._ctx: - if isinstance(ex, StopIteration): - for manager in managers: - manager.__exit__(None, None, None) - else: - for manager in managers: - manager.__exit__(type(ex), ex, None) - # crash in next - self.close() - raise - if self._ctx: - for manager in managers: - manager.__exit__(None, None, None) - return item diff --git a/dlt/extract/pipe_iterator.py b/dlt/extract/pipe_iterator.py new file mode 100644 index 0000000000..145b517802 --- /dev/null +++ b/dlt/extract/pipe_iterator.py @@ -0,0 +1,390 @@ +import inspect +import types +from typing import ( + AsyncIterator, + Dict, + Sequence, + Union, + Iterator, + List, + Awaitable, + Tuple, + Type, + Literal, +) +from concurrent.futures import TimeoutError as FutureTimeoutError + +from dlt.common.configuration import configspec +from dlt.common.configuration.inject import with_config +from dlt.common.configuration.specs import BaseConfiguration, ContainerInjectableContext +from dlt.common.configuration.container import Container +from dlt.common.exceptions import PipelineException +from dlt.common.source import unset_current_pipe_name, set_current_pipe_name +from dlt.common.utils import get_callable_name + +from dlt.extract.exceptions import ( + DltSourceException, + ExtractorException, + InvalidStepFunctionArguments, + PipeException, + PipeGenInvalid, + PipeItemProcessingError, + ResourceExtractionError, +) +from dlt.extract.pipe import Pipe +from dlt.extract.items import DataItemWithMeta, PipeItem, ResolvablePipeItem, SourcePipeItem +from dlt.extract.utils import wrap_async_iterator +from dlt.extract.concurrency import FuturesPool + +TPipeNextItemMode = Literal["fifo", "round_robin"] + + +class PipeIterator(Iterator[PipeItem]): + @configspec + class PipeIteratorConfiguration(BaseConfiguration): + max_parallel_items: int = 20 + workers: int = 5 + futures_poll_interval: float = 0.01 + copy_on_fork: bool = False + next_item_mode: str = "fifo" + + __section__ = "extract" + + def __init__( + self, + max_parallel_items: int, + workers: int, + futures_poll_interval: float, + sources: List[SourcePipeItem], + next_item_mode: TPipeNextItemMode, + ) -> None: + self._sources = sources + self._next_item_mode: TPipeNextItemMode = next_item_mode + self._initial_sources_count = len(sources) + self._current_source_index: int = 0 + self._futures_pool = FuturesPool( + workers=workers, + poll_interval=futures_poll_interval, + max_parallel_items=max_parallel_items, + ) + + @classmethod + @with_config(spec=PipeIteratorConfiguration) + def from_pipe( + cls, + pipe: Pipe, + *, + max_parallel_items: int = 20, + workers: int = 5, + futures_poll_interval: float = 0.01, + next_item_mode: TPipeNextItemMode = "fifo", + ) -> "PipeIterator": + # join all dependent pipes + if pipe.parent: + pipe = pipe.full_pipe() + # clone pipe to allow multiple iterations on pipe based on iterables/callables + pipe = pipe._clone() + # head must be iterator + pipe.evaluate_gen() + if not isinstance(pipe.gen, Iterator): + raise PipeGenInvalid(pipe.name, pipe.gen) + + # create extractor + sources = [SourcePipeItem(pipe.gen, 0, pipe, None)] + return cls(max_parallel_items, workers, futures_poll_interval, sources, next_item_mode) + + @classmethod + @with_config(spec=PipeIteratorConfiguration) + def from_pipes( + cls, + pipes: Sequence[Pipe], + yield_parents: bool = True, + *, + max_parallel_items: int = 20, + workers: int = 5, + futures_poll_interval: float = 0.01, + copy_on_fork: bool = False, + next_item_mode: TPipeNextItemMode = "fifo", + ) -> "PipeIterator": + # print(f"max_parallel_items: {max_parallel_items} workers: {workers}") + sources: List[SourcePipeItem] = [] + + # clone all pipes before iterating (recursively) as we will fork them (this add steps) and evaluate gens + pipes, _ = PipeIterator.clone_pipes(pipes) + + def _fork_pipeline(pipe: Pipe) -> None: + if pipe.parent: + # fork the parent pipe + pipe.evaluate_gen() + pipe.parent.fork(pipe, copy_on_fork=copy_on_fork) + # make the parent yield by sending a clone of item to itself with position at the end + if yield_parents and pipe.parent in pipes: + # fork is last step of the pipe so it will yield + pipe.parent.fork(pipe.parent, len(pipe.parent) - 1, copy_on_fork=copy_on_fork) + _fork_pipeline(pipe.parent) + else: + # head of independent pipe must be iterator + pipe.evaluate_gen() + if not isinstance(pipe.gen, Iterator): + raise PipeGenInvalid(pipe.name, pipe.gen) + # add every head as source only once + if not any(i.pipe == pipe for i in sources): + sources.append(SourcePipeItem(pipe.gen, 0, pipe, None)) + + # reverse pipes for current mode, as we start processing from the back + pipes.reverse() + for pipe in pipes: + _fork_pipeline(pipe) + + # create extractor + return cls(max_parallel_items, workers, futures_poll_interval, sources, next_item_mode) + + def __next__(self) -> PipeItem: + pipe_item: Union[ResolvablePipeItem, SourcePipeItem] = None + # __next__ should call itself to remove the `while` loop and continue clauses but that may lead to stack overflows: there's no tail recursion opt in python + # https://stackoverflow.com/questions/13591970/does-python-optimize-tail-recursion (see Y combinator on how it could be emulated) + while True: + # do we need new item? + if pipe_item is None: + # Always check for done futures to avoid starving the pool + pipe_item = self._futures_pool.resolve_next_future_no_wait() + + if pipe_item is None: + # if none then take element from the newest source + pipe_item = self._get_source_item() + + if pipe_item is None: + # Wait for some time for futures to resolve + try: + pipe_item = self._futures_pool.resolve_next_future( + use_configured_timeout=True + ) + except FutureTimeoutError: + pass + else: + if pipe_item is None: + # pool was empty - then do a regular poll sleep + self._futures_pool.sleep() + + if pipe_item is None: + if self._futures_pool.empty and len(self._sources) == 0: + # no more elements in futures or sources + raise StopIteration() + else: + continue + + item = pipe_item.item + # if item is iterator, then add it as a new source + if isinstance(item, Iterator): + # print(f"adding iterable {item}") + self._sources.append( + SourcePipeItem(item, pipe_item.step, pipe_item.pipe, pipe_item.meta) + ) + pipe_item = None + continue + + # handle async iterator items as new source + if isinstance(item, AsyncIterator): + self._sources.append( + SourcePipeItem( + wrap_async_iterator(item), pipe_item.step, pipe_item.pipe, pipe_item.meta + ), + ) + pipe_item = None + continue + + if isinstance(item, Awaitable) or callable(item): + # Callables are submitted to the pool to be executed in the background + self._futures_pool.submit(pipe_item) # type: ignore[arg-type] + pipe_item = None + # Future will be resolved later, move on to the next item + continue + + # if we are at the end of the pipe then yield element + if pipe_item.step == len(pipe_item.pipe) - 1: + # must be resolved + if isinstance(item, (Iterator, Awaitable, AsyncIterator)) or callable(item): + raise PipeItemProcessingError( + pipe_item.pipe.name, + f"Pipe item at step {pipe_item.step} was not fully evaluated and is of type" + f" {type(pipe_item.item).__name__}. This is internal error or you are" + " yielding something weird from resources ie. functions or awaitables.", + ) + # mypy not able to figure out that item was resolved + return pipe_item # type: ignore + + # advance to next step + step = pipe_item.pipe[pipe_item.step + 1] + try: + set_current_pipe_name(pipe_item.pipe.name) + next_meta = pipe_item.meta + next_item = step(item, meta=pipe_item.meta) # type: ignore + if isinstance(next_item, DataItemWithMeta): + next_meta = next_item.meta + next_item = next_item.data + except TypeError as ty_ex: + assert callable(step) + raise InvalidStepFunctionArguments( + pipe_item.pipe.name, + get_callable_name(step), + inspect.signature(step), + str(ty_ex), + ) + except (PipelineException, ExtractorException, DltSourceException, PipeException): + raise + except Exception as ex: + raise ResourceExtractionError( + pipe_item.pipe.name, step, str(ex), "transform" + ) from ex + # create next pipe item if a value was returned. A None means that item was consumed/filtered out and should not be further processed + if next_item is not None: + pipe_item = ResolvablePipeItem( + next_item, pipe_item.step + 1, pipe_item.pipe, next_meta + ) + else: + pipe_item = None + + def _get_source_item(self) -> ResolvablePipeItem: + sources_count = len(self._sources) + # no more sources to iterate + if sources_count == 0: + return None + try: + first_evaluated_index: int = None + # always reset to end of list for fifo mode, also take into account that new sources can be added + # if too many new sources is added we switch to fifo not to exhaust them + if self._next_item_mode == "fifo" or ( + sources_count - self._initial_sources_count >= self._futures_pool.max_parallel_items + ): + self._current_source_index = sources_count - 1 + else: + self._current_source_index = (self._current_source_index - 1) % sources_count + while True: + # if we have checked all sources once and all returned None, return and poll/resolve some futures + if self._current_source_index == first_evaluated_index: + return None + # get next item from the current source + gen, step, pipe, meta = self._sources[self._current_source_index] + set_current_pipe_name(pipe.name) + + pipe_item = next(gen) + if pipe_item is not None: + # full pipe item may be returned, this is used by ForkPipe step + # to redirect execution of an item to another pipe + # else + if not isinstance(pipe_item, ResolvablePipeItem): + # keep the item assigned step and pipe when creating resolvable item + if isinstance(pipe_item, DataItemWithMeta): + return ResolvablePipeItem(pipe_item.data, step, pipe, pipe_item.meta) + else: + return ResolvablePipeItem(pipe_item, step, pipe, meta) + + if pipe_item is not None: + return pipe_item + + # remember the first evaluated index + if first_evaluated_index is None: + first_evaluated_index = self._current_source_index + # always go round robin if None was returned or item is to be run as future + self._current_source_index = (self._current_source_index - 1) % sources_count + + except StopIteration: + # remove empty iterator and try another source + self._sources.pop(self._current_source_index) + # decrease initial source count if we popped an initial source + if self._current_source_index < self._initial_sources_count: + self._initial_sources_count -= 1 + return self._get_source_item() + except (PipelineException, ExtractorException, DltSourceException, PipeException): + raise + except Exception as ex: + raise ResourceExtractionError(pipe.name, gen, str(ex), "generator") from ex + + def close(self) -> None: + # unregister the pipe name right after execution of gen stopped + unset_current_pipe_name() + + # Close the futures pool and cancel all tasks + # It's important to do this before closing generators as we can't close a running generator + self._futures_pool.close() + + # close all generators + for gen, _, _, _ in self._sources: + if inspect.isgenerator(gen): + gen.close() + + self._sources.clear() + + def __enter__(self) -> "PipeIterator": + return self + + def __exit__( + self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: types.TracebackType + ) -> None: + self.close() + + @staticmethod + def clone_pipes( + pipes: Sequence[Pipe], existing_cloned_pairs: Dict[int, Pipe] = None + ) -> Tuple[List[Pipe], Dict[int, Pipe]]: + """This will clone pipes and fix the parent/dependent references""" + cloned_pipes = [p._clone() for p in pipes if id(p) not in (existing_cloned_pairs or {})] + cloned_pairs = {id(p): c for p, c in zip(pipes, cloned_pipes)} + if existing_cloned_pairs: + cloned_pairs.update(existing_cloned_pairs) + + for clone in cloned_pipes: + while True: + if not clone.parent: + break + # if already a clone + if clone.parent in cloned_pairs.values(): + break + # clone if parent pipe not yet cloned + parent_id = id(clone.parent) + if parent_id not in cloned_pairs: + # print("cloning:" + clone.parent.name) + cloned_pairs[parent_id] = clone.parent._clone() + # replace with clone + # print(f"replace depends on {clone.name} to {clone.parent.name}") + clone.parent = cloned_pairs[parent_id] + # recur with clone + clone = clone.parent + + return cloned_pipes, cloned_pairs + + +class ManagedPipeIterator(PipeIterator): + """A version of the pipe iterator that gets closed automatically on an exception in _next_""" + + _ctx: List[ContainerInjectableContext] = None + _container: Container = None + + def set_context(self, ctx: List[ContainerInjectableContext]) -> None: + """Sets list of injectable contexts that will be injected into Container for each call to __next__""" + self._ctx = ctx + self._container = Container() + + def __next__(self) -> PipeItem: + if self._ctx: + managers = [self._container.injectable_context(ctx) for ctx in self._ctx] + for manager in managers: + manager.__enter__() + try: + item = super().__next__() + except Exception as ex: + # release context manager + if self._ctx: + if isinstance(ex, StopIteration): + for manager in managers: + manager.__exit__(None, None, None) + else: + for manager in managers: + manager.__exit__(type(ex), ex, None) + # crash in next + self.close() + raise + if self._ctx: + for manager in managers: + manager.__exit__(None, None, None) + return item diff --git a/dlt/extract/resource.py b/dlt/extract/resource.py index e08cafe2c1..0fef502112 100644 --- a/dlt/extract/resource.py +++ b/dlt/extract/resource.py @@ -1,5 +1,6 @@ from copy import deepcopy import inspect +from functools import partial from typing import ( AsyncIterable, AsyncIterator, @@ -24,9 +25,9 @@ pipeline_state, ) from dlt.common.utils import flatten_list_or_items, get_callable_name, uniq_id -from dlt.extract.utils import wrap_async_iterator +from dlt.extract.utils import wrap_async_iterator, wrap_parallel_iterator -from dlt.extract.typing import ( +from dlt.extract.items import ( DataItemWithMeta, ItemTransformFunc, ItemTransformFunctionWithMeta, @@ -36,7 +37,8 @@ YieldMapItem, ValidateItem, ) -from dlt.extract.pipe import Pipe, ManagedPipeIterator, TPipeStep +from dlt.extract.pipe_iterator import ManagedPipeIterator +from dlt.extract.pipe import Pipe, TPipeStep from dlt.extract.hints import DltResourceHints, HintsMeta, TResourceHints from dlt.extract.incremental import Incremental, IncrementalResourceWrapper from dlt.extract.exceptions import ( @@ -48,6 +50,7 @@ InvalidTransformerGeneratorFunction, InvalidResourceDataTypeBasic, InvalidResourceDataTypeMultiplePipes, + InvalidParallelResourceDataType, ParametrizedResourceUnbound, ResourceNameMissing, ResourceNotATransformer, @@ -311,12 +314,21 @@ def add_limit(self, max_items: int) -> "DltResource": # noqa: A003 "DltResource": returns self """ + # make sure max_items is a number, to allow "None" as value for unlimited + if max_items is None: + max_items = -1 + def _gen_wrap(gen: TPipeStep) -> TPipeStep: """Wrap a generator to take the first `max_items` records""" + + # zero items should produce empty generator + if max_items == 0: + return + count = 0 is_async_gen = False - if inspect.isfunction(gen): - gen = gen() + if callable(gen): + gen = gen() # type: ignore # wrap async gen already here if isinstance(gen, AsyncIterator): @@ -340,7 +352,31 @@ def _gen_wrap(gen: TPipeStep) -> TPipeStep: # transformers should be limited by their input, so we only limit non-transformers if not self.is_transformer: - self._pipe.replace_gen(_gen_wrap(self._pipe.gen)) + gen = self._pipe.gen + # wrap gen directly + if inspect.isgenerator(gen): + self._pipe.replace_gen(_gen_wrap(gen)) + else: + # keep function as function to not evaluate generators before pipe starts + self._pipe.replace_gen(partial(_gen_wrap, gen)) + return self + + def parallelize(self) -> "DltResource": + """Wraps the resource to execute each item in a threadpool to allow multiple resources to extract in parallel. + + The resource must be a generator or generator function or a transformer function. + """ + if ( + not inspect.isgenerator(self._pipe.gen) + and not ( + callable(self._pipe.gen) + and inspect.isgeneratorfunction(inspect.unwrap(self._pipe.gen)) + ) + and not (callable(self._pipe.gen) and self.is_transformer) + ): + raise InvalidParallelResourceDataType(self.name, self._pipe.gen, type(self._pipe.gen)) + + self._pipe.replace_gen(wrap_parallel_iterator(self._pipe.gen)) # type: ignore # TODO return self def add_step( diff --git a/dlt/extract/source.py b/dlt/extract/source.py index bc33394d4d..5d9799e29c 100644 --- a/dlt/extract/source.py +++ b/dlt/extract/source.py @@ -23,14 +23,16 @@ ) from dlt.common.utils import graph_find_scc_nodes, flatten_list_or_items, graph_edges_to_nodes -from dlt.extract.typing import TDecompositionStrategy -from dlt.extract.pipe import Pipe, ManagedPipeIterator +from dlt.extract.items import TDecompositionStrategy +from dlt.extract.pipe_iterator import ManagedPipeIterator +from dlt.extract.pipe import Pipe from dlt.extract.hints import DltResourceHints, make_hints from dlt.extract.resource import DltResource from dlt.extract.exceptions import ( DataItemRequiredForDynamicTableHints, ResourcesNotFoundError, DeletingResourcesNotSupported, + InvalidParallelResourceDataType, ) @@ -333,6 +335,18 @@ def add_limit(self, max_items: int) -> "DltSource": # noqa: A003 resource.add_limit(max_items) return self + def parallelize(self) -> "DltSource": + """Mark all resources in the source to run in parallel. + + Only transformers and resources based on generators and generator functions are supported, unsupported resources will be skipped. + """ + for resource in self.resources.selected.values(): + try: + resource.parallelize() + except InvalidParallelResourceDataType: + pass + return self + @property def run(self) -> SupportsPipelineRun: """A convenience method that will call `run` run on the currently active `dlt` pipeline. If pipeline instance is not found, one with default settings will be created.""" diff --git a/dlt/extract/utils.py b/dlt/extract/utils.py index fc27a5c39e..69edcab93d 100644 --- a/dlt/extract/utils.py +++ b/dlt/extract/utils.py @@ -2,6 +2,7 @@ import makefun import asyncio from typing import ( + Callable, Optional, Tuple, Union, @@ -13,20 +14,27 @@ AsyncGenerator, Awaitable, Generator, + Iterator, ) from collections.abc import Mapping as C_Mapping +from functools import wraps, partial from dlt.common.exceptions import MissingDependencyException from dlt.common.pipeline import reset_resource_state from dlt.common.schema.typing import TColumnNames, TAnySchemaColumns, TTableSchemaColumns -from dlt.common.typing import AnyFun, DictStrAny, TDataItem, TDataItems +from dlt.common.typing import AnyFun, DictStrAny, TDataItem, TDataItems, TAnyFunOrGenerator from dlt.common.utils import get_callable_name from dlt.extract.exceptions import ( InvalidResourceDataTypeFunctionNotAGenerator, InvalidStepFunctionArguments, ) -from dlt.extract.typing import TTableHintTemplate, TDataItem, TFunHintTemplate, SupportsPipe +from dlt.extract.items import ( + TTableHintTemplate, + TDataItem, + TFunHintTemplate, + SupportsPipe, +) try: from dlt.common.libs import pydantic @@ -171,6 +179,55 @@ async def run() -> TDataItems: exhausted = True +def wrap_parallel_iterator(f: TAnyFunOrGenerator) -> TAnyFunOrGenerator: + """Wraps a generator for parallel extraction""" + + def _gen_wrapper(*args: Any, **kwargs: Any) -> Iterator[TDataItems]: + gen: TAnyFunOrGenerator + if callable(f): + gen = f(*args, **kwargs) + else: + gen = f + + exhausted = False + busy = False + + def _parallel_gen() -> TDataItems: + nonlocal busy + nonlocal exhausted + try: + return next(gen) # type: ignore[call-overload] + except StopIteration: + exhausted = True + return None + finally: + busy = False + + while not exhausted: + try: + while busy: + yield None + busy = True + yield _parallel_gen + except GeneratorExit: + gen.close() # type: ignore[attr-defined] + raise + + if callable(f): + if inspect.isgeneratorfunction(inspect.unwrap(f)): + return wraps(f)(_gen_wrapper) # type: ignore[arg-type] + else: + + def _fun_wrapper(*args: Any, **kwargs: Any) -> Any: + def _curry() -> Any: + return f(*args, **kwargs) + + return _curry + + return wraps(f)(_fun_wrapper) # type: ignore[arg-type] + return _gen_wrapper() # type: ignore[return-value] + + def wrap_compat_transformer( name: str, f: AnyFun, sig: inspect.Signature, *args: Any, **kwargs: Any ) -> AnyFun: diff --git a/dlt/extract/validation.py b/dlt/extract/validation.py index 72b70c5661..504eee1bfc 100644 --- a/dlt/extract/validation.py +++ b/dlt/extract/validation.py @@ -8,7 +8,7 @@ from dlt.common.typing import TDataItems from dlt.common.schema.typing import TAnySchemaColumns, TSchemaContract, TSchemaEvolutionMode -from dlt.extract.typing import TTableHintTemplate, ValidateItem +from dlt.extract.items import TTableHintTemplate, ValidateItem _TPydanticModel = TypeVar("_TPydanticModel", bound=PydanticBaseModel) diff --git a/dlt/helpers/airflow_helper.py b/dlt/helpers/airflow_helper.py index 437602d3a4..9a6616e9ea 100644 --- a/dlt/helpers/airflow_helper.py +++ b/dlt/helpers/airflow_helper.py @@ -1,3 +1,4 @@ +import functools import os from tempfile import gettempdir from typing import Any, Callable, List, Literal, Optional, Sequence, Tuple @@ -131,6 +132,158 @@ def __init__( if ConfigProvidersContext in Container(): del Container()[ConfigProvidersContext] + def run( + self, + pipeline: Pipeline, + data: Any, + table_name: str = None, + write_disposition: TWriteDisposition = None, + loader_file_format: TLoaderFileFormat = None, + schema_contract: TSchemaContract = None, + pipeline_name: str = None, + **kwargs: Any, + ) -> PythonOperator: + """ + Create a task to run the given pipeline with the + given data in Airflow. + + Args: + pipeline (Pipeline): The pipeline to run + data (Any): The data to run the pipeline with + table_name (str, optional): The name of the table to + which the data should be loaded within the `dataset`. + write_disposition (TWriteDisposition, optional): Same as + in `run` command. + loader_file_format (TLoaderFileFormat, optional): + The file format the loader will use to create the + load package. + schema_contract (TSchemaContract, optional): On override + for the schema contract settings, this will replace + the schema contract settings for all tables in the schema. + pipeline_name (str, optional): The name of the derived pipeline. + + Returns: + PythonOperator: Airflow task instance. + """ + f = functools.partial( + self._run, + pipeline, + data, + table_name=table_name, + write_disposition=write_disposition, + loader_file_format=loader_file_format, + schema_contract=schema_contract, + pipeline_name=pipeline_name, + ) + return PythonOperator(task_id=_task_name(pipeline, data), python_callable=f, **kwargs) + + def _run( + self, + pipeline: Pipeline, + data: Any, + table_name: str = None, + write_disposition: TWriteDisposition = None, + loader_file_format: TLoaderFileFormat = None, + schema_contract: TSchemaContract = None, + pipeline_name: str = None, + ) -> None: + """Run the given pipeline with the given data. + + Args: + pipeline (Pipeline): The pipeline to run + data (Any): The data to run the pipeline with + table_name (str, optional): The name of the + table to which the data should be loaded + within the `dataset`. + write_disposition (TWriteDisposition, optional): + Same as in `run` command. + loader_file_format (TLoaderFileFormat, optional): + The file format the loader will use to create + the load package. + schema_contract (TSchemaContract, optional): On + override for the schema contract settings, + this will replace the schema contract settings + for all tables in the schema. + pipeline_name (str, optional): The name of the + derived pipeline. + """ + # activate pipeline + pipeline.activate() + # drop local data + task_pipeline = pipeline.drop(pipeline_name=pipeline_name) + + # use task logger + if self.use_task_logger: + ti: TaskInstance = get_current_context()["ti"] # type: ignore + logger.LOGGER = ti.log + + # set global number of buffered items + if dlt.config.get("data_writer.buffer_max_items") is None and self.buffer_max_items > 0: + dlt.config["data_writer.buffer_max_items"] = self.buffer_max_items + logger.info(f"Set data_writer.buffer_max_items to {self.buffer_max_items}") + + # enable abort package if job failed + if self.abort_task_if_any_job_failed: + dlt.config["load.raise_on_failed_jobs"] = True + logger.info("Set load.abort_task_if_any_job_failed to True") + + if self.log_progress_period > 0 and task_pipeline.collector == NULL_COLLECTOR: + task_pipeline.collector = log(log_period=self.log_progress_period, logger=logger.LOGGER) + logger.info(f"Enabled log progress with period {self.log_progress_period}") + + logger.info(f"Pipeline data in {task_pipeline.working_dir}") + + def log_after_attempt(retry_state: RetryCallState) -> None: + if not retry_state.retry_object.stop(retry_state): + logger.error( + "Retrying pipeline run due to exception: %s", + retry_state.outcome.exception(), + ) + + try: + # retry with given policy on selected pipeline steps + for attempt in self.retry_policy.copy( + retry=retry_if_exception( + retry_load(retry_on_pipeline_steps=self.retry_pipeline_steps) + ), + after=log_after_attempt, + ): + with attempt: + logger.info( + "Running the pipeline, attempt=%s" % attempt.retry_state.attempt_number + ) + load_info = task_pipeline.run( + data, + table_name=table_name, + write_disposition=write_disposition, + loader_file_format=loader_file_format, + schema_contract=schema_contract, + ) + logger.info(str(load_info)) + # save load and trace + if self.save_load_info: + logger.info("Saving the load info in the destination") + task_pipeline.run( + [load_info], + table_name="_load_info", + loader_file_format=loader_file_format, + ) + if self.save_trace_info: + logger.info("Saving the trace in the destination") + task_pipeline.run( + [task_pipeline.last_trace], + table_name="_trace", + loader_file_format=loader_file_format, + ) + # raise on failed jobs if requested + if self.fail_task_if_any_job_failed: + load_info.raise_on_failed_jobs() + finally: + # always completely wipe out pipeline folder, in case of success and failure + if self.wipe_local_data: + logger.info(f"Removing folder {pipeline.working_dir}") + task_pipeline._wipe_working_folder() + @with_telemetry("helper", "airflow_add_run", False, "decompose") def add_run( self, @@ -194,106 +347,23 @@ def add_run( " pipelines directory is not set correctly." ) - def task_name(pipeline: Pipeline, data: Any) -> str: - task_name = pipeline.pipeline_name - if isinstance(data, DltSource): - resource_names = list(data.selected_resources.keys()) - task_name = data.name + "_" + "-".join(resource_names[:4]) - if len(resource_names) > 4: - task_name += f"-{len(resource_names)-4}-more" - return task_name - with self: # use factory function to make a task, in order to parametrize it # passing arguments to task function (_run) is serializing # them and running template engine on them def make_task(pipeline: Pipeline, data: Any, name: str = None) -> PythonOperator: - def _run() -> None: - # activate pipeline - pipeline.activate() - # drop local data - task_pipeline = pipeline.drop(pipeline_name=name) - - # use task logger - if self.use_task_logger: - ti: TaskInstance = get_current_context()["ti"] # type: ignore - logger.LOGGER = ti.log - - # set global number of buffered items - if ( - dlt.config.get("data_writer.buffer_max_items") is None - and self.buffer_max_items > 0 - ): - dlt.config["data_writer.buffer_max_items"] = self.buffer_max_items - logger.info(f"Set data_writer.buffer_max_items to {self.buffer_max_items}") - - # enable abort package if job failed - if self.abort_task_if_any_job_failed: - dlt.config["load.raise_on_failed_jobs"] = True - logger.info("Set load.abort_task_if_any_job_failed to True") - - if self.log_progress_period > 0 and task_pipeline.collector == NULL_COLLECTOR: - task_pipeline.collector = log( - log_period=self.log_progress_period, logger=logger.LOGGER - ) - logger.info(f"Enabled log progress with period {self.log_progress_period}") - - logger.info(f"Pipeline data in {task_pipeline.working_dir}") - - def log_after_attempt(retry_state: RetryCallState) -> None: - if not retry_state.retry_object.stop(retry_state): - logger.error( - "Retrying pipeline run due to exception: %s", - retry_state.outcome.exception(), - ) - - try: - # retry with given policy on selected pipeline steps - for attempt in self.retry_policy.copy( - retry=retry_if_exception( - retry_load(retry_on_pipeline_steps=self.retry_pipeline_steps) - ), - after=log_after_attempt, - ): - with attempt: - logger.info( - "Running the pipeline, attempt=%s" - % attempt.retry_state.attempt_number - ) - load_info = task_pipeline.run( - data, - table_name=table_name, - write_disposition=write_disposition, - loader_file_format=loader_file_format, - schema_contract=schema_contract, - ) - logger.info(str(load_info)) - # save load and trace - if self.save_load_info: - logger.info("Saving the load info in the destination") - task_pipeline.run( - [load_info], - table_name="_load_info", - loader_file_format=loader_file_format, - ) - if self.save_trace_info: - logger.info("Saving the trace in the destination") - task_pipeline.run( - [task_pipeline.last_trace], - table_name="_trace", - loader_file_format=loader_file_format, - ) - # raise on failed jobs if requested - if self.fail_task_if_any_job_failed: - load_info.raise_on_failed_jobs() - finally: - # always completely wipe out pipeline folder, in case of success and failure - if self.wipe_local_data: - logger.info(f"Removing folder {pipeline.working_dir}") - task_pipeline._wipe_working_folder() - + f = functools.partial( + self._run, + pipeline, + data, + table_name=table_name, + write_disposition=write_disposition, + loader_file_format=loader_file_format, + schema_contract=schema_contract, + pipeline_name=name, + ) return PythonOperator( - task_id=task_name(pipeline, data), python_callable=_run, **kwargs + task_id=_task_name(pipeline, data), python_callable=f, **kwargs ) if decompose == "none": @@ -323,7 +393,7 @@ def log_after_attempt(retry_state: RetryCallState) -> None: tasks = [] sources = data.decompose("scc") - t_name = task_name(pipeline, data) + t_name = _task_name(pipeline, data) start = make_task(pipeline, sources[0]) # parallel tasks @@ -364,16 +434,16 @@ def log_after_attempt(retry_state: RetryCallState) -> None: start = make_task( pipeline, sources[0], - naming.normalize_identifier(task_name(pipeline, sources[0])), + naming.normalize_identifier(_task_name(pipeline, sources[0])), ) # parallel tasks for source in sources[1:]: # name pipeline the same as task - new_pipeline_name = naming.normalize_identifier(task_name(pipeline, source)) + new_pipeline_name = naming.normalize_identifier(_task_name(pipeline, source)) tasks.append(make_task(pipeline, source, new_pipeline_name)) - t_name = task_name(pipeline, data) + t_name = _task_name(pipeline, data) end = DummyOperator(task_id=f"{t_name}_end") if tasks: @@ -388,10 +458,6 @@ def log_after_attempt(retry_state: RetryCallState) -> None: " 'parallel-isolated']" ) - def add_fun(self, f: Callable[..., Any], **kwargs: Any) -> Any: - """Will execute a function `f` inside an Airflow task. It is up to the function to create pipeline and source(s)""" - raise NotImplementedError() - def airflow_get_execution_dates() -> Tuple[pendulum.DateTime, Optional[pendulum.DateTime]]: # prefer logging to task logger @@ -402,3 +468,25 @@ def airflow_get_execution_dates() -> Tuple[pendulum.DateTime, Optional[pendulum. return context["data_interval_start"], context["data_interval_end"] except Exception: return None, None + + +def _task_name(pipeline: Pipeline, data: Any) -> str: + """Generate a task name. + + Args: + pipeline (Pipeline): The pipeline to run. + data (Any): The data to run the pipeline with. + + Returns: + str: The name of the task. + """ + task_name = pipeline.pipeline_name + + if isinstance(data, DltSource): + resource_names = list(data.selected_resources.keys()) + task_name = data.name + "_" + "-".join(resource_names[:4]) + + if len(resource_names) > 4: + task_name += f"-{len(resource_names)-4}-more" + + return task_name diff --git a/dlt/helpers/streamlit_helper.py b/dlt/helpers/streamlit_helper.py index d3e194b18d..f6b2f3a62f 100644 --- a/dlt/helpers/streamlit_helper.py +++ b/dlt/helpers/streamlit_helper.py @@ -9,7 +9,7 @@ from dlt.common.destination.reference import WithStateSync from dlt.common.utils import flatten_list_or_items -from dlt.common.libs.pandas import pandas as pd +from dlt.common.libs.pandas import pandas from dlt.pipeline import Pipeline from dlt.pipeline.exceptions import CannotRestorePipelineException, SqlClientNotAvailable from dlt.pipeline.state_sync import load_state_from_destination @@ -102,7 +102,7 @@ def write_load_status_page(pipeline: Pipeline) -> None: """Display pipeline loading information. Will be moved to dlt package once tested""" @cache_data(ttl=600) - def _query_data(query: str, schema_name: str = None) -> pd.DataFrame: + def _query_data(query: str, schema_name: str = None) -> pandas.DataFrame: try: with pipeline.sql_client(schema_name) as client: with client.execute_query(query) as curr: @@ -111,7 +111,7 @@ def _query_data(query: str, schema_name: str = None) -> pd.DataFrame: st.error("Cannot load data - SqlClient not available") @cache_data(ttl=5) - def _query_data_live(query: str, schema_name: str = None) -> pd.DataFrame: + def _query_data_live(query: str, schema_name: str = None) -> pandas.DataFrame: try: with pipeline.sql_client(schema_name) as client: with client.execute_query(query) as curr: @@ -244,7 +244,7 @@ def write_data_explorer_page( """ @cache_data(ttl=60) - def _query_data(query: str, chunk_size: int = None) -> pd.DataFrame: + def _query_data(query: str, chunk_size: int = None) -> pandas.DataFrame: try: with pipeline.sql_client(schema_name) as client: with client.execute_query(query) as curr: diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index d360a1c7c4..c5762af680 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -340,6 +340,8 @@ def spool_files( ) def spool_schema_files(self, load_id: str, schema: Schema, files: Sequence[str]) -> str: + # delete existing folder for the case that this is a retry + self.load_storage.new_packages.delete_package(load_id, not_exists_ok=True) # normalized files will go here before being atomically renamed self.load_storage.new_packages.create_package(load_id) logger.info(f"Created new load package {load_id} on loading volume") @@ -372,6 +374,20 @@ def run(self, pool: Optional[Executor]) -> TRunMetrics: for load_id in load_ids: # read schema from package schema = self.normalize_storage.extracted_packages.load_schema(load_id) + # prefer schema from schema storage if it exists + try: + # also import the schema + storage_schema = self.schema_storage.load_schema(schema.name) + if schema.stored_version_hash != storage_schema.stored_version_hash: + logger.warning( + f"When normalizing package {load_id} with schema {schema.name}: the storage" + f" schema hash {storage_schema.stored_version_hash} is different from" + f" extract package schema hash {schema.stored_version_hash}. Storage schema" + " was used." + ) + schema = storage_schema + except FileNotFoundError: + pass # read all files to normalize placed as new jobs schema_files = self.normalize_storage.extracted_packages.list_new_jobs(load_id) logger.info( diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 44edcf2da5..185a11962a 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -161,13 +161,30 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: for name in self._schema_storage.live_schemas: # refresh live schemas in storage or import schema path self._schema_storage.commit_live_schema(name) - rv = f(self, *args, **kwargs) - # save modified live schemas - for name in self._schema_storage.live_schemas: - self._schema_storage.commit_live_schema(name) - # refresh list of schemas if any new schemas are added - self.schema_names = self._list_schemas_sorted() - return rv + try: + rv = f(self, *args, **kwargs) + except Exception: + # because we committed live schema before calling f, we may safely + # drop all changes in live schemas + for name in list(self._schema_storage.live_schemas.keys()): + try: + schema = self._schema_storage.load_schema(name) + self._schema_storage.update_live_schema(schema, can_create_new=False) + except FileNotFoundError: + # no storage schema yet so pop live schema (created in call to f) + self._schema_storage.live_schemas.pop(name, None) + # NOTE: with_state_sync will restore schema_names and default_schema_name + # so we do not need to do that here + raise + else: + # save modified live schemas + for name, schema in self._schema_storage.live_schemas.items(): + self._schema_storage.commit_live_schema(name) + # also save import schemas only here + self._schema_storage.save_import_schema_if_not_exists(schema) + # refresh list of schemas if any new schemas are added + self.schema_names = self._list_schemas_sorted() + return rv return _wrap # type: ignore @@ -1019,20 +1036,32 @@ def _extract_source( self, extract: Extract, source: DltSource, max_parallel_items: int, workers: int ) -> str: # discover the existing pipeline schema - if source.schema.name in self.schemas: - # use clone until extraction complete - pipeline_schema = self.schemas[source.schema.name].clone() + try: + # all live schemas are initially committed and during the extract will accumulate changes in memory + # if schema is committed try to take schema from storage + if self._schema_storage.is_live_schema_committed(source.schema.name): + # this will (1) save live schema if modified (2) look for import schema if present + # (3) load import schema an overwrite pipeline schema if import schema modified + # (4) load pipeline schema if no import schema is present + pipeline_schema = self.schemas.load_schema(source.schema.name) + else: + # if schema is not committed we know we are in process of extraction + pipeline_schema = self.schemas[source.schema.name] + pipeline_schema = pipeline_schema.clone() # use clone until extraction complete # apply all changes in the source schema to pipeline schema # NOTE: we do not apply contracts to changes done programmatically pipeline_schema.update_schema(source.schema) # replace schema in the source source.schema = pipeline_schema + except FileNotFoundError: + pass # extract into pipeline schema load_id = extract.extract(source, max_parallel_items, workers) # save import with fully discovered schema - self._schema_storage.save_import_schema_if_not_exists(source.schema) + # NOTE: moved to with_schema_sync, remove this if all test pass + # self._schema_storage.save_import_schema_if_not_exists(source.schema) # update live schema but not update the store yet self._schema_storage.update_live_schema(source.schema) diff --git a/dlt/reflection/script_inspector.py b/dlt/reflection/script_inspector.py index d8d96804c8..f9068d31e4 100644 --- a/dlt/reflection/script_inspector.py +++ b/dlt/reflection/script_inspector.py @@ -13,7 +13,7 @@ from dlt.pipeline import Pipeline from dlt.extract import DltSource -from dlt.extract.pipe import ManagedPipeIterator +from dlt.extract.pipe_iterator import ManagedPipeIterator def patch__init__(self: Any, *args: Any, **kwargs: Any) -> None: diff --git a/dlt/sources/helpers/transform.py b/dlt/sources/helpers/transform.py index 1975c20586..3949823be7 100644 --- a/dlt/sources/helpers/transform.py +++ b/dlt/sources/helpers/transform.py @@ -1,5 +1,5 @@ from dlt.common.typing import TDataItem -from dlt.extract.typing import ItemTransformFunctionNoMeta +from dlt.extract.items import ItemTransformFunctionNoMeta def take_first(max_items: int) -> ItemTransformFunctionNoMeta[bool]: diff --git a/docs/examples/chess_production/chess.py b/docs/examples/chess_production/chess.py index 2e85805781..e2d0b9c10d 100644 --- a/docs/examples/chess_production/chess.py +++ b/docs/examples/chess_production/chess.py @@ -6,6 +6,7 @@ from dlt.common.typing import StrAny, TDataItems from dlt.sources.helpers.requests import client + @dlt.source def chess( chess_url: str = dlt.config.value, @@ -25,11 +26,8 @@ def players() -> Iterator[TDataItems]: yield from _get_data_with_retry(f"titled/{title}")["players"][:max_players] # this resource takes data from players and returns profiles - # it uses `defer` decorator to enable parallel run in thread pool. - # defer requires return at the end so we convert yield into return (we return one item anyway) - # you can still have yielding transformers, look for the test named `test_evolve_schema` - @dlt.transformer(data_from=players, write_disposition="replace") - @dlt.defer + # it uses `paralellized` flag to enable parallel run in thread pool. + @dlt.transformer(data_from=players, write_disposition="replace", parallelized=True) def players_profiles(username: Any) -> TDataItems: print(f"getting {username} profile via thread {threading.current_thread().name}") sleep(1) # add some latency to show parallel runs @@ -59,6 +57,7 @@ def players_games(username: Any) -> Iterator[TDataItems]: MAX_PLAYERS = 5 + def load_data_with_retry(pipeline, data): try: for attempt in Retrying( @@ -68,9 +67,7 @@ def load_data_with_retry(pipeline, data): reraise=True, ): with attempt: - logger.info( - f"Running the pipeline, attempt={attempt.retry_state.attempt_number}" - ) + logger.info(f"Running the pipeline, attempt={attempt.retry_state.attempt_number}") load_info = pipeline.run(data) logger.info(str(load_info)) @@ -92,9 +89,7 @@ def load_data_with_retry(pipeline, data): # print the information on the first load package and all jobs inside logger.info(f"First load package info: {load_info.load_packages[0]}") # print the information on the first completed job in first load package - logger.info( - f"First completed job info: {load_info.load_packages[0].jobs['completed_jobs'][0]}" - ) + logger.info(f"First completed job info: {load_info.load_packages[0].jobs['completed_jobs'][0]}") # check for schema updates: schema_updates = [p.schema_update for p in load_info.load_packages] @@ -152,4 +147,4 @@ def load_data_with_retry(pipeline, data): ) # get data for a few famous players data = chess(chess_url="https://api.chess.com/pub/", max_players=MAX_PLAYERS) - load_data_with_retry(pipeline, data) \ No newline at end of file + load_data_with_retry(pipeline, data) diff --git a/docs/examples/connector_x_arrow/load_arrow.py b/docs/examples/connector_x_arrow/load_arrow.py index 06ca4e17b3..b3c654cef9 100644 --- a/docs/examples/connector_x_arrow/load_arrow.py +++ b/docs/examples/connector_x_arrow/load_arrow.py @@ -3,6 +3,7 @@ import dlt from dlt.sources.credentials import ConnectionStringCredentials + def read_sql_x( conn_str: ConnectionStringCredentials = dlt.secrets.value, query: str = dlt.config.value, @@ -14,6 +15,7 @@ def read_sql_x( protocol="binary", ) + def genome_resource(): # create genome resource with merge on `upid` primary key genome = dlt.resource( diff --git a/docs/examples/google_sheets/google_sheets.py b/docs/examples/google_sheets/google_sheets.py index 8a93df9970..1ba330e4ca 100644 --- a/docs/examples/google_sheets/google_sheets.py +++ b/docs/examples/google_sheets/google_sheets.py @@ -9,6 +9,7 @@ ) from dlt.common.typing import DictStrAny, StrAny + def _initialize_sheets( credentials: Union[GcpOAuthCredentials, GcpServiceAccountCredentials] ) -> Any: @@ -16,6 +17,7 @@ def _initialize_sheets( service = build("sheets", "v4", credentials=credentials.to_native_credentials()) return service + @dlt.source def google_spreadsheet( spreadsheet_id: str, @@ -55,6 +57,7 @@ def get_sheet(sheet_name: str) -> Iterator[DictStrAny]: for name in sheet_names ] + if __name__ == "__main__": pipeline = dlt.pipeline(destination="duckdb") # see example.secrets.toml to where to put credentials @@ -67,4 +70,4 @@ def get_sheet(sheet_name: str) -> Iterator[DictStrAny]: sheet_names=range_names, ) ) - print(info) \ No newline at end of file + print(info) diff --git a/docs/examples/incremental_loading/zendesk.py b/docs/examples/incremental_loading/zendesk.py index 4b8597886a..6113f98793 100644 --- a/docs/examples/incremental_loading/zendesk.py +++ b/docs/examples/incremental_loading/zendesk.py @@ -6,12 +6,11 @@ from dlt.common.typing import TAnyDateTime from dlt.sources.helpers.requests import client + @dlt.source(max_table_nesting=2) def zendesk_support( credentials: Dict[str, str] = dlt.secrets.value, - start_date: Optional[TAnyDateTime] = pendulum.datetime( # noqa: B008 - year=2000, month=1, day=1 - ), + start_date: Optional[TAnyDateTime] = pendulum.datetime(year=2000, month=1, day=1), # noqa: B008 end_date: Optional[TAnyDateTime] = None, ): """ @@ -113,6 +112,7 @@ def get_pages( if not response_json["end_of_stream"]: get_url = response_json["next_page"] + if __name__ == "__main__": # create dlt pipeline pipeline = dlt.pipeline( @@ -120,4 +120,4 @@ def get_pages( ) load_info = pipeline.run(zendesk_support()) - print(load_info) \ No newline at end of file + print(load_info) diff --git a/docs/examples/nested_data/nested_data.py b/docs/examples/nested_data/nested_data.py index 3464448de6..7f85f0522e 100644 --- a/docs/examples/nested_data/nested_data.py +++ b/docs/examples/nested_data/nested_data.py @@ -13,6 +13,7 @@ CHUNK_SIZE = 10000 + # You can limit how deep dlt goes when generating child tables. # By default, the library will descend and generate child tables # for all nested lists, without a limit. @@ -81,6 +82,7 @@ def load_documents(self) -> Iterator[TDataItem]: while docs_slice := list(islice(cursor, CHUNK_SIZE)): yield map_nested_in_place(convert_mongo_objs, docs_slice) + def convert_mongo_objs(value: Any) -> Any: if isinstance(value, (ObjectId, Decimal128)): return str(value) diff --git a/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py b/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py index 8f7833e7d7..e7f57853ed 100644 --- a/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py +++ b/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py @@ -4,6 +4,7 @@ from dlt.destinations.impl.weaviate import weaviate_adapter from PyPDF2 import PdfReader + @dlt.resource(selected=False) def list_files(folder_path: str): folder_path = os.path.abspath(folder_path) @@ -15,6 +16,7 @@ def list_files(folder_path: str): "mtime": os.path.getmtime(file_path), } + @dlt.transformer(primary_key="page_id", write_disposition="merge") def pdf_to_text(file_item, separate_pages: bool = False): if not separate_pages: @@ -28,6 +30,7 @@ def pdf_to_text(file_item, separate_pages: bool = False): page_item["page_id"] = file_item["file_name"] + "_" + str(page_no) yield page_item + pipeline = dlt.pipeline(pipeline_name="pdf_to_text", destination="weaviate") # this constructs a simple pipeline that: (1) reads files from "invoices" folder (2) filters only those ending with ".pdf" @@ -51,4 +54,4 @@ def pdf_to_text(file_item, separate_pages: bool = False): client = weaviate.Client("http://localhost:8080") # get text of all the invoices in InvoiceText class we just created above -print(client.query.get("InvoiceText", ["text", "file_name", "mtime", "page_id"]).do()) \ No newline at end of file +print(client.query.get("InvoiceText", ["text", "file_name", "mtime", "page_id"]).do()) diff --git a/docs/examples/qdrant_zendesk/qdrant.py b/docs/examples/qdrant_zendesk/qdrant.py index 300d8dc6ad..bd0cbafc99 100644 --- a/docs/examples/qdrant_zendesk/qdrant.py +++ b/docs/examples/qdrant_zendesk/qdrant.py @@ -10,13 +10,12 @@ from dlt.common.configuration.inject import with_config + # function from: https://github.com/dlt-hub/verified-sources/tree/master/sources/zendesk @dlt.source(max_table_nesting=2) def zendesk_support( credentials: Dict[str, str] = dlt.secrets.value, - start_date: Optional[TAnyDateTime] = pendulum.datetime( # noqa: B008 - year=2000, month=1, day=1 - ), + start_date: Optional[TAnyDateTime] = pendulum.datetime(year=2000, month=1, day=1), # noqa: B008 end_date: Optional[TAnyDateTime] = None, ): """ @@ -80,6 +79,7 @@ def _parse_date_or_none(value: Optional[str]) -> Optional[pendulum.DateTime]: return None return ensure_pendulum_datetime(value) + # modify dates to return datetime objects instead def _fix_date(ticket): ticket["updated_at"] = _parse_date_or_none(ticket["updated_at"]) @@ -87,6 +87,7 @@ def _fix_date(ticket): ticket["due_at"] = _parse_date_or_none(ticket["due_at"]) return ticket + # function from: https://github.com/dlt-hub/verified-sources/tree/master/sources/zendesk def get_pages( url: str, @@ -127,6 +128,7 @@ def get_pages( if not response_json["end_of_stream"]: get_url = response_json["next_page"] + if __name__ == "__main__": # create a pipeline with an appropriate name pipeline = dlt.pipeline( @@ -146,7 +148,6 @@ def get_pages( print(load_info) - # running the Qdrant client to connect to your Qdrant database @with_config(sections=("destination", "qdrant", "credentials")) diff --git a/docs/examples/transformers/pokemon.py b/docs/examples/transformers/pokemon.py index c17beff6a8..ca32c570ef 100644 --- a/docs/examples/transformers/pokemon.py +++ b/docs/examples/transformers/pokemon.py @@ -1,6 +1,7 @@ import dlt from dlt.sources.helpers import requests + @dlt.source(max_table_nesting=2) def source(pokemon_api_url: str): """""" @@ -28,15 +29,13 @@ def _get_pokemon(_pokemon): # a special case where just one item is retrieved in transformer # a whole transformer may be marked for parallel execution - @dlt.transformer - @dlt.defer + @dlt.transformer(parallelized=True) def species(pokemon_details): """Yields species details for a pokemon""" species_data = requests.get(pokemon_details["species"]["url"]).json() # link back to pokemon so we have a relation in loaded data species_data["pokemon_id"] = pokemon_details["id"] - # just return the results, if you yield, - # generator will be evaluated in main thread + # You can return the result instead of yield since the transformer only generates one result return species_data # create two simple pipelines with | operator @@ -46,6 +45,7 @@ def species(pokemon_details): return (pokemon_list | pokemon, pokemon_list | pokemon | species) + if __name__ == "__main__": # build duck db pipeline pipeline = dlt.pipeline( @@ -54,4 +54,4 @@ def species(pokemon_details): # the pokemon_list resource does not need to be loaded load_info = pipeline.run(source("https://pokeapi.co/api/v2/pokemon")) - print(load_info) \ No newline at end of file + print(load_info) diff --git a/docs/website/blog/2024-02-28-what-is-pyairbyte.md b/docs/website/blog/2024-02-28-what-is-pyairbyte.md index ffacb1c2d5..02ab1b6de3 100644 --- a/docs/website/blog/2024-02-28-what-is-pyairbyte.md +++ b/docs/website/blog/2024-02-28-what-is-pyairbyte.md @@ -18,10 +18,8 @@ Here at dltHub, we work on the python library for data ingestion. So when I hear PyAirbyte is an interesting Airbyte’s initiative - similar to the one that Meltano had undertook 3 years ago. It provides a convenient way to download and install Airbyte sources and run them locally storing the data in a cache dataset. Users are allowed to then read the data from this cache. - A Python wrapper on the Airbyte source is quite nice and has a feeling close to [Alto](https://github.com/z3z1ma/alto). The whole process of cloning/pip installing the repository, spawning a separate process to run Airbyte connector and read the data via UNIX pipe is hidden behind Pythonic interface. - Note that this library is not an Airbyte replacement - the loaders of Airbyte and the library are very different. The library loader uses pandas.to_sql and sql alchemy and is not a replacement for Airbyte destinations that are available in Open Source Airbyte # Questions I had, answered diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md b/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md index f3ac6f83d6..df968422d7 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md @@ -103,7 +103,7 @@ import pandas as pd # Create a resource using that yields a dataframe, using the `ordered_at` field as an incremental cursor @dlt.resource(primary_key="order_id") -def orders(ordered_at = dlt.sources.incremental('ordered_at')) +def orders(ordered_at = dlt.sources.incremental('ordered_at')): # Get dataframe/arrow table from somewhere # If your database supports it, you can use the last_value to filter data at the source. # Otherwise it will be filtered automatically after loading the data. diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/salesforce.md b/docs/website/docs/dlt-ecosystem/verified-sources/salesforce.md index 3051c740e1..aa8fbe10d4 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/salesforce.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/salesforce.md @@ -1,13 +1,12 @@ # Salesforce :::info Need help deploying these sources, or figuring out how to run them in your data stack? - -[Join our Slack community](https://dlthub.com/community) -or [book a call](https://calendar.app.google/kiLhuMsWKpZUpfho6) with our support engineer Adrian. +[Join our Slack community](https://dlthub.com/community) or +[book a call](https://calendar.app.google/kiLhuMsWKpZUpfho6) with our support engineer Adrian. ::: -[Salesforce](https://www.salesforce.com) is a cloud platform that streamlines business operations and customer relationship -management, encompassing sales, marketing, and customer service. +[Salesforce](https://www.salesforce.com) is a cloud platform that streamlines business operations +and customer relationship management, encompassing sales, marketing, and customer service. This Salesforce `dlt` verified source and [pipeline example](https://github.com/dlt-hub/verified-sources/blob/master/sources/salesforce_pipeline.py) @@ -37,8 +36,8 @@ The resources that this verified source supports are: ### Grab credentials -To set up your pipeline, you'll need your Salesforce `user_name`, `password`, and `security_token`. Use -your login credentials for user_name and password. +To set up your pipeline, you'll need your Salesforce `user_name`, `password`, and `security_token`. +Use your login credentials for user_name and password. To obtain the `security_token`, follow these steps: @@ -55,9 +54,8 @@ To obtain the `security_token`, follow these steps: 1. Check your email for the token sent by Salesforce. -> Note: The Salesforce UI, which is described here, might change. -The full guide is available at [this link.](https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/quickstart_oauth.htm) - +> Note: The Salesforce UI, which is described here, might change. The full guide is available at +> [this link.](https://developer.salesforce.com/docs/atlas.en-us.api_rest.meta/api_rest/quickstart_oauth.htm) ### Initialize the verified source @@ -80,7 +78,8 @@ To get started with your data pipeline, follow these steps: 1. After running this command, a new directory will be created with the necessary files and configuration settings to get started. -For more information, read the guide on [how to add a verified source.](../../walkthroughs/add-a-verified-source) +For more information, read the guide on +[how to add a verified source.](../../walkthroughs/add-a-verified-source) ### Add credentials @@ -169,8 +168,8 @@ destination. | user_role() | contact() | lead() | campaign() | product_2() | pricebook_2() | pricebook_entry() | |-------------|-----------|--------|------------|-------------|---------------|-------------------| -The described functions fetch records from endpoints based on their names, e.g. user_role() -accesses the "user_role" endpoint. +The described functions fetch records from endpoints based on their names, e.g. user_role() accesses +the "user_role" endpoint. ### Resource `opportunity` (incremental loading): @@ -182,7 +181,7 @@ mode. def opportunity( last_timestamp: Incremental[str] = dlt.sources.incremental( "SystemModstamp", initial_value=None - ) + ) ) -> Iterator[Dict[str, Any]]: yield from get_records( @@ -190,9 +189,10 @@ def opportunity( ) ``` -`last_timestamp`: Argument that will receive [incremental](../../general-usage/incremental-loading) state, initialized with "initial_value". -It is configured to track "SystemModstamp" field in data item returned by "get_records" and then yielded. -It will store the newest "SystemModstamp" value in dlt state and make it available in "last_timestamp.last_value" on next pipeline run. +`last_timestamp`: Argument that will receive [incremental](../../general-usage/incremental-loading) +state, initialized with "initial_value". It is configured to track "SystemModstamp" field in data +item returned by "get_records" and then yielded. It will store the newest "SystemModstamp" value in +dlt state and make it available in "last_timestamp.last_value" on next pipeline run. Besides "opportunity", there are several resources that use replace mode for data writing to the destination. @@ -211,8 +211,7 @@ If you wish to create your own pipelines, you can leverage source and resource m above. To create your data pipeline using single loading and -[incremental data loading](../../general-usage/incremental-loading), follow these -steps: +[incremental data loading](../../general-usage/incremental-loading), follow these steps: 1. Configure the pipeline by specifying the pipeline name, destination, and dataset as follows: @@ -220,7 +219,7 @@ steps: pipeline = dlt.pipeline( pipeline_name="salesforce_pipeline", # Use a custom name if desired destination="duckdb", # Choose the appropriate destination (e.g., duckdb, redshift, post) - dataset_name="salesforce_data" # Use a custom name if desired + dataset_name="salesforce_data", # Use a custom name if desired ) ``` @@ -231,7 +230,7 @@ steps: ```python load_data = salesforce_source() - source.schema.merge_hints({"not_null": ["id"]}) #Hint for id field not null + source.schema.merge_hints({"not_null": ["id"]}) # Hint for id field not null load_info = pipeline.run(load_data) # print the information on data that was loaded print(load_info) @@ -254,15 +253,12 @@ steps: endpoints in merge mode with the “dlt.sources.incremental” parameter. > For incremental loading of endpoints, maintain the pipeline name and destination dataset name. - > The pipeline name is important for accessing the - > [state](../../general-usage/state) from the last run, including the end date - > for incremental data loads. Altering these names could trigger a - > [“full_refresh”](../../general-usage/pipeline#do-experiments-with-full-refresh), - > disrupting the metadata tracking for - > [incremental data loading](../../general-usage/incremental-loading). + > The pipeline name is important for accessing the [state](../../general-usage/state) from the + > last run, including the end date for incremental data loads. Altering these names could trigger + > a [“full_refresh”](../../general-usage/pipeline#do-experiments-with-full-refresh), disrupting + > the metadata tracking for [incremental data loading](../../general-usage/incremental-loading). -1. To load data from the “contact” in replace mode and “task” incrementally merge mode - endpoints: +1. To load data from the “contact” in replace mode and “task” incrementally merge mode endpoints: ```python load_info = pipeline.run(load_data.with_resources("contact", "task")) @@ -274,6 +270,19 @@ steps: > overwriting existing data. Conversely, the "task" endpoint supports "merge" mode for > incremental loads, updating or adding data based on the 'last_timestamp' value without erasing > previously loaded data. + +1. Salesforce enforces specific limits on API data requests. These limits + vary based on the Salesforce edition and license type, as outlined in the [Salesforce API Request Limits documentation](https://developer.salesforce.com/docs/atlas.en-us.salesforce_app_limits_cheatsheet.meta/salesforce_app_limits_cheatsheet/salesforce_app_limits_platform_api.htm). + + To limit the number of Salesforce API data requests, developers can control the environment for production or + development purposes. For development, you can set the `IS_PRODUCTION` variable + to `False` in "[salesforce/settings.py](https://github.com/dlt-hub/verified-sources/blob/master/sources/salesforce/settings.py)", + which limits API call requests to 100. To modify this limit, you can update the query limit in + "[salesforce/helpers.py](https://github.com/dlt-hub/verified-sources/blob/756edaa00f56234cd06699178098f44c16d6d597/sources/salesforce/helpers.py#L56)" + as required. + + >To read more about Salesforce query limits, please refer to their official + >[documentation here](https://developer.salesforce.com/docs/atlas.en-us.soql_sosl.meta/soql_sosl/sforce_api_calls_soql_select_limit.htm). - \ No newline at end of file + diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md b/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md index 3f0532e9d2..67965863ce 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/sql_database.md @@ -215,6 +215,10 @@ def sql_database( schema: Optional[str] = dlt.config.value, metadata: Optional[MetaData] = None, table_names: Optional[List[str]] = dlt.config.value, + chunk_size: int = 1000, + detect_precision_hints: Optional[bool] = dlt.config.value, + defer_table_reflect: Optional[bool] = dlt.config.value, + table_adapter_callback: Callable[[Table], None] = None, ) -> Iterable[DltResource]: ``` @@ -226,6 +230,16 @@ def sql_database( `table_names`: List of tables to load; defaults to all if not provided. +`chunk_size`: Number of records in a batch. Internally SqlAlchemy maintains a buffer twice that size + +`detect_precision_hints`: Infers full schema for columns including data type, precision and scale + +`defer_table_reflect`: Will connect to the source database and reflect the tables +only at runtime. Use when running on Airflow + +`table_adapter_callback`: A callback with SQLAlchemy `Table` where you can, for example, +remove certain columns to be selected. + ### Resource `sql_table` This function loads data from specific database tables. @@ -240,21 +254,18 @@ def sql_table( schema: Optional[str] = dlt.config.value, metadata: Optional[MetaData] = None, incremental: Optional[dlt.sources.incremental[Any]] = None, + chunk_size: int = 1000, + detect_precision_hints: Optional[bool] = dlt.config.value, + defer_table_reflect: Optional[bool] = dlt.config.value, + table_adapter_callback: Callable[[Table], None] = None, ) -> DltResource: ``` - -`credentials`: Database info or an Engine instance. - -`table`: Table to load, set in code or default from "config.toml". - -`schema`: Optional name of the table schema. - -`metadata`: Optional SQLAlchemy.MetaData; takes precedence over schema. - `incremental`: Optional, enables incremental loading. `write_disposition`: Can be "merge", "replace", or "append". +for other arguments, see `sql_database` source above. + ## Incremental Loading Efficient data management often requires loading only new or updated data from your SQL databases, rather than reprocessing the entire dataset. This is where incremental loading comes into play. @@ -264,15 +275,13 @@ Incremental loading uses a cursor column (e.g., timestamp or auto-incrementing I ### Configuring Incremental Loading 1. **Choose a Cursor Column**: Identify a column in your SQL table that can serve as a reliable indicator of new or updated rows. Common choices include timestamp columns or auto-incrementing IDs. 1. **Set an Initial Value**: Choose a starting value for the cursor to begin loading data. This could be a specific timestamp or ID from which you wish to start loading data. -1. **Apply Incremental Configuration**: Enable incremental loading with your configuration's `incremental` argument. 1. **Deduplication**: When using incremental loading, the system automatically handles the deduplication of rows based on the primary key (if available) or row hash for tables without a primary key. - -:::note -Incorporating incremental loading into your SQL data pipelines can significantly enhance performance by minimizing unnecessary data processing and transfer. -::: +1. **Set end_value for backfill**: Set `end_value` if you want to backfill data from +certain range. +1. **Order returned rows**. Set `row_order` to `asc` or `desc` to order returned rows. #### Incremental Loading Example -1. Consider a table with a `last_modified` timestamp column. By setting this column as your cursor and specifying an +1. Consider a table with a `last_modified` timestamp column. By setting this column as your cursor and specifying an initial value, the loader generates a SQL query filtering rows with `last_modified` values greater than the specified initial value. ```python @@ -288,8 +297,8 @@ Incorporating incremental loading into your SQL data pipelines can significantly ) ) - info = pipeline.extract(table, write_disposition="merge") - print(info) + info = pipeline.extract(table, write_disposition="merge") + print(info) ``` 1. To incrementally load the "family" table using the sql_database source method: @@ -325,6 +334,20 @@ Incorporating incremental loading into your SQL data pipelines can significantly * `apply_hints` is a powerful method that enables schema modifications after resource creation, like adjusting write disposition and primary keys. You can choose from various tables and use `apply_hints` multiple times to create pipelines with merged, appendend, or replaced resources. ::: +### Run on Airflow +When running on Airflow +1. Use `dlt` [Airflow Helper](../../walkthroughs/deploy-a-pipeline/deploy-with-airflow-composer.md#2-modify-dag-file) to create tasks from `sql_database` source. You should be able to run table extraction in parallel with `parallel-isolated` source->DAG conversion. +2. Reflect tables at runtime with `defer_table_reflect` argument. +3. Set `allow_external_schedulers` to load data using [Airflow intervals](../../general-usage/incremental-loading.md#using-airflow-schedule-for-backfill-and-incremental-loading). + +### Parallel extraction +You can extract each table in a separate thread (no multiprocessing at this point). This will decrease loading time if your queries take time to execute or your network latency/speed is low. +```python +database = sql_database().parallelize() +table = sql_table().parallelize() +``` + + ### Troubleshooting If you encounter issues where the expected WHERE clause for incremental loading is not generated, ensure your configuration aligns with the `sql_table` resource rather than applying hints post-resource creation. This ensures the loader generates the correct query for incremental loading. diff --git a/docs/website/docs/examples/chess_production/code/chess-snippets.py b/docs/website/docs/examples/chess_production/code/chess-snippets.py index b22dc3693b..39ddf14836 100644 --- a/docs/website/docs/examples/chess_production/code/chess-snippets.py +++ b/docs/website/docs/examples/chess_production/code/chess-snippets.py @@ -32,11 +32,8 @@ def players() -> Iterator[TDataItems]: yield from _get_data_with_retry(f"titled/{title}")["players"][:max_players] # this resource takes data from players and returns profiles - # it uses `defer` decorator to enable parallel run in thread pool. - # defer requires return at the end so we convert yield into return (we return one item anyway) - # you can still have yielding transformers, look for the test named `test_evolve_schema` - @dlt.transformer(data_from=players, write_disposition="replace") - @dlt.defer + # it uses `paralellized` flag to enable parallel run in thread pool. + @dlt.transformer(data_from=players, write_disposition="replace", parallelized=True) def players_profiles(username: Any) -> TDataItems: print(f"getting {username} profile via thread {threading.current_thread().name}") sleep(1) # add some latency to show parallel runs diff --git a/docs/website/docs/examples/chess_production/index.md b/docs/website/docs/examples/chess_production/index.md index f372d26d80..d80558e745 100644 --- a/docs/website/docs/examples/chess_production/index.md +++ b/docs/website/docs/examples/chess_production/index.md @@ -55,11 +55,8 @@ def chess( yield from _get_data_with_retry(f"titled/{title}")["players"][:max_players] # this resource takes data from players and returns profiles - # it uses `defer` decorator to enable parallel run in thread pool. - # defer requires return at the end so we convert yield into return (we return one item anyway) - # you can still have yielding transformers, look for the test named `test_evolve_schema` - @dlt.transformer(data_from=players, write_disposition="replace") - @dlt.defer + # it uses `paralellized` flag to enable parallel run in thread pool. + @dlt.transformer(data_from=players, write_disposition="replace", parallelized=True) def players_profiles(username: Any) -> TDataItems: print(f"getting {username} profile via thread {threading.current_thread().name}") sleep(1) # add some latency to show parallel runs @@ -180,6 +177,15 @@ def load_data_with_retry(pipeline, data): ``` +:::warning +To run this example you need to provide Slack incoming hook in `.dlt/secrets.toml`: +```python +[runtime] +slack_incoming_hook="https://hooks.slack.com/services/***" +``` +Read [Using Slack to send messages.](https://dlthub.com/docs/running-in-production/running#using-slack-to-send-messages) +::: + ### Run the pipeline @@ -195,4 +201,4 @@ if __name__ == "__main__": data = chess(chess_url="https://api.chess.com/pub/", max_players=MAX_PLAYERS) load_data_with_retry(pipeline, data) ``` - \ No newline at end of file + diff --git a/docs/website/docs/examples/incremental_loading/code/zendesk-snippets.py b/docs/website/docs/examples/incremental_loading/code/zendesk-snippets.py index 5ec3015741..ff12a00fca 100644 --- a/docs/website/docs/examples/incremental_loading/code/zendesk-snippets.py +++ b/docs/website/docs/examples/incremental_loading/code/zendesk-snippets.py @@ -140,4 +140,4 @@ def get_pages( # check that stuff was loaded row_counts = pipeline.last_trace.last_normalize_info.row_counts - assert row_counts["ticket_events"] >= 17 + assert row_counts["ticket_events"] == 17 \ No newline at end of file diff --git a/docs/website/docs/examples/pdf_to_weaviate/code/pdf_to_weaviate-snippets.py b/docs/website/docs/examples/pdf_to_weaviate/code/pdf_to_weaviate-snippets.py index 1ad7cc8159..ae61af3746 100644 --- a/docs/website/docs/examples/pdf_to_weaviate/code/pdf_to_weaviate-snippets.py +++ b/docs/website/docs/examples/pdf_to_weaviate/code/pdf_to_weaviate-snippets.py @@ -1,6 +1,8 @@ from tests.pipeline.utils import assert_load_info +from tests.utils import skipifgithubfork +@skipifgithubfork def pdf_to_weaviate_snippet() -> None: # @@@DLT_SNIPPET_START example # @@@DLT_SNIPPET_START pdf_to_weaviate diff --git a/docs/website/docs/examples/transformers/code/pokemon-snippets.py b/docs/website/docs/examples/transformers/code/pokemon-snippets.py index d8fe4f41ba..ff8757b94e 100644 --- a/docs/website/docs/examples/transformers/code/pokemon-snippets.py +++ b/docs/website/docs/examples/transformers/code/pokemon-snippets.py @@ -30,15 +30,13 @@ def _get_pokemon(_pokemon): # a special case where just one item is retrieved in transformer # a whole transformer may be marked for parallel execution - @dlt.transformer - @dlt.defer + @dlt.transformer(parallelized=True) def species(pokemon_details): """Yields species details for a pokemon""" species_data = requests.get(pokemon_details["species"]["url"]).json() # link back to pokemon so we have a relation in loaded data species_data["pokemon_id"] = pokemon_details["id"] - # just return the results, if you yield, - # generator will be evaluated in main thread + # You can return the result instead of yield since the transformer only generates one result return species_data # create two simple pipelines with | operator @@ -48,7 +46,6 @@ def species(pokemon_details): return (pokemon_list | pokemon, pokemon_list | pokemon | species) - __name__ = "__main__" # @@@DLT_REMOVE if __name__ == "__main__": # build duck db pipeline pipeline = dlt.pipeline( @@ -60,6 +57,14 @@ def species(pokemon_details): print(load_info) # @@@DLT_SNIPPET_END example + # Run without __main__ + pipeline = dlt.pipeline( + pipeline_name="pokemon", destination="duckdb", dataset_name="pokemon_data" + ) + + # the pokemon_list resource does not need to be loaded + load_info = pipeline.run(source("https://pokeapi.co/api/v2/pokemon")) + # test assertions row_counts = pipeline.last_trace.last_normalize_info.row_counts assert row_counts["pokemon"] == 20 diff --git a/docs/website/docs/examples/transformers/index.md b/docs/website/docs/examples/transformers/index.md index 860e830aae..2f5c3dd532 100644 --- a/docs/website/docs/examples/transformers/index.md +++ b/docs/website/docs/examples/transformers/index.md @@ -60,15 +60,13 @@ def source(pokemon_api_url: str): # a special case where just one item is retrieved in transformer # a whole transformer may be marked for parallel execution - @dlt.transformer - @dlt.defer + @dlt.transformer(parallelized=True) def species(pokemon_details): """Yields species details for a pokemon""" species_data = requests.get(pokemon_details["species"]["url"]).json() # link back to pokemon so we have a relation in loaded data species_data["pokemon_id"] = pokemon_details["id"] - # just return the results, if you yield, - # generator will be evaluated in main thread + # You can return the result instead of yield since the transformer only generates one result return species_data # create two simple pipelines with | operator diff --git a/docs/website/docs/general-usage/incremental-loading.md b/docs/website/docs/general-usage/incremental-loading.md index 7e4021214e..dd52c9c750 100644 --- a/docs/website/docs/general-usage/incremental-loading.md +++ b/docs/website/docs/general-usage/incremental-loading.md @@ -192,7 +192,7 @@ def resource(): yield [ {"id": 2, "val": "foo", "lsn": 1, "deleted_flag": False}, {"id": 2, "lsn": 2, "deleted_flag": True} - ] + ] ... ``` @@ -267,7 +267,7 @@ In essence, `dlt.sources.incremental` instance above * **updated_at.initial_value** which is always equal to "1970-01-01T00:00:00Z" passed in constructor * **updated_at.start_value** a maximum `updated_at` value from the previous run or the **initial_value** on first run * **updated_at.last_value** a "real time" `updated_at` value updated with each yielded item or page. before first yield it equals **start_value** -* **updated_at.end_value** (here not used) [marking end of backfill range](#using-dltsourcesincremental-for-backfill) +* **updated_at.end_value** (here not used) [marking end of backfill range](#using-end_value-for-backfill) When paginating you probably need **start_value** which does not change during the execution of the resource, however most paginators will return a **next page** link which you should use. @@ -284,31 +284,21 @@ duplicates and past issues. # use naming function in table name to generate separate tables for each event @dlt.resource(primary_key="id", table_name=lambda i: i['type']) # type: ignore def repo_events( - last_created_at = dlt.sources.incremental("created_at", initial_value="1970-01-01T00:00:00Z", last_value_func=max) + last_created_at = dlt.sources.incremental("created_at", initial_value="1970-01-01T00:00:00Z", last_value_func=max), row_order="desc" ) -> Iterator[TDataItems]: repos_path = "/repos/%s/%s/events" % (urllib.parse.quote(owner), urllib.parse.quote(name)) for page in _get_rest_pages(access_token, repos_path + "?per_page=100"): yield page - - # ---> part below is an optional optimization - # Stop requesting more pages when we encounter an element that - # is older than the incremental value at the beginning of the run. - # The start_out_of_range boolean flag is set in this case - if last_created_at.start_out_of_range: - break ``` We just yield all the events and `dlt` does the filtering (using `id` column declared as `primary_key`). -As an optimization we stop requesting more pages once the incremental value is out of range, -in this case that means we got an element which has a smaller `created_at` than the the `last_created_at.start_value`. -The `start_out_of_range` boolean flag is set when the first such element is yielded from the resource, and -since we know that github returns results ordered from newest to oldest, we know that all subsequent -items will be filtered out anyway and there's no need to fetch more data. + +Github returns events ordered from newest to oldest so we declare the `rows_order` as **descending** to [stop requesting more pages once the incremental value is out of range](#declare-row-order-to-not-request-unnecessary-data). We stop requesting more data from the API after finding first event with `created_at` earlier than `initial_value`. ### max, min or custom `last_value_func` -`dlt.sources.incremental` allows to choose a function that orders (compares) values coming from the items to current `last_value`. +`dlt.sources.incremental` allows to choose a function that orders (compares) cursor values to current `last_value`. * The default function is built-in `max` which returns bigger value of the two * Another built-in `min` returns smaller value. @@ -341,9 +331,134 @@ def get_events(last_created_at = dlt.sources.incremental("$", last_value_func=by yield json.load(f) ``` +### Using `end_value` for backfill +You can specify both initial and end dates when defining incremental loading. Let's go back to our Github example: +```python +@dlt.resource(primary_key="id") +def repo_issues( + access_token, + repository, + created_at = dlt.sources.incremental("created_at", initial_value="1970-01-01T00:00:00Z", end_value="2022-07-01T00:00:00Z") +): + # get issues from created from last "created_at" value + for page in _get_issues_page(access_token, repository, since=created_at.start_value, until=created_at.end_value): + yield page +``` +Above we use `initial_value` and `end_value` arguments of the `incremental` to define the range of issues that we want to retrieve +and pass this range to the Github API (`since` and `until`). As in the examples above, `dlt` will make sure that only the issues from +defined range are returned. + +Please note that when `end_date` is specified, `dlt` **will not modify the existing incremental state**. The backfill is **stateless** and: +1. You can run backfill and incremental load in parallel (ie. in Airflow DAG) in a single pipeline. +2. You can partition your backfill into several smaller chunks and run them in parallel as well. + +To define specific ranges to load, you can simply override the incremental argument in the resource, for example: + +```python +july_issues = repo_issues( + created_at=dlt.sources.incremental( + initial_value='2022-07-01T00:00:00Z', end_value='2022-08-01T00:00:00Z' + ) +) +august_issues = repo_issues( + created_at=dlt.sources.incremental( + initial_value='2022-08-01T00:00:00Z', end_value='2022-09-01T00:00:00Z' + ) +) +... +``` + +Note that `dlt`'s incremental filtering considers the ranges half closed. `initial_value` is inclusive, `end_value` is exclusive, so chaining ranges like above works without overlaps. + + +### Declare row order to not request unnecessary data +With `row_order` argument set, `dlt` will stop getting data from the data source (ie. Github API) if it detect that values of cursor field are out of range of **start** and **end** values. + +In particular: +* `dlt` stops processing when the resource yields any item with an _equal or greater_ cursor value than the `end_value` and `row_order` is set to **asc**. (`end_value` is not included) +* `dlt` stops processing when the resource yields any item with a _lower_ cursor value than the `last_value` and `row_order` is set to **desc**. (`last_value` is included) + +:::note +"higher" and "lower" here refers to when the default `last_value_func` is used (`max()`), +when using `min()` "higher" and "lower" are inverted. +::: + +:::caution +If you use `row_order`, **make sure that the data source returns ordered records** (ascending / descending) on the cursor field, +e.g. if an API returns results both higher and lower +than the given `end_value` in no particular order, data reading stops and you'll miss the data items that were out of order. +::: + +Row order is the most useful when: + +1. The data source does **not** offer start/end filtering of results (e.g. there is no `start_time/end_time` query parameter or similar) +2. The source returns results **ordered by the cursor field** + +The github events example is exactly such case. The results are ordered on cursor value descending but there's no way to tell API to limit returned items to those created before certain date. Without the `row_order` setting, we'd be getting all events, each time we extract the `github_events` resource. + +In the same fashion the `row_order` can be used to **optimize backfill** so we don't continue +making unnecessary API requests after the end of range is reached. For example: + +```python +@dlt.resource(primary_key="id") +def tickets( + zendesk_client, + updated_at=dlt.sources.incremental( + "updated_at", + initial_value="2023-01-01T00:00:00Z", + end_value="2023-02-01T00:00:00Z", + row_order="asc" + ), +): + for page in zendesk_client.get_pages( + "/api/v2/incremental/tickets", "tickets", start_time=updated_at.start_value + ): + yield page +``` + +In this example we're loading tickets from Zendesk. The Zendesk API yields items paginated and ordered by oldest to newest, +but only offers a `start_time` parameter for filtering so we cannot tell it to +stop getting data at `end_value`. Instead we set `row_order` to `asc` and `dlt` wil stop +getting more pages from API after first page with cursor value `updated_at` is found older +than `end_value`. + +:::caution +In rare cases when you use Incremental with a transformer, `dlt` will not be able to automatically close +generator associated with a row that is out of range. You can still use still call `can_close()` method on +incremental and exit yield loop when true. +::: + +:::tip +The `dlt.sources.incremental` instance provides `start_out_of_range` and `end_out_of_range` +attributes which are set when the resource yields an element with a higher/lower cursor value than the +initial or end values. If you do not want `dlt` to stop processing automatically and instead to handle such events yourself, do not specify `row_order`: +```python +@dlt.transformer(primary_key="id") +def tickets( + zendesk_client, + updated_at=dlt.sources.incremental( + "updated_at", + initial_value="2023-01-01T00:00:00Z", + end_value="2023-02-01T00:00:00Z", + row_order="asc" + ), +): + for page in zendesk_client.get_pages( + "/api/v2/incremental/tickets", "tickets", start_time=updated_at.start_value + ): + yield page + # Stop loading when we reach the end value + if updated_at.end_out_of_range: + return + +``` +::: + ### Deduplication primary_key -`dlt.sources.incremental` let's you optionally set a `primary_key` that is used exclusively to +`dlt.sources.incremental` will inherit the primary key that is set on the resource. + + let's you optionally set a `primary_key` that is used exclusively to deduplicate and which does not become a table hint. The same setting lets you disable the deduplication altogether when empty tuple is passed. Below we pass `primary_key` directly to `incremental` to disable deduplication. That overrides `delta` primary_key set in the resource: @@ -395,45 +510,6 @@ is created. That prevents `dlt` from controlling the **created** argument during result in `IncrementalUnboundError` exception. ::: -### Using `dlt.sources.incremental` for backfill -You can specify both initial and end dates when defining incremental loading. Let's go back to our Github example: -```python -@dlt.resource(primary_key="id") -def repo_issues( - access_token, - repository, - created_at = dlt.sources.incremental("created_at", initial_value="1970-01-01T00:00:00Z", end_value="2022-07-01T00:00:00Z") -): - # get issues from created from last "created_at" value - for page in _get_issues_page(access_token, repository, since=created_at.start_value, until=created_at.end_value): - yield page -``` -Above we use `initial_value` and `end_value` arguments of the `incremental` to define the range of issues that we want to retrieve -and pass this range to the Github API (`since` and `until`). As in the examples above, `dlt` will make sure that only the issues from -defined range are returned. - -Please note that when `end_date` is specified, `dlt` **will not modify the existing incremental state**. The backfill is **stateless** and: -1. You can run backfill and incremental load in parallel (ie. in Airflow DAG) in a single pipeline. -2. You can partition your backfill into several smaller chunks and run them in parallel as well. - -To define specific ranges to load, you can simply override the incremental argument in the resource, for example: - -```python -july_issues = repo_issues( - created_at=dlt.sources.incremental( - initial_value='2022-07-01T00:00:00Z', end_value='2022-08-01T00:00:00Z' - ) -) -august_issues = repo_issues( - created_at=dlt.sources.incremental( - initial_value='2022-08-01T00:00:00Z', end_value='2022-09-01T00:00:00Z' - ) -) -... -``` - -Note that `dlt`'s incremental filtering considers the ranges half closed. `initial_value` is inclusive, `end_value` is exclusive, so chaining ranges like above works without overlaps. - ### Using Airflow schedule for backfill and incremental loading When [running in Airflow task](../walkthroughs/deploy-a-pipeline/deploy-with-airflow-composer.md#2-modify-dag-file), you can opt-in your resource to get the `initial_value`/`start_value` and `end_value` from Airflow schedule associated with your DAG. Let's assume that **Zendesk tickets** resource contains a year of data with thousands of tickets. We want to backfill the last year of data week by week and then continue incremental loading daily. ```python @@ -527,59 +603,6 @@ Before `dlt` starts executing incremental resources, it looks for `data_interval You can run DAGs manually but you must remember to specify the Airflow logical date of the run in the past (use Run with config option). For such run `dlt` will load all data from that past date until now. If you do not specify the past date, a run with a range (now, now) will happen yielding no data. -### Using `start/end_out_of_range` flags with incremental resources - -The `dlt.sources.incremental` instance provides `start_out_of_range` and `end_out_of_range` -attributes which are set when the resource yields an element with a higher/lower cursor value than the -initial or end values. -This makes it convenient to optimize resources in some cases. - -* `start_out_of_range` is `True` when the resource yields any item with a _lower_ cursor value than the `initial_value` -* `end_out_of_range` is `True` when the resource yields any item with an equal or _higher_ cursor value than the `end_value` - -**Note**: "higher" and "lower" here refers to when the default `last_value_func` is used (`max()`), -when using `min()` "higher" and "lower" are inverted. - -You can use these flags when both: - -1. The source does **not** offer start/end filtering of results (e.g. there is no `start_time/end_time` query parameter or similar) -2. The source returns results **ordered by the cursor field** - -:::caution -If you use those flags, **make sure that the data source returns record ordered** (ascending / descending) on the cursor field, -e.g. if an API returns results both higher and lower -than the given `end_value` in no particular order, the `end_out_of_range` flag can be `True` but you'll still want to keep loading. -::: - -The github events example above demonstrates how to use `start_out_of_range` as a stop condition. -This approach works in any case where the API returns items in descending order and we're incrementally loading newer data. - -In the same fashion the `end_out_of_range` filter can be used to optimize backfill so we don't continue -making unnecessary API requests after the end of range is reached. For example: - -```python -@dlt.resource(primary_key="id") -def tickets( - zendesk_client, - updated_at=dlt.sources.incremental( - "updated_at", - initial_value="2023-01-01T00:00:00Z", - end_value="2023-02-01T00:00:00Z", - ), -): - for page in zendesk_client.get_pages( - "/api/v2/incremental/tickets", "tickets", start_time=updated_at.start_value - ): - yield page - - # Optimization: Stop loading when we reach the end value - if updated_at.end_out_of_range: - return -``` - -In this example we're loading tickets from Zendesk. The Zendesk API yields items paginated and ordered by oldest to newest, -but only offers a `start_time` parameter for filtering. The incremental `end_out_of_range` flag is set on the first item which -has a timestamp equal or higher than `end_value`. All subsequent items get filtered out so there's no need to request more data. ## Doing a full refresh diff --git a/docs/website/docs/general-usage/resource.md b/docs/website/docs/general-usage/resource.md index b7026c454e..9b8d45982d 100644 --- a/docs/website/docs/general-usage/resource.md +++ b/docs/website/docs/general-usage/resource.md @@ -265,6 +265,38 @@ kinesis_stream = kinesis("telemetry_stream") ``` `kinesis_stream` resource has a name **telemetry_stream** + +### Declare parallel and async resources +You can extract multiple resources in parallel threads or with async IO. +To enable this for a sync resource you can set the `parallelized` flag to `True` in the resource decorator: + + +```python +@dlt.resource(parallelized=True) +def get_users(): + for u in _get_users(): + yield u + +@dlt.resource(parallelized=True) +def get_orders(): + for o in _get_orders(): + yield o + +# users and orders will be iterated in parallel in two separate threads +pipeline.run(get_users(), get_orders()) +``` + +Async generators are automatically extracted concurrently with other resources: + +```python +@dlt.resource +async def get_users(): + async for u in _get_users(): # Assuming _get_users is an async generator + yield u +``` + +Please find more details in [extract performance](../reference/performance.md#extract) + ## Customize resources ### Filter, transform and pivot data @@ -330,6 +362,9 @@ assert list(r) == list(range(10)) > 💡 You cannot limit transformers. They should process all the data they receive fully to avoid > inconsistencies in generated datasets. +> 💡 If you are paremetrizing the value of `add_limit` and sometimes need it to be disabled, you can set `None` or `-1` +> to disable the limiting. You can also set the limit to `0` for the resource to not yield any items. + ### Set table name and adjust schema You can change the schema of a resource, be it standalone or as a part of a source. Look for method diff --git a/docs/website/docs/intro.md b/docs/website/docs/intro.md index 04af626566..6df0dad82d 100644 --- a/docs/website/docs/intro.md +++ b/docs/website/docs/intro.md @@ -165,7 +165,7 @@ print(load_info) Install **pymysql** driver: ```sh -pip install pymysql +pip install sqlalchemy pymysql ``` diff --git a/docs/website/docs/reference/performance.md b/docs/website/docs/reference/performance.md index af2d791324..7c095b53d4 100644 --- a/docs/website/docs/reference/performance.md +++ b/docs/website/docs/reference/performance.md @@ -141,32 +141,66 @@ You can create pipelines that extract, normalize and load data in parallel. ### Extract You can extract data concurrently if you write your pipelines to yield callables or awaitables or use async generators for your resources that can be then evaluated in a thread or futures pool respectively. -The example below simulates a typical situation where a dlt resource is used to fetch a page of items and then details of individual items are fetched separately in the transformer. The `@dlt.defer` decorator wraps the `get_details` function in another callable that will be executed in the thread pool. +This is easily accomplished by using the `parallelized` argument in the resource decorator. +Resources based on sync generators will execute each step (yield) of the generator in a thread pool, so each individual resource is still extracted one item at a time but multiple such resources can run in parallel with each other. + +Consider an example source which consists of 2 resources fetching pages of items from different API endpoints, and each of those resources are piped to transformers to fetch complete data items respectively. + +The `parallelized=True` argument wraps the resources in a generator that yields callables to evaluate each generator step. These callables are executed in the thread pool. Transformer that are not generators (as shown in the example) are internally wrapped in a generator that yields once. + ```py import dlt -from time import sleep +import time from threading import currentThread -@dlt.resource -def list_items(start, limit): - yield from range(start, start + limit) - -@dlt.transformer -@dlt.defer -def get_details(item_id): - # simulate a slow REST API where you wait 0.3 sec for each item - sleep(0.3) - print(f"item_id {item_id} in thread {currentThread().name}") - # just return the results, if you yield, generator will be evaluated in main thread - return {"row": item_id} +@dlt.resource(parallelized=True) +def list_users(n_users): + for i in range(1, 1 + n_users): + # Simulate network delay of a rest API call fetching a page of items + if i % 10 == 0: + time.sleep(0.1) + yield i + +@dlt.transformer(parallelized=True) +def get_user_details(user_id): + # Transformer that fetches details for users in a page + time.sleep(0.1) # Simulate latency of a rest API call + print(f"user_id {user_id} in thread {currentThread().name}") + return {"entity": "user", "id": user_id} + +@dlt.resource(parallelized=True) +def list_products(n_products): + for i in range(1, 1 + n_products): + if i % 10 == 0: + time.sleep(0.1) + yield i + +@dlt.transformer(parallelized=True) +def get_product_details(product_id): + time.sleep(0.1) + print(f"product_id {product_id} in thread {currentThread().name}") + return {"entity": "product", "id": product_id} + +@dlt.source +def api_data(): + return [ + list_users(24) | get_user_details, + list_products(32) | get_product_details, + ] # evaluate the pipeline and print all the items -# resources are iterators and they are evaluated in the same way in the pipeline.run -print(list(list_items(0, 10) | get_details)) +# sources are iterators and they are evaluated in the same way in the pipeline.run +print(list(api_data())) ``` +The `parallelized` flag in the `resource` and `transformer` decorators is supported for: + +* Generator functions (as shown in the example) +* Generators without functions (e.g. `dlt.resource(name='some_data', parallelized=True)(iter(range(100)))`) +* `dlt.transformer` decorated functions. These can be either generator functions or regular functions that return one value + You can control the number of workers in the thread pool with **workers** setting. The default number of workers is **5**. Below you see a few ways to do that with different granularity ```toml @@ -185,7 +219,8 @@ workers=4 -The example below does the same but using an async generator as the main resource and async/await and futures pool for the transformer: +The example below does the same but using an async generator as the main resource and async/await and futures pool for the transformer. +The `parallelized` flag is not supported or needed for async generators, these are wrapped and evaluated concurrently by default: ```py import asyncio @@ -234,7 +269,8 @@ of callables to be evaluated in a thread pool with a size of 5. This limit will ::: :::caution -Generators and iterators are always evaluated in the main thread. If you have a loop that yields items, instead yield functions or async functions that will create the items when evaluated in the pool. +Generators and iterators are always evaluated in a single thread: item by item. If you have a loop that yields items that you want to evaluate +in parallel, instead yield functions or async functions that will be evaluates in separate threads or in async pool. ::: ### Normalize @@ -394,26 +430,18 @@ import dlt from time import sleep from concurrent.futures import ThreadPoolExecutor -# create both futures and thread parallel resources - -def async_table(): - async def _gen(idx): - await asyncio.sleep(0.1) - return {"async_gen": idx} - - # just yield futures in a loop +# create both asyncio and thread parallel resources +@dlt.resource +async def async_table(): for idx_ in range(10): - yield _gen(idx_) + await asyncio.sleep(0.1) + yield {"async_gen": idx_} +@dlt.resource(parallelized=True) def defer_table(): - @dlt.defer - def _gen(idx): - sleep(0.1) - return {"thread_gen": idx} - - # just yield futures in a loop for idx_ in range(5): - yield _gen(idx_) + sleep(0.1) + yield idx_ def _run_pipeline(pipeline, gen_): # run the pipeline in a thread, also instantiate generators here! @@ -440,9 +468,9 @@ async def _run_async(): asyncio.run(_run_async()) # activate pipelines before they are used pipeline_1.activate() -# assert load_data_table_counts(pipeline_1) == {"async_table": 10} +assert pipeline_1.last_trace.last_normalize_info.row_counts["async_table"] == 10 pipeline_2.activate() -# assert load_data_table_counts(pipeline_2) == {"defer_table": 5} +assert pipeline_2.last_trace.last_normalize_info.row_counts["defer_table"] == 5 ``` diff --git a/docs/website/docs/reference/performance_snippets/performance-snippets.py b/docs/website/docs/reference/performance_snippets/performance-snippets.py index a2ebd102a6..68ec8ed72d 100644 --- a/docs/website/docs/reference/performance_snippets/performance-snippets.py +++ b/docs/website/docs/reference/performance_snippets/performance-snippets.py @@ -44,25 +44,47 @@ def read_table(limit): def parallel_extract_callables_snippet() -> None: # @@@DLT_SNIPPET_START parallel_extract_callables import dlt - from time import sleep + import time from threading import currentThread - @dlt.resource - def list_items(start, limit): - yield from range(start, start + limit) - - @dlt.transformer - @dlt.defer - def get_details(item_id): - # simulate a slow REST API where you wait 0.3 sec for each item - sleep(0.3) - print(f"item_id {item_id} in thread {currentThread().name}") - # just return the results, if you yield, generator will be evaluated in main thread - return {"row": item_id} + @dlt.resource(parallelized=True) + def list_users(n_users): + for i in range(1, 1 + n_users): + # Simulate network delay of a rest API call fetching a page of items + if i % 10 == 0: + time.sleep(0.1) + yield i + + @dlt.transformer(parallelized=True) + def get_user_details(user_id): + # Transformer that fetches details for users in a page + time.sleep(0.1) # Simulate latency of a rest API call + print(f"user_id {user_id} in thread {currentThread().name}") + return {"entity": "user", "id": user_id} + + @dlt.resource(parallelized=True) + def list_products(n_products): + for i in range(1, 1 + n_products): + if i % 10 == 0: + time.sleep(0.1) + yield i + + @dlt.transformer(parallelized=True) + def get_product_details(product_id): + time.sleep(0.1) + print(f"product_id {product_id} in thread {currentThread().name}") + return {"entity": "product", "id": product_id} + + @dlt.source + def api_data(): + return [ + list_users(24) | get_user_details, + list_products(32) | get_product_details, + ] # evaluate the pipeline and print all the items - # resources are iterators and they are evaluated in the same way in the pipeline.run - print(list(list_items(0, 10) | get_details)) + # sources are iterators and they are evaluated in the same way in the pipeline.run + print(list(api_data())) # @@@DLT_SNIPPET_END parallel_extract_callables # @@@DLT_SNIPPET_START parallel_extract_awaitables @@ -127,26 +149,18 @@ def parallel_pipelines_asyncio_snippet() -> None: from time import sleep from concurrent.futures import ThreadPoolExecutor - # create both futures and thread parallel resources - - def async_table(): - async def _gen(idx): - await asyncio.sleep(0.1) - return {"async_gen": idx} - - # just yield futures in a loop + # create both asyncio and thread parallel resources + @dlt.resource + async def async_table(): for idx_ in range(10): - yield _gen(idx_) + await asyncio.sleep(0.1) + yield {"async_gen": idx_} + @dlt.resource(parallelized=True) def defer_table(): - @dlt.defer - def _gen(idx): - sleep(0.1) - return {"thread_gen": idx} - - # just yield futures in a loop for idx_ in range(5): - yield _gen(idx_) + sleep(0.1) + yield idx_ def _run_pipeline(pipeline, gen_): # run the pipeline in a thread, also instantiate generators here! @@ -173,9 +187,9 @@ async def _run_async(): asyncio.run(_run_async()) # activate pipelines before they are used pipeline_1.activate() - # assert load_data_table_counts(pipeline_1) == {"async_table": 10} + assert pipeline_1.last_trace.last_normalize_info.row_counts["async_table"] == 10 pipeline_2.activate() - # assert load_data_table_counts(pipeline_2) == {"defer_table": 5} + assert pipeline_2.last_trace.last_normalize_info.row_counts["defer_table"] == 5 # @@@DLT_SNIPPET_END parallel_pipelines diff --git a/docs/website/docs/running-in-production/alerting.md b/docs/website/docs/running-in-production/alerting.md index 65a9d05eae..1364c1f988 100644 --- a/docs/website/docs/running-in-production/alerting.md +++ b/docs/website/docs/running-in-production/alerting.md @@ -40,5 +40,33 @@ receiving rich information on executed pipelines, including encountered errors a ## Slack -Read [here](./running#using-slack-to-send-messages) about how to send -messages to Slack. +Alerts can be sent to a Slack channel via Slack's incoming webhook URL. The code snippet below demonstrates automated Slack notifications for database table updates using the `send_slack_message` function. + +```python +# Import the send_slack_message function from the dlt library +from dlt.common.runtime.slack import send_slack_message + +# Define the URL for your Slack webhook +hook = "https://hooks.slack.com/services/xxx/xxx/xxx" + +# Iterate over each package in the load_info object +for package in info.load_packages: + # Iterate over each table in the schema_update of the current package + for table_name, table in package.schema_update.items(): + # Iterate over each column in the current table + for column_name, column in table["columns"].items(): + # Send a message to the Slack channel with the table + # and column update information + send_slack_message( + hook, + message=( + f"\tTable updated: {table_name}: " + f"Column changed: {column_name}: " + f"{column['data_type']}" + ) + ) +``` +Refer to this [example](../examples/chess_production/) for a practical application of the method in a production environment. + +Similarly, Slack notifications can be extended to include information on pipeline execution times, loading durations, schema modifications, and more. For comprehensive details on configuring and sending messages to Slack, please read [here](./running#using-slack-to-send-messages). + diff --git a/docs/website/docs/walkthroughs/create-new-destination.md b/docs/website/docs/walkthroughs/create-new-destination.md index 728abb2506..3e64cc55ab 100644 --- a/docs/website/docs/walkthroughs/create-new-destination.md +++ b/docs/website/docs/walkthroughs/create-new-destination.md @@ -1,13 +1,20 @@ # Create new destination -:::caution -This guide is compatible with `dlt` **0.3.x**. Version **0.4.x** has a different module layout. We are working on an update. -::: `dlt` can import destinations from external python modules. Below we show how to quickly add a [dbapi](https://peps.python.org/pep-0249/) based destination. `dbapi` is a standardized interface to access databases in Python. If you used ie. postgres (ie. `psycopg2`) you are already familiar with it. -> 🧪 This guide is not comprehensive. The internal interfaces are still evolving. Besides reading info below, you should check out [source code of existing destinations](https://github.com/dlt-hub/dlt/tree/devel/dlt/destinations) +> 🧪 This guide is not comprehensive. The internal interfaces are still evolving. Besides reading info below, you should check out [source code of existing destinations](https://github.com/dlt-hub/dlt/tree/devel/dlt/destinations/impl) + +## 0. Prerequisites + +Destinations are implemented in python packages under: `dlt.destinations.impl.`. Generally a destination consists of the following modules: + +* `__init__.py` - this module contains the destination capabilities +* `.py` - this module contains the job client and load job implementations for the destination +* `configuration.py` - this module contains the destination and credentials configuration classes +* `sql_client.py` - this module contains the SQL client implementation for the destination, this is a wrapper over `dbapi` that provides consistent interface to `dlt` for executing queries +* `factory.py` - this module contains a `Destination` subclass that is the entry point for the destination. ## 1. Copy existing destination to your `dlt` project Initialize a new project with [dlt init](../reference/command-line-interface.md#dlt-init) @@ -16,7 +23,7 @@ dlt init github postgres ``` This adds `github` verified source (it produces quite complicated datasets and that good for testing, does not require credentials to use) and `postgres` credentials (connection-string-like) that we'll repurpose later. -Clone [dlt](https://github.com/dlt-hub/dlt) repository to a separate folder. In the repository look for **dlt/destinations** folder and copy one of the destinations to your project. Pick your starting point: +Clone [dlt](https://github.com/dlt-hub/dlt) repository to a separate folder. In the repository look for **dlt/destinations/impl** folder and copy one of the destinations to your project. Pick your starting point: * **postgres** - a simple destination without staging storage support and COPY jobs * **redshift** - based on postgres, adds staging storage support and remote COPY jobs * **snowflake** - a destination supporting additional authentication schemes, local and remote COPY jobs and no support for direct INSERTs @@ -30,11 +37,10 @@ Below we'll use **postgres** as starting point. We keep config and credentials in `configuration.py`. You should: - rename the classes properly to match your destination name - if you need more properties (ie. look at `iam_role` in `redshift` credentials) then add them, remember about typing. Behind the hood credentials and configs are **dataclasses**. -- import and use new configuration class in `_configure()` method in `__init__.py` -- tell `dlt` the default configuration section by placing your destination name in `sections` argument of `@with_config` decorator. -- expose the configuration type in `spec` method in `__init__.py` +- adjust `__init__` arguments in your `Destination` class in `factory.py` to match the new credentials and config classes +- expose the configuration type in `spec` attribute in `factory.py` -> 💡 Each destination module implements `DestinationReference` protocol defined in [reference.py](https://github.com/dlt-hub/dlt/blob/devel/dlt/common/destination/reference.py). +> 💡 Each destination implements `Destination` abstract class defined in [reference.py](https://github.com/dlt-hub/dlt/blob/devel/dlt/common/destination/reference.py). > 💡 See how `snowflake` destination adds additional authorization methods and configuration options. @@ -42,8 +48,8 @@ We keep config and credentials in `configuration.py`. You should: `dlt` needs to know a few things about the destination to correctly work with it. Those are stored in `capabilities()` function in `__init__.py`. * supported loader file formats both for direct and staging loading (see below) -* `escape_identifier` a function that escapes database identifiers ie. table or column name. provided implementation for postgres should work for you. -* `escape_literal` a function that escapes string literal. it is only used if destination supports **insert-values** loader format +* `escape_identifier` a function that escapes database identifiers ie. table or column name. Look in `dlt.common.data_writers.escape` module to see how this is implemented for existing destinations. +* `escape_literal` a function that escapes string literal. it is only used if destination supports **insert-values** loader format (also see existing implementations in `dlt.common.data_writers.escape`) * `decimal_precision` precision and scale of decimal/numeric types. also used to create right decimal types in loader files ie. parquet * `wei_precision` precision and scale of decimal/numeric to store very large (up to 2**256) integers. specify maximum precision for scale 0 * `max_identifier_length` max length of table and schema/dataset names @@ -55,6 +61,8 @@ We keep config and credentials in `configuration.py`. You should: * `supports_ddl_transactions` tells if the destination supports ddl transactions. * `alter_add_multi_column` tells if destination can add multiple columns in **ALTER** statement * `supports_truncate_command` tells dlt if **truncate** command is used, otherwise it will use **DELETE** to clear tables. +* `schema_supports_numeric_precision` whether numeric data types support precision/scale configuration +* `max_rows_per_insert` max number of rows supported per insert statement, used with `insert-values` loader file format (set to `None` for no limit). E.g. MS SQL has a limit of 1000 rows per statement, but most databases have no limit and the statement is divided according to `max_query_length`. ### Supported loader file formats Specify which [loader file formats](../dlt-ecosystem/file-formats/) your destination will support directly and via [storage staging](../dlt-ecosystem/staging.md). Direct support means that destination is able to load a local file or supports INSERT command. Loading via staging is using `filesystem` to send load package to a (typically) bucket storage and then load from there. @@ -107,10 +115,10 @@ When created, `sql_client` is bound to particular dataset name (which typically - `DROP TABLE` only for CLI command (`pipeline drop`) ## 5. Adjust the job client -Job client is responsible for creating/starting load jobs and managing the schema updates. Here we'll adjust the `SqlJobClientBase` base class which uses the `sql_client` to manage the destination. Typically only a few methods needs to be overridden by a particular implementation. The job client code customarily resides in a file with name `.py` ie. `postgres.py` and is exposed in `__init__.py` by `client` method. +Job client is responsible for creating/starting load jobs and managing the schema updates. Here we'll adjust the `SqlJobClientBase` base class which uses the `sql_client` to manage the destination. Typically only a few methods needs to be overridden by a particular implementation. The job client code customarily resides in a file with name `.py` ie. `postgres.py` and is exposed in `factory.py` by the `client_class` property on the destination class. ### Database type mappings -You must map `dlt` data types to destination data types. This happens in `_to_db_type` and `_from_db_type` class methods. Typically just a mapping dictionary is enough. A few tricks to remember: +You must map `dlt` data types to destination data types. For this you can implement a subclass of `TypeMapper`. You can specify there dicts to map `dlt` data types to destination data types, with or without precision. A few tricks to remember: * the database types must be exactly those as used in `INFORMATION_SCHEMA.COLUMNS` * decimal precision and scale are filled from the capabilities (in all our implementations) * until now all destinations could handle binary types @@ -159,18 +167,23 @@ The `postgres` destination does not implement any copy jobs. - See `RedshiftCopyFileLoadJob` in `redshift.py` how we create and start a copy job from a bucket. It uses `CopyRemoteFileLoadJob` base to handle the references and creates a `COPY` SQL statement in `execute()` method. - See `SnowflakeLoadJob` in `snowflake.py` how to implement a job that can load local and reference files. It also forwards AWS credentials from staging destination. At the end the code just generates a COPY command for various loader file formats. +## 7. Expose your destination to dlt + +The `Destination` subclass in `dlt.destinations.impl..factory` module is the entry point for the destination. +Add an import to your factory in [`dlt.destinations.__init__`](https://github.com/dlt-hub/dlt/blob/devel/dlt/destinations/__init__.py). `dlt` looks in this module when you reference a destination by name, i.e. `dlt.pipeline(..., destination="postgres")`. + ## Testing We can quickly repurpose existing github source and `secrets.toml` already present in the project to test new destination. Let's assume that the module name is `presto`, same for the destination name and config section name. Here's our testing script `github_pipeline.py` ```python import dlt from github import github_repo_events -import presto # importing destination module +from presto import presto # importing destination factory def load_airflow_events() -> None: """Loads airflow events. Shows incremental loading. Forces anonymous access token""" pipeline = dlt.pipeline( - "github_events", destination=presto, dataset_name="airflow_events" + "github_events", destination=presto(), dataset_name="airflow_events" ) data = github_repo_events("apache", "airflow", access_token="") print(pipeline.run(data)) diff --git a/poetry.lock b/poetry.lock index 9c1c9b4226..cad68180dc 100644 --- a/poetry.lock +++ b/poetry.lock @@ -8986,4 +8986,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "3380a5a646776e0fc0d895b5271bb769872ac1cdb09a842af61ba1741d1c03b3" +content-hash = "a7aa3e523522ab3260a7a19f097a34349b66cf046289db1e17b48f88f7fd189f" diff --git a/pyproject.toml b/pyproject.toml index ee98c4d1e0..88e6bd9390 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "dlt" -version = "0.4.5" +version = "0.4.6" description = "dlt is an open-source python-first scalable data loading library that does not require any backend to run." authors = ["dltHub Inc. "] maintainers = [ "Marcin Rudolf ", "Adrian Brudaru ", "Ty Dunn "] @@ -33,7 +33,6 @@ hexbytes = ">=0.2.2" tzdata = ">=2022.1" tomlkit = ">=0.11.3" pathvalidate = ">=2.5.2" -SQLAlchemy = ">=1.4.0" typing-extensions = ">=4.0.0" makefun = ">=1.15.0" click = ">=7.1" @@ -137,7 +136,6 @@ types-python-dateutil = ">=2.8.15" flake8-tidy-imports = ">=4.8.0" flake8-encodings = "^0.5.0" flake8-builtins = "^1.5.3" -types-SQLAlchemy = ">=1.4.53" boto3-stubs = "^1.28.28" types-tqdm = "^4.66.0.2" types-psutil = "^5.9.5.16" @@ -145,6 +143,7 @@ types-psycopg2 = "^2.9.21.14" cryptography = "^41.0.7" google-api-python-client = ">=1.7.11" pytest-asyncio = "^0.23.5" +types-sqlalchemy = "^1.4.53.38" [tool.poetry.group.pipeline] optional=true @@ -186,6 +185,7 @@ sentry-sdk = "^1.5.6" optional = true [tool.poetry.group.docs.dependencies] +SQLAlchemy = ">=1.4.0" pymysql = "^1.1.0" pypdf2 = "^3.0.1" pydoc-markdown = "^4.8.2" diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index 81d49432d7..a883f76ddb 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -778,10 +778,11 @@ def test_values_serialization() -> None: assert deserialize_value("K", v, Wei) == Wei("0.01") # test credentials - credentials_str = "databricks+connector://token:@:443/?conn_timeout=15&search_path=a%2Cb%2Cc" + credentials_str = "databricks+connector://token:-databricks_token-@:443/?conn_timeout=15&search_path=a%2Cb%2Cc" credentials = deserialize_value("credentials", credentials_str, ConnectionStringCredentials) assert credentials.drivername == "databricks+connector" assert credentials.query == {"conn_timeout": "15", "search_path": "a,b,c"} + assert credentials.password == "-databricks_token-" assert serialize_value(credentials) == credentials_str # using dict also works credentials_dict = dict(credentials) diff --git a/tests/common/configuration/test_spec_union.py b/tests/common/configuration/test_spec_union.py index 25c32920bc..4892967ab7 100644 --- a/tests/common/configuration/test_spec_union.py +++ b/tests/common/configuration/test_spec_union.py @@ -1,7 +1,6 @@ import itertools import os import pytest -from sqlalchemy.engine import Engine, create_engine from typing import Optional, Union, Any import dlt @@ -236,6 +235,10 @@ def test_google_auth_union(environment: Any) -> None: assert isinstance(credentials, GcpServiceAccountCredentials) +class Engine: + pass + + @dlt.source def sql_database(credentials: Union[ConnectionStringCredentials, Engine, str] = dlt.secrets.value): yield dlt.resource([credentials], name="creds") @@ -243,7 +246,7 @@ def sql_database(credentials: Union[ConnectionStringCredentials, Engine, str] = def test_union_concrete_type(environment: Any) -> None: # we can pass engine explicitly - engine = create_engine("sqlite:///:memory:", echo=True) + engine = Engine() db = sql_database(credentials=engine) creds = list(db)[0] assert isinstance(creds, Engine) diff --git a/tests/common/schema/test_schema.py b/tests/common/schema/test_schema.py index ba817b946f..653e9cc351 100644 --- a/tests/common/schema/test_schema.py +++ b/tests/common/schema/test_schema.py @@ -237,6 +237,38 @@ def test_replace_schema_content() -> None: schema.replace_schema_content(schema_eth) assert schema.version_hash != schema.stored_version_hash + # make sure we linked the replaced schema to the incoming + schema = Schema("simple") + eth_v5 = load_yml_case("schemas/eth/ethereum_schema_v5") + schema_eth = Schema.from_dict(eth_v5, bump_version=False) # type: ignore[arg-type] + schema_eth.bump_version() + # modify simple schema by adding a table + schema.update_table(schema_eth.get_table("blocks")) + replaced_stored_hash = schema.stored_version_hash + schema.replace_schema_content(schema_eth, link_to_replaced_schema=True) + assert replaced_stored_hash in schema.previous_hashes + assert replaced_stored_hash == schema.stored_version_hash + assert schema.stored_version_hash != schema.version_hash + + # replace with self + eth_v5 = load_yml_case("schemas/eth/ethereum_schema_v5") + schema_eth = Schema.from_dict(eth_v5, bump_version=True) # type: ignore[arg-type] + stored_hash = schema_eth.stored_version_hash + schema_eth.replace_schema_content(schema_eth) + assert stored_hash == schema_eth.stored_version_hash + assert stored_hash == schema_eth.version_hash + assert stored_hash not in schema_eth.previous_hashes + + # replace with self but version is not bumped + eth_v5 = load_yml_case("schemas/eth/ethereum_schema_v5") + schema_eth = Schema.from_dict(eth_v5, bump_version=False) # type: ignore[arg-type] + stored_hash = schema_eth.stored_version_hash + schema_eth.replace_schema_content(schema_eth) + assert stored_hash == schema_eth.stored_version_hash + assert stored_hash != schema_eth.version_hash + assert stored_hash in schema_eth.previous_hashes + assert schema_eth.version_hash not in schema_eth.previous_hashes + @pytest.mark.parametrize( "columns,hint,value", diff --git a/tests/common/test_time.py b/tests/common/test_time.py index 72a9098e4d..7568e84046 100644 --- a/tests/common/test_time.py +++ b/tests/common/test_time.py @@ -4,6 +4,7 @@ from dlt.common import pendulum from dlt.common.time import ( + parse_iso_like_datetime, timestamp_before, timestamp_within, ensure_pendulum_datetime, @@ -40,27 +41,27 @@ def test_before() -> None: # python datetime without tz ( datetime(2021, 1, 1, 0, 0, 0), - pendulum.datetime(2021, 1, 1, 0, 0, 0).in_tz("UTC"), + pendulum.DateTime(2021, 1, 1, 0, 0, 0).in_tz("UTC"), ), # python datetime with tz ( datetime(2021, 1, 1, 0, 0, 0, tzinfo=timezone(timedelta(hours=-8))), - pendulum.datetime(2021, 1, 1, 8, 0, 0).in_tz("UTC"), + pendulum.DateTime(2021, 1, 1, 8, 0, 0).in_tz("UTC"), ), # python date object - (date(2021, 1, 1), pendulum.datetime(2021, 1, 1, 0, 0, 0).in_tz("UTC")), + (date(2021, 1, 1), pendulum.DateTime(2021, 1, 1, 0, 0, 0).in_tz("UTC")), # pendulum datetime with tz ( - pendulum.datetime(2021, 1, 1, 0, 0, 0).in_tz("UTC"), - pendulum.datetime(2021, 1, 1, 0, 0, 0).in_tz("UTC"), + pendulum.DateTime(2021, 1, 1, 0, 0, 0).in_tz("UTC"), + pendulum.DateTime(2021, 1, 1, 0, 0, 0).in_tz("UTC"), ), # pendulum datetime without tz ( - pendulum.datetime(2021, 1, 1, 0, 0, 0), - pendulum.datetime(2021, 1, 1, 0, 0, 0).in_tz("UTC"), + pendulum.DateTime(2021, 1, 1, 0, 0, 0), + pendulum.DateTime(2021, 1, 1, 0, 0, 0).in_tz("UTC"), ), # iso datetime in UTC - ("2021-01-01T00:00:00+00:00", pendulum.datetime(2021, 1, 1, 0, 0, 0).in_tz("UTC")), + ("2021-01-01T00:00:00+00:00", pendulum.DateTime(2021, 1, 1, 0, 0, 0).in_tz("UTC")), # iso datetime with non utc tz ( "2021-01-01T00:00:00+05:00", @@ -69,13 +70,18 @@ def test_before() -> None: # iso datetime without tz ( "2021-01-01T05:02:32", - pendulum.datetime(2021, 1, 1, 5, 2, 32).in_tz("UTC"), + pendulum.DateTime(2021, 1, 1, 5, 2, 32).in_tz("UTC"), ), # iso date - ("2021-01-01", pendulum.datetime(2021, 1, 1, 0, 0, 0).in_tz("UTC")), + ("2021-01-01", pendulum.DateTime(2021, 1, 1, 0, 0, 0).in_tz("UTC")), ] +def test_parse_iso_like_datetime() -> None: + # naive datetime is still naive + assert parse_iso_like_datetime("2021-01-01T05:02:32") == pendulum.DateTime(2021, 1, 1, 5, 2, 32) + + @pytest.mark.parametrize("date_value, expected", test_params) def test_ensure_pendulum_datetime(date_value: TAnyDateTime, expected: pendulum.DateTime) -> None: dt = ensure_pendulum_datetime(date_value) diff --git a/tests/common/test_utils.py b/tests/common/test_utils.py index 7cd8e9f1a2..456ef3cb91 100644 --- a/tests/common/test_utils.py +++ b/tests/common/test_utils.py @@ -21,6 +21,7 @@ extend_list_deduplicated, get_exception_trace, get_exception_trace_chain, + update_dict_nested, ) @@ -277,3 +278,14 @@ def test_exception_trace_chain() -> None: assert traces[0]["exception_type"] == "dlt.common.exceptions.PipelineException" assert traces[1]["exception_type"] == "dlt.common.exceptions.IdentifierTooLongException" assert traces[2]["exception_type"] == "dlt.common.exceptions.TerminalValueError" + + +def test_nested_dict_merge() -> None: + dict_1 = {"a": 1, "b": 2} + dict_2 = {"a": 2, "c": 4} + + assert update_dict_nested(dict(dict_1), dict_2) == {"a": 2, "b": 2, "c": 4} + assert update_dict_nested(dict(dict_2), dict_1) == {"a": 1, "b": 2, "c": 4} + assert update_dict_nested(dict(dict_1), dict_2, keep_dst_values=True) == update_dict_nested( + dict_2, dict_1 + ) diff --git a/tests/common/test_validation.py b/tests/common/test_validation.py index 533b91808c..3fff3bf2ea 100644 --- a/tests/common/test_validation.py +++ b/tests/common/test_validation.py @@ -1,18 +1,22 @@ from copy import deepcopy import pytest import yaml -from typing import Dict, List, Literal, Mapping, Sequence, TypedDict, Optional, Union +from typing import Callable, List, Literal, Mapping, Sequence, TypedDict, TypeVar, Optional, Union -from dlt.common import json from dlt.common.exceptions import DictValidationException from dlt.common.schema.typing import TStoredSchema, TColumnSchema from dlt.common.schema.utils import simple_regex_validator -from dlt.common.typing import DictStrStr, StrStr +from dlt.common.typing import DictStrStr, StrStr, TDataItem from dlt.common.validation import validate_dict, validate_dict_ignoring_xkeys TLiteral = Literal["uno", "dos", "tres"] +# some typevars for testing +TDynHintType = TypeVar("TDynHintType") +TFunHintTemplate = Callable[[TDataItem], TDynHintType] +TTableHintTemplate = Union[TDynHintType, TFunHintTemplate[TDynHintType]] + class TDict(TypedDict): field: TLiteral @@ -241,3 +245,31 @@ def test_nested_union(test_doc: TTestRecord) -> None: validate_dict(TTestRecord, test_doc, ".") assert e.value.field == "f_optional_union" assert e.value.value == "blah" + + +def test_no_name() -> None: + class TTestRecordNoName(TypedDict): + name: TTableHintTemplate[str] + + test_item = {"name": "test"} + try: + validate_dict(TTestRecordNoName, test_item, path=".") + except AttributeError: + pytest.fail("validate_dict raised AttributeError unexpectedly") + + test_item_2 = {"name": True} + with pytest.raises(DictValidationException): + validate_dict(TTestRecordNoName, test_item_2, path=".") + + +def test_callable() -> None: + class TTestRecordCallable(TypedDict): + prop: TTableHintTemplate # type: ignore + + def f(item: Union[TDataItem, TDynHintType]) -> TDynHintType: + return item + + test_item = {"prop": f} + validate_dict( + TTestRecordCallable, test_item, path=".", validator_f=lambda p, pk, pv, t: callable(pv) + ) diff --git a/tests/extract/test_decorators.py b/tests/extract/test_decorators.py index 03b3cb32c4..03f87db923 100644 --- a/tests/extract/test_decorators.py +++ b/tests/extract/test_decorators.py @@ -38,8 +38,9 @@ SourceIsAClassTypeError, SourceNotAFunction, CurrentSourceSchemaNotAvailable, + InvalidParallelResourceDataType, ) -from dlt.extract.typing import TableNameMeta +from dlt.extract.items import TableNameMeta from tests.common.utils import IMPORTED_VERSION_HASH_ETH_V9 @@ -910,3 +911,79 @@ async def _assert_source(source_coro_f, expected_data) -> None: @pytest.mark.skip("Not implemented") def test_class_resource() -> None: pass + + +def test_parallelized_resource_decorator() -> None: + """Test parallelized resources are wrapped correctly. + Note: tests for parallel execution are in test_resource_evaluation + """ + + def some_gen(): + yield from [1, 2, 3] + + # Create resource with decorated function + resource = dlt.resource(some_gen, parallelized=True) + + # Generator func is wrapped with parallelized gen that yields callables + gen = resource._pipe.gen() # type: ignore + result = next(gen) # type: ignore[arg-type] + assert result() == 1 + assert list(resource) == [1, 2, 3] + + # Same but wrapping generator directly + resource = dlt.resource(some_gen(), parallelized=True) + + result = next(resource._pipe.gen) # type: ignore + assert result() == 1 + # get remaining items + assert list(resource) == [2, 3] + + # Wrap a yielding transformer + def some_tx(item): + yield item + 1 + + resource = dlt.resource(some_gen, parallelized=True) + + transformer = dlt.transformer(some_tx, parallelized=True, data_from=resource) + pipe_gen = transformer._pipe.gen + # Calling transformer returns the parallel wrapper generator + inner = pipe_gen(1) # type: ignore + assert next(inner)() == 2 # type: ignore + assert list(transformer) == [2, 3, 4] # add 1 to resource + + # Wrap a transformer function + def some_tx_func(item): + return list(range(item)) + + transformer = dlt.transformer(some_tx_func, parallelized=True, data_from=resource) + pipe_gen = transformer._pipe.gen + inner = pipe_gen(3) # type: ignore + # this is a regular function returning list + assert inner() == [0, 1, 2] # type: ignore[operator] + assert list(transformer) == [0, 0, 1, 0, 1, 2] + + # Invalid parallel resources + + # From async generator + with pytest.raises(InvalidParallelResourceDataType): + + @dlt.resource(parallelized=True) + async def some_data(): + yield 1 + yield 2 + + # From list + with pytest.raises(InvalidParallelResourceDataType): + dlt.resource([1, 2, 3], name="T", parallelized=True) + + # Test that inner generator is closed when wrapper is closed + gen_orig = some_gen() + resource = dlt.resource(gen_orig, parallelized=True) + gen = resource._pipe.gen + + next(gen) # type: ignore + gen.close() # type: ignore + + with pytest.raises(StopIteration): + # Inner generator is also closed + next(gen_orig) diff --git a/tests/extract/test_extract_pipe.py b/tests/extract/test_extract_pipe.py index 38dc8a9319..68c1c82124 100644 --- a/tests/extract/test_extract_pipe.py +++ b/tests/extract/test_extract_pipe.py @@ -9,9 +9,10 @@ import dlt from dlt.common import sleep from dlt.common.typing import TDataItems -from dlt.extract.exceptions import CreatePipeException, ResourceExtractionError -from dlt.extract.typing import DataItemWithMeta, FilterItem, MapItem, YieldMapItem -from dlt.extract.pipe import ManagedPipeIterator, Pipe, PipeItem, PipeIterator +from dlt.extract.exceptions import CreatePipeException, ResourceExtractionError, UnclosablePipe +from dlt.extract.items import DataItemWithMeta, FilterItem, MapItem, YieldMapItem +from dlt.extract.pipe import Pipe +from dlt.extract.pipe_iterator import PipeIterator, ManagedPipeIterator, PipeItem def test_next_item_mode() -> None: @@ -72,7 +73,7 @@ def test_rotation_on_none() -> None: def source_gen1(): gen_1_started = time.time() yield None - while time.time() - gen_1_started < 0.6: + while time.time() - gen_1_started < 3: time.sleep(0.05) yield None yield 1 @@ -80,7 +81,7 @@ def source_gen1(): def source_gen2(): gen_2_started = time.time() yield None - while time.time() - gen_2_started < 0.2: + while time.time() - gen_2_started < 1: time.sleep(0.05) yield None yield 2 @@ -88,7 +89,7 @@ def source_gen2(): def source_gen3(): gen_3_started = time.time() yield None - while time.time() - gen_3_started < 0.4: + while time.time() - gen_3_started < 2: time.sleep(0.05) yield None yield 3 @@ -105,7 +106,7 @@ def get_pipes(): # items will be round robin, nested iterators are fully iterated and appear inline as soon as they are encountered assert [pi.item for pi in _l] == [2, 3, 1] # jobs should have been executed in parallel - assert time.time() - started < 0.8 + assert time.time() - started < 3.5 def test_add_step() -> None: @@ -613,6 +614,21 @@ def pass_gen(item, meta): _f_items(list(PipeIterator.from_pipes(pipes))) +def test_explicit_close_pipe() -> None: + list_pipe = Pipe.from_data("list_pipe", iter([1, 2, 3])) + with pytest.raises(UnclosablePipe): + list_pipe.close() + + # generator function cannot be closed + genfun_pipe = Pipe.from_data("genfun_pipe", lambda _: (yield from [1, 2, 3])) + with pytest.raises(UnclosablePipe): + genfun_pipe.close() + + gen_pipe = Pipe.from_data("gen_pipe", (lambda: (yield from [1, 2, 3]))()) + gen_pipe.close() + assert inspect.getgeneratorstate(gen_pipe.gen) == "GEN_CLOSED" # type: ignore[arg-type] + + close_pipe_got_exit = False close_pipe_yielding = False diff --git a/tests/extract/test_incremental.py b/tests/extract/test_incremental.py index 6b6c7887d3..7956c83947 100644 --- a/tests/extract/test_incremental.py +++ b/tests/extract/test_incremental.py @@ -1,8 +1,10 @@ import os +import asyncio from time import sleep from typing import Optional, Any +from unittest import mock from datetime import datetime # noqa: I251 -from itertools import chain +from itertools import chain, count import duckdb import pytest @@ -18,6 +20,7 @@ from dlt.common.json import json from dlt.extract import DltSource +from dlt.extract.exceptions import InvalidStepFunctionArguments from dlt.sources.helpers.transform import take_first from dlt.extract.incremental.exceptions import ( IncrementalCursorPathMissing, @@ -26,7 +29,12 @@ from dlt.pipeline.exceptions import PipelineStepFailed from tests.extract.utils import AssertItems, data_item_to_list -from tests.utils import data_to_item_format, TDataItemFormat, ALL_DATA_ITEM_FORMATS +from tests.utils import ( + data_item_length, + data_to_item_format, + TDataItemFormat, + ALL_DATA_ITEM_FORMATS, +) @pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) @@ -977,11 +985,17 @@ def some_data( "updated_at", initial_value=pendulum_start_dt ), max_hours: int = 2, + tz: str = None, ): data = [ {"updated_at": start_dt + timedelta(hours=hour), "hour": hour} for hour in range(1, max_hours + 1) ] + # make sure this is naive datetime + assert data[0]["updated_at"].tzinfo is None # type: ignore[attr-defined] + if tz: + data = [{**d, "updated_at": pendulum.instance(d["updated_at"])} for d in data] # type: ignore[call-overload] + yield data_to_item_format(item_type, data) pipeline = dlt.pipeline(pipeline_name=uniq_id()) @@ -1024,6 +1038,44 @@ def some_data( == 2 ) + # initial value is naive + resource = some_data(max_hours=4).with_name("copy_1") # also make new resource state + resource.apply_hints(incremental=dlt.sources.incremental("updated_at", initial_value=start_dt)) + # and the data is naive. so it will work as expected with naive datetimes in the result set + data = list(resource) + if item_type == "json": + # we do not convert data in arrow tables + assert data[0]["updated_at"].tzinfo is None + + # end value is naive + resource = some_data(max_hours=4).with_name("copy_2") # also make new resource state + resource.apply_hints( + incremental=dlt.sources.incremental( + "updated_at", initial_value=start_dt, end_value=start_dt + timedelta(hours=3) + ) + ) + data = list(resource) + if item_type == "json": + assert data[0]["updated_at"].tzinfo is None + + # now use naive initial value but data is UTC + resource = some_data(max_hours=4, tz="UTC").with_name("copy_3") # also make new resource state + resource.apply_hints( + incremental=dlt.sources.incremental( + "updated_at", initial_value=start_dt + timedelta(hours=3) + ) + ) + # will cause invalid comparison + if item_type == "json": + with pytest.raises(InvalidStepFunctionArguments): + list(resource) + else: + data = data_item_to_list(item_type, list(resource)) + # we select two rows by adding 3 hours to start_dt. rows have hours: + # 1, 2, 3, 4 + # and we select >=3 + assert len(data) == 2 + @dlt.resource def endless_sequence( @@ -1288,6 +1340,119 @@ def ascending_single_item( pipeline.extract(ascending_single_item()) +@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) +def test_async_row_order_out_of_range(item_type: TDataItemFormat) -> None: + @dlt.resource + async def descending( + updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( + "updated_at", initial_value=10, row_order="desc" + ) + ) -> Any: + for chunk in chunks(count(start=48, step=-1), 10): + await asyncio.sleep(0.01) + data = [{"updated_at": i} for i in chunk] + yield data_to_item_format(item_type, data) + + data = list(descending) + assert data_item_length(data) == 48 - 10 + 1 # both bounds included + + +@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) +def test_parallel_row_order_out_of_range(item_type: TDataItemFormat) -> None: + """Test automatic generator close for ordered rows""" + + @dlt.resource(parallelized=True) + def descending( + updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( + "updated_at", initial_value=10, row_order="desc" + ) + ) -> Any: + for chunk in chunks(count(start=48, step=-1), 10): + data = [{"updated_at": i} for i in chunk] + yield data_to_item_format(item_type, data) + + data = list(descending) + assert data_item_length(data) == 48 - 10 + 1 # both bounds included + + +def test_transformer_row_order_out_of_range() -> None: + out_of_range = [] + + @dlt.transformer + def descending( + package: int, + updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( + "updated_at", initial_value=10, row_order="desc", primary_key="updated_at" + ), + ) -> Any: + for chunk in chunks(count(start=48, step=-1), 10): + data = [{"updated_at": i, "package": package} for i in chunk] + yield data_to_item_format("json", data) + if updated_at.can_close(): + out_of_range.append(package) + return + + data = list([3, 2, 1] | descending) + assert len(data) == 48 - 10 + 1 + # we take full package 3 and then nothing in 1 and 2 + assert len(out_of_range) == 3 + + +@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) +def test_row_order_out_of_range(item_type: TDataItemFormat) -> None: + """Test automatic generator close for ordered rows""" + + @dlt.resource + def descending( + updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( + "updated_at", initial_value=10, row_order="desc" + ) + ) -> Any: + for chunk in chunks(count(start=48, step=-1), 10): + data = [{"updated_at": i} for i in chunk] + yield data_to_item_format(item_type, data) + + data = list(descending) + assert data_item_length(data) == 48 - 10 + 1 # both bounds included + + @dlt.resource + def ascending( + updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( + "updated_at", initial_value=22, end_value=45, row_order="asc" + ) + ) -> Any: + # use INFINITE sequence so this test wil not stop if closing logic is flawed + for chunk in chunks(count(start=22), 10): + data = [{"updated_at": i} for i in chunk] + yield data_to_item_format(item_type, data) + + data = list(ascending) + assert data_item_length(data) == 45 - 22 + + # use wrong row order, this will prevent end value to close pipe + + @dlt.resource + def ascending_desc( + updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( + "updated_at", initial_value=22, end_value=45, row_order="desc" + ) + ) -> Any: + for chunk in chunks(range(22, 100), 10): + data = [{"updated_at": i} for i in chunk] + yield data_to_item_format(item_type, data) + + from dlt.extract import pipe + + with mock.patch.object( + pipe.Pipe, + "close", + side_effect=RuntimeError("Close pipe should not be called"), + ) as close_pipe: + data = list(ascending_desc) + assert close_pipe.assert_not_called + assert data_item_length(data) == 45 - 22 + + @pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) def test_get_incremental_value_type(item_type: TDataItemFormat) -> None: assert dlt.sources.incremental("id").get_incremental_value_type() is Any diff --git a/tests/extract/test_sources.py b/tests/extract/test_sources.py index a94cf680fa..5895c3b658 100644 --- a/tests/extract/test_sources.py +++ b/tests/extract/test_sources.py @@ -2,6 +2,7 @@ from typing import Iterator import pytest +import asyncio import dlt from dlt.common.configuration.container import Container @@ -789,6 +790,31 @@ def test_limit_infinite_counter() -> None: assert list(r) == list(range(10)) +@pytest.mark.parametrize("limit", (None, -1, 0, 10)) +def test_limit_edge_cases(limit: int) -> None: + r = dlt.resource(range(20), name="infinity").add_limit(limit) # type: ignore + + @dlt.resource() + async def r_async(): + for i in range(20): + await asyncio.sleep(0.01) + yield i + + sync_list = list(r) + async_list = list(r_async().add_limit(limit)) + + # check the expected results + assert sync_list == async_list + if limit == 10: + assert sync_list == list(range(10)) + elif limit in [None, -1]: + assert sync_list == list(range(20)) + elif limit == 0: + assert sync_list == [] + else: + raise AssertionError(f"Unexpected limit: {limit}") + + def test_limit_source() -> None: def mul_c(item): yield from "A" * (item + 2) @@ -1104,8 +1130,29 @@ def mysource(): s = mysource() assert s.exhausted is False - assert next(iter(s)) == 2 # transformer is returned befor resource - assert s.exhausted is True + assert next(iter(s)) == 2 # transformer is returned before resource + assert s.exhausted is False + + +def test_exhausted_with_limit() -> None: + def open_generator_data(): + yield from [1, 2, 3, 4] + + s = DltSource( + Schema("source"), + "module", + [dlt.resource(open_generator_data)], + ) + assert s.exhausted is False + list(s) + assert s.exhausted is False + + # use limit + s.add_limit(1) + list(s) + # must still be false, limit should not open generator if it is still generator function + assert s.exhausted is False + assert list(s) == [1] def test_clone_resource_with_name() -> None: diff --git a/tests/extract/test_validation.py b/tests/extract/test_validation.py index 045f75ab73..b9307ab97c 100644 --- a/tests/extract/test_validation.py +++ b/tests/extract/test_validation.py @@ -10,7 +10,7 @@ from dlt.common.libs.pydantic import BaseModel from dlt.extract import DltResource -from dlt.extract.typing import ValidateItem +from dlt.extract.items import ValidateItem from dlt.extract.validation import PydanticValidator from dlt.extract.exceptions import ResourceExtractionError from dlt.pipeline.exceptions import PipelineStepFailed diff --git a/tests/extract/utils.py b/tests/extract/utils.py index 98e798d0f0..170781ba3c 100644 --- a/tests/extract/utils.py +++ b/tests/extract/utils.py @@ -6,7 +6,7 @@ from dlt.common.typing import TDataItem, TDataItems from dlt.extract.extract import ExtractStorage -from dlt.extract.typing import ItemTransform +from dlt.extract.items import ItemTransform from tests.utils import TDataItemFormat diff --git a/tests/helpers/airflow_tests/test_airflow_wrapper.py b/tests/helpers/airflow_tests/test_airflow_wrapper.py index 0399e3875d..d01330c8b2 100644 --- a/tests/helpers/airflow_tests/test_airflow_wrapper.py +++ b/tests/helpers/airflow_tests/test_airflow_wrapper.py @@ -241,6 +241,53 @@ def dag_decomposed(): assert pipeline_dag_decomposed_counts == pipeline_standalone_counts +def test_run() -> None: + task: PythonOperator = None + + pipeline_standalone = dlt.pipeline( + pipeline_name="pipeline_standalone", + dataset_name="mock_data_" + uniq_id(), + destination="duckdb", + credentials=":pipeline:", + ) + pipeline_standalone.run(mock_data_source()) + pipeline_standalone_counts = load_table_counts( + pipeline_standalone, *[t["name"] for t in pipeline_standalone.default_schema.data_tables()] + ) + + quackdb_path = os.path.join(TEST_STORAGE_ROOT, "pipeline_dag_regular.duckdb") + + @dag(schedule=None, start_date=DEFAULT_DATE, catchup=False, default_args=default_args) + def dag_regular(): + nonlocal task + tasks = PipelineTasksGroup( + "pipeline_dag_regular", local_data_folder=TEST_STORAGE_ROOT, wipe_local_data=False + ) + + # set duckdb to be outside of pipeline folder which is dropped on each task + pipeline_dag_regular = dlt.pipeline( + pipeline_name="pipeline_dag_regular", + dataset_name="mock_data_" + uniq_id(), + destination="duckdb", + credentials=quackdb_path, + ) + task = tasks.run(pipeline_dag_regular, mock_data_source()) + + dag_def: DAG = dag_regular() + assert task.task_id == "mock_data_source__r_init-_t_init_post-_t1-_t2-2-more" + + dag_def.test() + + pipeline_dag_regular = dlt.attach(pipeline_name="pipeline_dag_regular") + pipeline_dag_regular_counts = load_table_counts( + pipeline_dag_regular, + *[t["name"] for t in pipeline_dag_regular.default_schema.data_tables()], + ) + assert pipeline_dag_regular_counts == pipeline_standalone_counts + + assert isinstance(task, PythonOperator) + + def test_parallel_run(): pipeline_standalone = dlt.pipeline( pipeline_name="pipeline_parallel", diff --git a/tests/libs/test_pyarrow.py b/tests/libs/test_pyarrow.py index dffda35005..68541e96e0 100644 --- a/tests/libs/test_pyarrow.py +++ b/tests/libs/test_pyarrow.py @@ -1,9 +1,17 @@ from copy import deepcopy - +from datetime import timezone, datetime, timedelta # noqa: I251 import pyarrow as pa -from dlt.common.libs.pyarrow import py_arrow_to_table_schema_columns, get_py_arrow_datatype +from dlt.common import pendulum +from dlt.common.libs.pyarrow import ( + from_arrow_scalar, + get_py_arrow_timestamp, + py_arrow_to_table_schema_columns, + get_py_arrow_datatype, + to_arrow_scalar, +) from dlt.common.destination import DestinationCapabilitiesContext + from tests.cases import TABLE_UPDATE_COLUMNS_SCHEMA @@ -49,3 +57,55 @@ def test_py_arrow_to_table_schema_columns(): # Resulting schema should match the original assert result == dlt_schema + + +def test_to_arrow_scalar() -> None: + naive_dt = get_py_arrow_timestamp(6, tz=None) + # print(naive_dt) + # naive datetimes are converted as UTC when time aware python objects are used + assert to_arrow_scalar(datetime(2021, 1, 1, 5, 2, 32), naive_dt).as_py() == datetime( + 2021, 1, 1, 5, 2, 32 + ) + assert to_arrow_scalar( + datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone.utc), naive_dt + ).as_py() == datetime(2021, 1, 1, 5, 2, 32) + assert to_arrow_scalar( + datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone(timedelta(hours=-8))), naive_dt + ).as_py() == datetime(2021, 1, 1, 5, 2, 32) + timedelta(hours=8) + + # naive datetimes are treated like UTC + utc_dt = get_py_arrow_timestamp(6, tz="UTC") + dt_converted = to_arrow_scalar( + datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone(timedelta(hours=-8))), utc_dt + ).as_py() + assert dt_converted.utcoffset().seconds == 0 + assert dt_converted == datetime(2021, 1, 1, 13, 2, 32, tzinfo=timezone.utc) + + berlin_dt = get_py_arrow_timestamp(6, tz="Europe/Berlin") + dt_converted = to_arrow_scalar( + datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone(timedelta(hours=-8))), berlin_dt + ).as_py() + # no dst + assert dt_converted.utcoffset().seconds == 60 * 60 + assert dt_converted == datetime(2021, 1, 1, 13, 2, 32, tzinfo=timezone.utc) + + +def test_from_arrow_scalar() -> None: + naive_dt = get_py_arrow_timestamp(6, tz=None) + sc_dt = to_arrow_scalar(datetime(2021, 1, 1, 5, 2, 32), naive_dt) + + # this value is like UTC + py_dt = from_arrow_scalar(sc_dt) + assert isinstance(py_dt, pendulum.DateTime) + # and we convert to explicit UTC + assert py_dt == datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone.utc) + + # converts to UTC + berlin_dt = get_py_arrow_timestamp(6, tz="Europe/Berlin") + sc_dt = to_arrow_scalar( + datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone(timedelta(hours=-8))), berlin_dt + ) + py_dt = from_arrow_scalar(sc_dt) + assert isinstance(py_dt, pendulum.DateTime) + assert py_dt.tzname() == "UTC" + assert py_dt == datetime(2021, 1, 1, 13, 2, 32, tzinfo=timezone.utc) diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 5fa656ada9..a93599831d 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -878,7 +878,6 @@ def test_pipeline_upfront_tables_two_loads( # use staging tables for replace os.environ["DESTINATION__REPLACE_STRATEGY"] = replace_strategy - print(destination_config) pipeline = destination_config.setup_pipeline( "test_pipeline_upfront_tables_two_loads", dataset_name="test_pipeline_upfront_tables_two_loads", @@ -984,6 +983,48 @@ def table_3(make_data=False): ) +# @pytest.mark.skip(reason="Finalize the test: compare some_data values to values from database") +# @pytest.mark.parametrize( +# "destination_config", +# destinations_configs(all_staging_configs=True, default_sql_configs=True, file_format=["insert_values", "jsonl", "parquet"]), +# ids=lambda x: x.name, +# ) +# def test_load_non_utc_timestamps_with_arrow(destination_config: DestinationTestConfiguration) -> None: +# """Checks if dates are stored properly and timezones are not mangled""" +# from datetime import timedelta, datetime, timezone +# start_dt = datetime.now() + +# # columns=[{"name": "Hour", "data_type": "bool"}] +# @dlt.resource(standalone=True, primary_key="Hour") +# def some_data( +# max_hours: int = 2, +# ): +# data = [ +# { +# "naive_dt": start_dt + timedelta(hours=hour), "hour": hour, +# "utc_dt": pendulum.instance(start_dt + timedelta(hours=hour)), "hour": hour, +# # tz="Europe/Berlin" +# "berlin_dt": pendulum.instance(start_dt + timedelta(hours=hour), tz=timezone(offset=timedelta(hours=-8))), "hour": hour, +# } +# for hour in range(0, max_hours) +# ] +# data = data_to_item_format("arrow", data) +# # print(py_arrow_to_table_schema_columns(data[0].schema)) +# # print(data) +# yield data + +# pipeline = destination_config.setup_pipeline( +# "test_load_non_utc_timestamps", +# dataset_name="test_load_non_utc_timestamps", +# full_refresh=True, +# ) +# info = pipeline.run(some_data()) +# # print(pipeline.default_schema.to_pretty_yaml()) +# assert_load_info(info) +# table_name = pipeline.sql_client().make_qualified_table_name("some_data") +# print(select_data(pipeline, f"SELECT * FROM {table_name}")) + + def simple_nested_pipeline( destination_config: DestinationTestConfiguration, dataset_name: str, full_refresh: bool ) -> Tuple[dlt.Pipeline, Callable[[], DltSource]]: diff --git a/tests/load/snowflake/test_snowflake_configuration.py b/tests/load/snowflake/test_snowflake_configuration.py index fb8ff925c0..d0ca4de41b 100644 --- a/tests/load/snowflake/test_snowflake_configuration.py +++ b/tests/load/snowflake/test_snowflake_configuration.py @@ -1,7 +1,7 @@ import os import pytest from pathlib import Path -from sqlalchemy.engine import make_url +from dlt.common.libs.sql_alchemy import make_url pytest.importorskip("snowflake") diff --git a/tests/load/utils.py b/tests/load/utils.py index d8a20d5518..7b4cf72b47 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -159,7 +159,7 @@ def destinations_configs( all_buckets_filesystem_configs: bool = False, subset: Sequence[str] = (), exclude: Sequence[str] = (), - file_format: Optional[TLoaderFileFormat] = None, + file_format: Union[TLoaderFileFormat, Sequence[TLoaderFileFormat]] = None, supports_merge: Optional[bool] = None, supports_dbt: Optional[bool] = None, ) -> List[DestinationTestConfiguration]: @@ -383,8 +383,12 @@ def destinations_configs( conf for conf in destination_configs if conf.destination not in exclude ] if file_format: + if not isinstance(file_format, Sequence): + file_format = [file_format] destination_configs = [ - conf for conf in destination_configs if conf.file_format == file_format + conf + for conf in destination_configs + if conf.file_format and conf.file_format in file_format ] if supports_merge is not None: destination_configs = [ diff --git a/tests/normalize/test_normalize.py b/tests/normalize/test_normalize.py index a345a05ebe..39a18c5de2 100644 --- a/tests/normalize/test_normalize.py +++ b/tests/normalize/test_normalize.py @@ -18,6 +18,7 @@ from dlt.extract.extract import ExtractStorage from dlt.normalize import Normalize +from dlt.normalize.exceptions import NormalizeJobFailed from tests.cases import JSON_TYPED_DICT, JSON_TYPED_DICT_TYPES from tests.utils import ( @@ -436,6 +437,28 @@ def assert_schema(_schema: Schema): assert_schema(schema) +def test_normalize_retry(raw_normalize: Normalize) -> None: + load_id = extract_cases(raw_normalize, ["github.issues.load_page_5_duck"]) + schema = raw_normalize.normalize_storage.extracted_packages.load_schema(load_id) + schema.set_schema_contract("freeze") + raw_normalize.normalize_storage.extracted_packages.save_schema(load_id, schema) + # will fail on contract violatiom + with pytest.raises(NormalizeJobFailed): + raw_normalize.run(None) + + # drop the contract requirements + schema.set_schema_contract("evolve") + # save this schema into schema storage from which normalizer must pick it up + raw_normalize.schema_storage.save_schema(schema) + # raw_normalize.normalize_storage.extracted_packages.save_schema(load_id, schema) + # subsequent run must succeed + raw_normalize.run(None) + _, table_files = expect_load_package( + raw_normalize.load_storage, load_id, ["issues", "issues__labels", "issues__assignees"] + ) + assert len(table_files["issues"]) == 1 + + def test_group_worker_files() -> None: files = ["f%03d" % idx for idx in range(0, 100)] @@ -543,7 +566,7 @@ def normalize_pending(normalize: Normalize) -> str: return load_id -def extract_cases(normalize: Normalize, cases: Sequence[str]) -> None: +def extract_cases(normalize: Normalize, cases: Sequence[str]) -> str: items: List[StrAny] = [] for case in cases: # our cases have schema and table name encoded in file name @@ -555,7 +578,7 @@ def extract_cases(normalize: Normalize, cases: Sequence[str]) -> None: else: items.append(item) # we assume that all items belonged to a single schema - extract_items( + return extract_items( normalize.normalize_storage, items, load_or_create_schema(normalize, schema_name), diff --git a/tests/pipeline/test_import_export_schema.py b/tests/pipeline/test_import_export_schema.py new file mode 100644 index 0000000000..b1c2284f24 --- /dev/null +++ b/tests/pipeline/test_import_export_schema.py @@ -0,0 +1,206 @@ +import dlt, os, pytest + +from dlt.common.utils import uniq_id + +from tests.utils import TEST_STORAGE_ROOT +from dlt.common.schema import Schema +from dlt.common.storages.schema_storage import SchemaStorage +from dlt.common.schema.exceptions import CannotCoerceColumnException +from dlt.pipeline.exceptions import PipelineStepFailed + +from dlt.destinations import dummy + + +IMPORT_SCHEMA_PATH = os.path.join(TEST_STORAGE_ROOT, "schemas", "import") +EXPORT_SCHEMA_PATH = os.path.join(TEST_STORAGE_ROOT, "schemas", "export") + + +EXAMPLE_DATA = [{"id": 1, "name": "dave"}] + + +def _get_import_schema(schema_name: str) -> Schema: + return SchemaStorage.load_schema_file(IMPORT_SCHEMA_PATH, schema_name) + + +def _get_export_schema(schema_name: str) -> Schema: + return SchemaStorage.load_schema_file(EXPORT_SCHEMA_PATH, schema_name) + + +def test_schemas_files_get_created() -> None: + name = "schema_test" + uniq_id() + + p = dlt.pipeline( + pipeline_name=name, + destination=dummy(completed_prob=1), + import_schema_path=IMPORT_SCHEMA_PATH, + export_schema_path=EXPORT_SCHEMA_PATH, + ) + + p.run(EXAMPLE_DATA, table_name="person") + + # basic check we have the table def in the export schema + export_schema = _get_export_schema(name) + assert export_schema.tables["person"]["columns"]["id"]["data_type"] == "bigint" + assert export_schema.tables["person"]["columns"]["name"]["data_type"] == "text" + + # discovered columns are not present in the import schema + import_schema = _get_import_schema(name) + assert "id" not in import_schema.tables["person"]["columns"] + assert "name" not in import_schema.tables["person"]["columns"] + + +def test_provided_columns_exported_to_import() -> None: + name = "schema_test" + uniq_id() + + p = dlt.pipeline( + pipeline_name=name, + destination=dummy(completed_prob=1), + import_schema_path=IMPORT_SCHEMA_PATH, + export_schema_path=EXPORT_SCHEMA_PATH, + ) + + p.run(EXAMPLE_DATA, table_name="person", columns={"id": {"data_type": "text"}}) + + # updated columns are in export + export_schema = _get_export_schema(name) + assert export_schema.tables["person"]["columns"]["id"]["data_type"] == "text" + assert export_schema.tables["person"]["columns"]["name"]["data_type"] == "text" + + # discovered columns are not present in the import schema + # but provided column is + import_schema = _get_import_schema(name) + assert "name" not in import_schema.tables["person"]["columns"] + assert import_schema.tables["person"]["columns"]["id"]["data_type"] == "text" + + +def test_import_schema_is_respected() -> None: + name = "schema_test" + uniq_id() + + p = dlt.pipeline( + pipeline_name=name, + destination=dummy(completed_prob=1), + import_schema_path=IMPORT_SCHEMA_PATH, + export_schema_path=EXPORT_SCHEMA_PATH, + ) + p.run(EXAMPLE_DATA, table_name="person") + assert p.default_schema.tables["person"]["columns"]["id"]["data_type"] == "bigint" + + # take default schema, modify column type and save it to import folder + modified_schema = p.default_schema.clone() + modified_schema.tables["person"]["columns"]["id"]["data_type"] = "text" + with open(os.path.join(IMPORT_SCHEMA_PATH, name + ".schema.yaml"), "w", encoding="utf-8") as f: + f.write(modified_schema.to_pretty_yaml()) + + # this will provoke a CannotCoerceColumnException + with pytest.raises(PipelineStepFailed) as exc: + p.run(EXAMPLE_DATA, table_name="person") + assert type(exc.value.exception) == CannotCoerceColumnException + + # schema is changed + assert p.default_schema.tables["person"]["columns"]["id"]["data_type"] == "text" + + # import schema is not overwritten + assert _get_import_schema(name).tables["person"]["columns"]["id"]["data_type"] == "text" + + # when creating a new schema (e.g. with full refresh), this will work + p = dlt.pipeline( + pipeline_name=name, + destination=dummy(completed_prob=1), + import_schema_path=IMPORT_SCHEMA_PATH, + export_schema_path=EXPORT_SCHEMA_PATH, + full_refresh=True, + ) + p.run(EXAMPLE_DATA, table_name="person") + assert p.default_schema.tables["person"]["columns"]["id"]["data_type"] == "text" + + # import schema is not overwritten + assert _get_import_schema(name).tables["person"]["columns"]["id"]["data_type"] == "text" + + # export now includes the modified column type + export_schema = _get_export_schema(name) + assert export_schema.tables["person"]["columns"]["id"]["data_type"] == "text" + assert export_schema.tables["person"]["columns"]["name"]["data_type"] == "text" + + +def test_only_explicit_hints_in_import_schema() -> None: + @dlt.source(schema_contract={"columns": "evolve"}) + def source(): + @dlt.resource(primary_key="id", name="person") + def resource(): + yield EXAMPLE_DATA + + return resource() + + p = dlt.pipeline( + pipeline_name=uniq_id(), + destination=dummy(completed_prob=1), + import_schema_path=IMPORT_SCHEMA_PATH, + export_schema_path=EXPORT_SCHEMA_PATH, + full_refresh=True, + ) + p.run(source()) + + # import schema has only the primary key hint, but no name or data types + import_schema = _get_import_schema("source") + assert import_schema.tables["person"]["columns"].keys() == {"id"} + assert import_schema.tables["person"]["columns"]["id"] == { + "nullable": False, + "primary_key": True, + "name": "id", + } + + # pipeline schema has all the stuff + assert p.default_schema.tables["person"]["columns"].keys() == { + "id", + "name", + "_dlt_load_id", + "_dlt_id", + } + assert p.default_schema.tables["person"]["columns"]["id"] == { + "nullable": False, + "primary_key": True, + "name": "id", + "data_type": "bigint", + } + + # adding column to the resource will not change the import schema, but the pipeline schema will evolve + @dlt.resource(primary_key="id", name="person", columns={"email": {"data_type": "text"}}) + def resource(): + yield EXAMPLE_DATA + + p.run(resource()) + + # check schemas + import_schema = _get_import_schema("source") + assert import_schema.tables["person"]["columns"].keys() == {"id"} + assert p.default_schema.tables["person"]["columns"].keys() == { + "id", + "name", + "_dlt_load_id", + "_dlt_id", + "email", + } + + # changing the import schema will force full update + import_schema.tables["person"]["columns"]["age"] = { + "data_type": "bigint", + "nullable": True, + "name": "age", + } + with open( + os.path.join(IMPORT_SCHEMA_PATH, "source" + ".schema.yaml"), "w", encoding="utf-8" + ) as f: + f.write(import_schema.to_pretty_yaml()) + + # run with the original source, email hint should be gone after this, but we now have age + p.run(source()) + + assert p.default_schema.tables["person"]["columns"].keys() == { + "id", + "name", + "_dlt_load_id", + "_dlt_id", + "age", + } + import_schema = _get_import_schema("source") + assert import_schema.tables["person"]["columns"].keys() == {"id", "age"} diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index e1f7397ef9..0cebeb2ff7 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -3,6 +3,7 @@ import itertools import logging import os +import random from time import sleep from typing import Any, Tuple, cast import threading @@ -580,9 +581,8 @@ def data_piece_2(): assert p.first_run is True assert p.has_data is False assert p.default_schema_name is None - # one of the schemas is in memory - # TODO: we may want to fix that - assert len(p._schema_storage.list_schemas()) == 1 + # live schemas created during extract are popped from mem + assert len(p._schema_storage.list_schemas()) == 0 # restore the pipeline p = dlt.attach(pipeline_name) @@ -616,9 +616,8 @@ def data_schema_3(): # first run didn't really happen assert p.first_run is True assert p.has_data is False - # schemas from two sources are in memory - # TODO: we may want to fix that - assert len(p._schema_storage.list_schemas()) == 2 + # live schemas created during extract are popped from mem + assert len(p._schema_storage.list_schemas()) == 0 assert p.default_schema_name is None os.environ["COMPLETED_PROB"] = "1.0" # make it complete immediately @@ -1663,3 +1662,198 @@ def api_fetch(page_num): load_info = pipeline.run(product()) assert_load_info(load_info) assert pipeline.last_trace.last_normalize_info.row_counts["product"] == 12 + + +def test_run_with_pua_payload() -> None: + # prepare some data and complete load with run + os.environ["COMPLETED_PROB"] = "1.0" + pipeline_name = "pipe_" + uniq_id() + p = dlt.pipeline(pipeline_name=pipeline_name, destination="duckdb") + print(pipeline_name) + from dlt.common.json import PUA_START, PUA_CHARACTER_MAX + + def some_data(): + yield from [ + # text is only PUA + {"id": 1, "text": chr(PUA_START)}, + {"id": 2, "text": chr(PUA_START - 1)}, + {"id": 3, "text": chr(PUA_START + 1)}, + {"id": 4, "text": chr(PUA_START + PUA_CHARACTER_MAX + 1)}, + # PUA inside text + {"id": 5, "text": f"a{chr(PUA_START)}b"}, + {"id": 6, "text": f"a{chr(PUA_START - 1)}b"}, + {"id": 7, "text": f"a{chr(PUA_START + 1)}b"}, + # text starts with PUA + {"id": 8, "text": f"{chr(PUA_START)}a"}, + {"id": 9, "text": f"{chr(PUA_START - 1)}a"}, + {"id": 10, "text": f"{chr(PUA_START + 1)}a"}, + ] + + @dlt.source + def source(): + return dlt.resource(some_data(), name="pua_data") + + load_info = p.run(source()) + assert p.last_trace.last_normalize_info.row_counts["pua_data"] == 10 + + with p.sql_client() as client: + rows = client.execute_sql("SELECT text FROM pua_data ORDER BY id") + + values = [r[0] for r in rows] + assert values == [ + "\uf026", + "\uf025", + "\uf027", + "\uf02f", + "a\uf026b", + "a\uf025b", + "a\uf027b", + "\uf026a", + "\uf025a", + "\uf027a", + ] + assert len(load_info.loads_ids) == 1 + + +def test_pipeline_load_info_metrics_schema_is_not_chaning() -> None: + """Test if load info schema is idempotent throughout multiple load cycles + + ## Setup + + We will run the same pipeline with + + 1. A single source returning one resource and collect `schema.version_hash`, + 2. Another source returning 2 resources with more complex data and collect `schema.version_hash`, + 3. At last we run both sources, + 4. For each 1. 2. 3. we load `last_extract_info`, `last_normalize_info` and `last_load_info` and collect `schema.version_hash` + + ## Expected + + `version_hash` collected in each stage should remain the same at all times. + """ + data = [ + {"id": 1, "name": "Alice"}, + {"id": 2, "name": "Bob"}, + ] + + # this source must have all the hints so other sources do not change trace schema (extract/hints) + + @dlt.source + def users_source(): + return dlt.resource([data], name="users_resource") + + @dlt.source + def taxi_demand_source(): + @dlt.resource( + primary_key="city", columns=[{"name": "id", "data_type": "bigint", "precision": 4}] + ) + def locations(idx=dlt.sources.incremental("id")): + for idx in range(10): + yield { + "id": idx, + "address": f"address-{idx}", + "city": f"city-{idx}", + } + + @dlt.resource(primary_key="id") + def demand_map(): + for idx in range(10): + yield { + "id": idx, + "city": f"city-{idx}", + "demand": random.randint(0, 10000), + } + + return [locations, demand_map] + + schema = dlt.Schema(name="nice_load_info_schema") + pipeline = dlt.pipeline( + pipeline_name="quick_start", + destination="duckdb", + dataset_name="mydata", + # export_schema_path="schemas", + ) + + taxi_load_info = pipeline.run( + taxi_demand_source(), + ) + + schema_hashset = set() + pipeline.run( + [taxi_load_info], + table_name="_load_info", + schema=schema, + ) + + pipeline.run( + [pipeline.last_trace.last_normalize_info], + table_name="_normalize_info", + schema=schema, + ) + + pipeline.run( + [pipeline.last_trace.last_extract_info], + table_name="_extract_info", + schema=schema, + ) + schema_hashset.add(pipeline.schemas["nice_load_info_schema"].version_hash) + trace_schema = pipeline.schemas["nice_load_info_schema"].to_pretty_yaml() + + users_load_info = pipeline.run( + users_source(), + ) + + pipeline.run( + [users_load_info], + table_name="_load_info", + schema=schema, + ) + assert trace_schema == pipeline.schemas["nice_load_info_schema"].to_pretty_yaml() + schema_hashset.add(pipeline.schemas["nice_load_info_schema"].version_hash) + assert len(schema_hashset) == 1 + + pipeline.run( + [pipeline.last_trace.last_normalize_info], + table_name="_normalize_info", + schema=schema, + ) + schema_hashset.add(pipeline.schemas["nice_load_info_schema"].version_hash) + assert len(schema_hashset) == 1 + + pipeline.run( + [pipeline.last_trace.last_extract_info], + table_name="_extract_info", + schema=schema, + ) + schema_hashset.add(pipeline.schemas["nice_load_info_schema"].version_hash) + assert len(schema_hashset) == 1 + + load_info = pipeline.run( + [users_source(), taxi_demand_source()], + ) + + pipeline.run( + [load_info], + table_name="_load_info", + schema=schema, + ) + + schema_hashset.add(pipeline.schemas["nice_load_info_schema"].version_hash) + + pipeline.run( + [pipeline.last_trace.last_normalize_info], + table_name="_normalize_info", + schema=schema, + ) + + schema_hashset.add(pipeline.schemas["nice_load_info_schema"].version_hash) + + pipeline.run( + [pipeline.last_trace.last_extract_info], + table_name="_extract_info", + schema=schema, + ) + + schema_hashset.add(pipeline.schemas["nice_load_info_schema"].version_hash) + + assert len(schema_hashset) == 1 diff --git a/tests/pipeline/test_resources_evaluation.py b/tests/pipeline/test_resources_evaluation.py index 7f0a7890a7..5a85c06462 100644 --- a/tests/pipeline/test_resources_evaluation.py +++ b/tests/pipeline/test_resources_evaluation.py @@ -1,6 +1,11 @@ -from typing import Any +from typing import Any, List +import time +import threading +import random +from itertools import product import dlt, asyncio, pytest, os +from dlt.extract.exceptions import ResourceExtractionError def test_async_iterator_resource() -> None: @@ -212,99 +217,276 @@ async def async_resource1(): assert len(result) == 13 -# @pytest.mark.skip(reason="To be properly implemented in an upcoming PR") -# @pytest.mark.parametrize("parallelized", [True, False]) -# def test_async_decorator_experiment(parallelized) -> None: -# os.environ["EXTRACT__NEXT_ITEM_MODE"] = "fifo" -# execution_order = [] -# threads = set() - -# def parallelize(f) -> Any: -# """converts regular itarable to generator of functions that can be run in parallel in the pipe""" - -# @wraps(f) -# def _wrap(*args: Any, **kwargs: Any) -> Any: -# exhausted = False -# busy = False - -# gen = f(*args, **kwargs) -# # unpack generator -# if inspect.isfunction(gen): -# gen = gen() -# # if we have an async gen, no further action is needed -# if inspect.isasyncgen(gen): -# raise Exception("Already async gen") - -# # get next item from generator -# def _gen(): -# nonlocal exhausted -# # await asyncio.sleep(0.1) -# try: -# return next(gen) -# # on stop iteration mark as exhausted -# except StopIteration: -# exhausted = True -# return None -# finally: -# nonlocal busy -# busy = False - -# try: -# while not exhausted: -# while busy: -# yield None -# busy = True -# yield _gen -# except GeneratorExit: -# # clean up inner generator -# gen.close() - -# return _wrap - -# @parallelize -# def resource1(): -# for l_ in ["a", "b", "c"]: -# time.sleep(0.5) -# nonlocal execution_order -# execution_order.append("one") -# threads.add(threading.get_ident()) -# yield {"letter": l_} - -# @parallelize -# def resource2(): -# time.sleep(0.25) -# for l_ in ["e", "f", "g"]: -# time.sleep(0.5) -# nonlocal execution_order -# execution_order.append("two") -# threads.add(threading.get_ident()) -# yield {"letter": l_} - -# @dlt.source -# def source(): -# if parallelized: -# return [resource1(), resource2()] -# else: # return unwrapped resources -# return [resource1.__wrapped__(), resource2.__wrapped__()] - -# pipeline_1 = dlt.pipeline("pipeline_1", destination="duckdb", full_refresh=True) -# pipeline_1.run(source()) - -# # all records should be here -# with pipeline_1.sql_client() as c: -# with c.execute_query("SELECT * FROM resource1") as cur: -# rows = list(cur.fetchall()) -# assert len(rows) == 3 -# assert {r[0] for r in rows} == {"a", "b", "c"} - -# with c.execute_query("SELECT * FROM resource2") as cur: -# rows = list(cur.fetchall()) -# assert len(rows) == 3 -# assert {r[0] for r in rows} == {"e", "f", "g"} - -# if parallelized: -# assert len(threads) > 1 -# assert execution_order == ["one", "two", "one", "two", "one", "two"] -# else: -# assert execution_order == ["one", "one", "one", "two", "two", "two"] -# assert len(threads) == 1 +@pytest.mark.parametrize("parallelized", [True, False]) +def test_parallelized_resource(parallelized: bool) -> None: + os.environ["EXTRACT__NEXT_ITEM_MODE"] = "fifo" + execution_order = [] + threads = set() + + @dlt.resource(parallelized=parallelized) + def resource1(): + for l_ in ["a", "b", "c"]: + time.sleep(0.01) + execution_order.append("one") + threads.add(threading.get_ident()) + yield {"letter": l_} + + @dlt.resource(parallelized=parallelized) + def resource2(): + for l_ in ["e", "f", "g"]: + time.sleep(0.01) + execution_order.append("two") + threads.add(threading.get_ident()) + yield {"letter": l_} + + @dlt.source + def source(): + return [resource1(), resource2()] + + pipeline_1 = dlt.pipeline("pipeline_1", destination="duckdb", full_refresh=True) + pipeline_1.run(source()) + + # all records should be here + with pipeline_1.sql_client() as c: + with c.execute_query("SELECT * FROM resource1") as cur: + rows = list(cur.fetchall()) + assert len(rows) == 3 + assert {r[0] for r in rows} == {"a", "b", "c"} + + with c.execute_query("SELECT * FROM resource2") as cur: + rows = list(cur.fetchall()) + assert len(rows) == 3 + assert {r[0] for r in rows} == {"e", "f", "g"} + + if parallelized: + assert ( + len(threads) > 1 and threading.get_ident() not in threads + ) # Nothing runs in main thread + else: + assert execution_order == ["one", "one", "one", "two", "two", "two"] + assert threads == {threading.get_ident()} # Everything runs in main thread + + +# Parametrize with different resource counts to excersize the worker pool: +# 1. More than number of workers +# 2. 1 resource only +# 3. Exact number of workers +# 4. More than future pool max size +# 5. Exact future pool max size +@pytest.mark.parametrize( + "n_resources,next_item_mode", product([8, 1, 5, 25, 20], ["fifo", "round_robin"]) +) +def test_parallelized_resource_extract_order(n_resources: int, next_item_mode: str) -> None: + os.environ["EXTRACT__NEXT_ITEM_MODE"] = next_item_mode + + threads = set() + + item_counts = [random.randrange(10, 30) for _ in range(n_resources)] + item_ranges = [] # Create numeric ranges that each resource will yield + # Use below to check the extraction order + for i, n_items in enumerate(item_counts): + if i == 0: + start_range = 0 + else: + start_range = sum(item_counts[:i]) + end_range = start_range + n_items + item_ranges.append(range(start_range, end_range)) + + @dlt.source + def some_source(): + def some_data(resource_num: int): + for item in item_ranges[resource_num]: + threads.add(threading.get_ident()) + print(f"RESOURCE {resource_num}") + # Sleep for a random duration each yield + time.sleep(random.uniform(0.005, 0.012)) + yield f"item-{item}" + print(f"RESOURCE {resource_num}:", item) + + for i in range(n_resources): + yield dlt.resource(some_data, name=f"some_data_{i}", parallelized=True)(i) + + source = some_source() + result = list(source) + result = [int(item.split("-")[1]) for item in result] + + assert len(result) == sum(item_counts) + + # Check extracted results from each resource + chunked_results = [] + start_range = 0 + for item_range in item_ranges: + chunked_results.append([item for item in result if item in item_range]) + + for i, chunk in enumerate(chunked_results): + # All items are included + assert len(chunk) == item_counts[i] + assert len(set(chunk)) == len(chunk) + # Items are extracted in order per resource + assert chunk == sorted(chunk) + + assert len(threads) >= min(2, n_resources) and threading.get_ident() not in threads + + +def test_test_parallelized_resource_transformers() -> None: + item_count = 6 + threads = set() + transformer_threads = set() + + @dlt.resource(parallelized=True) + def pos_data(): + for i in range(1, item_count + 1): + threads.add(threading.get_ident()) + time.sleep(0.1) + yield i + + @dlt.resource(parallelized=True) + def neg_data(): + for i in range(-1, -item_count - 1, -1): + threads.add(threading.get_ident()) + time.sleep(0.1) + yield i + + @dlt.transformer(parallelized=True) + def multiply(item): + transformer_threads.add(threading.get_ident()) + time.sleep(0.05) + yield item * 10 + + @dlt.source + def some_source(): + return [ + neg_data | multiply.with_name("t_a"), + pos_data | multiply.with_name("t_b"), + ] + + result = list(some_source()) + + expected_result = [i * 10 for i in range(-item_count, item_count + 1)] + expected_result.remove(0) + + assert sorted(result) == expected_result + # Nothing runs in main thread + assert threads and threading.get_ident() not in threads + assert transformer_threads and threading.get_ident() not in transformer_threads + + threads = set() + transformer_threads = set() + + @dlt.transformer(parallelized=True) # type: ignore[no-redef] + def multiply(item): + # Transformer that is not a generator + transformer_threads.add(threading.get_ident()) + time.sleep(0.05) + return item * 10 + + @dlt.source # type: ignore[no-redef] + def some_source(): + return [ + neg_data | multiply.with_name("t_a"), + pos_data | multiply.with_name("t_b"), + ] + + result = list(some_source()) + + expected_result = [i * 10 for i in range(-item_count, item_count + 1)] + expected_result.remove(0) + + assert sorted(result) == expected_result + + # Nothing runs in main thread + assert len(threads) > 1 and threading.get_ident() not in threads + assert len(transformer_threads) > 1 and threading.get_ident() not in transformer_threads + + +def test_parallelized_resource_bare_generator() -> None: + main_thread = threading.get_ident() + threads = set() + + def pos_data(): + for i in range(1, 6): + threads.add(threading.get_ident()) + time.sleep(0.01) + yield i + + def neg_data(): + for i in range(-1, -6, -1): + threads.add(threading.get_ident()) + time.sleep(0.01) + yield i + + @dlt.source + def some_source(): + return [ + # Resources created from generators directly (not generator functions) can be parallelized + dlt.resource(pos_data(), parallelized=True, name="pos_data"), + dlt.resource(neg_data(), parallelized=True, name="neg_data"), + ] + + result = list(some_source()) + + assert len(threads) > 1 and main_thread not in threads + assert set(result) == {1, 2, 3, 4, 5, -1, -2, -3, -4, -5} + assert len(result) == 10 + + +def test_parallelized_resource_wrapped_generator() -> None: + threads = set() + + def some_data(): + for i in range(1, 6): + time.sleep(0.01) + threads.add(threading.get_ident()) + yield i + + def some_data2(): + for i in range(-1, -6, -1): + time.sleep(0.01) + threads.add(threading.get_ident()) + yield i + + @dlt.source + def some_source(): + # Bound resources result in a wrapped generator function, + return [ + dlt.resource(some_data, parallelized=True, name="some_data")(), + dlt.resource(some_data2, parallelized=True, name="some_data2")(), + ] + + source = some_source() + + result = list(source) + + assert len(threads) > 1 and threading.get_ident() not in threads + assert set(result) == {1, 2, 3, 4, 5, -1, -2, -3, -4, -5} + + +def test_parallelized_resource_exception_pool_is_closed() -> None: + """Checking that futures pool is closed before generators are closed when a parallel resource raises. + For now just checking that we don't get any "generator is already closed" errors, as would happen + when futures aren't cancelled before closing generators. + """ + + def some_data(): + for i in range(1, 6): + time.sleep(0.1) + yield i + + def some_data2(): + for i in range(1, 6): + time.sleep(0.005) + yield i + if i == 3: + raise RuntimeError("we have failed") + + @dlt.source + def some_source(): + yield dlt.resource(some_data, parallelized=True, name="some_data") + yield dlt.resource(some_data2, parallelized=True, name="some_data2") + + source = some_source() + + with pytest.raises(ResourceExtractionError) as einfo: + list(source) + + assert "we have failed" in str(einfo.value) diff --git a/tests/sources/helpers/test_requests.py b/tests/sources/helpers/test_requests.py index 695fa93eca..aefdf23e77 100644 --- a/tests/sources/helpers/test_requests.py +++ b/tests/sources/helpers/test_requests.py @@ -1,7 +1,5 @@ -from contextlib import contextmanager -from typing import Iterator, Any, cast, Type +from typing import Iterator, Type from unittest import mock -from email.utils import format_datetime import os import random @@ -105,6 +103,25 @@ def test_retry_on_status_without_raise_for_status(mock_sleep: mock.MagicMock) -> assert m.call_count == RunConfiguration.request_max_attempts +def test_hooks_with_raise_for_statue() -> None: + url = "https://example.com/data" + session = Client(raise_for_status=True).session + + def _no_content(resp: requests.Response, *args, **kwargs) -> requests.Response: + resp.status_code = 204 + resp._content = b"[]" + return resp + + with requests_mock.mock(session=session) as m: + m.get(url, status_code=503) + response = session.get(url, hooks={"response": _no_content}) + # we simulate empty response + assert response.status_code == 204 + assert response.json() == [] + + assert m.call_count == 1 + + @pytest.mark.parametrize( "exception_class", [requests.ConnectionError, requests.ConnectTimeout, requests.exceptions.ChunkedEncodingError], diff --git a/tests/utils.py b/tests/utils.py index 1d2ace2533..dd03279def 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -205,6 +205,25 @@ def data_to_item_format( raise ValueError(f"Unknown item format: {item_format}") +def data_item_length(data: TDataItem) -> int: + import pandas as pd + from dlt.common.libs.pyarrow import pyarrow as pa + + if isinstance(data, list): + # If data is a list, check if it's a list of supported data types + if all(isinstance(item, (list, pd.DataFrame, pa.Table, pa.RecordBatch)) for item in data): + return sum(data_item_length(item) for item in data) + # If it's a list but not a list of supported types, treat it as a single list object + else: + return len(data) + elif isinstance(data, pd.DataFrame): + return len(data.index) + elif isinstance(data, pa.Table) or isinstance(data, pa.RecordBatch): + return data.num_rows + else: + raise TypeError("Unsupported data type.") + + def init_test_logging(c: RunConfiguration = None) -> None: if not c: c = resolve_configuration(RunConfiguration())