diff --git a/.github/workflows/test_destinations.yml b/.github/workflows/test_destinations.yml index ada73b85d9..95fbd83ad9 100644 --- a/.github/workflows/test_destinations.yml +++ b/.github/workflows/test_destinations.yml @@ -77,11 +77,14 @@ jobs: - name: Install dependencies # if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' - run: poetry install --no-interaction -E redshift -E gs -E s3 -E az -E parquet -E duckdb -E cli --with sentry-sdk --with pipeline -E deltalake + run: poetry install --no-interaction -E redshift -E gs -E s3 -E az -E parquet -E duckdb -E cli -E filesystem --with sentry-sdk --with pipeline -E deltalake - name: create secrets.toml run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml + - name: clear duckdb secrets and cache + run: rm -rf ~/.duckdb + - run: | poetry run pytest tests/load --ignore tests/load/sources -m "essential" name: Run essential tests Linux diff --git a/.github/workflows/test_doc_snippets.yml b/.github/workflows/test_doc_snippets.yml index 2bff0df899..faa2c59a0b 100644 --- a/.github/workflows/test_doc_snippets.yml +++ b/.github/workflows/test_doc_snippets.yml @@ -60,7 +60,7 @@ jobs: uses: actions/checkout@master - name: Start weaviate - run: docker compose -f ".github/weaviate-compose.yml" up -d + run: docker compose -f "tests/load/weaviate/docker-compose.yml" up -d - name: Setup Python uses: actions/setup-python@v4 diff --git a/.github/workflows/test_local_destinations.yml b/.github/workflows/test_local_destinations.yml index 8911e05ecc..51a078b1ab 100644 --- a/.github/workflows/test_local_destinations.yml +++ b/.github/workflows/test_local_destinations.yml @@ -73,7 +73,7 @@ jobs: uses: actions/checkout@master - name: Start weaviate - run: docker compose -f ".github/weaviate-compose.yml" up -d + run: docker compose -f "tests/load/weaviate/docker-compose.yml" up -d - name: Setup Python uses: actions/setup-python@v4 @@ -122,7 +122,7 @@ jobs: - name: Stop weaviate if: always() - run: docker compose -f ".github/weaviate-compose.yml" down -v + run: docker compose -f "tests/load/weaviate/docker-compose.yml" down -v - name: Stop SFTP server if: always() diff --git a/.github/workflows/test_pyarrow17.yml b/.github/workflows/test_pyarrow17.yml index 78d6742ac1..dc776e4ce1 100644 --- a/.github/workflows/test_pyarrow17.yml +++ b/.github/workflows/test_pyarrow17.yml @@ -65,14 +65,18 @@ jobs: key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}-pyarrow17 - name: Install dependencies - run: poetry install --no-interaction --with sentry-sdk --with pipeline -E deltalake -E gs -E s3 -E az + run: poetry install --no-interaction --with sentry-sdk --with pipeline -E deltalake -E duckdb -E filesystem -E gs -E s3 -E az + - name: Upgrade pyarrow run: poetry run pip install pyarrow==17.0.0 - + - name: create secrets.toml run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml + - name: clear duckdb secrets and cache + run: rm -rf ~/.duckdb + - name: Run needspyarrow17 tests Linux run: | poetry run pytest tests/libs -m "needspyarrow17" diff --git a/.github/workflows/test_sqlalchemy_destinations.yml b/.github/workflows/test_sqlalchemy_destinations.yml index 5da2dac04b..a38d644158 100644 --- a/.github/workflows/test_sqlalchemy_destinations.yml +++ b/.github/workflows/test_sqlalchemy_destinations.yml @@ -94,6 +94,3 @@ jobs: # always run full suite, also on branches - run: poetry run pytest tests/load -x --ignore tests/load/sources name: Run tests Linux - env: - DESTINATION__SQLALCHEMY_MYSQL__CREDENTIALS: mysql://root:root@127.0.0.1:3306/dlt_data # Use root cause we need to create databases - DESTINATION__SQLALCHEMY_SQLITE__CREDENTIALS: sqlite:///_storage/dl_data.sqlite diff --git a/Makefile b/Makefile index a01d6ae8b9..d22ed6732d 100644 --- a/Makefile +++ b/Makefile @@ -110,4 +110,11 @@ test-build-images: build-library preprocess-docs: # run docs preprocessing to run a few checks and ensure examples can be parsed - cd docs/website && npm i && npm run preprocess-docs \ No newline at end of file + cd docs/website && npm i && npm run preprocess-docs + +start-test-containers: + docker compose -f "tests/load/dremio/docker-compose.yml" up -d + docker compose -f "tests/load/postgres/docker-compose.yml" up -d + docker compose -f "tests/load/weaviate/docker-compose.yml" up -d + docker compose -f "tests/load/filesystem_sftp/docker-compose.yml" up -d + docker compose -f "tests/load/sqlalchemy/docker-compose.yml" up -d diff --git a/dlt/__init__.py b/dlt/__init__.py index eee105e47e..328817efd2 100644 --- a/dlt/__init__.py +++ b/dlt/__init__.py @@ -22,7 +22,7 @@ from dlt.version import __version__ from dlt.common.configuration.accessors import config, secrets -from dlt.common.typing import TSecretValue as _TSecretValue +from dlt.common.typing import TSecretValue as _TSecretValue, TSecretStrValue as _TSecretStrValue from dlt.common.configuration.specs import CredentialsConfiguration as _CredentialsConfiguration from dlt.common.pipeline import source_state as state from dlt.common.schema import Schema @@ -50,10 +50,12 @@ TSecretValue = _TSecretValue "When typing source/resource function arguments it indicates that a given argument is a secret and should be taken from dlt.secrets." +TSecretStrValue = _TSecretStrValue +"When typing source/resource function arguments it indicates that a given argument is a secret STRING and should be taken from dlt.secrets." + TCredentials = _CredentialsConfiguration "When typing source/resource function arguments it indicates that a given argument represents credentials and should be taken from dlt.secrets. Credentials may be a string, dictionary or any other type." - __all__ = [ "__version__", "config", @@ -78,3 +80,12 @@ "sources", "destinations", ] + +# verify that no injection context was created +from dlt.common.configuration.container import Container as _Container + +assert ( + _Container._INSTANCE is None +), "Injection container should not be initialized during initial import" +# create injection container +_Container() diff --git a/dlt/cli/config_toml_writer.py b/dlt/cli/config_toml_writer.py index 1b39653a55..59b16b16e1 100644 --- a/dlt/cli/config_toml_writer.py +++ b/dlt/cli/config_toml_writer.py @@ -104,6 +104,7 @@ def write_spec(toml_table: TOMLTable, config: BaseConfiguration, overwrite_exist def write_values( toml: TOMLContainer, values: Iterable[WritableConfigValue], overwrite_existing: bool ) -> None: + # TODO: decouple writers from a particular object model ie. TOML for value in values: toml_table: TOMLTable = toml # type: ignore for section in value.sections: diff --git a/dlt/cli/deploy_command.py b/dlt/cli/deploy_command.py index b48dffa881..88c132f5e2 100644 --- a/dlt/cli/deploy_command.py +++ b/dlt/cli/deploy_command.py @@ -5,7 +5,6 @@ from importlib.metadata import version as pkg_version from dlt.common.configuration.providers import SECRETS_TOML, SECRETS_TOML_KEY -from dlt.common.configuration.paths import make_dlt_settings_path from dlt.common.configuration.utils import serialize_value from dlt.common.git import is_dirty @@ -210,7 +209,7 @@ def _echo_instructions(self, *args: Optional[Any]) -> None: fmt.echo( "1. Add the following secret values (typically stored in %s): \n%s\nin %s" % ( - fmt.bold(make_dlt_settings_path(SECRETS_TOML)), + fmt.bold(utils.make_dlt_settings_path(SECRETS_TOML)), fmt.bold( "\n".join( self.env_prov.get_key_name(s_v.key, *s_v.sections) @@ -368,7 +367,7 @@ def _echo_instructions(self, *args: Optional[Any]) -> None: "3. Add the following secret values (typically stored in %s): \n%s\n%s\nin" " ENVIRONMENT VARIABLES using Google Composer UI" % ( - fmt.bold(make_dlt_settings_path(SECRETS_TOML)), + fmt.bold(utils.make_dlt_settings_path(SECRETS_TOML)), fmt.bold( "\n".join( self.env_prov.get_key_name(s_v.key, *s_v.sections) diff --git a/dlt/cli/deploy_command_helpers.py b/dlt/cli/deploy_command_helpers.py index 2afbfbf46e..38e95ce5d0 100644 --- a/dlt/cli/deploy_command_helpers.py +++ b/dlt/cli/deploy_command_helpers.py @@ -15,8 +15,11 @@ from dlt.common import git from dlt.common.configuration.exceptions import LookupTrace, ConfigFieldMissingException -from dlt.common.configuration.providers import ConfigTomlProvider, EnvironProvider -from dlt.common.configuration.providers.toml import BaseDocProvider, StringTomlProvider +from dlt.common.configuration.providers import ( + ConfigTomlProvider, + EnvironProvider, + StringTomlProvider, +) from dlt.common.git import get_origin, get_repo, Repo from dlt.common.configuration.specs.run_configuration import get_default_pipeline_name from dlt.common.typing import StrAny @@ -242,7 +245,7 @@ def _display_missing_secret_info(self) -> None: ) def _lookup_secret_value(self, trace: LookupTrace) -> Any: - return dlt.secrets[BaseDocProvider.get_key_name(trace.key, *trace.sections)] + return dlt.secrets[StringTomlProvider.get_key_name(trace.key, *trace.sections)] def _echo_envs(self) -> None: for v in self.envs: diff --git a/dlt/cli/init_command.py b/dlt/cli/init_command.py index 797917a165..0d3b5fe99e 100644 --- a/dlt/cli/init_command.py +++ b/dlt/cli/init_command.py @@ -2,14 +2,10 @@ import ast import shutil import tomlkit -from types import ModuleType -from typing import Dict, List, Sequence, Tuple -from importlib.metadata import version as pkg_version +from typing import Dict, Sequence, Tuple from pathlib import Path -from importlib import import_module from dlt.common import git -from dlt.common.configuration.paths import get_dlt_settings_dir, make_dlt_settings_path from dlt.common.configuration.specs import known_sections from dlt.common.configuration.providers import ( CONFIG_TOML, @@ -18,28 +14,29 @@ SecretsTomlProvider, ) from dlt.common.pipeline import get_dlt_repos_dir -from dlt.common.source import _SOURCES from dlt.version import DLT_PKG_NAME, __version__ from dlt.common.destination import Destination from dlt.common.reflection.utils import rewrite_python_script +from dlt.common.runtime import run_context from dlt.common.schema.utils import is_valid_schema_name from dlt.common.schema.exceptions import InvalidSchemaName from dlt.common.storages.file_storage import FileStorage -from dlt.sources import pipeline_templates as init_module + +from dlt.sources import SourceReference import dlt.reflection.names as n -from dlt.reflection.script_inspector import inspect_pipeline_script, load_script_module +from dlt.reflection.script_inspector import inspect_pipeline_script from dlt.cli import echo as fmt, pipeline_files as files_ops, source_detection from dlt.cli import utils from dlt.cli.config_toml_writer import WritableConfigValue, write_values from dlt.cli.pipeline_files import ( + TEMPLATE_FILES, SourceConfiguration, TVerifiedSourceFileEntry, TVerifiedSourceFileIndex, ) from dlt.cli.exceptions import CliCommandException -from dlt.cli.requirements import SourceRequirements DLT_INIT_DOCS_URL = "https://dlthub.com/docs/reference/command-line-interface#dlt-init" @@ -213,7 +210,7 @@ def _welcome_message( if is_new_source: fmt.echo( "* Add credentials for %s and other secrets in %s" - % (fmt.bold(destination_type), fmt.bold(make_dlt_settings_path(SECRETS_TOML))) + % (fmt.bold(destination_type), fmt.bold(utils.make_dlt_settings_path(SECRETS_TOML))) ) if destination_type == "destination": @@ -308,6 +305,9 @@ def init_command( core_sources_storage = _get_core_sources_storage() templates_storage = _get_templates_storage() + # get current run context + run_ctx = run_context.current() + # discover type of source source_type: files_ops.TSourceType = "template" if ( @@ -324,9 +324,9 @@ def init_command( source_type = "verified" # prepare destination storage - dest_storage = FileStorage(os.path.abspath(".")) - if not dest_storage.has_folder(get_dlt_settings_dir()): - dest_storage.create_folder(get_dlt_settings_dir()) + dest_storage = FileStorage(run_ctx.run_dir) + if not dest_storage.has_folder(run_ctx.settings_dir): + dest_storage.create_folder(run_ctx.settings_dir) # get local index of verified source files local_index = files_ops.load_verified_sources_local_index(source_name) # folder deleted at dest - full refresh @@ -376,8 +376,6 @@ def init_command( f"The verified sources repository is dirty. {source_name} source files may not" " update correctly in the future." ) - # add template files - source_configuration.files.extend(files_ops.TEMPLATE_FILES) else: if source_type == "core": @@ -399,9 +397,9 @@ def init_command( return # add .dlt/*.toml files to be copied - source_configuration.files.extend( - [make_dlt_settings_path(CONFIG_TOML), make_dlt_settings_path(SECRETS_TOML)] - ) + # source_configuration.files.extend( + # [run_ctx.get_setting(CONFIG_TOML), run_ctx.get_setting(SECRETS_TOML)] + # ) # add dlt extras line to requirements source_configuration.requirements.update_dlt_extras(destination_type) @@ -449,8 +447,6 @@ def init_command( visitor, [ ("destination", destination_type), - ("pipeline_name", source_name), - ("dataset_name", source_name + "_data"), ], source_configuration.src_pipeline_script, ) @@ -465,54 +461,48 @@ def init_command( # detect all the required secrets and configs that should go into tomls files if source_configuration.source_type == "template": # replace destination, pipeline_name and dataset_name in templates - transformed_nodes = source_detection.find_call_arguments_to_replace( - visitor, - [ - ("destination", destination_type), - ("pipeline_name", source_name), - ("dataset_name", source_name + "_data"), - ], - source_configuration.src_pipeline_script, - ) + # transformed_nodes = source_detection.find_call_arguments_to_replace( + # visitor, + # [ + # ("destination", destination_type), + # ("pipeline_name", source_name), + # ("dataset_name", source_name + "_data"), + # ], + # source_configuration.src_pipeline_script, + # ) # template sources are always in module starting with "pipeline" # for templates, place config and secrets into top level section required_secrets, required_config, checked_sources = source_detection.detect_source_configs( - _SOURCES, source_configuration.source_module_prefix, () + SourceReference.SOURCES, source_configuration.source_module_prefix, () ) # template has a strict rules where sources are placed - for source_q_name, source_config in checked_sources.items(): - if source_q_name not in visitor.known_sources_resources: - raise CliCommandException( - "init", - f"The pipeline script {source_configuration.src_pipeline_script} imports a" - f" source/resource {source_config.f.__name__} from module" - f" {source_config.module.__name__}. In init scripts you must declare all" - " sources and resources in single file.", - ) + # for source_q_name, source_config in checked_sources.items(): + # if source_q_name not in visitor.known_sources_resources: + # raise CliCommandException( + # "init", + # f"The pipeline script {source_configuration.src_pipeline_script} imports a" + # f" source/resource {source_config.name} from section" + # f" {source_config.section}. In init scripts you must declare all" + # f" sources and resources in single file. Known names are {list(visitor.known_sources_resources.keys())}.", + # ) # rename sources and resources - transformed_nodes.extend( - source_detection.find_source_calls_to_replace(visitor, source_name) - ) + # transformed_nodes.extend( + # source_detection.find_source_calls_to_replace(visitor, source_name) + # ) else: - # replace only destination for existing pipelines - transformed_nodes = source_detection.find_call_arguments_to_replace( - visitor, [("destination", destination_type)], source_configuration.src_pipeline_script - ) # pipeline sources are in module with name starting from {pipeline_name} # for verified pipelines place in the specific source section required_secrets, required_config, checked_sources = source_detection.detect_source_configs( - _SOURCES, + SourceReference.SOURCES, source_configuration.source_module_prefix, (known_sections.SOURCES, source_name), ) - - # the intro template does not use sources, for now allow it to pass here - if len(checked_sources) == 0 and source_name != "intro": - raise CliCommandException( - "init", - f"The pipeline script {source_configuration.src_pipeline_script} is not creating or" - " importing any sources or resources. Exiting...", - ) + if len(checked_sources) == 0: + raise CliCommandException( + "init", + f"The pipeline script {source_configuration.src_pipeline_script} is not creating or" + " importing any sources or resources. Exiting...", + ) # add destination spec to required secrets required_secrets["destinations:" + destination_type] = WritableConfigValue( @@ -570,23 +560,32 @@ def init_command( ) # copy files at the very end - for file_name in source_configuration.files: + copy_files = [] + # copy template files + for file_name in TEMPLATE_FILES: dest_path = dest_storage.make_full_path(file_name) - # get files from init section first if templates_storage.has_file(file_name): if dest_storage.has_file(dest_path): # do not overwrite any init files continue - src_path = templates_storage.make_full_path(file_name) - else: - # only those that were modified should be copied from verified sources - if file_name in remote_modified: - src_path = source_configuration.storage.make_full_path(file_name) - else: - continue + copy_files.append((templates_storage.make_full_path(file_name), dest_path)) + + # only those that were modified should be copied from verified sources + for file_name in remote_modified: + copy_files.append( + ( + source_configuration.storage.make_full_path(file_name), + # copy into where "sources" reside in run context, being root dir by default + dest_storage.make_full_path( + os.path.join(run_ctx.get_run_entity("sources"), file_name) + ), + ) + ) + + # modify storage at the end + for src_path, dest_path in copy_files: os.makedirs(os.path.dirname(dest_path), exist_ok=True) shutil.copy2(src_path, dest_path) - if remote_index: # delete files for file_name in remote_deleted: @@ -600,15 +599,11 @@ def init_command( dest_storage.save(source_configuration.dest_pipeline_script, dest_script_source) # generate tomls with comments - secrets_prov = SecretsTomlProvider() - secrets_toml = tomlkit.document() - write_values(secrets_toml, required_secrets.values(), overwrite_existing=False) - secrets_prov._config_doc = secrets_toml - - config_prov = ConfigTomlProvider() - config_toml = tomlkit.document() - write_values(config_toml, required_config.values(), overwrite_existing=False) - config_prov._config_doc = config_toml + secrets_prov = SecretsTomlProvider(settings_dir=run_ctx.settings_dir) + write_values(secrets_prov._config_toml, required_secrets.values(), overwrite_existing=False) + + config_prov = ConfigTomlProvider(settings_dir=run_ctx.settings_dir) + write_values(config_prov._config_toml, required_config.values(), overwrite_existing=False) # write toml files secrets_prov.write_toml() diff --git a/dlt/cli/pipeline_files.py b/dlt/cli/pipeline_files.py index 6ca39e0195..c15f988e54 100644 --- a/dlt/cli/pipeline_files.py +++ b/dlt/cli/pipeline_files.py @@ -8,7 +8,6 @@ from dlt.cli.exceptions import VerifiedSourceRepoError from dlt.common import git -from dlt.common.configuration.paths import make_dlt_settings_path from dlt.common.storages import FileStorage from dlt.common.reflection.utils import get_module_docstring @@ -31,7 +30,7 @@ PIPELINE_FILE_SUFFIX = "_pipeline.py" # hardcode default template files here -TEMPLATE_FILES = [".gitignore", ".dlt/config.toml", ".dlt/secrets.toml"] +TEMPLATE_FILES = [".gitignore", ".dlt/config.toml"] DEFAULT_PIPELINE_TEMPLATE = "default_pipeline.py" @@ -67,13 +66,13 @@ class TVerifiedSourcesFileIndex(TypedDict): def _save_dot_sources(index: TVerifiedSourcesFileIndex) -> None: - with open(make_dlt_settings_path(SOURCES_INIT_INFO_FILE), "w", encoding="utf-8") as f: + with open(utils.make_dlt_settings_path(SOURCES_INIT_INFO_FILE), "w", encoding="utf-8") as f: yaml.dump(index, f, allow_unicode=True, default_flow_style=False, sort_keys=False) def _load_dot_sources() -> TVerifiedSourcesFileIndex: try: - with open(make_dlt_settings_path(SOURCES_INIT_INFO_FILE), "r", encoding="utf-8") as f: + with open(utils.make_dlt_settings_path(SOURCES_INIT_INFO_FILE), "r", encoding="utf-8") as f: index: TVerifiedSourcesFileIndex = yaml.safe_load(f) if not index: raise FileNotFoundError(SOURCES_INIT_INFO_FILE) @@ -215,7 +214,7 @@ def get_template_configuration( sources_storage, source_pipeline_file_name, destination_pipeline_file_name, - TEMPLATE_FILES, + [], SourceRequirements([]), docstring, source_pipeline_file_name == DEFAULT_PIPELINE_TEMPLATE, @@ -233,7 +232,7 @@ def get_core_source_configuration( sources_storage, pipeline_file, pipeline_file, - [".gitignore"], + [], SourceRequirements([]), _get_docstring_for_module(sources_storage, source_name), False, diff --git a/dlt/cli/source_detection.py b/dlt/cli/source_detection.py index 636615af61..787f28881d 100644 --- a/dlt/cli/source_detection.py +++ b/dlt/cli/source_detection.py @@ -7,8 +7,8 @@ from dlt.common.configuration.specs import BaseConfiguration from dlt.common.reflection.utils import creates_func_def_name_node from dlt.common.typing import is_optional_type -from dlt.common.source import SourceInfo +from dlt.sources import SourceReference from dlt.cli.config_toml_writer import WritableConfigValue from dlt.cli.exceptions import CliCommandException from dlt.reflection.script_visitor import PipelineScriptVisitor @@ -72,19 +72,23 @@ def find_source_calls_to_replace( def detect_source_configs( - sources: Dict[str, SourceInfo], module_prefix: str, section: Tuple[str, ...] -) -> Tuple[Dict[str, WritableConfigValue], Dict[str, WritableConfigValue], Dict[str, SourceInfo]]: + sources: Dict[str, SourceReference], module_prefix: str, section: Tuple[str, ...] +) -> Tuple[ + Dict[str, WritableConfigValue], Dict[str, WritableConfigValue], Dict[str, SourceReference] +]: + """Creates sample secret and configs for `sources` belonging to `module_prefix`. Assumes that + all sources belong to a single section so only source name is used to create sample layouts""" # all detected secrets with sections required_secrets: Dict[str, WritableConfigValue] = {} # all detected configs with sections required_config: Dict[str, WritableConfigValue] = {} - # all sources checked - checked_sources: Dict[str, SourceInfo] = {} + # all sources checked, indexed by source name + checked_sources: Dict[str, SourceReference] = {} - for source_name, source_info in sources.items(): + for _, source_info in sources.items(): # accept only sources declared in the `init` or `pipeline` modules if source_info.module.__name__.startswith(module_prefix): - checked_sources[source_name] = source_info + checked_sources[source_info.name] = source_info source_config = source_info.SPEC() if source_info.SPEC else BaseConfiguration() spec_fields = source_config.get_resolvable_fields() for field_name, field_type in spec_fields.items(): @@ -99,8 +103,8 @@ def detect_source_configs( val_store = required_config if val_store is not None: - # we are sure that all resources come from single file so we can put them in single section - val_store[source_name + ":" + field_name] = WritableConfigValue( + # we are sure that all sources come from single file so we can put them in single section + val_store[source_info.name + ":" + field_name] = WritableConfigValue( field_name, field_type, None, section ) diff --git a/dlt/cli/telemetry_command.py b/dlt/cli/telemetry_command.py index 45e9c270f9..094a6763a8 100644 --- a/dlt/cli/telemetry_command.py +++ b/dlt/cli/telemetry_command.py @@ -28,20 +28,19 @@ def change_telemetry_status_command(enabled: bool) -> None: WritableConfigValue("dlthub_telemetry", bool, enabled, (RunConfiguration.__section__,)) ] # write local config + # TODO: use designated (main) config provider (for non secret values) ie. taken from run context config = ConfigTomlProvider(add_global_config=False) - config_toml = tomlkit.document() if not config.is_empty: - write_values(config_toml, telemetry_value, overwrite_existing=True) - config._config_doc = config_toml + write_values(config._config_toml, telemetry_value, overwrite_existing=True) config.write_toml() # write global config - global_path = ConfigTomlProvider.global_config_path() + from dlt.common.runtime import run_context + + global_path = run_context.current().global_dir os.makedirs(global_path, exist_ok=True) - config = ConfigTomlProvider(project_dir=global_path, add_global_config=False) - config_toml = tomlkit.document() - write_values(config_toml, telemetry_value, overwrite_existing=True) - config._config_doc = config_toml + config = ConfigTomlProvider(settings_dir=global_path, add_global_config=False) + write_values(config._config_toml, telemetry_value, overwrite_existing=True) config.write_toml() if enabled: @@ -49,5 +48,5 @@ def change_telemetry_status_command(enabled: bool) -> None: else: fmt.echo("Telemetry switched %s" % fmt.bold("OFF")) # reload config providers - ctx = Container()[ConfigProvidersContext] - ctx.providers = ConfigProvidersContext.initial_providers() + if ConfigProvidersContext in Container(): + del Container()[ConfigProvidersContext] diff --git a/dlt/cli/utils.py b/dlt/cli/utils.py index 8699116628..9635348253 100644 --- a/dlt/cli/utils.py +++ b/dlt/cli/utils.py @@ -7,6 +7,7 @@ from dlt.common.configuration import resolve_configuration from dlt.common.configuration.specs import RunConfiguration from dlt.common.runtime.telemetry import with_telemetry +from dlt.common.runtime import run_context from dlt.reflection.script_visitor import PipelineScriptVisitor @@ -61,3 +62,11 @@ def track_command(command: str, track_before: bool, *args: str) -> Callable[[TFu def get_telemetry_status() -> bool: c = resolve_configuration(RunConfiguration()) return c.dlthub_telemetry + + +def make_dlt_settings_path(path: str = None) -> str: + """Returns path to file in dlt settings folder. Returns settings folder if path not specified.""" + ctx = run_context.current() + if not path: + return ctx.settings_dir + return ctx.get_setting(path) diff --git a/dlt/common/configuration/paths.py b/dlt/common/configuration/paths.py deleted file mode 100644 index 9d0b47f8b6..0000000000 --- a/dlt/common/configuration/paths.py +++ /dev/null @@ -1,54 +0,0 @@ -import os -import tempfile - -from dlt.common import known_env - - -# dlt settings folder -DOT_DLT = os.environ.get(known_env.DLT_CONFIG_FOLDER, ".dlt") - - -def get_dlt_project_dir() -> str: - """The dlt project dir is the current working directory but may be overridden by DLT_PROJECT_DIR env variable.""" - return os.environ.get(known_env.DLT_PROJECT_DIR, ".") - - -def get_dlt_settings_dir() -> str: - """Returns a path to dlt settings directory. If not overridden it resides in current working directory - - The name of the setting folder is '.dlt'. The path is current working directory '.' but may be overridden by DLT_PROJECT_DIR env variable. - """ - return os.path.join(get_dlt_project_dir(), DOT_DLT) - - -def make_dlt_settings_path(path: str) -> str: - """Returns path to file in dlt settings folder.""" - return os.path.join(get_dlt_settings_dir(), path) - - -def get_dlt_data_dir() -> str: - """Gets default directory where pipelines' data (working directories) will be stored - 1. if DLT_DATA_DIR is set in env then it is used - 2. in user home directory: ~/.dlt/ - 3. if current user is root: in /var/dlt/ - 4. if current user does not have a home directory: in /tmp/dlt/ - """ - if known_env.DLT_DATA_DIR in os.environ: - return os.environ[known_env.DLT_DATA_DIR] - - # geteuid not available on Windows - if hasattr(os, "geteuid") and os.geteuid() == 0: - # we are root so use standard /var - return os.path.join("/var", "dlt") - - home = _get_user_home_dir() - if home is None: - # no home dir - use temp - return os.path.join(tempfile.gettempdir(), "dlt") - else: - # if home directory is available use ~/.dlt/pipelines - return os.path.join(home, DOT_DLT) - - -def _get_user_home_dir() -> str: - return os.path.expanduser("~") diff --git a/dlt/common/configuration/plugins.py b/dlt/common/configuration/plugins.py new file mode 100644 index 0000000000..727725a758 --- /dev/null +++ b/dlt/common/configuration/plugins.py @@ -0,0 +1,55 @@ +from typing import ClassVar +import pluggy +import importlib.metadata + +from dlt.common.configuration.specs.base_configuration import ContainerInjectableContext + +hookspec = pluggy.HookspecMarker("dlt") +hookimpl = pluggy.HookimplMarker("dlt") + + +class PluginContext(ContainerInjectableContext): + global_affinity: ClassVar[bool] = True + + manager: pluggy.PluginManager + + def __init__(self) -> None: + super().__init__() + self.manager = pluggy.PluginManager("dlt") + + # we need to solve circular deps somehow + from dlt.common.runtime import run_context + + # register + self.manager.add_hookspecs(run_context) + self.manager.register(run_context) + load_setuptools_entrypoints(self.manager) + + +def manager() -> pluggy.PluginManager: + """Returns current plugin context""" + from .container import Container + + return Container()[PluginContext].manager + + +def load_setuptools_entrypoints(m: pluggy.PluginManager) -> None: + """Scans setuptools distributions that are path or have name starting with `dlt-` + loads entry points in group `dlt` and instantiates them to initialize contained plugins + """ + + for dist in list(importlib.metadata.distributions()): + # skip named dists that do not start with dlt- + if hasattr(dist, "name") and not dist.name.startswith("dlt-"): + continue + for ep in dist.entry_points: + if ( + ep.group != "dlt" + # already registered + or m.get_plugin(ep.name) + or m.is_blocked(ep.name) + ): + continue + plugin = ep.load() + m.register(plugin, name=ep.name) + m._plugin_distinfo.append((plugin, pluggy._manager.DistFacade(dist))) diff --git a/dlt/common/configuration/providers/__init__.py b/dlt/common/configuration/providers/__init__.py index 7338b82b7c..26b017ceda 100644 --- a/dlt/common/configuration/providers/__init__.py +++ b/dlt/common/configuration/providers/__init__.py @@ -4,12 +4,13 @@ from .toml import ( SecretsTomlProvider, ConfigTomlProvider, - ProjectDocProvider, + SettingsTomlProvider, CONFIG_TOML, SECRETS_TOML, StringTomlProvider, CustomLoaderDocProvider, ) +from .doc import CustomLoaderDocProvider from .vault import SECRETS_TOML_KEY from .google_secrets import GoogleSecretsProvider from .context import ContextProvider @@ -20,7 +21,7 @@ "DictionaryProvider", "SecretsTomlProvider", "ConfigTomlProvider", - "ProjectDocProvider", + "SettingsTomlProvider", "CONFIG_TOML", "SECRETS_TOML", "StringTomlProvider", diff --git a/dlt/common/configuration/providers/dictionary.py b/dlt/common/configuration/providers/dictionary.py index 5358d80be3..01bf62aa76 100644 --- a/dlt/common/configuration/providers/dictionary.py +++ b/dlt/common/configuration/providers/dictionary.py @@ -4,7 +4,7 @@ from dlt.common.typing import DictStrAny from .provider import get_key_name -from .toml import BaseDocProvider +from .doc import BaseDocProvider class DictionaryProvider(BaseDocProvider): diff --git a/dlt/common/configuration/providers/doc.py b/dlt/common/configuration/providers/doc.py new file mode 100644 index 0000000000..4be0875c70 --- /dev/null +++ b/dlt/common/configuration/providers/doc.py @@ -0,0 +1,169 @@ +import tomlkit +import yaml +from typing import Any, Callable, Dict, MutableMapping, Optional, Tuple, Type + +from dlt.common.configuration.utils import auto_cast, auto_config_fragment +from dlt.common.utils import update_dict_nested + +from .provider import ConfigProvider, get_key_name + + +class BaseDocProvider(ConfigProvider): + _config_doc: Dict[str, Any] + """Holds a dict with config values""" + + def __init__(self, config_doc: Dict[str, Any]) -> None: + self._config_doc = config_doc + + @staticmethod + def get_key_name(key: str, *sections: str) -> str: + return get_key_name(key, ".", *sections) + + def get_value( + self, key: str, hint: Type[Any], pipeline_name: str, *sections: str + ) -> Tuple[Optional[Any], str]: + full_path = sections + (key,) + if pipeline_name: + full_path = (pipeline_name,) + full_path + full_key = self.get_key_name(key, pipeline_name, *sections) + node = self._config_doc + try: + for k in full_path: + if not isinstance(node, dict): + raise KeyError(k) + node = node[k] + return node, full_key + except KeyError: + return None, full_key + + def set_value(self, key: str, value: Any, pipeline_name: Optional[str], *sections: str) -> None: + """Sets `value` under `key` in `sections` and optionally for `pipeline_name` + + If key already has value of type dict and value to set is also of type dict, the new value + is merged with old value. + """ + self._set_value(self._config_doc, key, value, pipeline_name, *sections) + + def set_fragment( + self, key: Optional[str], value_or_fragment: str, pipeline_name: str, *sections: str + ) -> None: + """Tries to interpret `value_or_fragment` as a fragment of toml, yaml or json string and replace/merge into config doc. + + If `key` is not provided, fragment is considered a full document and will replace internal config doc. Otherwise + fragment is merged with config doc from the root element and not from the element under `key`! + + For simple values it falls back to `set_value` method. + """ + self._config_doc = self._set_fragment( + self._config_doc, key, value_or_fragment, pipeline_name, *sections + ) + + def to_toml(self) -> str: + return tomlkit.dumps(self._config_doc) + + def to_yaml(self) -> str: + return yaml.dump( + self._config_doc, allow_unicode=True, default_flow_style=False, sort_keys=False + ) + + @property + def supports_sections(self) -> bool: + return True + + @property + def is_empty(self) -> bool: + return len(self._config_doc) == 0 + + @staticmethod + def _set_value( + master: MutableMapping[str, Any], + key: str, + value: Any, + pipeline_name: Optional[str], + *sections: str + ) -> None: + if pipeline_name: + sections = (pipeline_name,) + sections + if key is None: + raise ValueError("dlt_secrets_toml must contain toml document") + + # descend from root, create tables if necessary + for k in sections: + if not isinstance(master, dict): + raise KeyError(k) + if k not in master: + master[k] = {} + master = master[k] + if isinstance(value, dict): + # remove none values, TODO: we need recursive None removal + value = {k: v for k, v in value.items() if v is not None} + # if target is also dict then merge recursively + if isinstance(master.get(key), dict): + update_dict_nested(master[key], value) + return + master[key] = value + + @staticmethod + def _set_fragment( + master: MutableMapping[str, Any], + key: Optional[str], + value_or_fragment: str, + pipeline_name: str, + *sections: str + ) -> Any: + """Tries to interpret `value_or_fragment` as a fragment of toml, yaml or json string and replace/merge into config doc. + + If `key` is not provided, fragment is considered a full document and will replace internal config doc. Otherwise + fragment is merged with config doc from the root element and not from the element under `key`! + + For simple values it falls back to `set_value` method. + """ + fragment = auto_config_fragment(value_or_fragment) + if fragment is not None: + # always update the top document + if key is None: + master = fragment + else: + # TODO: verify that value contains only the elements under key + update_dict_nested(master, fragment) + else: + # set value using auto_cast + BaseDocProvider._set_value( + master, key, auto_cast(value_or_fragment), pipeline_name, *sections + ) + return master + + +class CustomLoaderDocProvider(BaseDocProvider): + def __init__( + self, name: str, loader: Callable[[], Dict[str, Any]], supports_secrets: bool = True + ) -> None: + """Provider that calls `loader` function to get a Python dict with config/secret values to be queried. + The `loader` function typically loads a string (ie. from file), parses it (ie. as toml or yaml), does additional + processing and returns a Python dict to be queried. + + Instance of CustomLoaderDocProvider must be registered for the returned dict to be used to resolve config values. + >>> import dlt + >>> dlt.config.register_provider(provider) + + Args: + name(str): name of the provider that will be visible ie. in exceptions + loader(Callable[[], Dict[str, Any]]): user-supplied function that will load the document with config/secret values + supports_secrets(bool): allows to store secret values in this provider + + """ + self._name = name + self._supports_secrets = supports_secrets + super().__init__(loader()) + + @property + def name(self) -> str: + return self._name + + @property + def supports_secrets(self) -> bool: + return self._supports_secrets + + @property + def is_writable(self) -> bool: + return True diff --git a/dlt/common/configuration/providers/environ.py b/dlt/common/configuration/providers/environ.py index f83ea9a24d..5381f9ee90 100644 --- a/dlt/common/configuration/providers/environ.py +++ b/dlt/common/configuration/providers/environ.py @@ -2,7 +2,7 @@ from os.path import isdir from typing import Any, Optional, Type, Tuple -from dlt.common.typing import TSecretValue +from dlt.common.configuration.specs.base_configuration import is_secret_hint from .provider import ConfigProvider, get_key_name @@ -23,10 +23,10 @@ def get_value( ) -> Tuple[Optional[Any], str]: # apply section to the key key = self.get_key_name(key, pipeline_name, *sections) - if hint is TSecretValue: + if is_secret_hint(hint): # try secret storage try: - # must conform to RFC1123 + # must conform to RFC1123 DNS LABELS (https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#dns-label-names) secret_name = key.lower().replace("_", "-") secret_path = SECRET_STORAGE_PATH % secret_name # kubernetes stores secrets as files in a dir, docker compose plainly diff --git a/dlt/common/configuration/providers/google_secrets.py b/dlt/common/configuration/providers/google_secrets.py index 55cc35e02c..d73d98f431 100644 --- a/dlt/common/configuration/providers/google_secrets.py +++ b/dlt/common/configuration/providers/google_secrets.py @@ -23,7 +23,7 @@ def normalize_key(in_string: str) -> str: in_string(str): input string Returns: - (str): a string without punctuatio characters and whitespaces + (str): a string without punctuation characters and whitespaces """ # Strip punctuation from the string diff --git a/dlt/common/configuration/providers/toml.py b/dlt/common/configuration/providers/toml.py index c13d1f8454..fce394caba 100644 --- a/dlt/common/configuration/providers/toml.py +++ b/dlt/common/configuration/providers/toml.py @@ -1,114 +1,18 @@ import os import tomlkit -import yaml +import tomlkit.items import functools -from tomlkit.items import Item as TOMLItem -from tomlkit.container import Container as TOMLContainer -from typing import Any, Callable, Dict, Optional, Tuple, Type +from typing import Any, Optional -from dlt.common.configuration.paths import get_dlt_settings_dir, get_dlt_data_dir -from dlt.common.configuration.utils import auto_cast, auto_config_fragment from dlt.common.utils import update_dict_nested -from .provider import ConfigProvider, ConfigProviderException, get_key_name +from .provider import ConfigProviderException +from .doc import BaseDocProvider, CustomLoaderDocProvider CONFIG_TOML = "config.toml" SECRETS_TOML = "secrets.toml" -class BaseDocProvider(ConfigProvider): - def __init__(self, config_doc: Dict[str, Any]) -> None: - self._config_doc = config_doc - - @staticmethod - def get_key_name(key: str, *sections: str) -> str: - return get_key_name(key, ".", *sections) - - def get_value( - self, key: str, hint: Type[Any], pipeline_name: str, *sections: str - ) -> Tuple[Optional[Any], str]: - full_path = sections + (key,) - if pipeline_name: - full_path = (pipeline_name,) + full_path - full_key = self.get_key_name(key, pipeline_name, *sections) - node = self._config_doc - try: - for k in full_path: - if not isinstance(node, dict): - raise KeyError(k) - node = node[k] - return node, full_key - except KeyError: - return None, full_key - - def set_value(self, key: str, value: Any, pipeline_name: Optional[str], *sections: str) -> None: - """Sets `value` under `key` in `sections` and optionally for `pipeline_name` - - If key already has value of type dict and value to set is also of type dict, the new value - is merged with old value. - """ - if pipeline_name: - sections = (pipeline_name,) + sections - if key is None: - raise ValueError("dlt_secrets_toml must contain toml document") - - master: Dict[str, Any] - # descend from root, create tables if necessary - master = self._config_doc - for k in sections: - if not isinstance(master, dict): - raise KeyError(k) - if k not in master: - master[k] = {} - master = master[k] - if isinstance(value, dict): - # remove none values, TODO: we need recursive None removal - value = {k: v for k, v in value.items() if v is not None} - # if target is also dict then merge recursively - if isinstance(master.get(key), dict): - update_dict_nested(master[key], value) - return - master[key] = value - - def set_fragment( - self, key: Optional[str], value_or_fragment: str, pipeline_name: str, *sections: str - ) -> None: - """Tries to interpret `value_or_fragment` as a fragment of toml, yaml or json string and replace/merge into config doc. - - If `key` is not provided, fragment is considered a full document and will replace internal config doc. Otherwise - fragment is merged with config doc from the root element and not from the element under `key`! - - For simple values it falls back to `set_value` method. - """ - fragment = auto_config_fragment(value_or_fragment) - if fragment is not None: - # always update the top document - if key is None: - self._config_doc = fragment - else: - # TODO: verify that value contains only the elements under key - update_dict_nested(self._config_doc, fragment) - else: - # set value using auto_cast - self.set_value(key, auto_cast(value_or_fragment), pipeline_name, *sections) - - def to_toml(self) -> str: - return tomlkit.dumps(self._config_doc) - - def to_yaml(self) -> str: - return yaml.dump( - self._config_doc, allow_unicode=True, default_flow_style=False, sort_keys=False - ) - - @property - def supports_sections(self) -> bool: - return True - - @property - def is_empty(self) -> bool: - return len(self._config_doc) == 0 - - class StringTomlProvider(BaseDocProvider): def __init__(self, toml_string: str) -> None: super().__init__(StringTomlProvider.loads(toml_string).unwrap()) @@ -132,54 +36,23 @@ def name(self) -> str: return "memory" -class CustomLoaderDocProvider(BaseDocProvider): - def __init__( - self, name: str, loader: Callable[[], Dict[str, Any]], supports_secrets: bool = True - ) -> None: - """Provider that calls `loader` function to get a Python dict with config/secret values to be queried. - The `loader` function typically loads a string (ie. from file), parses it (ie. as toml or yaml), does additional - processing and returns a Python dict to be queried. - - Instance of CustomLoaderDocProvider must be registered for the returned dict to be used to resolve config values. - >>> import dlt - >>> dlt.config.register_provider(provider) - - Args: - name(str): name of the provider that will be visible ie. in exceptions - loader(Callable[[], Dict[str, Any]]): user-supplied function that will load the document with config/secret values - supports_secrets(bool): allows to store secret values in this provider - - """ - self._name = name - self._supports_secrets = supports_secrets - super().__init__(loader()) - - @property - def name(self) -> str: - return self._name - - @property - def supports_secrets(self) -> bool: - return self._supports_secrets - - @property - def is_writable(self) -> bool: - return True +class SettingsTomlProvider(CustomLoaderDocProvider): + _config_toml: tomlkit.TOMLDocument + """Holds tomlkit document with config values that is in sync with _config_doc""" - -class ProjectDocProvider(CustomLoaderDocProvider): def __init__( self, name: str, supports_secrets: bool, file_name: str, - project_dir: str = None, + settings_dir: str = None, add_global_config: bool = False, ) -> None: """Creates config provider from a `toml` file The provider loads the `toml` file with specified name and from specified folder. If `add_global_config` flags is specified, - it will look for `file_name` in `dlt` home dir. The "project" (`project_dir`) values overwrite the "global" values. + it will additionally look for `file_name` in `dlt` global dir (home dir by default) and merge the content. + The "settings" (`settings_dir`) values overwrite the "global" values. If none of the files exist, an empty provider is created. @@ -187,44 +60,72 @@ def __init__( name(str): name of the provider when registering in context supports_secrets(bool): allows to store secret values in this provider file_name (str): The name of `toml` file to load - project_dir (str, optional): The location of `file_name`. If not specified, defaults to $cwd/.dlt + settings_dir (str, optional): The location of `file_name`. If not specified, defaults to $cwd/.dlt add_global_config (bool, optional): Looks for `file_name` in `dlt` home directory which in most cases is $HOME/.dlt Raises: TomlProviderReadException: File could not be read, most probably `toml` parsing error """ - self._toml_path = os.path.join(project_dir or get_dlt_settings_dir(), file_name) + from dlt.common.runtime import run_context + + self._toml_path = os.path.join( + settings_dir or run_context.current().settings_dir, file_name + ) self._add_global_config = add_global_config + self._config_toml = self._read_toml_files( + name, file_name, self._toml_path, add_global_config + ) super().__init__( name, - functools.partial( - self._read_toml_files, name, file_name, self._toml_path, add_global_config - ), + self._config_toml.unwrap, supports_secrets, ) - @staticmethod - def global_config_path() -> str: - return get_dlt_data_dir() - def write_toml(self) -> None: assert ( not self._add_global_config ), "Will not write configs when `add_global_config` flag was set" with open(self._toml_path, "w", encoding="utf-8") as f: - tomlkit.dump(self._config_doc, f) + tomlkit.dump(self._config_toml, f) + + def set_value(self, key: str, value: Any, pipeline_name: Optional[str], *sections: str) -> None: + # write both into tomlkit and dict representations + try: + self._set_value(self._config_toml, key, value, pipeline_name, *sections) + except tomlkit.items._ConvertError: + pass + if hasattr(value, "unwrap"): + value = value.unwrap() + super().set_value(key, value, pipeline_name, *sections) + + def set_fragment( + self, key: Optional[str], value_or_fragment: str, pipeline_name: str, *sections: str + ) -> None: + # write both into tomlkit and dict representations + try: + self._config_toml = self._set_fragment( + self._config_toml, key, value_or_fragment, pipeline_name, *sections + ) + except tomlkit.items._ConvertError: + pass + super().set_fragment(key, value_or_fragment, pipeline_name, *sections) + + def to_toml(self) -> str: + return tomlkit.dumps(self._config_toml) @staticmethod def _read_toml_files( name: str, file_name: str, toml_path: str, add_global_config: bool - ) -> Dict[str, Any]: + ) -> tomlkit.TOMLDocument: try: - project_toml = ProjectDocProvider._read_toml(toml_path).unwrap() + project_toml = SettingsTomlProvider._read_toml(toml_path) if add_global_config: - global_toml = ProjectDocProvider._read_toml( - os.path.join(ProjectDocProvider.global_config_path(), file_name) - ).unwrap() + from dlt.common.runtime import run_context + + global_toml = SettingsTomlProvider._read_toml( + os.path.join(run_context.current().global_dir, file_name) + ) project_toml = update_dict_nested(global_toml, project_toml) return project_toml except Exception as ex: @@ -240,13 +141,13 @@ def _read_toml(toml_path: str) -> tomlkit.TOMLDocument: return tomlkit.document() -class ConfigTomlProvider(ProjectDocProvider): - def __init__(self, project_dir: str = None, add_global_config: bool = False) -> None: +class ConfigTomlProvider(SettingsTomlProvider): + def __init__(self, settings_dir: str = None, add_global_config: bool = False) -> None: super().__init__( CONFIG_TOML, False, CONFIG_TOML, - project_dir=project_dir, + settings_dir=settings_dir, add_global_config=add_global_config, ) @@ -255,13 +156,13 @@ def is_writable(self) -> bool: return True -class SecretsTomlProvider(ProjectDocProvider): - def __init__(self, project_dir: str = None, add_global_config: bool = False) -> None: +class SecretsTomlProvider(SettingsTomlProvider): + def __init__(self, settings_dir: str = None, add_global_config: bool = False) -> None: super().__init__( SECRETS_TOML, True, SECRETS_TOML, - project_dir=project_dir, + settings_dir=settings_dir, add_global_config=add_global_config, ) diff --git a/dlt/common/configuration/providers/vault.py b/dlt/common/configuration/providers/vault.py index 0dcaa1b5c4..0ed8842d55 100644 --- a/dlt/common/configuration/providers/vault.py +++ b/dlt/common/configuration/providers/vault.py @@ -7,7 +7,7 @@ from dlt.common.configuration.specs import known_sections from dlt.common.configuration.specs.base_configuration import is_secret_hint -from .toml import BaseDocProvider +from .doc import BaseDocProvider SECRETS_TOML_KEY = "dlt_secrets_toml" diff --git a/dlt/common/configuration/specs/api_credentials.py b/dlt/common/configuration/specs/api_credentials.py index 918cd4ee45..0b328c3945 100644 --- a/dlt/common/configuration/specs/api_credentials.py +++ b/dlt/common/configuration/specs/api_credentials.py @@ -1,17 +1,17 @@ from typing import ClassVar, List, Union, Optional -from dlt.common.typing import TSecretValue +from dlt.common.typing import TSecretStrValue from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec @configspec class OAuth2Credentials(CredentialsConfiguration): client_id: str = None - client_secret: TSecretValue = None - refresh_token: Optional[TSecretValue] = None + client_secret: TSecretStrValue = None + refresh_token: Optional[TSecretStrValue] = None scopes: Optional[List[str]] = None - token: Optional[TSecretValue] = None + token: Optional[TSecretStrValue] = None """Access token""" # add refresh_token when generating config samples diff --git a/dlt/common/configuration/specs/azure_credentials.py b/dlt/common/configuration/specs/azure_credentials.py index 6794b581ce..371a988109 100644 --- a/dlt/common/configuration/specs/azure_credentials.py +++ b/dlt/common/configuration/specs/azure_credentials.py @@ -39,7 +39,7 @@ def to_object_store_rs_credentials(self) -> Dict[str, str]: def create_sas_token(self) -> None: from azure.storage.blob import generate_account_sas, ResourceTypes - self.azure_storage_sas_token = generate_account_sas( # type: ignore[assignment] + self.azure_storage_sas_token = generate_account_sas( account_name=self.azure_storage_account_name, account_key=self.azure_storage_account_key, resource_types=ResourceTypes(container=True, object=True), diff --git a/dlt/common/configuration/specs/base_configuration.py b/dlt/common/configuration/specs/base_configuration.py index 2504fdeaef..c7c4bfb1ce 100644 --- a/dlt/common/configuration/specs/base_configuration.py +++ b/dlt/common/configuration/specs/base_configuration.py @@ -30,10 +30,13 @@ from dlt.common.typing import ( AnyType, + SecretSentinel, ConfigValueSentinel, TAnyClass, + Annotated, extract_inner_type, is_annotated, + is_any_type, is_final_type, is_optional_type, is_subclass, @@ -111,7 +114,7 @@ def is_valid_hint(hint: Type[Any]) -> bool: hint = get_config_if_union_hint(hint) or hint hint = get_origin(hint) or hint - if hint is Any: + if is_any_type(hint): return True if is_base_configuration_inner_hint(hint): return True @@ -122,27 +125,31 @@ def is_valid_hint(hint: Type[Any]) -> bool: def extract_inner_hint( - hint: Type[Any], preserve_new_types: bool = False, preserve_literal: bool = False + hint: Type[Any], + preserve_new_types: bool = False, + preserve_literal: bool = False, + preserve_annotated: bool = False, ) -> Type[Any]: # extract hint from Optional / Literal / NewType hints - inner_hint = extract_inner_type(hint, preserve_new_types, preserve_literal) + inner_hint = extract_inner_type(hint, preserve_new_types, preserve_literal, preserve_annotated) # get base configuration from union type inner_hint = get_config_if_union_hint(inner_hint) or inner_hint # extract origin from generic types (ie List[str] -> List) origin = get_origin(inner_hint) or inner_hint - if preserve_literal and origin is Literal: + if preserve_literal and origin is Literal or preserve_annotated and origin is Annotated: return inner_hint return origin or inner_hint def is_secret_hint(hint: Type[Any]) -> bool: is_secret = False - if hasattr(hint, "__name__"): - is_secret = hint.__name__ == "TSecretValue" + if is_annotated(hint): + _, *a_m = get_args(hint) + is_secret = SecretSentinel in a_m if not is_secret: is_secret = is_credentials_inner_hint(hint) if not is_secret: - inner_hint = extract_inner_hint(hint, preserve_new_types=True) + inner_hint = extract_inner_hint(hint, preserve_annotated=True, preserve_new_types=True) # something was encapsulated if inner_hint is not hint: is_secret = is_secret_hint(inner_hint) @@ -319,7 +326,7 @@ def parse_native_representation(self, native_value: Any) -> None: """Initialize the configuration fields by parsing the `native_value` which should be a native representation of the configuration or credentials, for example database connection string or JSON serialized GCP service credentials file. - #### Args: + Args: native_value (Any): A native representation of the configuration Raises: diff --git a/dlt/common/configuration/specs/config_providers_context.py b/dlt/common/configuration/specs/config_providers_context.py index d77d97cee8..5c482173f4 100644 --- a/dlt/common/configuration/specs/config_providers_context.py +++ b/dlt/common/configuration/specs/config_providers_context.py @@ -19,7 +19,6 @@ configspec, known_sections, ) -from dlt.common.runtime.exec_info import is_airflow_installed @configspec diff --git a/dlt/common/configuration/specs/config_section_context.py b/dlt/common/configuration/specs/config_section_context.py index 1e6cd56155..14b85eca27 100644 --- a/dlt/common/configuration/specs/config_section_context.py +++ b/dlt/common/configuration/specs/config_section_context.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Optional, Tuple, TYPE_CHECKING +from typing import Callable, List, Optional, Tuple from dlt.common.configuration.specs import known_sections from dlt.common.configuration.specs.base_configuration import ContainerInjectableContext, configspec diff --git a/dlt/common/configuration/specs/connection_string_credentials.py b/dlt/common/configuration/specs/connection_string_credentials.py index 1da7961ef8..6673ee6e45 100644 --- a/dlt/common/configuration/specs/connection_string_credentials.py +++ b/dlt/common/configuration/specs/connection_string_credentials.py @@ -1,9 +1,10 @@ import dataclasses from typing import Any, ClassVar, Dict, List, Optional, Union -from dlt.common.libs.sql_alchemy_shims import URL, make_url +# avoid importing sqlalchemy +from dlt.common.libs.sql_alchemy_shims import URL from dlt.common.configuration.specs.exceptions import InvalidConnectionString -from dlt.common.typing import TSecretValue +from dlt.common.typing import TSecretStrValue from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec @@ -11,7 +12,7 @@ class ConnectionStringCredentials(CredentialsConfiguration): drivername: str = dataclasses.field(default=None, init=False, repr=False, compare=False) database: Optional[str] = None - password: Optional[TSecretValue] = None + password: Optional[TSecretStrValue] = None username: Optional[str] = None host: Optional[str] = None port: Optional[int] = None @@ -34,6 +35,8 @@ def parse_native_representation(self, native_value: Any) -> None: if not isinstance(native_value, str): raise InvalidConnectionString(self.__class__, native_value, self.drivername) try: + from dlt.common.libs.sql_alchemy_compat import make_url + url = make_url(native_value) # update only values that are not None self.update({k: v for k, v in url._asdict().items() if v is not None}) @@ -45,7 +48,7 @@ def parse_native_representation(self, native_value: Any) -> None: def on_resolved(self) -> None: if self.password: - self.password = TSecretValue(self.password.strip()) + self.password = self.password.strip() def to_native_representation(self) -> str: return self.to_url().render_as_string(hide_password=False) @@ -66,6 +69,10 @@ def _serialize_value(v_: Any) -> str: # query must be str -> str query = {k: _serialize_value(v) for k, v in self.get_query().items()} + + # import "real" URL + from dlt.common.libs.sql_alchemy_compat import URL + return URL.create( self.drivername, self.username, diff --git a/dlt/common/configuration/specs/gcp_credentials.py b/dlt/common/configuration/specs/gcp_credentials.py index ca5bd076f1..7d852dd67e 100644 --- a/dlt/common/configuration/specs/gcp_credentials.py +++ b/dlt/common/configuration/specs/gcp_credentials.py @@ -13,7 +13,7 @@ OAuth2ScopesRequired, ) from dlt.common.exceptions import MissingDependencyException -from dlt.common.typing import DictStrAny, TSecretValue, StrAny +from dlt.common.typing import DictStrAny, TSecretStrValue, StrAny from dlt.common.configuration.specs.base_configuration import ( CredentialsConfiguration, CredentialsWithDefault, @@ -67,7 +67,7 @@ def to_gcs_credentials(self) -> Dict[str, Any]: @configspec class GcpServiceAccountCredentialsWithoutDefaults(GcpCredentials): - private_key: TSecretValue = None + private_key: TSecretStrValue = None private_key_id: Optional[str] = None client_email: str = None type: Final[str] = dataclasses.field( # noqa: A003 @@ -105,7 +105,7 @@ def parse_native_representation(self, native_value: Any) -> None: def on_resolved(self) -> None: if self.private_key and self.private_key[-1] != "\n": # must end with new line, otherwise won't be parsed by Crypto - self.private_key = TSecretValue(self.private_key + "\n") + self.private_key = self.private_key + "\n" def to_native_credentials(self) -> Any: """Returns google.oauth2.service_account.Credentials""" @@ -128,7 +128,7 @@ def __str__(self) -> str: @configspec class GcpOAuthCredentialsWithoutDefaults(GcpCredentials, OAuth2Credentials): # only desktop app supported - refresh_token: TSecretValue = None + refresh_token: TSecretStrValue = None client_type: Final[str] = dataclasses.field( default="installed", init=False, repr=False, compare=False ) @@ -195,13 +195,13 @@ def auth(self, scopes: Union[str, List[str]] = None, redirect_url: str = None) - def on_partial(self) -> None: """Allows for an empty refresh token if the session is interactive or tty is attached""" if sys.stdin.isatty() or is_interactive(): - self.refresh_token = TSecretValue("") + self.refresh_token = "" # still partial - raise if not self.is_partial(): self.resolve() self.refresh_token = None - def _get_access_token(self) -> TSecretValue: + def _get_access_token(self) -> str: try: from requests_oauthlib import OAuth2Session except ModuleNotFoundError: @@ -209,19 +209,19 @@ def _get_access_token(self) -> TSecretValue: google = OAuth2Session(client_id=self.client_id, scope=self.scopes) extra = {"client_id": self.client_id, "client_secret": self.client_secret} - token = google.refresh_token( + token: str = google.refresh_token( token_url=self.token_uri, refresh_token=self.refresh_token, **extra )["access_token"] - return TSecretValue(token) + return token - def _get_refresh_token(self, redirect_url: str) -> Tuple[TSecretValue, TSecretValue]: + def _get_refresh_token(self, redirect_url: str) -> Tuple[str, str]: try: from google_auth_oauthlib.flow import InstalledAppFlow except ModuleNotFoundError: raise MissingDependencyException("GcpOAuthCredentials", ["google-auth-oauthlib"]) flow = InstalledAppFlow.from_client_config(self._installed_dict(redirect_url), self.scopes) credentials = flow.run_local_server(port=0) - return TSecretValue(credentials.refresh_token), TSecretValue(credentials.token) + return credentials.refresh_token, credentials.token def to_native_credentials(self) -> Any: """Returns google.oauth2.credentials.Credentials""" diff --git a/dlt/common/configuration/specs/pluggable_run_context.py b/dlt/common/configuration/specs/pluggable_run_context.py new file mode 100644 index 0000000000..190d8d2aae --- /dev/null +++ b/dlt/common/configuration/specs/pluggable_run_context.py @@ -0,0 +1,55 @@ +from typing import ClassVar, Protocol + +from dlt.common.configuration.specs.base_configuration import ContainerInjectableContext + + +class SupportsRunContext(Protocol): + """Describes where `dlt` looks for settings, pipeline working folder""" + + @property + def name(self) -> str: + """Name of the run context. Entities like sources and destinations added to registries when this context + is active, will be scoped to it. Typically corresponds to Python package name ie. `dlt`. + """ + + @property + def global_dir(self) -> str: + """Directory in which global settings are stored ie ~/.dlt/""" + + @property + def run_dir(self) -> str: + """Defines the current working directory""" + + @property + def settings_dir(self) -> str: + """Defines where the current settings (secrets and configs) are located""" + + @property + def data_dir(self) -> str: + """Defines where the pipelines working folders are stored.""" + + def get_data_entity(self, entity: str) -> str: + """Gets path in data_dir where `entity` (ie. `pipelines`, `repos`) are stored""" + + def get_run_entity(self, entity: str) -> str: + """Gets path in run_dir where `entity` (ie. `sources`, `destinations` etc.) are stored""" + + def get_setting(self, setting_path: str) -> str: + """Gets path in settings_dir where setting (ie. `secrets.toml`) are stored""" + + +class PluggableRunContext(ContainerInjectableContext): + """Injectable run context taken via plugin""" + + global_affinity: ClassVar[bool] = True + + context: SupportsRunContext + + def __init__(self) -> None: + super().__init__() + + from dlt.common.configuration import plugins + + m = plugins.manager() + self.context = m.hook.plug_run_context() + assert self.context, "plug_run_context hook returned None" diff --git a/dlt/common/configuration/specs/run_configuration.py b/dlt/common/configuration/specs/run_configuration.py index dcb78683fb..ffc2a0deb1 100644 --- a/dlt/common/configuration/specs/run_configuration.py +++ b/dlt/common/configuration/specs/run_configuration.py @@ -11,6 +11,7 @@ @configspec class RunConfiguration(BaseConfiguration): + # TODO: deprecate pipeline_name, it is not used in any reasonable way pipeline_name: Optional[str] = None sentry_dsn: Optional[str] = None # keep None to disable Sentry slack_incoming_hook: Optional[TSecretStrValue] = None diff --git a/dlt/common/configuration/utils.py b/dlt/common/configuration/utils.py index 450dde29df..7b1ed72d2c 100644 --- a/dlt/common/configuration/utils.py +++ b/dlt/common/configuration/utils.py @@ -20,7 +20,7 @@ import yaml from dlt.common.json import json -from dlt.common.typing import AnyType, DictStrAny, TAny +from dlt.common.typing import AnyType, DictStrAny, TAny, is_any_type from dlt.common.data_types import coerce_value, py_type_to_sc_type from dlt.common.configuration.providers import EnvironProvider from dlt.common.configuration.exceptions import ConfigValueCannotBeCoercedException, LookupTrace @@ -45,7 +45,7 @@ class ResolvedValueTrace(NamedTuple): def deserialize_value(key: str, value: Any, hint: Type[TAny]) -> TAny: try: - if hint != Any: + if not is_any_type(hint): # if deserializing to base configuration, try parse the value if is_base_configuration_inner_hint(hint): c = hint() diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index d6be15abdd..b3b997629f 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -320,23 +320,10 @@ def _create_writer(self, schema: "pa.Schema") -> "pa.parquet.ParquetWriter": ) def write_header(self, columns_schema: TTableSchemaColumns) -> None: - from dlt.common.libs.pyarrow import pyarrow, get_py_arrow_datatype + from dlt.common.libs.pyarrow import columns_to_arrow # build schema - self.schema = pyarrow.schema( - [ - pyarrow.field( - name, - get_py_arrow_datatype( - schema_item, - self._caps, - self.timestamp_timezone, - ), - nullable=is_nullable_column(schema_item), - ) - for name, schema_item in columns_schema.items() - ] - ) + self.schema = columns_to_arrow(columns_schema, self._caps, self.timestamp_timezone) # find row items that are of the json type (could be abstracted out for use in other writers?) self.nested_indices = [ i for i, field in columns_schema.items() if field["data_type"] == "json" diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 527b9419e8..0c572379de 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod import dataclasses from importlib import import_module +from contextlib import contextmanager + from types import TracebackType from typing import ( Callable, @@ -18,24 +20,33 @@ Any, TypeVar, Generic, + Generator, + TYPE_CHECKING, + Protocol, + Tuple, + AnyStr, ) from typing_extensions import Annotated import datetime # noqa: 251 import inspect from dlt.common import logger, pendulum + from dlt.common.configuration.specs.base_configuration import extract_inner_hint from dlt.common.destination.typing import PreparedTableSchema from dlt.common.destination.utils import verify_schema_capabilities, verify_supported_data_types from dlt.common.exceptions import TerminalException from dlt.common.metrics import LoadJobMetrics from dlt.common.normalizers.naming import NamingConvention -from dlt.common.schema import Schema, TSchemaTables +from dlt.common.schema.typing import TTableSchemaColumns + +from dlt.common.schema import Schema, TSchemaTables, TTableSchema from dlt.common.schema.typing import ( C_DLT_LOAD_ID, TLoaderReplaceStrategy, ) from dlt.common.schema.utils import fill_hints_from_parent_and_clone_table + from dlt.common.configuration import configspec, resolve_configuration, known_sections, NotResolved from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration from dlt.common.destination.capabilities import DestinationCapabilitiesContext @@ -49,6 +60,8 @@ from dlt.common.storages import FileStorage from dlt.common.storages.load_storage import ParsedLoadJobFileName from dlt.common.storages.load_package import LoadJobInfo, TPipelineStateDoc +from dlt.common.exceptions import MissingDependencyException + TDestinationConfig = TypeVar("TDestinationConfig", bound="DestinationClientConfiguration") TDestinationClient = TypeVar("TDestinationClient", bound="JobClientBase") @@ -56,6 +69,17 @@ DEFAULT_FILE_LAYOUT = "{table_name}/{load_id}.{file_id}.{ext}" +if TYPE_CHECKING: + try: + from dlt.common.libs.pandas import DataFrame + from dlt.common.libs.pyarrow import Table as ArrowTable + except MissingDependencyException: + DataFrame = Any + ArrowTable = Any +else: + DataFrame = Any + ArrowTable = Any + class StorageSchemaInfo(NamedTuple): version_hash: str @@ -442,6 +466,65 @@ def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJobRe return [] +class SupportsReadableRelation(Protocol): + """A readable relation retrieved from a destination that supports it""" + + schema_columns: TTableSchemaColumns + """Known dlt table columns for this relation""" + + def df(self, chunk_size: int = None) -> Optional[DataFrame]: + """Fetches the results as data frame. For large queries the results may be chunked + + Fetches the results into a data frame. The default implementation uses helpers in `pandas.io.sql` to generate Pandas data frame. + This function will try to use native data frame generation for particular destination. For `BigQuery`: `QueryJob.to_dataframe` is used. + For `duckdb`: `DuckDBPyConnection.df' + + Args: + chunk_size (int, optional): Will chunk the results into several data frames. Defaults to None + **kwargs (Any): Additional parameters which will be passed to native data frame generation function. + + Returns: + Optional[DataFrame]: A data frame with query results. If chunk_size > 0, None will be returned if there is no more data in results + """ + ... + + def arrow(self, chunk_size: int = None) -> Optional[ArrowTable]: ... + + def iter_df(self, chunk_size: int) -> Generator[DataFrame, None, None]: ... + + def iter_arrow(self, chunk_size: int) -> Generator[ArrowTable, None, None]: ... + + def fetchall(self) -> List[Tuple[Any, ...]]: ... + + def fetchmany(self, chunk_size: int) -> List[Tuple[Any, ...]]: ... + + def iter_fetch(self, chunk_size: int) -> Generator[List[Tuple[Any, ...]], Any, Any]: ... + + def fetchone(self) -> Optional[Tuple[Any, ...]]: ... + + +class DBApiCursor(SupportsReadableRelation): + """Protocol for DBAPI cursor""" + + description: Tuple[Any, ...] + + native_cursor: "DBApiCursor" + """Cursor implementation native to current destination""" + + def execute(self, query: AnyStr, *args: Any, **kwargs: Any) -> None: ... + def close(self) -> None: ... + + +class SupportsReadableDataset(Protocol): + """A readable dataset retrieved from a destination, has support for creating readable relations for a query or table""" + + def __call__(self, query: Any) -> SupportsReadableRelation: ... + + def __getitem__(self, table: str) -> SupportsReadableRelation: ... + + def __getattr__(self, table: str) -> SupportsReadableRelation: ... + + class JobClientBase(ABC): def __init__( self, diff --git a/dlt/common/libs/pandas.py b/dlt/common/libs/pandas.py index 022aa9b9cd..a165ea8747 100644 --- a/dlt/common/libs/pandas.py +++ b/dlt/common/libs/pandas.py @@ -3,6 +3,7 @@ try: import pandas + from pandas import DataFrame except ModuleNotFoundError: raise MissingDependencyException("dlt Pandas Helpers", ["pandas"]) diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index adba832c43..805b43b163 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -18,6 +18,8 @@ from dlt.common.pendulum import pendulum from dlt.common.exceptions import MissingDependencyException from dlt.common.schema.typing import C_DLT_ID, C_DLT_LOAD_ID, TTableSchemaColumns +from dlt.common import logger, json +from dlt.common.json import custom_encode, map_nested_in_place from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.schema.typing import TColumnType @@ -31,6 +33,7 @@ import pyarrow.compute import pyarrow.dataset from pyarrow.parquet import ParquetFile + from pyarrow import Table except ModuleNotFoundError: raise MissingDependencyException( "dlt pyarrow helpers", @@ -394,6 +397,37 @@ def py_arrow_to_table_schema_columns(schema: pyarrow.Schema) -> TTableSchemaColu return result +def columns_to_arrow( + columns: TTableSchemaColumns, + caps: DestinationCapabilitiesContext, + timestamp_timezone: str = "UTC", +) -> pyarrow.Schema: + """Convert a table schema columns dict to a pyarrow schema. + + Args: + columns (TTableSchemaColumns): table schema columns + + Returns: + pyarrow.Schema: pyarrow schema + + """ + return pyarrow.schema( + [ + pyarrow.field( + name, + get_py_arrow_datatype( + schema_item, + caps or DestinationCapabilitiesContext.generic_capabilities(), + timestamp_timezone, + ), + nullable=schema_item.get("nullable", True), + ) + for name, schema_item in columns.items() + if schema_item.get("data_type") is not None + ] + ) + + def get_parquet_metadata(parquet_file: TFileOrPath) -> Tuple[int, pyarrow.Schema]: """Gets parquet file metadata (including row count and schema) @@ -531,6 +565,119 @@ def concat_batches_and_tables_in_order( return pyarrow.concat_tables(tables, promote_options="none") +def row_tuples_to_arrow( + rows: Sequence[Any], caps: DestinationCapabilitiesContext, columns: TTableSchemaColumns, tz: str +) -> Any: + """Converts the rows to an arrow table using the columns schema. + Columns missing `data_type` will be inferred from the row data. + Columns with object types not supported by arrow are excluded from the resulting table. + """ + from dlt.common.libs.pyarrow import pyarrow as pa + import numpy as np + + try: + from pandas._libs import lib + + pivoted_rows = lib.to_object_array_tuples(rows).T + except ImportError: + logger.info( + "Pandas not installed, reverting to numpy.asarray to create a table which is slower" + ) + pivoted_rows = np.asarray(rows, dtype="object", order="k").T # type: ignore[call-overload] + + columnar = { + col: dat.ravel() for col, dat in zip(columns, np.vsplit(pivoted_rows, len(columns))) + } + columnar_known_types = { + col["name"]: columnar[col["name"]] + for col in columns.values() + if col.get("data_type") is not None + } + columnar_unknown_types = { + col["name"]: columnar[col["name"]] + for col in columns.values() + if col.get("data_type") is None + } + + arrow_schema = columns_to_arrow(columns, caps, tz) + + for idx in range(0, len(arrow_schema.names)): + field = arrow_schema.field(idx) + py_type = type(rows[0][idx]) + # cast double / float ndarrays to decimals if type mismatch, looks like decimals and floats are often mixed up in dialects + if pa.types.is_decimal(field.type) and issubclass(py_type, (str, float)): + logger.warning( + f"Field {field.name} was reflected as decimal type, but rows contains" + f" {py_type.__name__}. Additional cast is required which may slow down arrow table" + " generation." + ) + float_array = pa.array(columnar_known_types[field.name], type=pa.float64()) + columnar_known_types[field.name] = float_array.cast(field.type, safe=False) + if issubclass(py_type, (dict, list)): + logger.warning( + f"Field {field.name} was reflected as JSON type and needs to be serialized back to" + " string to be placed in arrow table. This will slow data extraction down. You" + " should cast JSON field to STRING in your database system ie. by creating and" + " extracting an SQL VIEW that selects with cast." + ) + json_str_array = pa.array( + [None if s is None else json.dumps(s) for s in columnar_known_types[field.name]] + ) + columnar_known_types[field.name] = json_str_array + + # If there are unknown type columns, first create a table to infer their types + if columnar_unknown_types: + new_schema_fields = [] + for key in list(columnar_unknown_types): + arrow_col: Optional[pa.Array] = None + try: + arrow_col = pa.array(columnar_unknown_types[key]) + if pa.types.is_null(arrow_col.type): + logger.warning( + f"Column {key} contains only NULL values and data type could not be" + " inferred. This column is removed from a arrow table" + ) + continue + + except pa.ArrowInvalid as e: + # Try coercing types not supported by arrow to a json friendly format + # E.g. dataclasses -> dict, UUID -> str + try: + arrow_col = pa.array( + map_nested_in_place(custom_encode, list(columnar_unknown_types[key])) + ) + logger.warning( + f"Column {key} contains a data type which is not supported by pyarrow and" + f" got converted into {arrow_col.type}. This slows down arrow table" + " generation." + ) + except (pa.ArrowInvalid, TypeError): + logger.warning( + f"Column {key} contains a data type which is not supported by pyarrow. This" + f" column will be ignored. Error: {e}" + ) + if arrow_col is not None: + columnar_known_types[key] = arrow_col + new_schema_fields.append( + pa.field( + key, + arrow_col.type, + nullable=columns[key]["nullable"], + ) + ) + + # New schema + column_order = {name: idx for idx, name in enumerate(columns)} + arrow_schema = pa.schema( + sorted( + list(arrow_schema) + new_schema_fields, + key=lambda x: column_order[x.name], + ) + ) + + return pa.Table.from_pydict(columnar_known_types, schema=arrow_schema) + + class NameNormalizationCollision(ValueError): def __init__(self, reason: str) -> None: msg = f"Arrow column name collision after input data normalization. {reason}" diff --git a/dlt/common/libs/sql_alchemy_compat.py b/dlt/common/libs/sql_alchemy_compat.py new file mode 100644 index 0000000000..28678ee25d --- /dev/null +++ b/dlt/common/libs/sql_alchemy_compat.py @@ -0,0 +1,6 @@ +try: + import sqlalchemy +except ImportError: + from dlt.common.libs.sql_alchemy_shims import URL, make_url +else: + from sqlalchemy.engine import URL, make_url # type: ignore[assignment] diff --git a/dlt/common/libs/sql_alchemy_shims.py b/dlt/common/libs/sql_alchemy_shims.py index 2f3b51ec0d..11f4513be9 100644 --- a/dlt/common/libs/sql_alchemy_shims.py +++ b/dlt/common/libs/sql_alchemy_shims.py @@ -4,443 +4,440 @@ from typing import cast +# port basic functionality without the whole Sql Alchemy + +import re +from typing import ( + Any, + Dict, + Iterable, + List, + Mapping, + NamedTuple, + Optional, + Sequence, + Tuple, + TypeVar, + Union, + overload, +) +import collections.abc as collections_abc +from urllib.parse import ( + quote_plus, + parse_qsl, + quote, + unquote, +) + +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) + + +class ImmutableDict(Dict[_KT, _VT]): + """Not a real immutable dict""" + + def __setitem__(self, __key: _KT, __value: _VT) -> None: + raise NotImplementedError("Cannot modify immutable dict") + + def __delitem__(self, _KT: Any) -> None: + raise NotImplementedError("Cannot modify immutable dict") + + def update(self, *arg: Any, **kw: Any) -> None: + raise NotImplementedError("Cannot modify immutable dict") + + +EMPTY_DICT: ImmutableDict[Any, Any] = ImmutableDict() + + +def to_list(value: Any, default: Optional[List[Any]] = None) -> List[Any]: + if value is None: + return default + if not isinstance(value, collections_abc.Iterable) or isinstance(value, str): + return [value] + elif isinstance(value, list): + return value + else: + return list(value) + + +class URL(NamedTuple): + """ + Represent the components of a URL used to connect to a database. + + Based on SqlAlchemy URL class with copyright as below: + + # engine/url.py + # Copyright (C) 2005-2023 the SQLAlchemy authors and contributors + # + # This module is part of SQLAlchemy and is released under + # the MIT License: https://www.opensource.org/licenses/mit-license.php + """ + + drivername: str + """database backend and driver name, such as `postgresql+psycopg2`""" + username: Optional[str] + "username string" + password: Optional[str] + """password, which is normally a string but may also be any object that has a `__str__()` method.""" + host: Optional[str] + """hostname or IP number. May also be a data source name for some drivers.""" + port: Optional[int] + """integer port number""" + database: Optional[str] + """database name""" + query: ImmutableDict[str, Union[Tuple[str, ...], str]] + """an immutable mapping representing the query string. contains strings + for keys and either strings or tuples of strings for values""" + + @classmethod + def create( + cls, + drivername: str, + username: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + database: Optional[str] = None, + query: Mapping[str, Union[Sequence[str], str]] = None, + ) -> "URL": + """Create a new `URL` object.""" + return cls( + cls._assert_str(drivername, "drivername"), + cls._assert_none_str(username, "username"), + password, + cls._assert_none_str(host, "host"), + cls._assert_port(port), + cls._assert_none_str(database, "database"), + cls._str_dict(query or EMPTY_DICT), + ) -try: - import sqlalchemy -except ImportError: - # port basic functionality without the whole Sql Alchemy - - import re - from typing import ( - Any, - Dict, - Iterable, - List, - Mapping, - NamedTuple, - Optional, - Sequence, - Tuple, - TypeVar, - Union, - overload, - ) - import collections.abc as collections_abc - from urllib.parse import ( - quote_plus, - parse_qsl, - quote, - unquote, - ) - - _KT = TypeVar("_KT", bound=Any) - _VT = TypeVar("_VT", bound=Any) - - class ImmutableDict(Dict[_KT, _VT]): - """Not a real immutable dict""" - - def __setitem__(self, __key: _KT, __value: _VT) -> None: - raise NotImplementedError("Cannot modify immutable dict") - - def __delitem__(self, _KT: Any) -> None: - raise NotImplementedError("Cannot modify immutable dict") + @classmethod + def _assert_port(cls, port: Optional[int]) -> Optional[int]: + if port is None: + return None + try: + return int(port) + except TypeError: + raise TypeError("Port argument must be an integer or None") + + @classmethod + def _assert_str(cls, v: str, paramname: str) -> str: + if not isinstance(v, str): + raise TypeError("%s must be a string" % paramname) + return v + + @classmethod + def _assert_none_str(cls, v: Optional[str], paramname: str) -> Optional[str]: + if v is None: + return v - def update(self, *arg: Any, **kw: Any) -> None: - raise NotImplementedError("Cannot modify immutable dict") + return cls._assert_str(v, paramname) + + @classmethod + def _str_dict( + cls, + dict_: Optional[ + Union[ + Sequence[Tuple[str, Union[Sequence[str], str]]], + Mapping[str, Union[Sequence[str], str]], + ] + ], + ) -> ImmutableDict[str, Union[Tuple[str, ...], str]]: + if dict_ is None: + return EMPTY_DICT + + @overload + def _assert_value( + val: str, + ) -> str: ... + + @overload + def _assert_value( + val: Sequence[str], + ) -> Union[str, Tuple[str, ...]]: ... + + def _assert_value( + val: Union[str, Sequence[str]], + ) -> Union[str, Tuple[str, ...]]: + if isinstance(val, str): + return val + elif isinstance(val, collections_abc.Sequence): + return tuple(_assert_value(elem) for elem in val) + else: + raise TypeError("Query dictionary values must be strings or sequences of strings") - EMPTY_DICT: ImmutableDict[Any, Any] = ImmutableDict() + def _assert_str(v: str) -> str: + if not isinstance(v, str): + raise TypeError("Query dictionary keys must be strings") + return v - def to_list(value: Any, default: Optional[List[Any]] = None) -> List[Any]: - if value is None: - return default - if not isinstance(value, collections_abc.Iterable) or isinstance(value, str): - return [value] - elif isinstance(value, list): - return value + dict_items: Iterable[Tuple[str, Union[Sequence[str], str]]] + if isinstance(dict_, collections_abc.Sequence): + dict_items = dict_ else: - return list(value) - - class URL(NamedTuple): - """ - Represent the components of a URL used to connect to a database. - - Based on SqlAlchemy URL class with copyright as below: + dict_items = dict_.items() - # engine/url.py - # Copyright (C) 2005-2023 the SQLAlchemy authors and contributors - # - # This module is part of SQLAlchemy and is released under - # the MIT License: https://www.opensource.org/licenses/mit-license.php - """ + return ImmutableDict( + { + _assert_str(key): _assert_value( + value, + ) + for key, value in dict_items + } + ) - drivername: str - """database backend and driver name, such as `postgresql+psycopg2`""" - username: Optional[str] - "username string" - password: Optional[str] - """password, which is normally a string but may also be any object that has a `__str__()` method.""" - host: Optional[str] - """hostname or IP number. May also be a data source name for some drivers.""" - port: Optional[int] - """integer port number""" - database: Optional[str] - """database name""" - query: ImmutableDict[str, Union[Tuple[str, ...], str]] - """an immutable mapping representing the query string. contains strings - for keys and either strings or tuples of strings for values""" - - @classmethod - def create( - cls, - drivername: str, - username: Optional[str] = None, - password: Optional[str] = None, - host: Optional[str] = None, - port: Optional[int] = None, - database: Optional[str] = None, - query: Mapping[str, Union[Sequence[str], str]] = None, - ) -> "URL": - """Create a new `URL` object.""" - return cls( - cls._assert_str(drivername, "drivername"), - cls._assert_none_str(username, "username"), - password, - cls._assert_none_str(host, "host"), - cls._assert_port(port), - cls._assert_none_str(database, "database"), - cls._str_dict(query or EMPTY_DICT), - ) + def set( # noqa + self, + drivername: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + database: Optional[str] = None, + query: Optional[Mapping[str, Union[Sequence[str], str]]] = None, + ) -> "URL": + """return a new `URL` object with modifications.""" + + kw: Dict[str, Any] = {} + if drivername is not None: + kw["drivername"] = drivername + if username is not None: + kw["username"] = username + if password is not None: + kw["password"] = password + if host is not None: + kw["host"] = host + if port is not None: + kw["port"] = port + if database is not None: + kw["database"] = database + if query is not None: + kw["query"] = query + + return self._assert_replace(**kw) + + def _assert_replace(self, **kw: Any) -> "URL": + """argument checks before calling _replace()""" + + if "drivername" in kw: + self._assert_str(kw["drivername"], "drivername") + for name in "username", "host", "database": + if name in kw: + self._assert_none_str(kw[name], name) + if "port" in kw: + self._assert_port(kw["port"]) + if "query" in kw: + kw["query"] = self._str_dict(kw["query"]) + + return self._replace(**kw) + + def update_query_string(self, query_string: str, append: bool = False) -> "URL": + return self.update_query_pairs(parse_qsl(query_string), append=append) + + def update_query_pairs( + self, + key_value_pairs: Iterable[Tuple[str, Union[str, List[str]]]], + append: bool = False, + ) -> "URL": + """Return a new `URL` object with the `query` parameter dictionary updated by the given sequence of key/value pairs""" + existing_query = self.query + new_keys: Dict[str, Union[str, List[str]]] = {} + + for key, value in key_value_pairs: + if key in new_keys: + new_keys[key] = to_list(new_keys[key]) + cast("List[str]", new_keys[key]).append(cast(str, value)) + else: + new_keys[key] = to_list(value) if isinstance(value, (list, tuple)) else value - @classmethod - def _assert_port(cls, port: Optional[int]) -> Optional[int]: - if port is None: - return None - try: - return int(port) - except TypeError: - raise TypeError("Port argument must be an integer or None") - - @classmethod - def _assert_str(cls, v: str, paramname: str) -> str: - if not isinstance(v, str): - raise TypeError("%s must be a string" % paramname) - return v + new_query: Mapping[str, Union[str, Sequence[str]]] + if append: + new_query = {} - @classmethod - def _assert_none_str(cls, v: Optional[str], paramname: str) -> Optional[str]: - if v is None: - return v - - return cls._assert_str(v, paramname) - - @classmethod - def _str_dict( - cls, - dict_: Optional[ - Union[ - Sequence[Tuple[str, Union[Sequence[str], str]]], - Mapping[str, Union[Sequence[str], str]], - ] - ], - ) -> ImmutableDict[str, Union[Tuple[str, ...], str]]: - if dict_ is None: - return EMPTY_DICT - - @overload - def _assert_value( - val: str, - ) -> str: ... - - @overload - def _assert_value( - val: Sequence[str], - ) -> Union[str, Tuple[str, ...]]: ... - - def _assert_value( - val: Union[str, Sequence[str]], - ) -> Union[str, Tuple[str, ...]]: - if isinstance(val, str): - return val - elif isinstance(val, collections_abc.Sequence): - return tuple(_assert_value(elem) for elem in val) + for k in new_keys: + if k in existing_query: + new_query[k] = tuple(to_list(existing_query[k]) + to_list(new_keys[k])) else: - raise TypeError( - "Query dictionary values must be strings or sequences of strings" - ) - - def _assert_str(v: str) -> str: - if not isinstance(v, str): - raise TypeError("Query dictionary keys must be strings") - return v - - dict_items: Iterable[Tuple[str, Union[Sequence[str], str]]] - if isinstance(dict_, collections_abc.Sequence): - dict_items = dict_ - else: - dict_items = dict_.items() + new_query[k] = new_keys[k] - return ImmutableDict( + new_query.update( + {k: existing_query[k] for k in set(existing_query).difference(new_keys)} + ) + else: + new_query = ImmutableDict( { - _assert_str(key): _assert_value( - value, - ) - for key, value in dict_items + **self.query, + **{k: tuple(v) if isinstance(v, list) else v for k, v in new_keys.items()}, } ) - - def set( # noqa - self, - drivername: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None, - host: Optional[str] = None, - port: Optional[int] = None, - database: Optional[str] = None, - query: Optional[Mapping[str, Union[Sequence[str], str]]] = None, - ) -> "URL": - """return a new `URL` object with modifications.""" - - kw: Dict[str, Any] = {} - if drivername is not None: - kw["drivername"] = drivername - if username is not None: - kw["username"] = username - if password is not None: - kw["password"] = password - if host is not None: - kw["host"] = host - if port is not None: - kw["port"] = port - if database is not None: - kw["database"] = database - if query is not None: - kw["query"] = query - - return self._assert_replace(**kw) - - def _assert_replace(self, **kw: Any) -> "URL": - """argument checks before calling _replace()""" - - if "drivername" in kw: - self._assert_str(kw["drivername"], "drivername") - for name in "username", "host", "database": - if name in kw: - self._assert_none_str(kw[name], name) - if "port" in kw: - self._assert_port(kw["port"]) - if "query" in kw: - kw["query"] = self._str_dict(kw["query"]) - - return self._replace(**kw) - - def update_query_string(self, query_string: str, append: bool = False) -> "URL": - return self.update_query_pairs(parse_qsl(query_string), append=append) - - def update_query_pairs( - self, - key_value_pairs: Iterable[Tuple[str, Union[str, List[str]]]], - append: bool = False, - ) -> "URL": - """Return a new `URL` object with the `query` parameter dictionary updated by the given sequence of key/value pairs""" - existing_query = self.query - new_keys: Dict[str, Union[str, List[str]]] = {} - - for key, value in key_value_pairs: - if key in new_keys: - new_keys[key] = to_list(new_keys[key]) - cast("List[str]", new_keys[key]).append(cast(str, value)) - else: - new_keys[key] = to_list(value) if isinstance(value, (list, tuple)) else value - - new_query: Mapping[str, Union[str, Sequence[str]]] - if append: - new_query = {} - - for k in new_keys: - if k in existing_query: - new_query[k] = tuple(to_list(existing_query[k]) + to_list(new_keys[k])) - else: - new_query[k] = new_keys[k] - - new_query.update( - {k: existing_query[k] for k in set(existing_query).difference(new_keys)} - ) + return self.set(query=new_query) + + def update_query_dict( + self, + query_parameters: Mapping[str, Union[str, List[str]]], + append: bool = False, + ) -> "URL": + return self.update_query_pairs(query_parameters.items(), append=append) + + def render_as_string(self, hide_password: bool = True) -> str: + """Render this `URL` object as a string.""" + s = self.drivername + "://" + if self.username is not None: + s += quote(self.username, safe=" +") + if self.password is not None: + s += ":" + ("***" if hide_password else quote(str(self.password), safe=" +")) + s += "@" + if self.host is not None: + if ":" in self.host: + s += f"[{self.host}]" else: - new_query = ImmutableDict( - { - **self.query, - **{k: tuple(v) if isinstance(v, list) else v for k, v in new_keys.items()}, - } - ) - return self.set(query=new_query) - - def update_query_dict( - self, - query_parameters: Mapping[str, Union[str, List[str]]], - append: bool = False, - ) -> "URL": - return self.update_query_pairs(query_parameters.items(), append=append) - - def render_as_string(self, hide_password: bool = True) -> str: - """Render this `URL` object as a string.""" - s = self.drivername + "://" - if self.username is not None: - s += quote(self.username, safe=" +") - if self.password is not None: - s += ":" + ("***" if hide_password else quote(str(self.password), safe=" +")) - s += "@" - if self.host is not None: - if ":" in self.host: - s += f"[{self.host}]" - else: - s += self.host - if self.port is not None: - s += ":" + str(self.port) - if self.database is not None: - s += "/" + self.database - if self.query: - keys = to_list(self.query) - keys.sort() - s += "?" + "&".join( - f"{quote_plus(k)}={quote_plus(element)}" - for k in keys - for element in to_list(self.query[k]) - ) - return s - - def __repr__(self) -> str: - return self.render_as_string() - - def __copy__(self) -> "URL": - return self.__class__.create( - self.drivername, - self.username, - self.password, - self.host, - self.port, - self.database, - self.query.copy(), + s += self.host + if self.port is not None: + s += ":" + str(self.port) + if self.database is not None: + s += "/" + self.database + if self.query: + keys = to_list(self.query) + keys.sort() + s += "?" + "&".join( + f"{quote_plus(k)}={quote_plus(element)}" + for k in keys + for element in to_list(self.query[k]) ) + return s + + def __repr__(self) -> str: + return self.render_as_string() + + def __copy__(self) -> "URL": + return self.__class__.create( + self.drivername, + self.username, + self.password, + self.host, + self.port, + self.database, + self.query.copy(), + ) - def __deepcopy__(self, memo: Any) -> "URL": - return self.__copy__() - - def __hash__(self) -> int: - return hash(str(self)) - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, URL) - and self.drivername == other.drivername - and self.username == other.username - and self.password == other.password - and self.host == other.host - and self.database == other.database - and self.query == other.query - and self.port == other.port - ) + def __deepcopy__(self, memo: Any) -> "URL": + return self.__copy__() + + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, URL) + and self.drivername == other.drivername + and self.username == other.username + and self.password == other.password + and self.host == other.host + and self.database == other.database + and self.query == other.query + and self.port == other.port + ) - def __ne__(self, other: Any) -> bool: - return not self == other + def __ne__(self, other: Any) -> bool: + return not self == other - def get_backend_name(self) -> str: - """Return the backend name. + def get_backend_name(self) -> str: + """Return the backend name. - This is the name that corresponds to the database backend in - use, and is the portion of the `drivername` - that is to the left of the plus sign. + This is the name that corresponds to the database backend in + use, and is the portion of the `drivername` + that is to the left of the plus sign. - """ - if "+" not in self.drivername: - return self.drivername - else: - return self.drivername.split("+")[0] + """ + if "+" not in self.drivername: + return self.drivername + else: + return self.drivername.split("+")[0] + + def get_driver_name(self) -> str: + """Return the backend name. - def get_driver_name(self) -> str: - """Return the backend name. + This is the name that corresponds to the DBAPI driver in + use, and is the portion of the `drivername` + that is to the right of the plus sign. + """ - This is the name that corresponds to the DBAPI driver in - use, and is the portion of the `drivername` - that is to the right of the plus sign. - """ + if "+" not in self.drivername: + return self.drivername + else: + return self.drivername.split("+")[1] - if "+" not in self.drivername: - return self.drivername - else: - return self.drivername.split("+")[1] - def make_url(name_or_url: Union[str, URL]) -> URL: - """Given a string, produce a new URL instance. +def make_url(name_or_url: Union[str, URL]) -> URL: + """Given a string, produce a new URL instance. - The format of the URL generally follows `RFC-1738`, with some exceptions, including - that underscores, and not dashes or periods, are accepted within the - "scheme" portion. + The format of the URL generally follows `RFC-1738`, with some exceptions, including + that underscores, and not dashes or periods, are accepted within the + "scheme" portion. - If a `URL` object is passed, it is returned as is.""" + If a `URL` object is passed, it is returned as is.""" - if isinstance(name_or_url, str): - return _parse_url(name_or_url) - elif not isinstance(name_or_url, URL): - raise ValueError(f"Expected string or URL object, got {name_or_url!r}") - else: - return name_or_url + if isinstance(name_or_url, str): + return _parse_url(name_or_url) + elif not isinstance(name_or_url, URL): + raise ValueError(f"Expected string or URL object, got {name_or_url!r}") + else: + return name_or_url - def _parse_url(name: str) -> URL: - pattern = re.compile( - r""" - (?P[\w\+]+):// - (?: - (?P[^:/]*) - (?::(?P[^@]*))? - @)? + +def _parse_url(name: str) -> URL: + pattern = re.compile( + r""" + (?P[\w\+]+):// + (?: + (?P[^:/]*) + (?::(?P[^@]*))? + @)? + (?: (?: - (?: - \[(?P[^/\?]+)\] | - (?P[^/:\?]+) - )? - (?::(?P[^/\?]*))? + \[(?P[^/\?]+)\] | + (?P[^/:\?]+) )? - (?:/(?P[^\?]*))? - (?:\?(?P.*))? - """, - re.X, - ) - - m = pattern.match(name) - if m is not None: - components = m.groupdict() - query: Optional[Dict[str, Union[str, List[str]]]] - if components["query"] is not None: - query = {} - - for key, value in parse_qsl(components["query"]): - if key in query: - query[key] = to_list(query[key]) - cast("List[str]", query[key]).append(value) - else: - query[key] = value - else: - query = None + (?::(?P[^/\?]*))? + )? + (?:/(?P[^\?]*))? + (?:\?(?P.*))? + """, + re.X, + ) - components["query"] = query - if components["username"] is not None: - components["username"] = unquote(components["username"]) + m = pattern.match(name) + if m is not None: + components = m.groupdict() + query: Optional[Dict[str, Union[str, List[str]]]] + if components["query"] is not None: + query = {} + + for key, value in parse_qsl(components["query"]): + if key in query: + query[key] = to_list(query[key]) + cast("List[str]", query[key]).append(value) + else: + query[key] = value + else: + query = None - if components["password"] is not None: - components["password"] = unquote(components["password"]) + components["query"] = query + if components["username"] is not None: + components["username"] = unquote(components["username"]) - ipv4host = components.pop("ipv4host") - ipv6host = components.pop("ipv6host") - components["host"] = ipv4host or ipv6host - name = components.pop("name") + if components["password"] is not None: + components["password"] = unquote(components["password"]) - if components["port"]: - components["port"] = int(components["port"]) + ipv4host = components.pop("ipv4host") + ipv6host = components.pop("ipv6host") + components["host"] = ipv4host or ipv6host + name = components.pop("name") - return URL.create(name, **components) # type: ignore + if components["port"]: + components["port"] = int(components["port"]) - else: - raise ValueError("Could not parse SQLAlchemy URL from string '%s'" % name) + return URL.create(name, **components) # type: ignore -else: - from sqlalchemy.engine import URL, make_url # type: ignore[assignment] + else: + raise ValueError("Could not parse SQLAlchemy URL from string '%s'" % name) diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index 8a07ddbd33..e2727153ad 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -4,7 +4,7 @@ import datetime # noqa: 251 import humanize import contextlib - +import threading from typing import ( Any, Callable, @@ -30,11 +30,14 @@ from dlt.common.configuration.exceptions import ContextDefaultCannotBeCreated from dlt.common.configuration.specs import ContainerInjectableContext from dlt.common.configuration.specs.config_section_context import ConfigSectionContext -from dlt.common.configuration.paths import get_dlt_data_dir from dlt.common.configuration.specs import RunConfiguration from dlt.common.destination import TDestinationReferenceArg, TDestination from dlt.common.destination.exceptions import DestinationHasFailedJobs -from dlt.common.exceptions import PipelineStateNotAvailable, SourceSectionNotAvailable +from dlt.common.exceptions import ( + PipelineStateNotAvailable, + SourceSectionNotAvailable, + ResourceNameNotAvailable, +) from dlt.common.metrics import ( DataWriterMetrics, ExtractDataInfo, @@ -50,7 +53,6 @@ TWriteDispositionConfig, TSchemaContract, ) -from dlt.common.source import get_current_pipe_name from dlt.common.storages.load_package import ParsedLoadJobFileName from dlt.common.storages.load_storage import LoadPackageInfo from dlt.common.time import ensure_pendulum_datetime, precise_time @@ -546,9 +548,7 @@ def __call__( @configspec class PipelineContext(ContainerInjectableContext): - _deferred_pipeline: Callable[[], SupportsPipeline] = dataclasses.field( - default=None, init=False, repr=False, compare=False - ) + _DEFERRED_PIPELINE: ClassVar[Callable[[], SupportsPipeline]] = None _pipeline: SupportsPipeline = dataclasses.field( default=None, init=False, repr=False, compare=False ) @@ -559,11 +559,11 @@ def pipeline(self) -> SupportsPipeline: """Creates or returns exiting pipeline""" if not self._pipeline: # delayed pipeline creation - assert self._deferred_pipeline is not None, ( + assert PipelineContext._DEFERRED_PIPELINE is not None, ( "Deferred pipeline creation function not provided to PipelineContext. Are you" " calling dlt.pipeline() from another thread?" ) - self.activate(self._deferred_pipeline()) + self.activate(PipelineContext._DEFERRED_PIPELINE()) return self._pipeline def activate(self, pipeline: SupportsPipeline) -> None: @@ -582,9 +582,10 @@ def deactivate(self) -> None: self._pipeline._set_context(False) self._pipeline = None - def __init__(self, deferred_pipeline: Callable[..., SupportsPipeline] = None) -> None: + @classmethod + def cls__init__(self, deferred_pipeline: Callable[..., SupportsPipeline] = None) -> None: """Initialize the context with a function returning the Pipeline object to allow creation on first use""" - self._deferred_pipeline = deferred_pipeline + self._DEFERRED_PIPELINE = deferred_pipeline def current_pipeline() -> SupportsPipeline: @@ -781,9 +782,38 @@ def get_dlt_pipelines_dir() -> str: 2. if current user is root in /var/dlt/pipelines 3. if current user does not have a home directory in /tmp/dlt/pipelines """ - return os.path.join(get_dlt_data_dir(), "pipelines") + from dlt.common.runtime import run_context + + return run_context.current().get_data_entity("pipelines") def get_dlt_repos_dir() -> str: """Gets default directory where command repositories will be stored""" - return os.path.join(get_dlt_data_dir(), "repos") + from dlt.common.runtime import run_context + + return run_context.current().get_data_entity("repos") + + +_CURRENT_PIPE_NAME: Dict[int, str] = {} +"""Name of currently executing pipe per thread id set during execution of a gen in pipe""" + + +def set_current_pipe_name(name: str) -> None: + """Set pipe name in current thread""" + _CURRENT_PIPE_NAME[threading.get_ident()] = name + + +def unset_current_pipe_name() -> None: + """Unset pipe name in current thread""" + _CURRENT_PIPE_NAME[threading.get_ident()] = None + + +def get_current_pipe_name() -> str: + """When executed from withing dlt.resource decorated function, gets pipe name associated with current thread. + + Pipe name is the same as resource name for all currently known cases. In some multithreading cases, pipe name may be not available. + """ + name = _CURRENT_PIPE_NAME.get(threading.get_ident()) + if name is None: + raise ResourceNameNotAvailable() + return name diff --git a/dlt/common/reflection/spec.py b/dlt/common/reflection/spec.py index db791c60cd..00ed6e6727 100644 --- a/dlt/common/reflection/spec.py +++ b/dlt/common/reflection/spec.py @@ -1,9 +1,17 @@ import re import inspect -from typing import Dict, List, Tuple, Type, Any, Optional, NewType +from typing import Dict, Tuple, Type, Any, Optional from inspect import Signature, Parameter -from dlt.common.typing import AnyType, AnyFun, ConfigValueSentinel, NoneType, TSecretValue +from dlt.common.typing import ( + AnyType, + AnyFun, + ConfigValueSentinel, + NoneType, + TSecretValue, + Annotated, + SecretSentinel, +) from dlt.common.configuration import configspec, is_valid_hint, is_secret_hint from dlt.common.configuration.specs import BaseConfiguration from dlt.common.utils import get_callable_name @@ -87,7 +95,7 @@ def spec_from_signature( field_type = TSecretValue else: # generate typed SecretValue - field_type = NewType("TSecretValue", field_type) # type: ignore + field_type = Annotated[field_type, SecretSentinel] # remove sentinel from default p = p.replace(default=None) elif field_type is AnyType: diff --git a/dlt/common/runners/venv.py b/dlt/common/runners/venv.py index 9a92b30326..b59456dcc2 100644 --- a/dlt/common/runners/venv.py +++ b/dlt/common/runners/venv.py @@ -120,6 +120,7 @@ def add_dependencies(self, dependencies: List[str] = None) -> None: @staticmethod def _install_deps(context: types.SimpleNamespace, dependencies: List[str]) -> None: cmd = [context.env_exe, "-Im", "pip", "install"] + # cmd = ["uv", "pip", "install", "--python", context.env_exe] try: subprocess.check_output(cmd + dependencies, stderr=subprocess.STDOUT) except subprocess.CalledProcessError as exc: diff --git a/dlt/common/runtime/anon_tracker.py b/dlt/common/runtime/anon_tracker.py index 2e45daa65b..6c881fb36c 100644 --- a/dlt/common/runtime/anon_tracker.py +++ b/dlt/common/runtime/anon_tracker.py @@ -9,8 +9,8 @@ from dlt.common import logger from dlt.common.managed_thread_pool import ManagedThreadPool from dlt.common.configuration.specs import RunConfiguration -from dlt.common.configuration.paths import get_dlt_data_dir from dlt.common.runtime.exec_info import get_execution_context, TExecutionContext +from dlt.common.runtime import run_context from dlt.common.typing import DictStrAny, StrAny from dlt.common.utils import uniq_id @@ -113,7 +113,8 @@ def _tracker_request_header(write_key: str) -> StrAny: def get_anonymous_id() -> str: """Creates or reads a anonymous user id""" - home_dir = get_dlt_data_dir() + home_dir = run_context.current().global_dir + if not os.path.isdir(home_dir): os.makedirs(home_dir, exist_ok=True) anonymous_id_file = os.path.join(home_dir, ".anonymous_id") diff --git a/dlt/common/runtime/run_context.py b/dlt/common/runtime/run_context.py new file mode 100644 index 0000000000..f8e7920577 --- /dev/null +++ b/dlt/common/runtime/run_context.py @@ -0,0 +1,90 @@ +import os +import tempfile +from typing import ClassVar + +from dlt.common import known_env +from dlt.common.configuration import plugins +from dlt.common.configuration.container import Container +from dlt.common.configuration.specs.pluggable_run_context import ( + SupportsRunContext, + PluggableRunContext, +) + +# dlt settings folder +DOT_DLT = os.environ.get(known_env.DLT_CONFIG_FOLDER, ".dlt") + + +class RunContext(SupportsRunContext): + """A default run context used by dlt""" + + CONTEXT_NAME: ClassVar[str] = "dlt" + + @property + def global_dir(self) -> str: + return self.data_dir + + @property + def run_dir(self) -> str: + """The default run dir is the current working directory but may be overridden by DLT_PROJECT_DIR env variable.""" + return os.environ.get(known_env.DLT_PROJECT_DIR, ".") + + @property + def settings_dir(self) -> str: + """Returns a path to dlt settings directory. If not overridden it resides in current working directory + + The name of the setting folder is '.dlt'. The path is current working directory '.' but may be overridden by DLT_PROJECT_DIR env variable. + """ + return os.path.join(self.run_dir, DOT_DLT) + + @property + def data_dir(self) -> str: + """Gets default directory where pipelines' data (working directories) will be stored + 1. if DLT_DATA_DIR is set in env then it is used + 2. in user home directory: ~/.dlt/ + 3. if current user is root: in /var/dlt/ + 4. if current user does not have a home directory: in /tmp/dlt/ + """ + if known_env.DLT_DATA_DIR in os.environ: + return os.environ[known_env.DLT_DATA_DIR] + + # geteuid not available on Windows + if hasattr(os, "geteuid") and os.geteuid() == 0: + # we are root so use standard /var + return os.path.join("/var", "dlt") + + home = os.path.expanduser("~") + if home is None: + # no home dir - use temp + return os.path.join(tempfile.gettempdir(), "dlt") + else: + # if home directory is available use ~/.dlt/pipelines + return os.path.join(home, DOT_DLT) + + def get_data_entity(self, entity: str) -> str: + return os.path.join(self.data_dir, entity) + + def get_run_entity(self, entity: str) -> str: + """Default run context assumes that entities are defined in root dir""" + return self.run_dir + + def get_setting(self, setting_path: str) -> str: + return os.path.join(self.settings_dir, setting_path) + + @property + def name(self) -> str: + return self.__class__.CONTEXT_NAME + + +@plugins.hookspec(firstresult=True) +def plug_run_context() -> SupportsRunContext: + """Spec for plugin hook that returns current run context.""" + + +@plugins.hookimpl(specname="plug_run_context") +def plug_run_context_impl() -> SupportsRunContext: + return RunContext() + + +def current() -> SupportsRunContext: + """Returns currently active run context""" + return Container()[PluggableRunContext].context diff --git a/dlt/common/source.py b/dlt/common/source.py deleted file mode 100644 index ea2a25f1d7..0000000000 --- a/dlt/common/source.py +++ /dev/null @@ -1,51 +0,0 @@ -import threading -from types import ModuleType -from typing import Dict, NamedTuple, Optional, Type - -from dlt.common.configuration.specs import BaseConfiguration -from dlt.common.exceptions import ResourceNameNotAvailable -from dlt.common.typing import AnyFun -from dlt.common.utils import get_callable_name - - -class SourceInfo(NamedTuple): - """Runtime information on the source/resource""" - - SPEC: Type[BaseConfiguration] - f: AnyFun - module: ModuleType - - -_SOURCES: Dict[str, SourceInfo] = {} -"""A registry of all the decorated sources and resources discovered when importing modules""" - -_CURRENT_PIPE_NAME: Dict[int, str] = {} -"""Name of currently executing pipe per thread id set during execution of a gen in pipe""" - - -def set_current_pipe_name(name: str) -> None: - """Set pipe name in current thread""" - _CURRENT_PIPE_NAME[threading.get_ident()] = name - - -def unset_current_pipe_name() -> None: - """Unset pipe name in current thread""" - _CURRENT_PIPE_NAME[threading.get_ident()] = None - - -def get_current_pipe_name() -> str: - """When executed from withing dlt.resource decorated function, gets pipe name associated with current thread. - - Pipe name is the same as resource name for all currently known cases. In some multithreading cases, pipe name may be not available. - """ - name = _CURRENT_PIPE_NAME.get(threading.get_ident()) - if name is None: - raise ResourceNameNotAvailable() - return name - - -def _get_source_for_inner_function(f: AnyFun) -> Optional[SourceInfo]: - # find source function - parts = get_callable_name(f, "__qualname__").split(".") - parent_fun = ".".join(parts[:-2]) - return _SOURCES.get(parent_fun) diff --git a/dlt/common/typing.py b/dlt/common/typing.py index 8d18d84400..94edb57194 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -79,7 +79,10 @@ REPattern = _REPattern PathLike = os.PathLike + AnyType: TypeAlias = Any +CallableAny = NewType("CallableAny", Any) # type: ignore[valid-newtype] +"""A special callable Any that returns argument but is recognized as Any type by dlt hint checkers""" NoneType = type(None) DictStrAny: TypeAlias = Dict[str, Any] DictStrStr: TypeAlias = Dict[str, str] @@ -95,8 +98,20 @@ TAnyClass = TypeVar("TAnyClass", bound=object) TimedeltaSeconds = Union[int, float, timedelta] # represent secret value ie. coming from Kubernetes/Docker secrets or other providers -TSecretValue = NewType("TSecretValue", Any) # type: ignore -TSecretStrValue = NewType("TSecretValue", str) # type: ignore + + +class SecretSentinel: + """Marks a secret type when part of type annotations""" + + +if TYPE_CHECKING: + TSecretValue = Annotated[Any, SecretSentinel] +else: + # use callable Any type for backward compatibility at runtime + TSecretValue = Annotated[CallableAny, SecretSentinel] + +TSecretStrValue = Annotated[str, SecretSentinel] + TDataItem: TypeAlias = Any """A single data item as extracted from data source""" TDataItems: TypeAlias = Union[TDataItem, List[TDataItem]] @@ -184,8 +199,9 @@ def is_callable_type(hint: Type[Any]) -> bool: return False -def extract_type_if_modifier(t: Type[Any]) -> Optional[Type[Any]]: - if get_origin(t) in (Final, ClassVar, Annotated): +def extract_type_if_modifier(t: Type[Any], preserve_annotated: bool = False) -> Optional[Type[Any]]: + modifiers = (Final, ClassVar) if preserve_annotated else (Final, ClassVar, Annotated) + if get_origin(t) in modifiers: t = get_args(t)[0] if m_t := extract_type_if_modifier(t): return m_t @@ -219,6 +235,11 @@ def is_union_type(hint: Type[Any]) -> bool: return False +def is_any_type(t: Type[Any]) -> bool: + """Checks if `t` is one of recognized Any types""" + return t in (Any, CallableAny) + + def is_optional_type(t: Type[Any]) -> bool: origin = get_origin(t) is_union = origin is Union or origin is UnionType @@ -324,7 +345,10 @@ def is_dict_generic_type(t: Type[Any]) -> bool: def extract_inner_type( - hint: Type[Any], preserve_new_types: bool = False, preserve_literal: bool = False + hint: Type[Any], + preserve_new_types: bool = False, + preserve_literal: bool = False, + preserve_annotated: bool = False, ) -> Type[Any]: """Gets the inner type from Literal, Optional, Final and NewType @@ -335,17 +359,23 @@ def extract_inner_type( Returns: Type[Any]: Inner type if hint was Literal, Optional or NewType, otherwise hint """ - if maybe_modified := extract_type_if_modifier(hint): - return extract_inner_type(maybe_modified, preserve_new_types, preserve_literal) + if maybe_modified := extract_type_if_modifier(hint, preserve_annotated): + return extract_inner_type( + maybe_modified, preserve_new_types, preserve_literal, preserve_annotated + ) # make sure we deal with optional directly if is_union_type(hint) and is_optional_type(hint): - return extract_inner_type(get_args(hint)[0], preserve_new_types, preserve_literal) + return extract_inner_type( + get_args(hint)[0], preserve_new_types, preserve_literal, preserve_annotated + ) if is_literal_type(hint) and not preserve_literal: # assume that all literals are of the same type return type(get_args(hint)[0]) - if is_newtype_type(hint) and not preserve_new_types: + if hasattr(hint, "__supertype__") and not preserve_new_types: # descend into supertypes of NewType - return extract_inner_type(hint.__supertype__, preserve_new_types, preserve_literal) + return extract_inner_type( + hint.__supertype__, preserve_new_types, preserve_literal, preserve_annotated + ) return hint @@ -408,7 +438,7 @@ def get_generic_type_argument_from_instance( cls_ = bases_[0] if cls_: orig_param_type = get_args(cls_)[0] - if orig_param_type is Any and sample_value is not None: + if orig_param_type in (Any, CallableAny) and sample_value is not None: orig_param_type = type(sample_value) return orig_param_type # type: ignore diff --git a/dlt/common/utils.py b/dlt/common/utils.py index 436e5504f7..be8b28fc6b 100644 --- a/dlt/common/utils.py +++ b/dlt/common/utils.py @@ -271,7 +271,6 @@ def update_dict_nested(dst: TDict, src: TDict, copy_src_dicts: bool = False) -> dst[key] = update_dict_nested({}, src_val, True) else: dst[key] = src_val - return dst diff --git a/dlt/destinations/dataset.py b/dlt/destinations/dataset.py new file mode 100644 index 0000000000..a5584851e9 --- /dev/null +++ b/dlt/destinations/dataset.py @@ -0,0 +1,99 @@ +from typing import Any, Generator, AnyStr, Optional + +from contextlib import contextmanager +from dlt.common.destination.reference import ( + SupportsReadableRelation, + SupportsReadableDataset, +) + +from dlt.common.schema.typing import TTableSchemaColumns +from dlt.destinations.sql_client import SqlClientBase +from dlt.common.schema import Schema + + +class ReadableDBAPIRelation(SupportsReadableRelation): + def __init__( + self, + *, + client: SqlClientBase[Any], + query: Any, + schema_columns: TTableSchemaColumns = None, + ) -> None: + """Create a lazy evaluated relation to for the dataset of a destination""" + self.client = client + self.schema_columns = schema_columns + self.query = query + + # wire protocol functions + self.df = self._wrap_func("df") # type: ignore + self.arrow = self._wrap_func("arrow") # type: ignore + self.fetchall = self._wrap_func("fetchall") # type: ignore + self.fetchmany = self._wrap_func("fetchmany") # type: ignore + self.fetchone = self._wrap_func("fetchone") # type: ignore + + self.iter_df = self._wrap_iter("iter_df") # type: ignore + self.iter_arrow = self._wrap_iter("iter_arrow") # type: ignore + self.iter_fetch = self._wrap_iter("iter_fetch") # type: ignore + + @contextmanager + def cursor(self) -> Generator[SupportsReadableRelation, Any, Any]: + """Gets a DBApiCursor for the current relation""" + with self.client as client: + # this hacky code is needed for mssql to disable autocommit, read iterators + # will not work otherwise. in the future we should be able to create a readony + # client which will do this automatically + if hasattr(self.client, "_conn") and hasattr(self.client._conn, "autocommit"): + self.client._conn.autocommit = False + with client.execute_query(self.query) as cursor: + if self.schema_columns: + cursor.schema_columns = self.schema_columns + yield cursor + + def _wrap_iter(self, func_name: str) -> Any: + """wrap SupportsReadableRelation generators in cursor context""" + + def _wrap(*args: Any, **kwargs: Any) -> Any: + with self.cursor() as cursor: + yield from getattr(cursor, func_name)(*args, **kwargs) + + return _wrap + + def _wrap_func(self, func_name: str) -> Any: + """wrap SupportsReadableRelation functions in cursor context""" + + def _wrap(*args: Any, **kwargs: Any) -> Any: + with self.cursor() as cursor: + return getattr(cursor, func_name)(*args, **kwargs) + + return _wrap + + +class ReadableDBAPIDataset(SupportsReadableDataset): + """Access to dataframes and arrowtables in the destination dataset via dbapi""" + + def __init__(self, client: SqlClientBase[Any], schema: Optional[Schema]) -> None: + self.client = client + self.schema = schema + + def __call__( + self, query: Any, schema_columns: TTableSchemaColumns = None + ) -> ReadableDBAPIRelation: + schema_columns = schema_columns or {} + return ReadableDBAPIRelation(client=self.client, query=query, schema_columns=schema_columns) # type: ignore[abstract] + + def table(self, table_name: str) -> SupportsReadableRelation: + # prepare query for table relation + schema_columns = ( + self.schema.tables.get(table_name, {}).get("columns", {}) if self.schema else {} + ) + table_name = self.client.make_qualified_table_name(table_name) + query = f"SELECT * FROM {table_name}" + return self(query, schema_columns) + + def __getitem__(self, table_name: str) -> SupportsReadableRelation: + """access of table via dict notation""" + return self.table(table_name) + + def __getattr__(self, table_name: str) -> SupportsReadableRelation: + """access of table via property notation""" + return self.table(table_name) diff --git a/dlt/destinations/decorators.py b/dlt/destinations/decorators.py index c398086fc0..c4110035b9 100644 --- a/dlt/destinations/decorators.py +++ b/dlt/destinations/decorators.py @@ -50,7 +50,7 @@ def destination( Here all incoming data will be sent to the destination function with the items in the requested format and the dlt table schema. The config and secret values will be resolved from the path destination.my_destination.api_url and destination.my_destination.api_secret. - #### Args: + Args: batch_size: defines how many items per function call are batched together and sent as an array. If you set a batch-size of 0, instead of passing in actual dataitems, you will receive one call per load job with the path of the file as the items argument. You can then open and process that file in any way you like. loader_file_format: defines in which format files are stored in the load package before being sent to the destination function, this can be puae-jsonl or parquet. name: defines the name of the destination that get's created by the destination decorator, defaults to the name of the function diff --git a/dlt/destinations/fs_client.py b/dlt/destinations/fs_client.py index 14e77b6b4e..ab4c91544a 100644 --- a/dlt/destinations/fs_client.py +++ b/dlt/destinations/fs_client.py @@ -1,5 +1,6 @@ -import gzip from typing import Iterable, cast, Any, List + +import gzip from abc import ABC, abstractmethod from fsspec import AbstractFileSystem diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index 72611a9568..a2e2566a76 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -56,7 +56,7 @@ raise_database_error, raise_open_connection_error, ) -from dlt.destinations.typing import DBApiCursor +from dlt.common.destination.reference import DBApiCursor from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs, FinalizedLoadJob from dlt.destinations.impl.athena.configuration import AthenaClientConfiguration diff --git a/dlt/destinations/impl/bigquery/sql_client.py b/dlt/destinations/impl/bigquery/sql_client.py index c56742f1ff..650db1d8b9 100644 --- a/dlt/destinations/impl/bigquery/sql_client.py +++ b/dlt/destinations/impl/bigquery/sql_client.py @@ -23,7 +23,8 @@ raise_database_error, raise_open_connection_error, ) -from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction, DataFrame +from dlt.destinations.typing import DBApi, DBTransaction, DataFrame, ArrowTable +from dlt.common.destination.reference import DBApiCursor # terminal reasons as returned in BQ gRPC error response @@ -44,32 +45,15 @@ class BigQueryDBApiCursorImpl(DBApiCursorImpl): """Use native BigQuery data frame support if available""" native_cursor: BQDbApiCursor # type: ignore - df_iterator: Generator[Any, None, None] def __init__(self, curr: DBApiCursor) -> None: super().__init__(curr) - self.df_iterator = None - def df(self, chunk_size: Optional[int] = None, **kwargs: Any) -> DataFrame: - query_job: bigquery.QueryJob = getattr( - self.native_cursor, "_query_job", self.native_cursor.query_job - ) - if self.df_iterator: - return next(self.df_iterator, None) - try: - if chunk_size is not None: - # create iterator with given page size - self.df_iterator = query_job.result(page_size=chunk_size).to_dataframe_iterable() - return next(self.df_iterator, None) - return query_job.to_dataframe(**kwargs) - except ValueError as ex: - # no pyarrow/db-types, fallback to our implementation - logger.warning(f"Native BigQuery pandas reader could not be used: {str(ex)}") - return super().df(chunk_size=chunk_size) - - def close(self) -> None: - if self.df_iterator: - self.df_iterator.close() + def iter_df(self, chunk_size: int) -> Generator[DataFrame, None, None]: + yield from self.native_cursor.query_job.result(page_size=chunk_size).to_dataframe_iterable() + + def iter_arrow(self, chunk_size: int) -> Generator[ArrowTable, None, None]: + yield from self.native_cursor.query_job.result(page_size=chunk_size).to_arrow_iterable() class BigQuerySqlClient(SqlClientBase[bigquery.Client], DBTransaction): diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 54d37f8c08..fbf552d3b1 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -33,6 +33,7 @@ from dlt.destinations.job_impl import ReferenceFollowupJobRequest AZURE_BLOB_STORAGE_PROTOCOLS = ["az", "abfss", "abfs"] +SUPPORTED_BLOB_STORAGE_PROTOCOLS = AZURE_BLOB_STORAGE_PROTOCOLS + ["s3", "gs", "gcs"] class DatabricksLoadJob(RunnableLoadJob, HasFollowupJobs): @@ -69,11 +70,12 @@ def run(self) -> None: bucket_url = urlparse(bucket_path) bucket_scheme = bucket_url.scheme - if bucket_scheme not in AZURE_BLOB_STORAGE_PROTOCOLS + ["s3"]: + if bucket_scheme not in SUPPORTED_BLOB_STORAGE_PROTOCOLS: raise LoadJobTerminalException( self._file_path, - f"Databricks cannot load data from staging bucket {bucket_path}. Only s3 and" - " azure buckets are supported", + f"Databricks cannot load data from staging bucket {bucket_path}. Only s3, azure" + " and gcs buckets are supported. Please note that gcs buckets are supported" + " only via named credential", ) if self._job_client.config.is_staging_external_location: @@ -106,6 +108,12 @@ def run(self) -> None: bucket_path = self.ensure_databricks_abfss_url( bucket_path, staging_credentials.azure_storage_account_name ) + else: + raise LoadJobTerminalException( + self._file_path, + "You need to use Databricks named credential to use google storage." + " Passing explicit Google credentials is not supported by Databricks.", + ) if bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS: assert isinstance( @@ -125,7 +133,7 @@ def run(self) -> None: raise LoadJobTerminalException( self._file_path, "Cannot load from local file. Databricks does not support loading from local files." - " Configure staging with an s3 or azure storage bucket.", + " Configure staging with an s3, azure or google storage bucket.", ) # decide on source format, stage_file_path will either be a local file or a bucket path diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index 8228fa06a4..88d47410d5 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -1,5 +1,17 @@ from contextlib import contextmanager, suppress -from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence, List, Tuple, Union, Dict +from typing import ( + Any, + AnyStr, + ClassVar, + Generator, + Iterator, + Optional, + Sequence, + List, + Tuple, + Union, + Dict, +) from databricks import sql as databricks_lib @@ -21,25 +33,30 @@ raise_database_error, raise_open_connection_error, ) -from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction, DataFrame +from dlt.destinations.typing import ArrowTable, DBApi, DBTransaction, DataFrame from dlt.destinations.impl.databricks.configuration import DatabricksCredentials +from dlt.common.destination.reference import DBApiCursor class DatabricksCursorImpl(DBApiCursorImpl): """Use native data frame support if available""" native_cursor: DatabricksSqlCursor # type: ignore[assignment] - vector_size: ClassVar[int] = 2048 + vector_size: ClassVar[int] = 2048 # vector size is 2048 - def df(self, chunk_size: int = None, **kwargs: Any) -> DataFrame: + def iter_arrow(self, chunk_size: int) -> Generator[ArrowTable, None, None]: if chunk_size is None: - return self.native_cursor.fetchall_arrow().to_pandas() - else: - df = self.native_cursor.fetchmany_arrow(chunk_size).to_pandas() - if df.shape[0] == 0: - return None - else: - return df + yield self.native_cursor.fetchall_arrow() + return + while True: + table = self.native_cursor.fetchmany_arrow(chunk_size) + if table.num_rows == 0: + return + yield table + + def iter_df(self, chunk_size: int) -> Generator[DataFrame, None, None]: + for table in self.iter_arrow(chunk_size=chunk_size): + yield table.to_pandas() class DatabricksSqlClient(SqlClientBase[DatabricksSqlConnection], DBTransaction): diff --git a/dlt/destinations/impl/dremio/configuration.py b/dlt/destinations/impl/dremio/configuration.py index d1893e76b7..0a95c2807c 100644 --- a/dlt/destinations/impl/dremio/configuration.py +++ b/dlt/destinations/impl/dremio/configuration.py @@ -4,7 +4,6 @@ from dlt.common.configuration import configspec from dlt.common.configuration.specs import ConnectionStringCredentials from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration -from dlt.common.libs.sql_alchemy_shims import URL from dlt.common.typing import TSecretStrValue from dlt.common.utils import digest128 @@ -21,6 +20,8 @@ class DremioCredentials(ConnectionStringCredentials): __config_gen_annotations__: ClassVar[List[str]] = ["port"] def to_native_credentials(self) -> str: + from dlt.common.libs.sql_alchemy_compat import URL + return URL.create( drivername=self.drivername, host=self.host, port=self.port ).render_as_string(hide_password=False) diff --git a/dlt/destinations/impl/dremio/sql_client.py b/dlt/destinations/impl/dremio/sql_client.py index 7dee056da7..030009c74b 100644 --- a/dlt/destinations/impl/dremio/sql_client.py +++ b/dlt/destinations/impl/dremio/sql_client.py @@ -18,7 +18,8 @@ raise_database_error, raise_open_connection_error, ) -from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction, DataFrame +from dlt.destinations.typing import DBApi, DBTransaction, DataFrame +from dlt.common.destination.reference import DBApiCursor class DremioCursorImpl(DBApiCursorImpl): @@ -26,9 +27,14 @@ class DremioCursorImpl(DBApiCursorImpl): def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]: if chunk_size is None: - return self.native_cursor.fetch_arrow_table().to_pandas() + return self.arrow(chunk_size=chunk_size).to_pandas() return super().df(chunk_size=chunk_size, **kwargs) + def arrow(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]: + if chunk_size is None: + return self.native_cursor.fetch_arrow_table() + return super().arrow(chunk_size=chunk_size, **kwargs) + class DremioSqlClient(SqlClientBase[pydremio.DremioConnection]): dbapi: ClassVar[DBApi] = pydremio diff --git a/dlt/destinations/impl/duckdb/configuration.py b/dlt/destinations/impl/duckdb/configuration.py index ec58d66c8b..0f35770747 100644 --- a/dlt/destinations/impl/duckdb/configuration.py +++ b/dlt/destinations/impl/duckdb/configuration.py @@ -10,7 +10,6 @@ from dlt.common.configuration.specs import ConnectionStringCredentials from dlt.common.configuration.specs.exceptions import InvalidConnectionString from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration -from dlt.common.typing import TSecretValue from dlt.destinations.impl.duckdb.exceptions import InvalidInMemoryDuckdbCredentials try: @@ -83,6 +82,11 @@ def parse_native_representation(self, native_value: Any) -> None: else: raise + @property + def has_open_connection(self) -> bool: + """Returns true if connection was not yet created or no connections were borrowed in case of external connection""" + return not hasattr(self, "_conn") or self._conn_borrows == 0 + def _get_conn_config(self) -> Dict[str, Any]: return {} @@ -90,7 +94,6 @@ def _conn_str(self) -> str: return self.database def _delete_conn(self) -> None: - # print("Closing conn because is owner") self._conn.close() delattr(self, "_conn") diff --git a/dlt/destinations/impl/duckdb/sql_client.py b/dlt/destinations/impl/duckdb/sql_client.py index 80bbbedc9c..89a522c8f7 100644 --- a/dlt/destinations/impl/duckdb/sql_client.py +++ b/dlt/destinations/impl/duckdb/sql_client.py @@ -1,7 +1,9 @@ import duckdb +import math + from contextlib import contextmanager -from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence +from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence, Generator from dlt.common.destination import DestinationCapabilitiesContext from dlt.destinations.exceptions import ( @@ -9,7 +11,7 @@ DatabaseTransientException, DatabaseUndefinedRelation, ) -from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction, DataFrame +from dlt.destinations.typing import DBApi, DBTransaction, DataFrame, ArrowTable from dlt.destinations.sql_client import ( SqlClientBase, DBApiCursorImpl, @@ -18,26 +20,42 @@ ) from dlt.destinations.impl.duckdb.configuration import DuckDbBaseCredentials +from dlt.common.destination.reference import DBApiCursor class DuckDBDBApiCursorImpl(DBApiCursorImpl): """Use native duckdb data frame support if available""" native_cursor: duckdb.DuckDBPyConnection # type: ignore - vector_size: ClassVar[int] = 2048 - - def df(self, chunk_size: int = None, **kwargs: Any) -> DataFrame: - if chunk_size is None: - return self.native_cursor.df(**kwargs) - else: - multiple = chunk_size // self.vector_size + ( - 0 if self.vector_size % chunk_size == 0 else 1 - ) - df = self.native_cursor.fetch_df_chunk(multiple, **kwargs) + vector_size: ClassVar[int] = 2048 # vector size is 2048 + + def _get_page_count(self, chunk_size: int) -> int: + """get the page count for vector size""" + if chunk_size < self.vector_size: + return 1 + return math.floor(chunk_size / self.vector_size) + + def iter_df(self, chunk_size: int) -> Generator[DataFrame, None, None]: + # full frame + if not chunk_size: + yield self.native_cursor.fetch_df() + return + # iterate + while True: + df = self.native_cursor.fetch_df_chunk(self._get_page_count(chunk_size)) if df.shape[0] == 0: - return None - else: - return df + break + yield df + + def iter_arrow(self, chunk_size: int) -> Generator[ArrowTable, None, None]: + if not chunk_size: + yield self.native_cursor.fetch_arrow_table() + return + # iterate + try: + yield from self.native_cursor.fetch_record_batch(chunk_size) + except StopIteration: + pass class DuckDbSqlClient(SqlClientBase[duckdb.DuckDBPyConnection], DBTransaction): diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index 3f2f793559..d6d9865a06 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -1,9 +1,23 @@ import posixpath import os import base64 - +from contextlib import contextmanager from types import TracebackType -from typing import Dict, List, Type, Iterable, Iterator, Optional, Tuple, Sequence, cast, Any +from typing import ( + ContextManager, + List, + Type, + Iterable, + Iterator, + Optional, + Tuple, + Sequence, + cast, + Generator, + Literal, + Any, + Dict, +) from fsspec import AbstractFileSystem from contextlib import contextmanager @@ -23,10 +37,12 @@ TPipelineStateDoc, load_package as current_load_package, ) +from dlt.destinations.sql_client import DBApiCursor, WithSqlClient, SqlClientBase from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( FollowupJobRequest, PreparedTableSchema, + SupportsReadableRelation, TLoadJobState, RunnableLoadJob, JobClientBase, @@ -38,6 +54,7 @@ LoadJob, ) from dlt.common.destination.exceptions import DestinationUndefinedEntity + from dlt.destinations.job_impl import ( ReferenceFollowupJobRequest, FinalizedLoadJob, @@ -46,6 +63,7 @@ from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration from dlt.destinations import path_utils from dlt.destinations.fs_client import FSClientBase +from dlt.destinations.dataset import ReadableDBAPIDataset from dlt.destinations.utils import verify_schema_merge_disposition INIT_FILE_NAME = "init" @@ -209,7 +227,9 @@ def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJobRe return jobs -class FilesystemClient(FSClientBase, JobClientBase, WithStagingDataset, WithStateSync): +class FilesystemClient( + FSClientBase, WithSqlClient, JobClientBase, WithStagingDataset, WithStateSync +): """filesystem client storing jobs in memory""" fs_client: AbstractFileSystem @@ -238,6 +258,21 @@ def __init__( # cannot be replaced and we cannot initialize folders consistently self.table_prefix_layout = path_utils.get_table_prefix_layout(config.layout) self.dataset_name = self.config.normalize_dataset_name(self.schema) + self._sql_client: SqlClientBase[Any] = None + + @property + def sql_client(self) -> SqlClientBase[Any]: + # we use an inner import here, since the sql client depends on duckdb and will + # only be used for read access on data, some users will not need the dependency + from dlt.destinations.impl.filesystem.sql_client import FilesystemSqlClient + + if not self._sql_client: + self._sql_client = FilesystemSqlClient(self) + return self._sql_client + + @sql_client.setter + def sql_client(self, client: SqlClientBase[Any]) -> None: + self._sql_client = client def drop_storage(self) -> None: if self.is_storage_initialized(): diff --git a/dlt/destinations/impl/filesystem/sql_client.py b/dlt/destinations/impl/filesystem/sql_client.py new file mode 100644 index 0000000000..87aa254e96 --- /dev/null +++ b/dlt/destinations/impl/filesystem/sql_client.py @@ -0,0 +1,279 @@ +from typing import Any, Iterator, AnyStr, List, cast, TYPE_CHECKING, Dict + +import os +import re + +import dlt + +import duckdb + +import sqlglot +import sqlglot.expressions as exp +from dlt.common import logger + +from contextlib import contextmanager + +from dlt.common.destination.reference import DBApiCursor +from dlt.common.destination.typing import PreparedTableSchema + +from dlt.destinations.sql_client import raise_database_error + +from dlt.destinations.impl.duckdb.sql_client import DuckDbSqlClient +from dlt.destinations.impl.duckdb.factory import duckdb as duckdb_factory, DuckDbCredentials +from dlt.common.configuration.specs import ( + AwsCredentials, + AzureServicePrincipalCredentialsWithoutDefaults, + AzureCredentialsWithoutDefaults, +) + +SUPPORTED_PROTOCOLS = ["gs", "gcs", "s3", "file", "memory", "az", "abfss"] + +if TYPE_CHECKING: + from dlt.destinations.impl.filesystem.filesystem import FilesystemClient +else: + FilesystemClient = Any + + +class FilesystemSqlClient(DuckDbSqlClient): + memory_db: duckdb.DuckDBPyConnection = None + """Internally created in-mem database in case external is not provided""" + + def __init__( + self, + fs_client: FilesystemClient, + dataset_name: str = None, + credentials: DuckDbCredentials = None, + ) -> None: + # if no credentials are passed from the outside + # we know to keep an in memory instance here + if not credentials: + self.memory_db = duckdb.connect(":memory:") + credentials = DuckDbCredentials(self.memory_db) + + super().__init__( + dataset_name=dataset_name or fs_client.dataset_name, + staging_dataset_name=None, + credentials=credentials, + capabilities=duckdb_factory()._raw_capabilities(), + ) + self.fs_client = fs_client + + if self.fs_client.config.protocol not in SUPPORTED_PROTOCOLS: + raise NotImplementedError( + f"Protocol {self.fs_client.config.protocol} currently not supported for" + f" FilesystemSqlClient. Supported protocols are {SUPPORTED_PROTOCOLS}." + ) + + def _create_default_secret_name(self) -> str: + regex = re.compile("[^a-zA-Z]") + escaped_bucket_name = regex.sub("", self.fs_client.config.bucket_url.lower()) + return f"secret_{escaped_bucket_name}" + + def drop_authentication(self, secret_name: str = None) -> None: + if not secret_name: + secret_name = self._create_default_secret_name() + self._conn.sql(f"DROP PERSISTENT SECRET IF EXISTS {secret_name}") + + def create_authentication(self, persistent: bool = False, secret_name: str = None) -> None: + if not secret_name: + secret_name = self._create_default_secret_name() + + persistent_stmt = "" + if persistent: + persistent_stmt = " PERSISTENT " + + # abfss buckets have an @ compontent + scope = self.fs_client.config.bucket_url + if "@" in scope: + scope = scope.split("@")[0] + + # add secrets required for creating views + if self.fs_client.config.protocol == "s3": + aws_creds = cast(AwsCredentials, self.fs_client.config.credentials) + endpoint = ( + aws_creds.endpoint_url.replace("https://", "") + if aws_creds.endpoint_url + else "s3.amazonaws.com" + ) + self._conn.sql(f""" + CREATE OR REPLACE {persistent_stmt} SECRET {secret_name} ( + TYPE S3, + KEY_ID '{aws_creds.aws_access_key_id}', + SECRET '{aws_creds.aws_secret_access_key}', + REGION '{aws_creds.region_name}', + ENDPOINT '{endpoint}', + SCOPE '{scope}' + );""") + + # azure with storage account creds + elif self.fs_client.config.protocol in ["az", "abfss"] and isinstance( + self.fs_client.config.credentials, AzureCredentialsWithoutDefaults + ): + azsa_creds = self.fs_client.config.credentials + self._conn.sql(f""" + CREATE OR REPLACE {persistent_stmt} SECRET {secret_name} ( + TYPE AZURE, + CONNECTION_STRING 'AccountName={azsa_creds.azure_storage_account_name};AccountKey={azsa_creds.azure_storage_account_key}', + SCOPE '{scope}' + );""") + + # azure with service principal creds + elif self.fs_client.config.protocol in ["az", "abfss"] and isinstance( + self.fs_client.config.credentials, AzureServicePrincipalCredentialsWithoutDefaults + ): + azsp_creds = self.fs_client.config.credentials + self._conn.sql(f""" + CREATE OR REPLACE {persistent_stmt} SECRET {secret_name} ( + TYPE AZURE, + PROVIDER SERVICE_PRINCIPAL, + TENANT_ID '{azsp_creds.azure_tenant_id}', + CLIENT_ID '{azsp_creds.azure_client_id}', + CLIENT_SECRET '{azsp_creds.azure_client_secret}', + ACCOUNT_NAME '{azsp_creds.azure_storage_account_name}', + SCOPE '{scope}' + );""") + elif persistent: + raise Exception( + "Cannot create persistent secret for filesystem protocol" + f" {self.fs_client.config.protocol}. If you are trying to use persistent secrets" + " with gs/gcs, please use the s3 compatibility layer." + ) + + # native google storage implementation is not supported.. + elif self.fs_client.config.protocol in ["gs", "gcs"]: + logger.warn( + "For gs/gcs access via duckdb please use the gs/gcs s3 compatibility layer. Falling" + " back to fsspec." + ) + self._conn.register_filesystem(self.fs_client.fs_client) + + # for memory we also need to register filesystem + elif self.fs_client.config.protocol == "memory": + self._conn.register_filesystem(self.fs_client.fs_client) + + # the line below solves problems with certificate path lookup on linux + # see duckdb docs + if self.fs_client.config.protocol in ["az", "abfss"]: + self._conn.sql("SET azure_transport_option_type = 'curl';") + + def open_connection(self) -> duckdb.DuckDBPyConnection: + # we keep the in memory instance around, so if this prop is set, return it + first_connection = self.credentials.has_open_connection + super().open_connection() + + if first_connection: + # set up dataset + if not self.has_dataset(): + self.create_dataset() + self._conn.sql(f"USE {self.fully_qualified_dataset_name()}") + + # create authentication to data provider + self.create_authentication() + + return self._conn + + @raise_database_error + def create_views_for_tables(self, tables: Dict[str, str]) -> None: + """Add the required tables as views to the duckdb in memory instance""" + + # create all tables in duck instance + for table_name in tables.keys(): + view_name = tables[table_name] + + if table_name not in self.fs_client.schema.tables: + # unknown views will not be created + continue + + # only create view if it does not exist in the current schema yet + existing_tables = [tname[0] for tname in self._conn.execute("SHOW TABLES").fetchall()] + if view_name in existing_tables: + continue + + # discover file type + schema_table = cast(PreparedTableSchema, self.fs_client.schema.tables[table_name]) + folder = self.fs_client.get_table_dir(table_name) + files = self.fs_client.list_table_files(table_name) + first_file_type = os.path.splitext(files[0])[1][1:] + + # build files string + supports_wildcard_notation = self.fs_client.config.protocol != "abfss" + protocol = ( + "" if self.fs_client.is_local_filesystem else f"{self.fs_client.config.protocol}://" + ) + resolved_folder = f"{protocol}{folder}" + resolved_files_string = f"'{resolved_folder}/**/*.{first_file_type}'" + if not supports_wildcard_notation: + resolved_files_string = ",".join(map(lambda f: f"'{protocol}{f}'", files)) + + # build columns definition + type_mapper = self.capabilities.get_type_mapper() + columns = ",".join( + map( + lambda c: ( + f'{self.escape_column_name(c["name"])}:' + f' "{type_mapper.to_destination_type(c, schema_table)}"' + ), + self.fs_client.schema.tables[table_name]["columns"].values(), + ) + ) + + # discover wether compression is enabled + compression = ( + "" + if dlt.config.get("data_writer.disable_compression") + else ", compression = 'gzip'" + ) + + # dlt tables are never compressed for now... + if table_name in self.fs_client.schema.dlt_table_names(): + compression = "" + + # create from statement + from_statement = "" + if schema_table.get("table_format") == "delta": + from_statement = f"delta_scan('{resolved_folder}')" + elif first_file_type == "parquet": + from_statement = f"read_parquet([{resolved_files_string}])" + elif first_file_type == "jsonl": + from_statement = ( + f"read_json([{resolved_files_string}], columns = {{{columns}}}) {compression}" + ) + else: + raise NotImplementedError( + f"Unknown filetype {first_file_type} for table {table_name}. Currently only" + " jsonl and parquet files as well as delta tables are supported." + ) + + # create table + view_name = self.make_qualified_table_name(view_name) + create_table_sql_base = f"CREATE VIEW {view_name} AS SELECT * FROM {from_statement}" + self._conn.execute(create_table_sql_base) + + @contextmanager + @raise_database_error + def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DBApiCursor]: + # skip parametrized queries, we could also render them but currently user is not able to + # do parametrized queries via dataset interface + if not args and not kwargs: + # find all tables to preload + expression = sqlglot.parse_one(query, read="duckdb") # type: ignore + load_tables: Dict[str, str] = {} + for table in expression.find_all(exp.Table): + # sqlglot has tables without tables ie. schemas are tables + if not table.this: + continue + schema = table.db + # add only tables from the dataset schema + if not schema or schema.lower() == self.dataset_name.lower(): + load_tables[table.name] = table.name + + if load_tables: + self.create_views_for_tables(load_tables) + + with super().execute_query(query, *args, **kwargs) as cursor: + yield cursor + + def __del__(self) -> None: + if self.memory_db: + self.memory_db.close() + self.memory_db = None diff --git a/dlt/destinations/impl/motherduck/configuration.py b/dlt/destinations/impl/motherduck/configuration.py index 695cf3766d..d842a6ae69 100644 --- a/dlt/destinations/impl/motherduck/configuration.py +++ b/dlt/destinations/impl/motherduck/configuration.py @@ -6,7 +6,7 @@ from dlt.common.configuration import configspec from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration from dlt.common.destination.exceptions import DestinationTerminalException -from dlt.common.typing import TSecretValue +from dlt.common.typing import TSecretStrValue from dlt.common.utils import digest128 from dlt.destinations.impl.duckdb.configuration import DuckDbBaseCredentials @@ -21,7 +21,7 @@ class MotherDuckCredentials(DuckDbBaseCredentials): default="md", init=False, repr=False, compare=False ) username: str = "motherduck" - password: TSecretValue = None + password: TSecretStrValue = None database: str = "my_db" custom_user_agent: Optional[str] = MOTHERDUCK_USER_AGENT @@ -35,7 +35,7 @@ def _conn_str(self) -> str: def _token_to_password(self) -> None: # could be motherduck connection if self.query and "token" in self.query: - self.password = TSecretValue(self.query.pop("token")) + self.password = self.query.pop("token") def borrow_conn(self, read_only: bool) -> Any: from duckdb import HTTPException, InvalidInputException diff --git a/dlt/destinations/impl/mssql/configuration.py b/dlt/destinations/impl/mssql/configuration.py index a30b300343..c95f52a566 100644 --- a/dlt/destinations/impl/mssql/configuration.py +++ b/dlt/destinations/impl/mssql/configuration.py @@ -1,11 +1,10 @@ import dataclasses from typing import Final, ClassVar, Any, List, Dict -from dlt.common.libs.sql_alchemy_shims import URL from dlt.common.configuration import configspec from dlt.common.configuration.specs import ConnectionStringCredentials from dlt.common.utils import digest128 -from dlt.common.typing import TSecretValue +from dlt.common.typing import TSecretStrValue from dlt.common.exceptions import SystemConfigurationException from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration @@ -16,7 +15,7 @@ class MsSqlCredentials(ConnectionStringCredentials): drivername: Final[str] = dataclasses.field(default="mssql", init=False, repr=False, compare=False) # type: ignore database: str = None username: str = None - password: TSecretValue = None + password: TSecretStrValue = None host: str = None port: int = 1433 connect_timeout: int = 15 diff --git a/dlt/destinations/impl/mssql/mssql.py b/dlt/destinations/impl/mssql/mssql.py index 9eabfcf392..27aebe07f2 100644 --- a/dlt/destinations/impl/mssql/mssql.py +++ b/dlt/destinations/impl/mssql/mssql.py @@ -1,6 +1,9 @@ from typing import Dict, Optional, Sequence, List, Any -from dlt.common.destination.reference import FollowupJobRequest, PreparedTableSchema +from dlt.common.destination.reference import ( + FollowupJobRequest, + PreparedTableSchema, +) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.schema import TColumnSchema, TColumnHint, Schema from dlt.common.schema.typing import TColumnType diff --git a/dlt/destinations/impl/mssql/sql_client.py b/dlt/destinations/impl/mssql/sql_client.py index e1b51743f5..6ec2beb95e 100644 --- a/dlt/destinations/impl/mssql/sql_client.py +++ b/dlt/destinations/impl/mssql/sql_client.py @@ -13,7 +13,7 @@ DatabaseTransientException, DatabaseUndefinedRelation, ) -from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction +from dlt.destinations.typing import DBApi, DBTransaction from dlt.destinations.sql_client import ( DBApiCursorImpl, SqlClientBase, @@ -22,6 +22,7 @@ ) from dlt.destinations.impl.mssql.configuration import MsSqlCredentials +from dlt.common.destination.reference import DBApiCursor def handle_datetimeoffset(dto_value: bytes) -> datetime: diff --git a/dlt/destinations/impl/postgres/configuration.py b/dlt/destinations/impl/postgres/configuration.py index 656d1b3ac1..14eb499f89 100644 --- a/dlt/destinations/impl/postgres/configuration.py +++ b/dlt/destinations/impl/postgres/configuration.py @@ -2,11 +2,10 @@ from typing import Dict, Final, ClassVar, Any, List, Optional from dlt.common.data_writers.configuration import CsvFormatConfiguration -from dlt.common.libs.sql_alchemy_shims import URL from dlt.common.configuration import configspec from dlt.common.configuration.specs import ConnectionStringCredentials from dlt.common.utils import digest128 -from dlt.common.typing import TSecretValue +from dlt.common.typing import TSecretStrValue from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration @@ -16,7 +15,7 @@ class PostgresCredentials(ConnectionStringCredentials): drivername: Final[str] = dataclasses.field(default="postgresql", init=False, repr=False, compare=False) # type: ignore database: str = None username: str = None - password: TSecretValue = None + password: TSecretStrValue = None host: str = None port: int = 5432 connect_timeout: int = 15 diff --git a/dlt/destinations/impl/postgres/sql_client.py b/dlt/destinations/impl/postgres/sql_client.py index d867248196..a97c8511f1 100644 --- a/dlt/destinations/impl/postgres/sql_client.py +++ b/dlt/destinations/impl/postgres/sql_client.py @@ -17,7 +17,8 @@ DatabaseTransientException, DatabaseUndefinedRelation, ) -from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction +from dlt.common.destination.reference import DBApiCursor +from dlt.destinations.typing import DBApi, DBTransaction from dlt.destinations.sql_client import ( DBApiCursorImpl, SqlClientBase, diff --git a/dlt/destinations/impl/redshift/configuration.py b/dlt/destinations/impl/redshift/configuration.py index 3b84c8663e..bab7e371cf 100644 --- a/dlt/destinations/impl/redshift/configuration.py +++ b/dlt/destinations/impl/redshift/configuration.py @@ -1,7 +1,7 @@ import dataclasses from typing import Final, Optional -from dlt.common.typing import TSecretValue +from dlt.common.typing import TSecretStrValue from dlt.common.configuration import configspec from dlt.common.utils import digest128 @@ -14,7 +14,7 @@ @configspec(init=False) class RedshiftCredentials(PostgresCredentials): port: int = 5439 - password: TSecretValue = None + password: TSecretStrValue = None username: str = None host: str = None diff --git a/dlt/destinations/impl/snowflake/configuration.py b/dlt/destinations/impl/snowflake/configuration.py index de8faa91a6..4a89a1564b 100644 --- a/dlt/destinations/impl/snowflake/configuration.py +++ b/dlt/destinations/impl/snowflake/configuration.py @@ -4,7 +4,6 @@ from dlt import version from dlt.common.data_writers.configuration import CsvFormatConfiguration -from dlt.common.libs.sql_alchemy_shims import URL from dlt.common.exceptions import MissingDependencyException from dlt.common.typing import TSecretStrValue from dlt.common.configuration.specs import ConnectionStringCredentials diff --git a/dlt/destinations/impl/snowflake/sql_client.py b/dlt/destinations/impl/snowflake/sql_client.py index 8d11c23363..e52c5424d3 100644 --- a/dlt/destinations/impl/snowflake/sql_client.py +++ b/dlt/destinations/impl/snowflake/sql_client.py @@ -16,8 +16,9 @@ raise_database_error, raise_open_connection_error, ) -from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction, DataFrame +from dlt.destinations.typing import DBApi, DBTransaction, DataFrame from dlt.destinations.impl.snowflake.configuration import SnowflakeCredentials +from dlt.common.destination.reference import DBApiCursor class SnowflakeCursorImpl(DBApiCursorImpl): diff --git a/dlt/destinations/impl/sqlalchemy/db_api_client.py b/dlt/destinations/impl/sqlalchemy/db_api_client.py index 829fe8db82..a407e53d70 100644 --- a/dlt/destinations/impl/sqlalchemy/db_api_client.py +++ b/dlt/destinations/impl/sqlalchemy/db_api_client.py @@ -17,6 +17,7 @@ import sqlalchemy as sa from sqlalchemy.engine import Connection +from sqlalchemy.exc import ResourceClosedError from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import PreparedTableSchema @@ -27,11 +28,13 @@ LoadClientNotConnected, DatabaseException, ) -from dlt.destinations.typing import DBTransaction, DBApiCursor -from dlt.destinations.sql_client import SqlClientBase, DBApiCursorImpl +from dlt.common.destination.reference import DBApiCursor +from dlt.destinations.typing import DBTransaction +from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.impl.sqlalchemy.configuration import SqlalchemyCredentials from dlt.destinations.impl.sqlalchemy.alter_table import MigrationMaker from dlt.common.typing import TFun +from dlt.destinations.sql_client import DBApiCursorImpl class SqlaTransactionWrapper(DBTransaction): @@ -77,8 +80,14 @@ def __init__(self, curr: sa.engine.CursorResult) -> None: self.fetchone = curr.fetchone # type: ignore[assignment] self.fetchmany = curr.fetchmany # type: ignore[assignment] + self.set_default_schema_columns() + def _get_columns(self) -> List[str]: - return list(self.native_cursor.keys()) # type: ignore[attr-defined] + try: + return list(self.native_cursor.keys()) # type: ignore[attr-defined] + except ResourceClosedError: + # this happens if now rows are returned + return [] # @property # def description(self) -> Any: @@ -238,16 +247,16 @@ def _sqlite_create_dataset(self, dataset_name: str) -> None: """Mimic multiple schemas in sqlite using ATTACH DATABASE to attach a new database file to the current connection. """ - if dataset_name == "main": - # main always exists - return if self._sqlite_is_memory_db(): new_db_fn = ":memory:" else: new_db_fn = self._sqlite_dataset_filename(dataset_name) - statement = "ATTACH DATABASE :fn AS :name" - self.execute_sql(statement, fn=new_db_fn, name=dataset_name) + if dataset_name != "main": # main is the current file, it is always attached + statement = "ATTACH DATABASE :fn AS :name" + self.execute_sql(statement, fn=new_db_fn, name=dataset_name) + # WAL mode is applied to all currently attached databases + self.execute_sql("PRAGMA journal_mode=WAL") self._sqlite_attached_datasets.add(dataset_name) def _sqlite_drop_dataset(self, dataset_name: str) -> None: diff --git a/dlt/destinations/impl/sqlalchemy/factory.py b/dlt/destinations/impl/sqlalchemy/factory.py index 360dd89192..edd827ed00 100644 --- a/dlt/destinations/impl/sqlalchemy/factory.py +++ b/dlt/destinations/impl/sqlalchemy/factory.py @@ -1,5 +1,6 @@ import typing as t +from dlt.common import pendulum from dlt.common.destination import Destination, DestinationCapabilitiesContext from dlt.common.destination.capabilities import DataTypeMapper from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE @@ -9,14 +10,7 @@ SqlalchemyCredentials, SqlalchemyClientConfiguration, ) - -SqlalchemyTypeMapper: t.Type[DataTypeMapper] - -try: - from dlt.destinations.impl.sqlalchemy.type_mapper import SqlalchemyTypeMapper -except ModuleNotFoundError: - # assign mock type mapper if no sqlalchemy - from dlt.common.destination.capabilities import UnsupportedTypeMapper as SqlalchemyTypeMapper +from dlt.common.data_writers.escape import format_datetime_literal if t.TYPE_CHECKING: # from dlt.destinations.impl.sqlalchemy.sqlalchemy_client import SqlalchemyJobClient @@ -24,10 +18,28 @@ from sqlalchemy.engine import Engine +def _format_mysql_datetime_literal( + v: pendulum.DateTime, precision: int = 6, no_tz: bool = False +) -> str: + # Format without timezone to prevent tz conversion in SELECT + return format_datetime_literal(v, precision, no_tz=True) + + class sqlalchemy(Destination[SqlalchemyClientConfiguration, "SqlalchemyJobClient"]): spec = SqlalchemyClientConfiguration def _raw_capabilities(self) -> DestinationCapabilitiesContext: + # lazy import to avoid sqlalchemy dep + SqlalchemyTypeMapper: t.Type[DataTypeMapper] + + try: + from dlt.destinations.impl.sqlalchemy.type_mapper import SqlalchemyTypeMapper + except ModuleNotFoundError: + # assign mock type mapper if no sqlalchemy + from dlt.common.destination.capabilities import ( + UnsupportedTypeMapper as SqlalchemyTypeMapper, + ) + # https://www.sqlalchemyql.org/docs/current/limits.html caps = DestinationCapabilitiesContext.generic_capabilities() caps.preferred_loader_file_format = "typed-jsonl" @@ -50,6 +62,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.supports_multiple_statements = False caps.type_mapper = SqlalchemyTypeMapper caps.supported_replace_strategies = ["truncate-and-insert", "insert-from-staging"] + caps.supported_merge_strategies = ["delete-insert", "scd2"] return caps @@ -67,6 +80,8 @@ def adjust_capabilities( caps.max_identifier_length = dialect.max_identifier_length caps.max_column_identifier_length = dialect.max_identifier_length caps.supports_native_boolean = dialect.supports_native_boolean + if dialect.name == "mysql": + caps.format_datetime_literal = _format_mysql_datetime_literal return caps diff --git a/dlt/destinations/impl/sqlalchemy/load_jobs.py b/dlt/destinations/impl/sqlalchemy/load_jobs.py index c8486dc0f0..3cfd6bd910 100644 --- a/dlt/destinations/impl/sqlalchemy/load_jobs.py +++ b/dlt/destinations/impl/sqlalchemy/load_jobs.py @@ -13,6 +13,7 @@ from dlt.destinations.sql_jobs import SqlFollowupJob, SqlJobParams from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient +from dlt.destinations.impl.sqlalchemy.merge_job import SqlalchemyMergeFollowupJob if TYPE_CHECKING: from dlt.destinations.impl.sqlalchemy.sqlalchemy_job_client import SqlalchemyJobClient @@ -134,3 +135,11 @@ def generate_sql( statements.append(stmt) return statements + + +__all__ = [ + "SqlalchemyJsonLInsertJob", + "SqlalchemyParquetInsertJob", + "SqlalchemyStagingCopyJob", + "SqlalchemyMergeFollowupJob", +] diff --git a/dlt/destinations/impl/sqlalchemy/merge_job.py b/dlt/destinations/impl/sqlalchemy/merge_job.py new file mode 100644 index 0000000000..5360939ba0 --- /dev/null +++ b/dlt/destinations/impl/sqlalchemy/merge_job.py @@ -0,0 +1,441 @@ +from typing import Sequence, Tuple, Optional, List, Union +import operator + +import sqlalchemy as sa + +from dlt.destinations.sql_jobs import SqlMergeFollowupJob +from dlt.common.destination.reference import PreparedTableSchema, DestinationCapabilitiesContext +from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient +from dlt.common.schema.utils import ( + get_columns_names_with_prop, + get_dedup_sort_tuple, + get_first_column_name_with_prop, + is_nested_table, + get_validity_column_names, + get_active_record_timestamp, +) +from dlt.common.time import ensure_pendulum_datetime +from dlt.common.storages.load_package import load_package as current_load_package + + +class SqlalchemyMergeFollowupJob(SqlMergeFollowupJob): + """Uses SQLAlchemy to generate merge SQL statements. + Result is equivalent to the SQL generated by `SqlMergeFollowupJob` + except for delete-insert we use concrete tables instead of temporary tables. + """ + + @classmethod + def gen_merge_sql( + cls, + table_chain: Sequence[PreparedTableSchema], + sql_client: SqlalchemyClient, # type: ignore[override] + ) -> List[str]: + root_table = table_chain[0] + + root_table_obj = sql_client.get_existing_table(root_table["name"]) + staging_root_table_obj = root_table_obj.to_metadata( + sql_client.metadata, schema=sql_client.staging_dataset_name + ) + + primary_key_names = get_columns_names_with_prop(root_table, "primary_key") + merge_key_names = get_columns_names_with_prop(root_table, "merge_key") + + temp_metadata = sa.MetaData() + + append_fallback = (len(primary_key_names) + len(merge_key_names)) == 0 + + sqla_statements = [] + tables_to_drop: List[sa.Table] = ( + [] + ) # Keep track of temp tables to drop at the end of the job + + if not append_fallback: + key_clause = cls._generate_key_table_clauses( + primary_key_names, merge_key_names, root_table_obj, staging_root_table_obj + ) + + # Generate the delete statements + if len(table_chain) == 1 and not cls.requires_temp_table_for_delete(): + delete_statement = root_table_obj.delete().where( + sa.exists( + sa.select(sa.literal(1)) + .where(key_clause) + .select_from(staging_root_table_obj) + ) + ) + sqla_statements.append(delete_statement) + else: + row_key_col_name = cls._get_row_key_col(table_chain, sql_client, root_table) + row_key_col = root_table_obj.c[row_key_col_name] + # Use a real table cause sqlalchemy doesn't have TEMPORARY TABLE abstractions + delete_temp_table = sa.Table( + "delete_" + root_table_obj.name, + temp_metadata, + # Give this column a fixed name to be able to reference it later + sa.Column("_dlt_id", row_key_col.type), + schema=staging_root_table_obj.schema, + ) + tables_to_drop.append(delete_temp_table) + # Add the CREATE TABLE statement + sqla_statements.append(sa.sql.ddl.CreateTable(delete_temp_table)) + # Insert data into the "temporary" table + insert_statement = delete_temp_table.insert().from_select( + [row_key_col], + sa.select(row_key_col).where( + sa.exists( + sa.select(sa.literal(1)) + .where(key_clause) + .select_from(staging_root_table_obj) + ) + ), + ) + sqla_statements.append(insert_statement) + + for table in table_chain[1:]: + chain_table_obj = sql_client.get_existing_table(table["name"]) + root_key_name = cls._get_root_key_col(table_chain, sql_client, table) + root_key_col = chain_table_obj.c[root_key_name] + + delete_statement = chain_table_obj.delete().where( + root_key_col.in_(sa.select(delete_temp_table.c._dlt_id)) + ) + + sqla_statements.append(delete_statement) + + # Delete from root table + delete_statement = root_table_obj.delete().where( + row_key_col.in_(sa.select(delete_temp_table.c._dlt_id)) + ) + sqla_statements.append(delete_statement) + + hard_delete_col_name, not_delete_cond = cls._get_hard_delete_col_and_cond( + root_table, + root_table_obj, + invert=True, + ) + + dedup_sort = get_dedup_sort_tuple(root_table) # column_name, 'asc' | 'desc' + + if len(table_chain) > 1 and (primary_key_names or hard_delete_col_name is not None): + condition_column_names = ( + None if hard_delete_col_name is None else [hard_delete_col_name] + ) + condition_columns = ( + [staging_root_table_obj.c[col_name] for col_name in condition_column_names] + if condition_column_names is not None + else [] + ) + + staging_row_key_col = staging_root_table_obj.c[row_key_col_name] + + # Create the insert "temporary" table (but use a concrete table) + insert_temp_table = sa.Table( + "insert_" + root_table_obj.name, + temp_metadata, + sa.Column(row_key_col_name, staging_row_key_col.type), + schema=staging_root_table_obj.schema, + ) + tables_to_drop.append(insert_temp_table) + create_insert_temp_table_statement = sa.sql.ddl.CreateTable(insert_temp_table) + sqla_statements.append(create_insert_temp_table_statement) + staging_primary_key_cols = [ + staging_root_table_obj.c[col_name] for col_name in primary_key_names + ] + + inner_cols = [staging_row_key_col] + + if primary_key_names: + if dedup_sort is not None: + order_by_col = staging_root_table_obj.c[dedup_sort[0]] + order_dir_func = sa.asc if dedup_sort[1] == "asc" else sa.desc + else: + order_by_col = sa.select(sa.literal(None)) + order_dir_func = sa.asc + if condition_columns: + inner_cols += condition_columns + + inner_select = sa.select( + sa.func.row_number() + .over( + partition_by=set(staging_primary_key_cols), + order_by=order_dir_func(order_by_col), + ) + .label("_dlt_dedup_rn"), + *inner_cols, + ).subquery() + + select_for_temp_insert = sa.select(inner_select.c[row_key_col_name]).where( + inner_select.c._dlt_dedup_rn == 1 + ) + hard_delete_col_name, not_delete_cond = cls._get_hard_delete_col_and_cond( + root_table, + inner_select, + invert=True, + ) + + if not_delete_cond is not None: + select_for_temp_insert = select_for_temp_insert.where(not_delete_cond) + else: + hard_delete_col_name, not_delete_cond = cls._get_hard_delete_col_and_cond( + root_table, + staging_root_table_obj, + invert=True, + ) + select_for_temp_insert = sa.select(staging_row_key_col).where(not_delete_cond) + + insert_into_temp_table = insert_temp_table.insert().from_select( + [row_key_col_name], select_for_temp_insert + ) + sqla_statements.append(insert_into_temp_table) + + # Insert from staging to dataset + for table in table_chain: + table_obj = sql_client.get_existing_table(table["name"]) + staging_table_obj = table_obj.to_metadata( + sql_client.metadata, schema=sql_client.staging_dataset_name + ) + select_sql = staging_table_obj.select() + + if (primary_key_names and len(table_chain) > 1) or ( + not primary_key_names + and is_nested_table(table) + and hard_delete_col_name is not None + ): + uniq_column_name = root_key_name if is_nested_table(table) else row_key_col_name + uniq_column = staging_table_obj.c[uniq_column_name] + select_sql = select_sql.where( + uniq_column.in_( + sa.select( + insert_temp_table.c[row_key_col_name].label(uniq_column_name) + ).subquery() + ) + ) + elif primary_key_names and len(table_chain) == 1: + staging_primary_key_cols = [ + staging_table_obj.c[col_name] for col_name in primary_key_names + ] + if dedup_sort is not None: + order_by_col = staging_table_obj.c[dedup_sort[0]] + order_dir_func = sa.asc if dedup_sort[1] == "asc" else sa.desc + else: + order_by_col = sa.select(sa.literal(None)) + order_dir_func = sa.asc + + inner_select = sa.select( + staging_table_obj, + sa.func.row_number() + .over( + partition_by=set(staging_primary_key_cols), + order_by=order_dir_func(order_by_col), + ) + .label("_dlt_dedup_rn"), + ).subquery() + + select_sql = sa.select( + *[c for c in inner_select.c if c.name != "_dlt_dedup_rn"] + ).where(inner_select.c._dlt_dedup_rn == 1) + + hard_delete_col_name, not_delete_cond = cls._get_hard_delete_col_and_cond( + root_table, inner_select, invert=True + ) + + if hard_delete_col_name is not None: + select_sql = select_sql.where(not_delete_cond) + else: + hard_delete_col_name, not_delete_cond = cls._get_hard_delete_col_and_cond( + root_table, staging_root_table_obj, invert=True + ) + + if hard_delete_col_name is not None: + select_sql = select_sql.where(not_delete_cond) + + insert_statement = table_obj.insert().from_select( + [col.name for col in table_obj.columns], select_sql + ) + sqla_statements.append(insert_statement) + + # Drop all "temp" tables at the end + for table_obj in tables_to_drop: + sqla_statements.append(sa.sql.ddl.DropTable(table_obj)) + + return [ + x + ";" if not x.endswith(";") else x + for x in ( + str(stmt.compile(sql_client.engine, compile_kwargs={"literal_binds": True})) + for stmt in sqla_statements + ) + ] + + @classmethod + def _get_hard_delete_col_and_cond( # type: ignore[override] + cls, + table: PreparedTableSchema, + table_obj: sa.Table, + invert: bool = False, + ) -> Tuple[Optional[str], Optional[sa.sql.elements.BinaryExpression]]: + col_name = get_first_column_name_with_prop(table, "hard_delete") + if col_name is None: + return None, None + col = table_obj.c[col_name] + if invert: + cond = col.is_(None) + else: + cond = col.isnot(None) + if table["columns"][col_name]["data_type"] == "bool": + if invert: + cond = sa.or_(cond, col.is_(False)) + else: + cond = col.is_(True) + return col_name, cond + + @classmethod + def _generate_key_table_clauses( + cls, + primary_keys: Sequence[str], + merge_keys: Sequence[str], + root_table_obj: sa.Table, + staging_root_table_obj: sa.Table, + ) -> sa.sql.ClauseElement: + # Returns an sqlalchemy or_ clause + clauses = [] + if primary_keys or merge_keys: + for key in primary_keys: + clauses.append( + sa.and_( + *[ + root_table_obj.c[key] == staging_root_table_obj.c[key] + for key in primary_keys + ] + ) + ) + for key in merge_keys: + clauses.append( + sa.and_( + *[ + root_table_obj.c[key] == staging_root_table_obj.c[key] + for key in merge_keys + ] + ) + ) + return sa.or_(*clauses) # type: ignore[no-any-return] + else: + return sa.true() # type: ignore[no-any-return] + + @classmethod + def _gen_concat_sqla( + cls, columns: Sequence[sa.Column] + ) -> Union[sa.sql.elements.BinaryExpression, sa.Column]: + # Use col1 + col2 + col3 ... to generate a dialect specific concat expression + result = columns[0] + if len(columns) == 1: + return result + # Cast because CONCAT is only generated for string columns + result = sa.cast(result, sa.String) + for col in columns[1:]: + result = operator.add(result, sa.cast(col, sa.String)) + return result + + @classmethod + def gen_scd2_sql( + cls, + table_chain: Sequence[PreparedTableSchema], + sql_client: SqlalchemyClient, # type: ignore[override] + ) -> List[str]: + sqla_statements = [] + root_table = table_chain[0] + root_table_obj = sql_client.get_existing_table(root_table["name"]) + staging_root_table_obj = root_table_obj.to_metadata( + sql_client.metadata, schema=sql_client.staging_dataset_name + ) + + from_, to = get_validity_column_names(root_table) + hash_ = get_first_column_name_with_prop(root_table, "x-row-version") + + caps = sql_client.capabilities + + format_datetime_literal = caps.format_datetime_literal + if format_datetime_literal is None: + format_datetime_literal = ( + DestinationCapabilitiesContext.generic_capabilities().format_datetime_literal + ) + + boundary_ts = ensure_pendulum_datetime( + root_table.get("x-boundary-timestamp", current_load_package()["state"]["created_at"]) # type: ignore[arg-type] + ) + + boundary_literal = format_datetime_literal(boundary_ts, caps.timestamp_precision) + + active_record_timestamp = get_active_record_timestamp(root_table) + + update_statement = ( + root_table_obj.update() + .values({to: sa.text(boundary_literal)}) + .where(root_table_obj.c[hash_].notin_(sa.select(staging_root_table_obj.c[hash_]))) + ) + + if active_record_timestamp is None: + active_record_literal = None + root_is_active_clause = root_table_obj.c[to].is_(None) + else: + active_record_literal = format_datetime_literal( + active_record_timestamp, caps.timestamp_precision + ) + root_is_active_clause = root_table_obj.c[to] == sa.text(active_record_literal) + + update_statement = update_statement.where(root_is_active_clause) + + merge_keys = get_columns_names_with_prop(root_table, "merge_key") + if merge_keys: + root_merge_key_cols = [root_table_obj.c[key] for key in merge_keys] + staging_merge_key_cols = [staging_root_table_obj.c[key] for key in merge_keys] + + update_statement = update_statement.where( + cls._gen_concat_sqla(root_merge_key_cols).in_( + sa.select(cls._gen_concat_sqla(staging_merge_key_cols)) + ) + ) + + sqla_statements.append(update_statement) + + insert_statement = root_table_obj.insert().from_select( + [col.name for col in root_table_obj.columns], + sa.select( + sa.literal(boundary_literal.strip("'")).label(from_), + sa.literal( + active_record_literal.strip("'") if active_record_literal is not None else None + ).label(to), + *[c for c in staging_root_table_obj.columns if c.name not in [from_, to]], + ).where( + staging_root_table_obj.c[hash_].notin_( + sa.select(root_table_obj.c[hash_]).where(root_is_active_clause) + ) + ), + ) + sqla_statements.append(insert_statement) + + nested_tables = table_chain[1:] + for table in nested_tables: + row_key_column = cls._get_root_key_col(table_chain, sql_client, table) + + table_obj = sql_client.get_existing_table(table["name"]) + staging_table_obj = table_obj.to_metadata( + sql_client.metadata, schema=sql_client.staging_dataset_name + ) + + insert_statement = table_obj.insert().from_select( + [col.name for col in table_obj.columns], + staging_table_obj.select().where( + staging_table_obj.c[row_key_column].notin_( + sa.select(table_obj.c[row_key_column]) + ) + ), + ) + sqla_statements.append(insert_statement) + + return [ + x + ";" if not x.endswith(";") else x + for x in ( + str(stmt.compile(sql_client.engine, compile_kwargs={"literal_binds": True})) + for stmt in sqla_statements + ) + ] diff --git a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py index a2514a43e0..c5a6442d8a 100644 --- a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py +++ b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py @@ -18,7 +18,11 @@ from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.schema import Schema, TTableSchema, TColumnSchema, TSchemaTables from dlt.common.schema.typing import TColumnType, TTableSchemaColumns -from dlt.common.schema.utils import pipeline_state_table, normalize_table_identifiers +from dlt.common.schema.utils import ( + pipeline_state_table, + normalize_table_identifiers, + is_complete_column, +) from dlt.destinations.exceptions import DatabaseUndefinedRelation from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient from dlt.destinations.impl.sqlalchemy.configuration import SqlalchemyClientConfiguration @@ -26,6 +30,7 @@ SqlalchemyJsonLInsertJob, SqlalchemyParquetInsertJob, SqlalchemyStagingCopyJob, + SqlalchemyMergeFollowupJob, ) @@ -65,6 +70,7 @@ def _to_table_object(self, schema_table: PreparedTableSchema) -> sa.Table: *[ self._to_column_object(col, schema_table) for col in schema_table["columns"].values() + if is_complete_column(col) ], extend_existing=True, schema=self.sql_client.dataset_name, @@ -97,13 +103,10 @@ def _create_replace_followup_jobs( def _create_merge_followup_jobs( self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: + # Ensure all tables exist in metadata before generating sql job for table in table_chain: self._to_table_object(table) - return [ - SqlalchemyStagingCopyJob.from_table_chain( - table_chain, self.sql_client, {"replace": False} - ) - ] + return [SqlalchemyMergeFollowupJob.from_table_chain(table_chain, self.sql_client)] def create_load_job( self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 0ddded98b6..0fca64d7ba 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -14,9 +14,11 @@ Type, Iterable, Iterator, + Generator, ) import zlib import re +from contextlib import contextmanager from dlt.common import pendulum, logger from dlt.common.json import json @@ -41,6 +43,7 @@ PreparedTableSchema, StateInfo, StorageSchemaInfo, + SupportsReadableDataset, WithStateSync, DestinationClientConfiguration, DestinationClientDwhConfiguration, @@ -51,7 +54,9 @@ JobClientBase, HasFollowupJobs, CredentialsConfiguration, + SupportsReadableRelation, ) +from dlt.destinations.dataset import ReadableDBAPIDataset from dlt.destinations.exceptions import DatabaseUndefinedRelation from dlt.destinations.job_impl import ( @@ -59,7 +64,7 @@ ) from dlt.destinations.sql_jobs import SqlMergeFollowupJob, SqlStagingCopyFollowupJob from dlt.destinations.typing import TNativeConn -from dlt.destinations.sql_client import SqlClientBase +from dlt.destinations.sql_client import SqlClientBase, WithSqlClient from dlt.destinations.utils import ( get_pipeline_state_query_columns, info_schema_null_to_bool, @@ -123,7 +128,7 @@ def __init__( self._bucket_path = ReferenceFollowupJobRequest.resolve_reference(file_path) -class SqlJobClientBase(JobClientBase, WithStateSync): +class SqlJobClientBase(WithSqlClient, JobClientBase, WithStateSync): INFO_TABLES_QUERY_THRESHOLD: ClassVar[int] = 1000 """Fallback to querying all tables in the information schema if checking more than threshold""" @@ -153,6 +158,14 @@ def __init__( assert isinstance(config, DestinationClientDwhConfiguration) self.config: DestinationClientDwhConfiguration = config + @property + def sql_client(self) -> SqlClientBase[TNativeConn]: + return self._sql_client + + @sql_client.setter + def sql_client(self, client: SqlClientBase[TNativeConn]) -> None: + self._sql_client = client + def drop_storage(self) -> None: self.sql_client.drop_dataset() diff --git a/dlt/destinations/sql_client.py b/dlt/destinations/sql_client.py index 96f18cea3d..51f3211f1b 100644 --- a/dlt/destinations/sql_client.py +++ b/dlt/destinations/sql_client.py @@ -16,18 +16,29 @@ Type, AnyStr, List, + Generator, TypedDict, + cast, ) from dlt.common.typing import TFun +from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.utils import concat_strings_with_limit +from dlt.common.destination.reference import JobClientBase from dlt.destinations.exceptions import ( DestinationConnectionError, LoadClientNotConnected, ) -from dlt.destinations.typing import DBApi, TNativeConn, DBApiCursor, DataFrame, DBTransaction +from dlt.destinations.typing import ( + DBApi, + TNativeConn, + DataFrame, + DBTransaction, + ArrowTable, +) +from dlt.common.destination.reference import DBApiCursor class TJobQueryTags(TypedDict): @@ -292,6 +303,20 @@ def _truncate_table_sql(self, qualified_table_name: str) -> str: return f"DELETE FROM {qualified_table_name} WHERE 1=1;" +class WithSqlClient(JobClientBase): + @property + @abstractmethod + def sql_client(self) -> SqlClientBase[TNativeConn]: ... + + def __enter__(self) -> "WithSqlClient": + return self + + def __exit__( + self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType + ) -> None: + pass + + class DBApiCursorImpl(DBApiCursor): """A DBApi Cursor wrapper with dataframes reading functionality""" @@ -304,11 +329,20 @@ def __init__(self, curr: DBApiCursor) -> None: self.fetchmany = curr.fetchmany # type: ignore self.fetchone = curr.fetchone # type: ignore + self.set_default_schema_columns() + def __getattr__(self, name: str) -> Any: return getattr(self.native_cursor, name) def _get_columns(self) -> List[str]: - return [c[0] for c in self.native_cursor.description] + if self.native_cursor.description: + return [c[0] for c in self.native_cursor.description] + return [] + + def set_default_schema_columns(self) -> None: + self.schema_columns = cast( + TTableSchemaColumns, {c: {"name": c, "nullable": True} for c in self._get_columns()} + ) def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]: """Fetches results as data frame in full or in specified chunks. @@ -316,18 +350,55 @@ def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]: May use native pandas/arrow reader if available. Depending on the native implementation chunk size may vary. """ - from dlt.common.libs.pandas_sql import _wrap_result + try: + return next(self.iter_df(chunk_size=chunk_size)) + except StopIteration: + return None - columns = self._get_columns() - if chunk_size is None: - return _wrap_result(self.native_cursor.fetchall(), columns, **kwargs) - else: - df = _wrap_result(self.native_cursor.fetchmany(chunk_size), columns, **kwargs) - # if no rows return None - if df.shape[0] == 0: - return None - else: - return df + def arrow(self, chunk_size: int = None, **kwargs: Any) -> Optional[ArrowTable]: + """Fetches results as data frame in full or in specified chunks. + + May use native pandas/arrow reader if available. Depending on + the native implementation chunk size may vary. + """ + try: + return next(self.iter_arrow(chunk_size=chunk_size)) + except StopIteration: + return None + + def iter_fetch(self, chunk_size: int) -> Generator[List[Tuple[Any, ...]], Any, Any]: + while True: + if not (result := self.fetchmany(chunk_size)): + return + yield result + + def iter_df(self, chunk_size: int) -> Generator[DataFrame, None, None]: + """Default implementation converts arrow to df""" + from dlt.common.libs.pandas import pandas as pd + + for table in self.iter_arrow(chunk_size=chunk_size): + # NOTE: we go via arrow table, types are created for arrow is columns are known + # https://github.com/apache/arrow/issues/38644 for reference on types_mapper + yield table.to_pandas() + + def iter_arrow(self, chunk_size: int) -> Generator[ArrowTable, None, None]: + """Default implementation converts query result to arrow table""" + from dlt.common.libs.pyarrow import row_tuples_to_arrow + from dlt.common.configuration.container import Container + + # get capabilities of possibly currently active pipeline + caps = ( + Container().get(DestinationCapabilitiesContext) + or DestinationCapabilitiesContext.generic_capabilities() + ) + + if not chunk_size: + result = self.fetchall() + yield row_tuples_to_arrow(result, caps, self.schema_columns, tz="UTC") + return + + for result in self.iter_fetch(chunk_size=chunk_size): + yield row_tuples_to_arrow(result, caps, self.schema_columns, tz="UTC") def raise_database_error(f: TFun) -> TFun: diff --git a/dlt/destinations/typing.py b/dlt/destinations/typing.py index 99ffed01fd..c809bf3230 100644 --- a/dlt/destinations/typing.py +++ b/dlt/destinations/typing.py @@ -1,17 +1,22 @@ -from typing import Any, AnyStr, List, Type, Optional, Protocol, Tuple, TypeVar +from typing import Any, AnyStr, List, Type, Optional, Protocol, Tuple, TypeVar, Generator + + +# native connection +TNativeConn = TypeVar("TNativeConn", bound=Any) try: from pandas import DataFrame except ImportError: DataFrame: Type[Any] = None # type: ignore -# native connection -TNativeConn = TypeVar("TNativeConn", bound=Any) +try: + from pyarrow import Table as ArrowTable +except ImportError: + ArrowTable: Type[Any] = None # type: ignore class DBTransaction(Protocol): def commit_transaction(self) -> None: ... - def rollback_transaction(self) -> None: ... @@ -19,34 +24,3 @@ class DBApi(Protocol): threadsafety: int apilevel: str paramstyle: str - - -class DBApiCursor(Protocol): - """Protocol for DBAPI cursor""" - - description: Tuple[Any, ...] - - native_cursor: "DBApiCursor" - """Cursor implementation native to current destination""" - - def execute(self, query: AnyStr, *args: Any, **kwargs: Any) -> None: ... - def fetchall(self) -> List[Tuple[Any, ...]]: ... - def fetchmany(self, size: int = ...) -> List[Tuple[Any, ...]]: ... - def fetchone(self) -> Optional[Tuple[Any, ...]]: ... - def close(self) -> None: ... - - def df(self, chunk_size: int = None, **kwargs: None) -> Optional[DataFrame]: - """Fetches the results as data frame. For large queries the results may be chunked - - Fetches the results into a data frame. The default implementation uses helpers in `pandas.io.sql` to generate Pandas data frame. - This function will try to use native data frame generation for particular destination. For `BigQuery`: `QueryJob.to_dataframe` is used. - For `duckdb`: `DuckDBPyConnection.df' - - Args: - chunk_size (int, optional): Will chunk the results into several data frames. Defaults to None - **kwargs (Any): Additional parameters which will be passed to native data frame generation function. - - Returns: - Optional[DataFrame]: A data frame with query results. If chunk_size > 0, None will be returned if there is no more data in results - """ - ... diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index 5df165adb7..59cb1ff20b 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -3,11 +3,10 @@ from types import ModuleType from functools import wraps from typing import ( - TYPE_CHECKING, Any, - Awaitable, Callable, ClassVar, + Dict, Iterator, List, Literal, @@ -18,7 +17,7 @@ cast, overload, ) -from typing_extensions import TypeVar +from typing_extensions import TypeVar, Self from dlt.common.configuration import with_config, get_fun_spec, known_sections, configspec from dlt.common.configuration.container import Container @@ -31,7 +30,6 @@ from dlt.common.pipeline import PipelineContext from dlt.common.reflection.spec import spec_from_signature from dlt.common.schema.utils import DEFAULT_WRITE_DISPOSITION -from dlt.common.source import _SOURCES, SourceInfo from dlt.common.schema.schema import Schema from dlt.common.schema.typing import ( TColumnNames, @@ -63,7 +61,13 @@ CurrentSourceSchemaNotAvailable, ) from dlt.extract.items import TTableHintTemplate -from dlt.extract.source import DltSource +from dlt.extract.source import ( + DltSource, + SourceReference, + SourceFactory, + TDltSourceImpl, + TSourceFunParams, +) from dlt.extract.resource import DltResource, TUnboundDltResource, TDltResourceImpl @@ -85,9 +89,223 @@ class SourceInjectableContext(ContainerInjectableContext): can_create_default: ClassVar[bool] = False -TSourceFunParams = ParamSpec("TSourceFunParams") +class _DltSingleSource(DltSource): + """Used to register standalone (non-inner) resources""" + + @property + def single_resource(self) -> DltResource: + return list(self.resources.values())[0] + + +class DltSourceFactoryWrapper(SourceFactory[TSourceFunParams, TDltSourceImpl]): + def __init__( + self, + ) -> None: + """Creates a wrapper that is returned by @source decorator. It preserves the decorated function when called and + allows to change the decorator arguments at runtime. Changing the `name` and `section` creates a clone of the source + with different name and taking the configuration from a different keys. + + This wrapper registers the source under `section`.`name` type in SourceReference registry, using the original + `section` (which corresponds to module name) and `name` (which corresponds to source function name). + """ + self._f: AnyFun = None + self._ref: SourceReference = None + self._deco_f: Callable[..., TDltSourceImpl] = None + + self.name: str = None + self.section: str = None + self.max_table_nesting: int = None + self.root_key: bool = False + self.schema: Schema = None + self.schema_contract: TSchemaContract = None + self.spec: Type[BaseConfiguration] = None + self.parallelized: bool = None + self._impl_cls: Type[TDltSourceImpl] = DltSource # type: ignore[assignment] + + def with_args( + self, + *, + name: str = None, + section: str = None, + max_table_nesting: int = None, + root_key: bool = None, + schema: Schema = None, + schema_contract: TSchemaContract = None, + spec: Type[BaseConfiguration] = None, + parallelized: bool = None, + _impl_cls: Type[TDltSourceImpl] = None, + ) -> Self: + """Overrides default arguments that will be used to create DltSource instance when this wrapper is called. This method + clones this wrapper. + """ + # if source function not set, apply args in place + ovr = self.__class__() if self._f else self + + if name is not None: + ovr.name = name + else: + ovr.name = self.name + if section is not None: + ovr.section = section + else: + ovr.section = self.section + if max_table_nesting is not None: + ovr.max_table_nesting = max_table_nesting + else: + ovr.max_table_nesting = self.max_table_nesting + if root_key is not None: + ovr.root_key = root_key + else: + ovr.root_key = self.root_key + ovr.schema = schema or self.schema + if schema_contract is not None: + ovr.schema_contract = schema_contract + else: + ovr.schema_contract = self.schema_contract + ovr.spec = spec or self.spec + if parallelized is not None: + ovr.parallelized = parallelized + else: + ovr.parallelized = self.parallelized + ovr._impl_cls = _impl_cls or self._impl_cls + + # also remember original source function + ovr._f = self._f + # try to bind _f + ovr.wrap() + return ovr + + def __call__(self, *args: Any, **kwargs: Any) -> TDltSourceImpl: + assert self._deco_f, f"Attempt to call source function on {self.name} before bind" + # if source impl is a single resource source + if issubclass(self._impl_cls, _DltSingleSource): + # call special source function that will create renamed resource + source = self._deco_f(self.name, self.section, args, kwargs) + assert isinstance(source, _DltSingleSource) + # set source section to empty to not interfere with resource sections, same thing we do in extract + source.section = "" + # apply selected settings directly to resource + resource = source.single_resource + if self.max_table_nesting is not None: + resource.max_table_nesting = self.max_table_nesting + if self.schema_contract is not None: + resource.apply_hints(schema_contract=self.schema_contract) + else: + source = self._deco_f(*args, **kwargs) + return source + + def bind(self, f: AnyFun) -> Self: + """Binds wrapper to the original source function and registers the source reference. This method is called only once by the decorator""" + self._f = f + self._ref = self.wrap() + SourceReference.register(self._ref) + return self + + def wrap(self) -> SourceReference: + """Wrap the original source function using _deco.""" + if not self._f: + return None + if hasattr(self._f, "__qualname__"): + self.__qualname__ = self._f.__qualname__ + return self._wrap(self._f) + + def _wrap(self, f: AnyFun) -> SourceReference: + """Wraps source function `f` in configuration injector.""" + if not callable(f) or isinstance(f, DltResource): + raise SourceNotAFunction(self.name or "", f, type(f)) + + if inspect.isclass(f): + raise SourceIsAClassTypeError(self.name or "", f) + + # source name is passed directly or taken from decorated function name + effective_name = self.name or get_callable_name(f) + + if self.schema and self.name and self.name != self.schema.name: + raise ExplicitSourceNameInvalid(self.name, self.schema.name) + + # wrap source extraction function in configuration with section + func_module = inspect.getmodule(f) + source_section = self.section or _get_source_section_name(func_module) + # use effective_name which is explicit source name or callable name to represent third element in source config path + source_sections = (known_sections.SOURCES, source_section, effective_name) + conf_f = with_config(f, spec=self.spec, sections=source_sections) + + def _eval_rv(_rv: Any, schema_copy: Schema) -> TDltSourceImpl: + """Evaluates return value from the source function or coroutine""" + if _rv is None: + raise SourceDataIsNone(schema_copy.name) + # if generator, consume it immediately + if inspect.isgenerator(_rv): + _rv = list(_rv) + + # convert to source + s = self._impl_cls.from_data(schema_copy, source_section, _rv) + # apply hints + if self.max_table_nesting is not None: + s.max_table_nesting = self.max_table_nesting + s.schema_contract = self.schema_contract + # enable root propagation + s.root_key = self.root_key + # parallelize resources + if self.parallelized: + s.parallelize() + return s + + def _make_schema() -> Schema: + if not self.schema: + # load the schema from file with name_schema.yaml/json from the same directory, the callable resides OR create new default schema + return _maybe_load_schema_for_callable(f, effective_name) or Schema(effective_name) + else: + # clone the schema passed to decorator, update normalizers, remove processing hints + # NOTE: source may be called several times in many different settings + return self.schema.clone(update_normalizers=True, remove_processing_hints=True) + + @wraps(conf_f) + def _wrap(*args: Any, **kwargs: Any) -> TDltSourceImpl: + """Wrap a regular function, injection context must be a part of the wrap""" + schema_copy = _make_schema() + with Container().injectable_context(SourceSchemaInjectableContext(schema_copy)): + # configurations will be accessed in this section in the source + proxy = Container()[PipelineContext] + pipeline_name = None if not proxy.is_active() else proxy.pipeline().pipeline_name + with inject_section( + ConfigSectionContext( + pipeline_name=pipeline_name, + sections=source_sections, + source_state_key=schema_copy.name, + ) + ): + rv = conf_f(*args, **kwargs) + return _eval_rv(rv, schema_copy) + + @wraps(conf_f) + async def _wrap_coro(*args: Any, **kwargs: Any) -> TDltSourceImpl: + """In case of co-routine we must wrap the whole injection context in awaitable, + there's no easy way to avoid some code duplication + """ + schema_copy = _make_schema() + with Container().injectable_context(SourceSchemaInjectableContext(schema_copy)): + # configurations will be accessed in this section in the source + proxy = Container()[PipelineContext] + pipeline_name = None if not proxy.is_active() else proxy.pipeline().pipeline_name + with inject_section( + ConfigSectionContext( + pipeline_name=pipeline_name, + sections=source_sections, + source_state_key=schema_copy.name, + ) + ): + rv = await conf_f(*args, **kwargs) + return _eval_rv(rv, schema_copy) + + # get spec for wrapped function + SPEC = get_fun_spec(conf_f) + # get correct wrapper + self._deco_f = _wrap_coro if inspect.iscoroutinefunction(inspect.unwrap(f)) else _wrap # type: ignore[assignment] + return SourceReference(SPEC, self, func_module, source_section, effective_name) # type: ignore[arg-type] + + TResourceFunParams = ParamSpec("TResourceFunParams") -TDltSourceImpl = TypeVar("TDltSourceImpl", bound=DltSource, default=DltSource) @overload @@ -101,8 +319,9 @@ def source( schema: Schema = None, schema_contract: TSchemaContract = None, spec: Type[BaseConfiguration] = None, + parallelized: bool = False, _impl_cls: Type[TDltSourceImpl] = DltSource, # type: ignore[assignment] -) -> Callable[TSourceFunParams, TDltSourceImpl]: ... +) -> SourceFactory[TSourceFunParams, TDltSourceImpl]: ... @overload @@ -116,8 +335,11 @@ def source( schema: Schema = None, schema_contract: TSchemaContract = None, spec: Type[BaseConfiguration] = None, + parallelized: bool = False, _impl_cls: Type[TDltSourceImpl] = DltSource, # type: ignore[assignment] -) -> Callable[[Callable[TSourceFunParams, Any]], Callable[TSourceFunParams, TDltSourceImpl]]: ... +) -> Callable[ + [Callable[TSourceFunParams, Any]], SourceFactory[TSourceFunParams, TDltSourceImpl] +]: ... def source( @@ -130,11 +352,12 @@ def source( schema: Schema = None, schema_contract: TSchemaContract = None, spec: Type[BaseConfiguration] = None, + parallelized: bool = False, _impl_cls: Type[TDltSourceImpl] = DltSource, # type: ignore[assignment] ) -> Any: """A decorator that transforms a function returning one or more `dlt resources` into a `dlt source` in order to load it with `dlt`. - #### Note: + Note: A `dlt source` is a logical grouping of resources that are often extracted and loaded together. A source is associated with a schema, which describes the structure of the loaded data and provides instructions how to load it. Such schema contains table schemas that describe the structure of the data coming from the resources. @@ -151,7 +374,7 @@ def source( Here `username` is a required, explicit python argument, `chess_url` is a required argument, that if not explicitly passed will be taken from configuration ie. `config.toml`, `api_secret` is a required argument, that if not explicitly passed will be taken from dlt secrets ie. `secrets.toml`. See https://dlthub.com/docs/general-usage/credentials for details. - #### Args: + Args: func: A function that returns a dlt resource or a list of those or a list of any data items that can be loaded by `dlt`. name (str, optional): A name of the source which is also the name of the associated schema. If not present, the function name will be used. @@ -168,122 +391,40 @@ def source( spec (Type[BaseConfiguration], optional): A specification of configuration and secret values required by the source. + parallelized (bool, optional): If `True`, resource generators will be extracted in parallel with other resources. + Transformers that return items are also parallelized. Non-eligible resources are ignored. Defaults to `False` which preserves resource settings. + _impl_cls (Type[TDltSourceImpl], optional): A custom implementation of DltSource, may be also used to providing just a typing stub Returns: - `DltSource` instance + Wrapped decorated source function, see SourceFactory reference for additional wrapper capabilities """ if name and schema: raise ArgumentsOverloadException( "'name' has no effect when `schema` argument is present", source.__name__ ) - def decorator( - f: Callable[TSourceFunParams, Any] - ) -> Callable[TSourceFunParams, Union[Awaitable[TDltSourceImpl], TDltSourceImpl]]: - nonlocal schema, name - - if not callable(f) or isinstance(f, DltResource): - raise SourceNotAFunction(name or "", f, type(f)) - - if inspect.isclass(f): - raise SourceIsAClassTypeError(name or "", f) - - # source name is passed directly or taken from decorated function name - effective_name = name or get_callable_name(f) - - if schema and name and name != schema.name: - raise ExplicitSourceNameInvalid(name, schema.name) - - # wrap source extraction function in configuration with section - func_module = inspect.getmodule(f) - source_section = section or _get_source_section_name(func_module) - # use effective_name which is explicit source name or callable name to represent third element in source config path - source_sections = (known_sections.SOURCES, source_section, effective_name) - conf_f = with_config(f, spec=spec, sections=source_sections) - - def _eval_rv(_rv: Any, schema_copy: Schema) -> TDltSourceImpl: - """Evaluates return value from the source function or coroutine""" - if _rv is None: - raise SourceDataIsNone(schema_copy.name) - # if generator, consume it immediately - if inspect.isgenerator(_rv): - _rv = list(_rv) - - # convert to source - s = _impl_cls.from_data(schema_copy, source_section, _rv) - # apply hints - if max_table_nesting is not None: - s.max_table_nesting = max_table_nesting - s.schema_contract = schema_contract - # enable root propagation - s.root_key = root_key - return s - - def _make_schema() -> Schema: - if not schema: - # load the schema from file with name_schema.yaml/json from the same directory, the callable resides OR create new default schema - return _maybe_load_schema_for_callable(f, effective_name) or Schema(effective_name) - else: - # clone the schema passed to decorator, update normalizers, remove processing hints - # NOTE: source may be called several times in many different settings - return schema.clone(update_normalizers=True, remove_processing_hints=True) - - @wraps(conf_f) - def _wrap(*args: Any, **kwargs: Any) -> TDltSourceImpl: - """Wrap a regular function, injection context must be a part of the wrap""" - schema_copy = _make_schema() - with Container().injectable_context(SourceSchemaInjectableContext(schema_copy)): - # configurations will be accessed in this section in the source - proxy = Container()[PipelineContext] - pipeline_name = None if not proxy.is_active() else proxy.pipeline().pipeline_name - with inject_section( - ConfigSectionContext( - pipeline_name=pipeline_name, - sections=source_sections, - source_state_key=schema_copy.name, - ) - ): - rv = conf_f(*args, **kwargs) - return _eval_rv(rv, schema_copy) - - @wraps(conf_f) - async def _wrap_coro(*args: Any, **kwargs: Any) -> TDltSourceImpl: - """In case of co-routine we must wrap the whole injection context in awaitable, - there's no easy way to avoid some code duplication - """ - schema_copy = _make_schema() - with Container().injectable_context(SourceSchemaInjectableContext(schema_copy)): - # configurations will be accessed in this section in the source - proxy = Container()[PipelineContext] - pipeline_name = None if not proxy.is_active() else proxy.pipeline().pipeline_name - with inject_section( - ConfigSectionContext( - pipeline_name=pipeline_name, - sections=source_sections, - source_state_key=schema_copy.name, - ) - ): - rv = await conf_f(*args, **kwargs) - return _eval_rv(rv, schema_copy) - - # get spec for wrapped function - SPEC = get_fun_spec(conf_f) - # get correct wrapper - wrapper: AnyFun = _wrap_coro if inspect.iscoroutinefunction(inspect.unwrap(f)) else _wrap # type: ignore[assignment] - # store the source information - _SOURCES[_wrap.__qualname__] = SourceInfo(SPEC, wrapper, func_module) - if inspect.iscoroutinefunction(inspect.unwrap(f)): - return _wrap_coro - else: - return _wrap + source_wrapper = ( + DltSourceFactoryWrapper[Any, TDltSourceImpl]() + .with_args( + name=name, + section=section, + max_table_nesting=max_table_nesting, + root_key=root_key, + schema=schema, + schema_contract=schema_contract, + spec=spec, + parallelized=parallelized, + _impl_cls=_impl_cls, + ) + .bind + ) if func is None: # we're called with parens. - return decorator - + return source_wrapper # we're called as @source without parens. - return decorator(func) + return source_wrapper(func) @overload @@ -414,21 +555,21 @@ def resource( See https://dlthub.com/docs/general-usage/credentials for details. Note that if decorated function is an inner function, passing of the credentials will be disabled. - #### Args: + Args: data (Callable | Any, optional): a function to be decorated or a data compatible with `dlt` `run`. name (str, optional): A name of the resource that by default also becomes the name of the table to which the data is loaded. - If not present, the name of the decorated function will be used. + If not present, the name of the decorated function will be used. table_name (TTableHintTemplate[str], optional): An table name, if different from `name`. - This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. + This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. max_table_nesting (int, optional): A schema hint that sets the maximum depth of nested table above which the remaining nodes are loaded as structs or JSON. write_disposition (TTableHintTemplate[TWriteDispositionConfig], optional): Controls how to write data to a table. Accepts a shorthand string literal or configuration dictionary. - Allowed shorthand string literals: `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". - Write behaviour can be further customized through a configuration dictionary. For example, to obtain an SCD2 table provide `write_disposition={"disposition": "merge", "strategy": "scd2"}`. - This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. + Allowed shorthand string literals: `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". + Write behaviour can be further customized through a configuration dictionary. For example, to obtain an SCD2 table provide `write_disposition={"disposition": "merge", "strategy": "scd2"}`. + This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. columns (Sequence[TAnySchemaColumns], optional): A list, dict or pydantic model of column schemas. Typed dictionary describing column names, data types, write disposition and performance hints that gives you full control over the created table schema. @@ -436,18 +577,18 @@ def resource( When the argument is a pydantic model, the model will be used to validate the data yielded by the resource as well. primary_key (str | Sequence[str]): A column name or a list of column names that comprise a private key. Typically used with "merge" write disposition to deduplicate loaded data. - This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. + This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. merge_key (str | Sequence[str]): A column name or a list of column names that define a merge key. Typically used with "merge" write disposition to remove overlapping data ranges ie. to keep a single record for a given day. - This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. + This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. schema_contract (TSchemaContract, optional): Schema contract settings that will be applied to all resources of this source (if not overridden in the resource itself) table_format (Literal["iceberg", "delta"], optional): Defines the storage format of the table. Currently only "iceberg" is supported on Athena, and "delta" on the filesystem. - Other destinations ignore this hint. + Other destinations ignore this hint. file_format (Literal["preferred", ...], optional): Format of the file in which resource data is stored. Useful when importing external files. Use `preferred` to force - a file format that is preferred by the destination used. This setting superseded the `load_file_format` passed to pipeline `run` method. + a file format that is preferred by the destination used. This setting superseded the `load_file_format` passed to pipeline `run` method. selected (bool, optional): When `True` `dlt pipeline` will extract and load this resource, if `False`, the resource will be ignored. @@ -457,7 +598,8 @@ def resource( data_from (TUnboundDltResource, optional): Allows to pipe data from one resource to another to build multi-step pipelines. - parallelized (bool, optional): If `True`, the resource generator will be extracted in parallel with other resources. Defaults to `False`. + parallelized (bool, optional): If `True`, the resource generator will be extracted in parallel with other resources. + Transformers that return items are also parallelized. Defaults to `False`. _impl_cls (Type[TDltResourceImpl], optional): A custom implementation of DltResource, may be also used to providing just a typing stub @@ -500,6 +642,33 @@ def make_resource(_name: str, _section: str, _data: Any) -> TDltResourceImpl: return resource.parallelize() return resource + def wrap_standalone( + _name: str, _section: str, f: AnyFun + ) -> Callable[TResourceFunParams, TDltResourceImpl]: + if not standalone: + # we return a DltResource that is callable and returns dlt resource when called + # so it should match the signature + return make_resource(_name, _section, f) # type: ignore[return-value] + + @wraps(f) + def _wrap(*args: Any, **kwargs: Any) -> TDltResourceImpl: + skip_args = 1 if data_from else 0 + _, mod_sig, bound_args = simulate_func_call(f, skip_args, *args, **kwargs) + actual_resource_name = name(bound_args.arguments) if callable(name) else _name + r = make_resource(actual_resource_name, _section, f) + # wrap the standalone resource + data_ = r._pipe.bind_gen(*args, **kwargs) + if isinstance(data_, DltResource): + # we allow an edge case: resource can return another resource + r = data_ # type: ignore[assignment] + # consider transformer arguments bound + r._args_bound = True + # keep explicit args passed + r._set_explicit_args(f, mod_sig, *args, **kwargs) + return r + + return _wrap + def decorator( f: Callable[TResourceFunParams, Any] ) -> Callable[TResourceFunParams, TDltResourceImpl]: @@ -536,33 +705,38 @@ def decorator( # assign spec to "f" set_fun_spec(f, SPEC) - # store the non-inner resource information + # register non inner resources as source with single resource in it if not is_inner_resource: - _SOURCES[f.__qualname__] = SourceInfo(SPEC, f, func_module) - - if not standalone: - # we return a DltResource that is callable and returns dlt resource when called - # so it should match the signature - return make_resource(resource_name, source_section, f) # type: ignore[return-value] + # a source function for the source wrapper, args that go to source are forwarded + # to a single resource within + def _source( + name_ovr: str, section_ovr: str, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> TDltResourceImpl: + return wrap_standalone(name_ovr or resource_name, section_ovr or source_section, f)( + *args, **kwargs + ) - @wraps(f) - def _wrap(*args: Any, **kwargs: Any) -> TDltResourceImpl: - skip_args = 1 if data_from else 0 - _, mod_sig, bound_args = simulate_func_call(f, skip_args, *args, **kwargs) - actual_resource_name = name(bound_args.arguments) if callable(name) else resource_name - r = make_resource(actual_resource_name, source_section, f) - # wrap the standalone resource - data_ = r._pipe.bind_gen(*args, **kwargs) - if isinstance(data_, DltResource): - # we allow an edge case: resource can return another resource - r = data_ # type: ignore[assignment] - # consider transformer arguments bound - r._args_bound = True - # keep explicit args passed - r._set_explicit_args(f, mod_sig, *args, **kwargs) - return r + # make the source module same as original resource + _source.__qualname__ = f.__qualname__ + _source.__module__ = f.__module__ + # setup our special single resource source + factory = ( + DltSourceFactoryWrapper[Any, DltSource]() + .with_args( + name=resource_name, + section=source_section, + spec=BaseConfiguration, + _impl_cls=_DltSingleSource, + ) + .bind(_source) + ) + # remove name and section overrides from the wrapper so resource is not unnecessarily renamed + factory.name = None + factory.section = None + # mod the reference to keep the right spec + factory._ref.SPEC = SPEC - return _wrap + return wrap_standalone(resource_name, source_section, f) # if data is callable or none use decorator if data is None: @@ -717,37 +891,37 @@ def transformer( >>> list(players("GM") | player_profile) Args: - f: (Callable): a function taking minimum one argument of TDataItems type which will receive data yielded from `data_from` resource. + f (Callable): a function taking minimum one argument of TDataItems type which will receive data yielded from `data_from` resource. data_from (Callable | Any, optional): a resource that will send data to the decorated function `f` name (str, optional): A name of the resource that by default also becomes the name of the table to which the data is loaded. - If not present, the name of the decorated function will be used. + If not present, the name of the decorated function will be used. table_name (TTableHintTemplate[str], optional): An table name, if different from `name`. - This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. + This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. max_table_nesting (int, optional): A schema hint that sets the maximum depth of nested table above which the remaining nodes are loaded as structs or JSON. write_disposition (Literal["skip", "append", "replace", "merge"], optional): Controls how to write data to a table. `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". - This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. + This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. columns (Sequence[TAnySchemaColumns], optional): A list, dict or pydantic model of column schemas. Typed dictionary describing column names, data types, write disposition and performance hints that gives you full control over the created table schema. - This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. + This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. primary_key (str | Sequence[str]): A column name or a list of column names that comprise a private key. Typically used with "merge" write disposition to deduplicate loaded data. - This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. + This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. merge_key (str | Sequence[str]): A column name or a list of column names that define a merge key. Typically used with "merge" write disposition to remove overlapping data ranges ie. to keep a single record for a given day. - This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. + This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. schema_contract (TSchemaContract, optional): Schema contract settings that will be applied to all resources of this source (if not overridden in the resource itself) table_format (Literal["iceberg", "delta"], optional): Defines the storage format of the table. Currently only "iceberg" is supported on Athena, and "delta" on the filesystem. - Other destinations ignore this hint. + Other destinations ignore this hint. file_format (Literal["preferred", ...], optional): Format of the file in which resource data is stored. Useful when importing external files. Use `preferred` to force - a file format that is preferred by the destination used. This setting superseded the `load_file_format` passed to pipeline `run` method. + a file format that is preferred by the destination used. This setting superseded the `load_file_format` passed to pipeline `run` method. selected (bool, optional): When `True` `dlt pipeline` will extract and load this resource, if `False`, the resource will be ignored. @@ -756,6 +930,13 @@ def transformer( standalone (bool, optional): Returns a wrapped decorated function that creates DltResource instance. Must be called before use. Cannot be part of a source. _impl_cls (Type[TDltResourceImpl], optional): A custom implementation of DltResource, may be also used to providing just a typing stub + + Raises: + ResourceNameMissing: indicates that name of the resource cannot be inferred from the `data` being passed. + InvalidResourceDataType: indicates that the `data` argument cannot be converted into `dlt resource` + + Returns: + TDltResourceImpl instance which may be loaded, iterated or combined with other resources into a pipeline. """ if isinstance(f, DltResource): raise ValueError( @@ -801,7 +982,7 @@ def _maybe_load_schema_for_callable(f: AnyFun, name: str) -> Optional[Schema]: def _get_source_section_name(m: ModuleType) -> str: - """Gets the source section name (as in SOURCES
tuple) from __source_name__ of the module `m` or from its name""" + """Gets the source section name (as in SOURCES (section, name) tuple) from __source_name__ of the module `m` or from its name""" if m is None: return None if hasattr(m, "__source_name__"): diff --git a/dlt/extract/exceptions.py b/dlt/extract/exceptions.py index c3a20e72e5..f4d2b1f302 100644 --- a/dlt/extract/exceptions.py +++ b/dlt/extract/exceptions.py @@ -1,5 +1,5 @@ from inspect import Signature, isgenerator, isgeneratorfunction, unwrap -from typing import Any, Set, Type +from typing import Any, Sequence, Set, Type from dlt.common.exceptions import DltException from dlt.common.utils import get_callable_name @@ -401,11 +401,28 @@ def __init__(self, source_name: str, schema_name: str) -> None: self.source_name = source_name self.schema_name = schema_name super().__init__( - f"Your explicit source name {source_name} is not a valid schema name. Please use a" - f" valid schema name ie. '{schema_name}'." + f"Your explicit source name {source_name} does not match explicit schema name" + f" '{schema_name}'." ) +class UnknownSourceReference(DltSourceException): + def __init__(self, ref: Sequence[str]) -> None: + self.ref = ref + msg = ( + f"{ref} is not one of registered sources and could not be imported as module with" + " source function" + ) + super().__init__(msg) + + +# class InvalidDestinationReference(DestinationException): +# def __init__(self, destination_module: Any) -> None: +# self.destination_module = destination_module +# msg = f"Destination {destination_module} is not a valid destination module." +# super().__init__(msg) + + class IncrementalUnboundError(DltResourceException): def __init__(self, cursor_path: str) -> None: super().__init__( diff --git a/dlt/extract/pipe_iterator.py b/dlt/extract/pipe_iterator.py index 3a10f651c0..465040f9f4 100644 --- a/dlt/extract/pipe_iterator.py +++ b/dlt/extract/pipe_iterator.py @@ -24,7 +24,7 @@ ) from dlt.common.configuration.container import Container from dlt.common.exceptions import PipelineException -from dlt.common.source import unset_current_pipe_name, set_current_pipe_name +from dlt.common.pipeline import unset_current_pipe_name, set_current_pipe_name from dlt.common.utils import get_callable_name from dlt.extract.exceptions import ( diff --git a/dlt/extract/resource.py b/dlt/extract/resource.py index 55c0bd728f..c6ca1660f4 100644 --- a/dlt/extract/resource.py +++ b/dlt/extract/resource.py @@ -7,6 +7,7 @@ Callable, Iterable, Iterator, + Type, Union, Any, Optional, @@ -16,7 +17,7 @@ from dlt.common import logger from dlt.common.configuration.inject import get_fun_spec, with_config from dlt.common.configuration.resolve import inject_section -from dlt.common.configuration.specs import known_sections +from dlt.common.configuration.specs import BaseConfiguration, known_sections from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.typing import AnyFun, DictStrAny, StrAny, TDataItem, TDataItems, NoneType from dlt.common.configuration.container import Container @@ -89,20 +90,25 @@ class DltResource(Iterable[TDataItem], DltResourceHints): """Name of the source that contains this instance of the source, set when added to DltResourcesDict""" section: str """A config section name""" + SPEC: Type[BaseConfiguration] + """A SPEC that defines signature of callable(parametrized) resource/transformer""" def __init__( self, pipe: Pipe, hints: TResourceHints, selected: bool, + *, section: str = None, args_bound: bool = False, + SPEC: Type[BaseConfiguration] = None, ) -> None: self.section = section self.selected = selected self._pipe = pipe self._args_bound = args_bound self._explicit_args: DictStrAny = None + self.SPEC = SPEC self.source_name = None super().__init__(hints) @@ -132,7 +138,8 @@ def from_data( return data # type: ignore[return-value] if isinstance(data, Pipe): - r_ = cls(data, hints, selected, section=section) + SPEC_ = None if data.is_empty else get_fun_spec(data.gen) # type: ignore[arg-type] + r_ = cls(data, hints, selected, section=section, SPEC=SPEC_) if inject_config: r_._inject_config() return r_ @@ -170,6 +177,7 @@ def from_data( selected, section=section, args_bound=not callable(data), + SPEC=get_fun_spec(data), ) if inject_config: r_._inject_config() @@ -647,6 +655,7 @@ def _clone( selected=self.selected, section=self.section, args_bound=self._args_bound, + SPEC=self.SPEC, ) # try to eject and then inject configuration and incremental wrapper when resource is cloned # this makes sure that a take config values from a right section and wrapper has a separated diff --git a/dlt/extract/source.py b/dlt/extract/source.py index 6e5d30b62f..df6f8fcc80 100644 --- a/dlt/extract/source.py +++ b/dlt/extract/source.py @@ -1,18 +1,27 @@ import contextlib from copy import copy +from importlib import import_module import makefun import inspect -from typing import Dict, Iterable, Iterator, List, Sequence, Tuple, Any -from typing_extensions import Self +from typing import Dict, Iterable, Iterator, List, Sequence, Tuple, Any, Generic +from typing_extensions import Self, Protocol, TypeVar +from types import ModuleType +from typing import Dict, Type, ClassVar +from dlt.common import logger from dlt.common.configuration.resolve import inject_section -from dlt.common.configuration.specs import known_sections +from dlt.common.configuration.specs import BaseConfiguration, known_sections from dlt.common.configuration.specs.config_section_context import ConfigSectionContext +from dlt.common.configuration.specs.pluggable_run_context import ( + PluggableRunContext, + SupportsRunContext, +) from dlt.common.normalizers.json.relational import DataItemNormalizer as RelationalNormalizer +from dlt.common.runtime.run_context import RunContext from dlt.common.schema import Schema from dlt.common.schema.typing import TColumnName, TSchemaContract from dlt.common.schema.utils import normalize_table_identifiers -from dlt.common.typing import StrAny, TDataItem +from dlt.common.typing import StrAny, TDataItem, ParamSpec from dlt.common.configuration.container import Container from dlt.common.pipeline import ( PipelineContext, @@ -26,13 +35,14 @@ from dlt.extract.items import TDecompositionStrategy from dlt.extract.pipe_iterator import ManagedPipeIterator from dlt.extract.pipe import Pipe -from dlt.extract.hints import DltResourceHints, make_hints +from dlt.extract.hints import make_hints from dlt.extract.resource import DltResource from dlt.extract.exceptions import ( DataItemRequiredForDynamicTableHints, ResourcesNotFoundError, DeletingResourcesNotSupported, InvalidParallelResourceDataType, + UnknownSourceReference, ) @@ -104,7 +114,7 @@ def selected_pipes(self) -> Sequence[Pipe]: return [r._pipe for r in self.values() if r.selected] def select(self, *resource_names: str) -> Dict[str, DltResource]: - # checks if keys are present + """Selects `resource_name` to be extracted, and unselects remaining resources.""" for name in resource_names: if name not in self: # if any key is missing, display the full info @@ -130,6 +140,14 @@ def add(self, *resources: DltResource) -> None: self._suppress_clone_on_setitem = False self._clone_new_pipes([r.name for r in resources]) + def detach(self, resource_name: str = None) -> DltResource: + """Clones `resource_name` (including parent resource pipes) and removes source contexts. + Defaults to the first resource in the source if `resource_name` is None. + """ + return (self[resource_name] if resource_name else list(self.values())[0])._clone( + with_parent=True + ) + def _clone_new_pipes(self, resource_names: Sequence[str]) -> None: # clone all new pipes and keep _, self._cloned_pairs = ManagedPipeIterator.clone_pipes(self._new_pipes, self._cloned_pairs) @@ -463,3 +481,130 @@ def __str__(self) -> str: info += " Note that, like any iterator, you can iterate the source only once." info += f"\ninstance id: {id(self)}" return info + + +TDltSourceImpl = TypeVar("TDltSourceImpl", bound=DltSource, default=DltSource) +TSourceFunParams = ParamSpec("TSourceFunParams") + + +class SourceFactory(Protocol, Generic[TSourceFunParams, TDltSourceImpl]): + def __call__( + self, *args: TSourceFunParams.args, **kwargs: TSourceFunParams.kwargs + ) -> TDltSourceImpl: + """Makes dlt source""" + pass + + def with_args( + self, + *, + name: str = None, + section: str = None, + max_table_nesting: int = None, + root_key: bool = False, + schema: Schema = None, + schema_contract: TSchemaContract = None, + spec: Type[BaseConfiguration] = None, + parallelized: bool = None, + _impl_cls: Type[TDltSourceImpl] = None, + ) -> Self: + """Overrides default decorator arguments that will be used to when DltSource instance and returns modified clone.""" + + +class SourceReference: + """Runtime information on the source/resource""" + + SOURCES: ClassVar[Dict[str, "SourceReference"]] = {} + """A registry of all the decorated sources and resources discovered when importing modules""" + + SPEC: Type[BaseConfiguration] + f: SourceFactory[Any, DltSource] + module: ModuleType + section: str + name: str + context: SupportsRunContext + + def __init__( + self, + SPEC: Type[BaseConfiguration], + f: SourceFactory[Any, DltSource], + module: ModuleType, + section: str, + name: str, + ) -> None: + self.SPEC = SPEC + self.f = f + self.module = module + self.section = section + self.name = name + self.context = Container()[PluggableRunContext].context + + @staticmethod + def to_fully_qualified_ref(ref: str) -> List[str]: + """Converts ref into fully qualified form, return one or more alternatives for shorthand notations. + Run context is injected in needed. + """ + ref_split = ref.split(".") + if len(ref_split) > 3: + return [] + # fully qualified path + if len(ref_split) == 3: + return [ref] + # context name is needed + refs = [] + run_names = [Container()[PluggableRunContext].context.name] + # always look in default run context + if run_names[0] != RunContext.CONTEXT_NAME: + run_names.append(RunContext.CONTEXT_NAME) + for run_name in run_names: + # expand shorthand notation + if len(ref_split) == 1: + refs.append(f"{run_name}.{ref}.{ref}") + else: + # for ref with two parts two options are possible + refs.extend([f"{run_name}.{ref}", f"{ref_split[0]}.{ref_split[1]}.{ref_split[1]}"]) + return refs + + @classmethod + def register(cls, ref_obj: "SourceReference") -> None: + ref = f"{ref_obj.context.name}.{ref_obj.section}.{ref_obj.name}" + if ref in cls.SOURCES: + logger.warning(f"A source with ref {ref} is already registered and will be overwritten") + cls.SOURCES[ref] = ref_obj + + @classmethod + def find(cls, ref: str) -> "SourceReference": + refs = cls.to_fully_qualified_ref(ref) + + for ref_ in refs: + if wrapper := cls.SOURCES.get(ref_): + return wrapper + raise KeyError(refs) + + @classmethod + def from_reference(cls, ref: str) -> SourceFactory[Any, DltSource]: + """Returns registered source factory or imports source module and returns a function. + Expands shorthand notation into section.name eg. "sql_database" is expanded into "sql_database.sql_database" + """ + refs = cls.to_fully_qualified_ref(ref) + + for ref_ in refs: + if wrapper := cls.SOURCES.get(ref_): + return wrapper.f + + # try to import module + if "." in ref: + try: + module_path, attr_name = ref.rsplit(".", 1) + dest_module = import_module(module_path) + factory = getattr(dest_module, attr_name) + if hasattr(factory, "with_args"): + return factory # type: ignore[no-any-return] + else: + raise ValueError(f"{attr_name} in {module_path} is of type {type(factory)}") + except ModuleNotFoundError: + # raise regular exception later + pass + except Exception as e: + raise UnknownSourceReference([ref]) from e + + raise UnknownSourceReference(refs or [ref]) diff --git a/dlt/helpers/airflow_helper.py b/dlt/helpers/airflow_helper.py index 9623e65850..eedbc44b65 100644 --- a/dlt/helpers/airflow_helper.py +++ b/dlt/helpers/airflow_helper.py @@ -396,7 +396,8 @@ def add_run( if not pipeline.pipelines_dir.startswith(os.environ[DLT_DATA_DIR]): raise ValueError( "Please create your Pipeline instance after AirflowTasks are created. The dlt" - " pipelines directory is not set correctly." + f" pipelines directory {pipeline.pipelines_dir} is not set correctly" + f" ({os.environ[DLT_DATA_DIR]} expected)." ) with self: diff --git a/dlt/helpers/dbt/__init__.py b/dlt/helpers/dbt/__init__.py index 08d6c23ed1..fc229ed1d0 100644 --- a/dlt/helpers/dbt/__init__.py +++ b/dlt/helpers/dbt/__init__.py @@ -6,7 +6,7 @@ from dlt.common.runners import Venv from dlt.common.destination.reference import DestinationClientDwhConfiguration from dlt.common.configuration.specs import CredentialsWithDefault -from dlt.common.typing import TSecretValue, ConfigValue +from dlt.common.typing import TSecretStrValue, ConfigValue from dlt.version import get_installed_requirement_string from dlt.helpers.dbt.runner import create_runner, DBTPackageRunner @@ -85,7 +85,7 @@ def package_runner( working_dir: str, package_location: str, package_repository_branch: str = ConfigValue, - package_repository_ssh_key: TSecretValue = TSecretValue(""), # noqa + package_repository_ssh_key: TSecretStrValue = "", auto_full_refresh_when_out_of_sync: bool = ConfigValue, ) -> DBTPackageRunner: default_profile_name = _default_profile_name(destination_configuration) diff --git a/dlt/helpers/dbt/configuration.py b/dlt/helpers/dbt/configuration.py index bec0bace3c..7f7042f745 100644 --- a/dlt/helpers/dbt/configuration.py +++ b/dlt/helpers/dbt/configuration.py @@ -1,7 +1,7 @@ import os from typing import Optional, Sequence -from dlt.common.typing import StrAny, TSecretValue +from dlt.common.typing import StrAny, TSecretStrValue from dlt.common.configuration import configspec from dlt.common.configuration.specs import BaseConfiguration, RunConfiguration @@ -10,9 +10,8 @@ class DBTRunnerConfiguration(BaseConfiguration): package_location: str = None package_repository_branch: Optional[str] = None - package_repository_ssh_key: Optional[TSecretValue] = TSecretValue( - "" - ) # the default is empty value which will disable custom SSH KEY + # the default is empty value which will disable custom SSH KEY + package_repository_ssh_key: Optional[TSecretStrValue] = "" package_profiles_dir: Optional[str] = None package_profile_name: Optional[str] = None auto_full_refresh_when_out_of_sync: bool = True @@ -27,4 +26,4 @@ def on_resolved(self) -> None: self.package_profiles_dir = os.path.dirname(__file__) if self.package_repository_ssh_key and self.package_repository_ssh_key[-1] != "\n": # must end with new line, otherwise won't be parsed by Crypto - self.package_repository_ssh_key = TSecretValue(self.package_repository_ssh_key + "\n") + self.package_repository_ssh_key = self.package_repository_ssh_key + "\n" diff --git a/dlt/helpers/dbt/runner.py b/dlt/helpers/dbt/runner.py index aa1c60901e..49c165b05d 100644 --- a/dlt/helpers/dbt/runner.py +++ b/dlt/helpers/dbt/runner.py @@ -11,7 +11,7 @@ from dlt.common.destination.reference import DestinationClientDwhConfiguration from dlt.common.runners import Venv from dlt.common.runners.stdout import iter_stdout_with_result -from dlt.common.typing import StrAny, TSecretValue +from dlt.common.typing import StrAny, TSecretStrValue from dlt.common.logger import is_json_logging from dlt.common.storages import FileStorage from dlt.common.git import git_custom_key_command, ensure_remote_head, force_clone_repo @@ -306,7 +306,7 @@ def create_runner( working_dir: str, package_location: str = dlt.config.value, package_repository_branch: Optional[str] = None, - package_repository_ssh_key: Optional[TSecretValue] = TSecretValue(""), # noqa + package_repository_ssh_key: Optional[TSecretStrValue] = "", package_profiles_dir: Optional[str] = None, package_profile_name: Optional[str] = None, auto_full_refresh_when_out_of_sync: bool = True, diff --git a/dlt/helpers/dbt_cloud/configuration.py b/dlt/helpers/dbt_cloud/configuration.py index 3c95d53431..9d567a4aff 100644 --- a/dlt/helpers/dbt_cloud/configuration.py +++ b/dlt/helpers/dbt_cloud/configuration.py @@ -2,12 +2,12 @@ from dlt.common.configuration import configspec from dlt.common.configuration.specs import BaseConfiguration -from dlt.common.typing import TSecretValue +from dlt.common.typing import TSecretStrValue @configspec class DBTCloudConfiguration(BaseConfiguration): - api_token: TSecretValue = TSecretValue("") + api_token: TSecretStrValue = "" account_id: Optional[str] = None job_id: Optional[str] = None diff --git a/dlt/helpers/streamlit_app/pages/dashboard.py b/dlt/helpers/streamlit_app/pages/dashboard.py index 941c0966f7..3584f929b1 100644 --- a/dlt/helpers/streamlit_app/pages/dashboard.py +++ b/dlt/helpers/streamlit_app/pages/dashboard.py @@ -17,7 +17,7 @@ def write_data_explorer_page( ) -> None: """Writes Streamlit app page with a schema and live data preview. - #### Args: + Args: pipeline (Pipeline): Pipeline instance to use. schema_name (str, optional): Name of the schema to display. If None, default schema is used. example_query (str, optional): Example query to be displayed in the SQL Query box. diff --git a/dlt/helpers/streamlit_app/pages/load_info.py b/dlt/helpers/streamlit_app/pages/load_info.py index ee13cf2531..699e786410 100644 --- a/dlt/helpers/streamlit_app/pages/load_info.py +++ b/dlt/helpers/streamlit_app/pages/load_info.py @@ -27,7 +27,7 @@ def write_load_status_page(pipeline: Pipeline) -> None: ) if loads_df is not None: - selected_load_id = st.selectbox("Select load id", loads_df) + selected_load_id: str = st.selectbox("Select load id", loads_df) schema = pipeline.default_schema st.markdown("**Number of loaded rows:**") diff --git a/dlt/pipeline/__init__.py b/dlt/pipeline/__init__.py index 7af965e989..e8344cfe0f 100644 --- a/dlt/pipeline/__init__.py +++ b/dlt/pipeline/__init__.py @@ -9,7 +9,7 @@ TSchemaContract, ) -from dlt.common.typing import TSecretValue, Any +from dlt.common.typing import TSecretStrValue, Any from dlt.common.configuration import with_config from dlt.common.configuration.container import Container from dlt.common.configuration.inject import get_orig_args, last_config @@ -28,7 +28,7 @@ def pipeline( pipeline_name: str = None, pipelines_dir: str = None, - pipeline_salt: TSecretValue = None, + pipeline_salt: TSecretStrValue = None, destination: TDestinationReferenceArg = None, staging: TDestinationReferenceArg = None, dataset_name: str = None, @@ -51,30 +51,30 @@ def pipeline( - Pipeline architecture and data loading steps: https://dlthub.com/docs/reference - List of supported destinations: https://dlthub.com/docs/dlt-ecosystem/destinations - #### Args: + Args: pipeline_name (str, optional): A name of the pipeline that will be used to identify it in monitoring events and to restore its state and data schemas on subsequent runs. - Defaults to the file name of pipeline script with `dlt_` prefix added. + Defaults to the file name of pipeline script with `dlt_` prefix added. pipelines_dir (str, optional): A working directory in which pipeline state and temporary files will be stored. Defaults to user home directory: `~/dlt/pipelines/`. - pipeline_salt (TSecretValue, optional): A random value used for deterministic hashing during data anonymization. Defaults to a value derived from the pipeline name. - Default value should not be used for any cryptographic purposes. + pipeline_salt (TSecretStrValue, optional): A random value used for deterministic hashing during data anonymization. Defaults to a value derived from the pipeline name. + Default value should not be used for any cryptographic purposes. destination (str | DestinationReference, optional): A name of the destination to which dlt will load the data, or a destination module imported from `dlt.destination`. - May also be provided to `run` method of the `pipeline`. + May also be provided to `run` method of the `pipeline`. staging (str | DestinationReference, optional): A name of the destination where dlt will stage the data before final loading, or a destination module imported from `dlt.destination`. - May also be provided to `run` method of the `pipeline`. + May also be provided to `run` method of the `pipeline`. dataset_name (str, optional): A name of the dataset to which the data will be loaded. A dataset is a logical group of tables ie. `schema` in relational databases or folder grouping many files. - May also be provided later to the `run` or `load` methods of the `Pipeline`. If not provided at all then defaults to the `pipeline_name` + May also be provided later to the `run` or `load` methods of the `Pipeline`. If not provided at all then defaults to the `pipeline_name` import_schema_path (str, optional): A path from which the schema `yaml` file will be imported on each pipeline run. Defaults to None which disables importing. export_schema_path (str, optional): A path where the schema `yaml` file will be exported after every schema change. Defaults to None which disables exporting. dev_mode (bool, optional): When set to True, each instance of the pipeline with the `pipeline_name` starts from scratch when run and loads the data to a separate dataset. - The datasets are identified by `dataset_name_` + datetime suffix. Use this setting whenever you experiment with your data to be sure you start fresh on each run. Defaults to False. + The datasets are identified by `dataset_name_` + datetime suffix. Use this setting whenever you experiment with your data to be sure you start fresh on each run. Defaults to False. refresh (str | TRefreshMode): Fully or partially reset sources during pipeline run. When set here the refresh is applied on each run of the pipeline. To apply refresh only once you can pass it to `pipeline.run` or `extract` instead. The following refresh modes are supported: @@ -83,10 +83,10 @@ def pipeline( * `drop_data`: Wipe all data and resource state for all resources being processed. Schema is not modified. progress(str, Collector): A progress monitor that shows progress bars, console or log messages with current information on sources, resources, data items etc. processed in - `extract`, `normalize` and `load` stage. Pass a string with a collector name or configure your own by choosing from `dlt.progress` module. - We support most of the progress libraries: try passing `tqdm`, `enlighten` or `alive_progress` or `log` to write to console/log. + `extract`, `normalize` and `load` stage. Pass a string with a collector name or configure your own by choosing from `dlt.progress` module. + We support most of the progress libraries: try passing `tqdm`, `enlighten` or `alive_progress` or `log` to write to console/log. - #### Returns: + Returns: Pipeline: An instance of `Pipeline` class with. Please check the documentation of `run` method for information on what to do with it. """ @@ -101,7 +101,7 @@ def pipeline() -> Pipeline: # type: ignore def pipeline( pipeline_name: str = None, pipelines_dir: str = None, - pipeline_salt: TSecretValue = None, + pipeline_salt: TSecretStrValue = None, destination: TDestinationReferenceArg = None, staging: TDestinationReferenceArg = None, dataset_name: str = None, @@ -170,7 +170,7 @@ def pipeline( def attach( pipeline_name: str = None, pipelines_dir: str = None, - pipeline_salt: TSecretValue = None, + pipeline_salt: TSecretStrValue = None, destination: TDestinationReferenceArg = None, staging: TDestinationReferenceArg = None, progress: TCollectorArg = _NULL_COLLECTOR, @@ -246,33 +246,33 @@ def run( Next it will make sure that data from the previous is fully processed. If not, `run` method normalizes and loads pending data items. Only then the new data from `data` argument is extracted, normalized and loaded. - #### Args: + Args: data (Any): Data to be loaded to destination destination (str | DestinationReference, optional): A name of the destination to which dlt will load the data, or a destination module imported from `dlt.destination`. - If not provided, the value passed to `dlt.pipeline` will be used. + If not provided, the value passed to `dlt.pipeline` will be used. - dataset_name (str, optional):A name of the dataset to which the data will be loaded. A dataset is a logical group of tables ie. `schema` in relational databases or folder grouping many files. - If not provided, the value passed to `dlt.pipeline` will be used. If not provided at all then defaults to the `pipeline_name` + dataset_name (str, optional): A name of the dataset to which the data will be loaded. A dataset is a logical group of tables ie. `schema` in relational databases or folder grouping many files. + If not provided, the value passed to `dlt.pipeline` will be used. If not provided at all then defaults to the `pipeline_name` table_name (str, optional): The name of the table to which the data should be loaded within the `dataset`. This argument is required for a `data` that is a list/Iterable or Iterator without `__name__` attribute. - The behavior of this argument depends on the type of the `data`: - * generator functions: the function name is used as table name, `table_name` overrides this default - * `@dlt.resource`: resource contains the full table schema and that includes the table name. `table_name` will override this property. Use with care! - * `@dlt.source`: source contains several resources each with a table schema. `table_name` will override all table names within the source and load the data into single table. + The behavior of this argument depends on the type of the `data`: + * generator functions: the function name is used as table name, `table_name` overrides this default + * `@dlt.resource`: resource contains the full table schema and that includes the table name. `table_name` will override this property. Use with care! + * `@dlt.source`: source contains several resources each with a table schema. `table_name` will override all table names within the source and load the data into single table. write_disposition (TWriteDispositionConfig, optional): Controls how to write data to a table. Accepts a shorthand string literal or configuration dictionary. - Allowed shorthand string literals: `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". - Write behaviour can be further customized through a configuration dictionary. For example, to obtain an SCD2 table provide `write_disposition={"disposition": "merge", "strategy": "scd2"}`. - Please note that in case of `dlt.resource` the table schema value will be overwritten and in case of `dlt.source`, the values in all resources will be overwritten. + Allowed shorthand string literals: `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". + Write behaviour can be further customized through a configuration dictionary. For example, to obtain an SCD2 table provide `write_disposition={"disposition": "merge", "strategy": "scd2"}`. + Please note that in case of `dlt.resource` the table schema value will be overwritten and in case of `dlt.source`, the values in all resources will be overwritten. columns (Sequence[TColumnSchema], optional): A list of column schemas. Typed dictionary describing column names, data types, write disposition and performance hints that gives you full control over the created table schema. schema (Schema, optional): An explicit `Schema` object in which all table schemas will be grouped. By default `dlt` takes the schema from the source (if passed in `data` argument) or creates a default one itself. - loader_file_format (Literal["jsonl", "insert_values", "parquet"], optional). The file format the loader will use to create the load package. Not all file_formats are compatible with all destinations. Defaults to the preferred file format of the selected destination. + loader_file_format (Literal["jsonl", "insert_values", "parquet"], optional): The file format the loader will use to create the load package. Not all file_formats are compatible with all destinations. Defaults to the preferred file format of the selected destination. - table_format (Literal["delta", "iceberg"], optional). The table format used by the destination to store tables. Currently you can select table format on filesystem and Athena destinations. + table_format (Literal["delta", "iceberg"], optional): The table format used by the destination to store tables. Currently you can select table format on filesystem and Athena destinations. schema_contract (TSchemaContract, optional): On override for the schema contract settings, this will replace the schema contract settings for all tables in the schema. Defaults to None. @@ -282,7 +282,7 @@ def run( * `drop_data`: Wipe all data and resource state for all resources being processed. Schema is not modified. Raises: - PipelineStepFailed when a problem happened during `extract`, `normalize` or `load` steps. + PipelineStepFailed: when a problem happened during `extract`, `normalize` or `load` steps. Returns: LoadInfo: Information on loaded data including the list of package ids and failed job statuses. Please not that `dlt` will not raise if a single job terminally fails. Such information is provided via LoadInfo. """ @@ -309,4 +309,4 @@ def run( trace.TRACKING_MODULES = [track, platform] # setup default pipeline in the container -Container()[PipelineContext] = PipelineContext(pipeline) +PipelineContext.cls__init__(pipeline) diff --git a/dlt/pipeline/configuration.py b/dlt/pipeline/configuration.py index 723e0ded83..6dc0c87e10 100644 --- a/dlt/pipeline/configuration.py +++ b/dlt/pipeline/configuration.py @@ -3,7 +3,7 @@ import dlt from dlt.common.configuration import configspec from dlt.common.configuration.specs import RunConfiguration, BaseConfiguration -from dlt.common.typing import AnyFun, TSecretValue +from dlt.common.typing import AnyFun, TSecretStrValue from dlt.common.utils import digest256 from dlt.common.destination import TLoaderFileFormat from dlt.common.pipeline import TRefreshMode @@ -22,7 +22,7 @@ class PipelineConfiguration(BaseConfiguration): dataset_name: Optional[str] = None dataset_name_layout: Optional[str] = None """Layout for dataset_name, where %s is replaced with dataset_name. For example: 'prefix_%s'""" - pipeline_salt: Optional[TSecretValue] = None + pipeline_salt: Optional[TSecretStrValue] = None restore_from_destination: bool = True """Enables the `run` method of the `Pipeline` object to restore the pipeline state and schemas from the destination""" enable_runtime_trace: bool = True @@ -44,7 +44,7 @@ def on_resolved(self) -> None: else: self.runtime.pipeline_name = self.pipeline_name if not self.pipeline_salt: - self.pipeline_salt = TSecretValue(digest256(self.pipeline_name)) + self.pipeline_salt = digest256(self.pipeline_name) if self.dataset_name_layout and "%s" not in self.dataset_name_layout: raise ConfigurationValueError( "The dataset_name_layout must contain a '%s' placeholder for dataset_name. For" diff --git a/dlt/pipeline/current.py b/dlt/pipeline/current.py index 2ae74e2532..91c8615149 100644 --- a/dlt/pipeline/current.py +++ b/dlt/pipeline/current.py @@ -1,18 +1,19 @@ """Easy access to active pipelines, state, sources and schemas""" from dlt.common.pipeline import source_state as _state, resource_state, get_current_pipe_name -from dlt.pipeline.pipeline import Pipeline -from dlt.extract.decorators import get_source_schema from dlt.common.storages.load_package import ( load_package, commit_load_package_state, destination_state, clear_destination_state, ) +from dlt.common.runtime.run_context import current as run + from dlt.extract.decorators import get_source_schema, get_source +from dlt.pipeline.pipeline import Pipeline as _Pipeline -def pipeline() -> Pipeline: +def pipeline() -> _Pipeline: """Currently active pipeline ie. the most recently created or run""" from dlt import _pipeline diff --git a/dlt/pipeline/dbt.py b/dlt/pipeline/dbt.py index 0b6ec5f896..85126e225d 100644 --- a/dlt/pipeline/dbt.py +++ b/dlt/pipeline/dbt.py @@ -3,7 +3,7 @@ from dlt.common.exceptions import VenvNotFound from dlt.common.runners import Venv from dlt.common.schema import Schema -from dlt.common.typing import ConfigValue, TSecretValue +from dlt.common.typing import ConfigValue, TSecretStrValue from dlt.common.schema.utils import normalize_schema_name from dlt.helpers.dbt import ( @@ -53,7 +53,7 @@ def package( pipeline: Pipeline, package_location: str, package_repository_branch: str = ConfigValue, - package_repository_ssh_key: TSecretValue = TSecretValue(""), # noqa + package_repository_ssh_key: TSecretStrValue = "", auto_full_refresh_when_out_of_sync: bool = ConfigValue, venv: Venv = None, ) -> DBTPackageRunner: diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 54e576b5fc..348f445967 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -16,6 +16,7 @@ get_type_hints, ContextManager, Dict, + Literal, ) from dlt import version @@ -49,7 +50,7 @@ ) from dlt.common.schema.utils import normalize_schema_name from dlt.common.storages.exceptions import LoadPackageNotFound -from dlt.common.typing import ConfigValue, TFun, TSecretValue, is_optional_type +from dlt.common.typing import ConfigValue, TFun, TSecretStrValue, is_optional_type from dlt.common.runners import pool_runner as runner from dlt.common.storages import ( LiveSchemaStorage, @@ -82,6 +83,7 @@ DestinationClientStagingConfiguration, DestinationClientStagingConfiguration, DestinationClientDwhWithStagingConfiguration, + SupportsReadableDataset, ) from dlt.common.normalizers.naming import NamingConvention from dlt.common.pipeline import ( @@ -108,9 +110,10 @@ from dlt.extract.extract import Extract, data_to_sources from dlt.normalize import Normalize from dlt.normalize.configuration import NormalizeConfiguration -from dlt.destinations.sql_client import SqlClientBase +from dlt.destinations.sql_client import SqlClientBase, WithSqlClient from dlt.destinations.fs_client import FSClientBase from dlt.destinations.job_client_impl import SqlJobClientBase +from dlt.destinations.dataset import ReadableDBAPIDataset from dlt.load.configuration import LoaderConfiguration from dlt.load import Load @@ -320,7 +323,7 @@ def __init__( self, pipeline_name: str, pipelines_dir: str, - pipeline_salt: TSecretValue, + pipeline_salt: TSecretStrValue, destination: TDestination, staging: TDestination, dataset_name: str, @@ -444,6 +447,7 @@ def extract( workers, refresh=refresh or self.refresh, ) + # this will update state version hash so it will not be extracted again by with_state_sync self._bump_version_and_extract_state( self._container[StateInjectableContext].state, @@ -622,28 +626,28 @@ def run( Next it will make sure that data from the previous is fully processed. If not, `run` method normalizes, loads pending data items and **exits** If there was no pending data, new data from `data` argument is extracted, normalized and loaded. - #### Args: + Args: data (Any): Data to be loaded to destination destination (str | DestinationReference, optional): A name of the destination to which dlt will load the data, or a destination module imported from `dlt.destination`. - If not provided, the value passed to `dlt.pipeline` will be used. + If not provided, the value passed to `dlt.pipeline` will be used. - dataset_name (str, optional):A name of the dataset to which the data will be loaded. A dataset is a logical group of tables ie. `schema` in relational databases or folder grouping many files. - If not provided, the value passed to `dlt.pipeline` will be used. If not provided at all then defaults to the `pipeline_name` + dataset_name (str, optional): A name of the dataset to which the data will be loaded. A dataset is a logical group of tables ie. `schema` in relational databases or folder grouping many files. + If not provided, the value passed to `dlt.pipeline` will be used. If not provided at all then defaults to the `pipeline_name` credentials (Any, optional): Credentials for the `destination` ie. database connection string or a dictionary with google cloud credentials. - In most cases should be set to None, which lets `dlt` to use `secrets.toml` or environment variables to infer right credentials values. + In most cases should be set to None, which lets `dlt` to use `secrets.toml` or environment variables to infer right credentials values. table_name (str, optional): The name of the table to which the data should be loaded within the `dataset`. This argument is required for a `data` that is a list/Iterable or Iterator without `__name__` attribute. - The behavior of this argument depends on the type of the `data`: - * generator functions: the function name is used as table name, `table_name` overrides this default - * `@dlt.resource`: resource contains the full table schema and that includes the table name. `table_name` will override this property. Use with care! - * `@dlt.source`: source contains several resources each with a table schema. `table_name` will override all table names within the source and load the data into single table. + The behavior of this argument depends on the type of the `data`: + * generator functions - the function name is used as table name, `table_name` overrides this default + * `@dlt.resource` - resource contains the full table schema and that includes the table name. `table_name` will override this property. Use with care! + * `@dlt.source` - source contains several resources each with a table schema. `table_name` will override all table names within the source and load the data into single table. write_disposition (TWriteDispositionConfig, optional): Controls how to write data to a table. Accepts a shorthand string literal or configuration dictionary. - Allowed shorthand string literals: `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". - Write behaviour can be further customized through a configuration dictionary. For example, to obtain an SCD2 table provide `write_disposition={"disposition": "merge", "strategy": "scd2"}`. - Please note that in case of `dlt.resource` the table schema value will be overwritten and in case of `dlt.source`, the values in all resources will be overwritten. + Allowed shorthand string literals: `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". + Write behaviour can be further customized through a configuration dictionary. For example, to obtain an SCD2 table provide `write_disposition={"disposition": "merge", "strategy": "scd2"}`. + Please note that in case of `dlt.resource` the table schema value will be overwritten and in case of `dlt.source`, the values in all resources will be overwritten. columns (Sequence[TColumnSchema], optional): A list of column schemas. Typed dictionary describing column names, data types, write disposition and performance hints that gives you full control over the created table schema. @@ -651,19 +655,19 @@ def run( schema (Schema, optional): An explicit `Schema` object in which all table schemas will be grouped. By default `dlt` takes the schema from the source (if passed in `data` argument) or creates a default one itself. - loader_file_format (Literal["jsonl", "insert_values", "parquet"], optional). The file format the loader will use to create the load package. Not all file_formats are compatible with all destinations. Defaults to the preferred file format of the selected destination. + loader_file_format (Literal["jsonl", "insert_values", "parquet"], optional): The file format the loader will use to create the load package. Not all file_formats are compatible with all destinations. Defaults to the preferred file format of the selected destination. - table_format (Literal["delta", "iceberg"], optional). The table format used by the destination to store tables. Currently you can select table format on filesystem and Athena destinations. + table_format (Literal["delta", "iceberg"], optional): The table format used by the destination to store tables. Currently you can select table format on filesystem and Athena destinations. schema_contract (TSchemaContract, optional): On override for the schema contract settings, this will replace the schema contract settings for all tables in the schema. Defaults to None. refresh (str | TRefreshMode): Fully or partially reset sources before loading new data in this run. The following refresh modes are supported: - * `drop_sources`: Drop tables and source and resource state for all sources currently being processed in `run` or `extract` methods of the pipeline. (Note: schema history is erased) - * `drop_resources`: Drop tables and resource state for all resources being processed. Source level state is not modified. (Note: schema history is erased) - * `drop_data`: Wipe all data and resource state for all resources being processed. Schema is not modified. + * `drop_sources` - Drop tables and source and resource state for all sources currently being processed in `run` or `extract` methods of the pipeline. (Note: schema history is erased) + * `drop_resources`- Drop tables and resource state for all resources being processed. Source level state is not modified. (Note: schema history is erased) + * `drop_data` - Wipe all data and resource state for all resources being processed. Schema is not modified. Raises: - PipelineStepFailed when a problem happened during `extract`, `normalize` or `load` steps. + PipelineStepFailed: when a problem happened during `extract`, `normalize` or `load` steps. Returns: LoadInfo: Information on loaded data including the list of package ids and failed job statuses. Please not that `dlt` will not raise if a single job terminally fails. Such information is provided via LoadInfo. """ @@ -1005,7 +1009,12 @@ def sql_client(self, schema_name: str = None) -> SqlClientBase[Any]: # "Sql Client is not available in a pipeline without a default schema. Extract some data first or restore the pipeline from the destination using 'restore_from_destination' flag. There's also `_inject_schema` method for advanced users." # ) schema = self._get_schema_or_create(schema_name) - return self._sql_job_client(schema).sql_client + client_config = self._get_destination_client_initial_config() + client = self._get_destination_clients(schema, client_config)[0] + if isinstance(client, WithSqlClient): + return client.sql_client + else: + raise SqlClientNotAvailable(self.pipeline_name, self.destination.destination_name) def _fs_client(self, schema_name: str = None) -> FSClientBase: """Returns a filesystem client configured to point to the right folder / bucket for each table. @@ -1707,3 +1716,11 @@ def _save_state(self, state: TPipelineState) -> None: def __getstate__(self) -> Any: # pickle only the SupportsPipeline protocol fields return {"pipeline_name": self.pipeline_name} + + def _dataset(self, dataset_type: Literal["dbapi", "ibis"] = "dbapi") -> SupportsReadableDataset: + """Access helper to dataset""" + if dataset_type == "dbapi": + return ReadableDBAPIDataset( + self.sql_client(), schema=self.default_schema if self.default_schema_name else None + ) + raise NotImplementedError(f"Dataset of type {dataset_type} not implemented") diff --git a/dlt/pipeline/trace.py b/dlt/pipeline/trace.py index c47926e5f4..007a819729 100644 --- a/dlt/pipeline/trace.py +++ b/dlt/pipeline/trace.py @@ -24,7 +24,7 @@ StepMetrics, SupportsPipeline, ) -from dlt.common.source import get_current_pipe_name +from dlt.common.pipeline import get_current_pipe_name from dlt.common.storages.file_storage import FileStorage from dlt.common.typing import DictStrAny, StrAny, SupportsHumanize from dlt.common.utils import uniq_id, get_exception_trace_chain diff --git a/dlt/sources/__init__.py b/dlt/sources/__init__.py index dcfc281160..4ee30d2fdd 100644 --- a/dlt/sources/__init__.py +++ b/dlt/sources/__init__.py @@ -1,12 +1,14 @@ """Module with built in sources and source building blocks""" from dlt.common.typing import TDataItem, TDataItems from dlt.extract import DltSource, DltResource, Incremental as incremental -from . import credentials -from . import config +from dlt.extract.source import SourceReference +from . import credentials, config + __all__ = [ "DltSource", "DltResource", + "SourceReference", "TDataItem", "TDataItems", "incremental", diff --git a/dlt/sources/filesystem/__init__.py b/dlt/sources/filesystem/__init__.py index 80dabe7e66..66e69624c2 100644 --- a/dlt/sources/filesystem/__init__.py +++ b/dlt/sources/filesystem/__init__.py @@ -2,6 +2,7 @@ from typing import Iterator, List, Optional, Tuple, Union import dlt +from dlt.extract import decorators from dlt.common.storages.fsspec_filesystem import ( FileItem, FileItemDict, @@ -25,7 +26,7 @@ from dlt.sources.filesystem.settings import DEFAULT_CHUNK_SIZE -@dlt.source(_impl_cls=ReadersSource, spec=FilesystemConfigurationResource) +@decorators.source(_impl_cls=ReadersSource, spec=FilesystemConfigurationResource) def readers( bucket_url: str = dlt.secrets.value, credentials: Union[FileSystemCredentials, AbstractFileSystem] = dlt.secrets.value, @@ -54,7 +55,7 @@ def readers( ) -@dlt.resource(primary_key="file_url", spec=FilesystemConfigurationResource, standalone=True) +@decorators.resource(primary_key="file_url", spec=FilesystemConfigurationResource, standalone=True) def filesystem( bucket_url: str = dlt.secrets.value, credentials: Union[FileSystemCredentials, AbstractFileSystem] = dlt.secrets.value, @@ -96,7 +97,7 @@ def filesystem( yield files_chunk -read_csv = dlt.transformer(standalone=True)(_read_csv) -read_jsonl = dlt.transformer(standalone=True)(_read_jsonl) -read_parquet = dlt.transformer(standalone=True)(_read_parquet) -read_csv_duckdb = dlt.transformer(standalone=True)(_read_csv_duckdb) +read_csv = decorators.transformer(standalone=True)(_read_csv) +read_jsonl = decorators.transformer(standalone=True)(_read_jsonl) +read_parquet = decorators.transformer(standalone=True)(_read_parquet) +read_csv_duckdb = decorators.transformer(standalone=True)(_read_csv_duckdb) diff --git a/dlt/sources/helpers/requests/retry.py b/dlt/sources/helpers/requests/retry.py index 7d7d6493ec..3268fd77c8 100644 --- a/dlt/sources/helpers/requests/retry.py +++ b/dlt/sources/helpers/requests/retry.py @@ -153,7 +153,7 @@ class Client: The retry is triggered when either any of the predicates or the default conditions based on status code/exception are `True`. - #### Args: + Args: request_timeout: Timeout for requests in seconds. May be passed as `timedelta` or `float/int` number of seconds. max_connections: Max connections per host in the HTTPAdapter pool raise_for_status: Whether to raise exception on error status codes (using `response.raise_for_status()`) diff --git a/dlt/sources/helpers/requests/session.py b/dlt/sources/helpers/requests/session.py index 5ba4d9b611..8f05feabb2 100644 --- a/dlt/sources/helpers/requests/session.py +++ b/dlt/sources/helpers/requests/session.py @@ -24,7 +24,7 @@ def _timeout_to_seconds(timeout: TRequestTimeout) -> Optional[Union[Tuple[float, class Session(BaseSession): """Requests session which by default adds a timeout to all requests and calls `raise_for_status()` on response - #### Args: + Args: timeout: Timeout for requests in seconds. May be passed as `timedelta` or `float/int` number of seconds. May be a single value or a tuple for separate (connect, read) timeout. raise_for_status: Whether to raise exception on error status codes (using `response.raise_for_status()`) diff --git a/dlt/sources/helpers/rest_client/auth.py b/dlt/sources/helpers/rest_client/auth.py index 31c52527da..988ce65549 100644 --- a/dlt/sources/helpers/rest_client/auth.py +++ b/dlt/sources/helpers/rest_client/auth.py @@ -24,7 +24,6 @@ from dlt.common.configuration.specs.exceptions import NativeValueError from dlt.common.pendulum import pendulum from dlt.common.typing import TSecretStrValue -from dlt.sources.helpers import requests if TYPE_CHECKING: from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes @@ -54,7 +53,7 @@ class BearerTokenAuth(AuthConfigBase): def parse_native_representation(self, value: Any) -> None: if isinstance(value, str): - self.token = cast(TSecretStrValue, value) + self.token = value else: raise NativeValueError( type(self), @@ -77,7 +76,7 @@ class APIKeyAuth(AuthConfigBase): def parse_native_representation(self, value: Any) -> None: if isinstance(value, str): - self.api_key = cast(TSecretStrValue, value) + self.api_key = value else: raise NativeValueError( type(self), @@ -130,7 +129,7 @@ class OAuth2AuthBase(AuthConfigBase): def parse_native_representation(self, value: Any) -> None: if isinstance(value, str): - self.access_token = cast(TSecretStrValue, value) + self.access_token = value else: raise NativeValueError( type(self), @@ -146,7 +145,7 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest: @configspec class OAuth2ClientCredentials(OAuth2AuthBase): """ - This class implements OAuth2 Client Credentials flow where the autorization service + This class implements OAuth2 Client Credentials flow where the authorization service gives permission without the end user approving. This is often used for machine-to-machine authorization. The client sends its client ID and client secret to the authorization service which replies @@ -154,27 +153,25 @@ class OAuth2ClientCredentials(OAuth2AuthBase): With the access token, the client can access resource services. """ - def __init__( - self, - access_token_url: str, - client_id: TSecretStrValue, - client_secret: TSecretStrValue, - access_token_request_data: Dict[str, Any] = None, - default_token_expiration: int = 3600, - session: Annotated[BaseSession, NotResolved()] = None, - ) -> None: - super().__init__() - self.access_token_url = access_token_url - self.client_id = client_id - self.client_secret = client_secret - if access_token_request_data is None: + access_token: Annotated[Optional[TSecretStrValue], NotResolved()] = None + access_token_url: str = None + client_id: TSecretStrValue = None + client_secret: TSecretStrValue = None + access_token_request_data: Dict[str, Any] = None + default_token_expiration: int = 3600 + session: Annotated[BaseSession, NotResolved()] = None + + def __post_init__(self) -> None: + if self.access_token_request_data is None: self.access_token_request_data = {} else: - self.access_token_request_data = access_token_request_data - self.default_token_expiration = default_token_expiration + self.access_token_request_data = self.access_token_request_data self.token_expiry: pendulum.DateTime = pendulum.now() + # use default system session unless specified otherwise + if self.session is None: + from dlt.sources.helpers import requests - self.session = session if session is not None else requests.client.session + self.session = requests.client.session def __call__(self, request: PreparedRequest) -> PreparedRequest: if self.access_token is None or self.is_token_expired(): @@ -235,6 +232,8 @@ def __post_init__(self) -> None: self.token_expiry: Optional[pendulum.DateTime] = None # use default system session unless specified otherwise if self.session is None: + from dlt.sources.helpers import requests + self.session = requests.client.session def __call__(self, r: PreparedRequest) -> PreparedRequest: diff --git a/dlt/sources/helpers/rest_client/client.py b/dlt/sources/helpers/rest_client/client.py index c05dabc30c..86e72ccf4c 100644 --- a/dlt/sources/helpers/rest_client/client.py +++ b/dlt/sources/helpers/rest_client/client.py @@ -16,8 +16,6 @@ from dlt.common import jsonpath, logger -from dlt.sources.helpers.requests.retry import Client - from .typing import HTTPMethodBasic, HTTPMethod, Hooks from .paginators import BasePaginator from .detector import PaginatorFactory, find_response_page_data @@ -84,6 +82,8 @@ def __init__( # has raise_for_status=True by default self.session = _warn_if_raise_for_status_and_return(session) else: + from dlt.sources.helpers.requests.retry import Client + self.session = Client(raise_for_status=False).session self.paginator = paginator diff --git a/dlt/sources/helpers/rest_client/paginators.py b/dlt/sources/helpers/rest_client/paginators.py index 872d4f34e8..82b97e253b 100644 --- a/dlt/sources/helpers/rest_client/paginators.py +++ b/dlt/sources/helpers/rest_client/paginators.py @@ -1,9 +1,10 @@ import warnings from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional -from urllib.parse import urlparse, urljoin +from urllib.parse import urljoin, urlparse + +from requests import Request, Response -from requests import Response, Request from dlt.common import jsonpath @@ -127,6 +128,7 @@ def __init__( " provided." ) self.param_name = param_name + self.initial_value = initial_value self.current_value = initial_value self.value_step = value_step self.base_index = base_index @@ -136,6 +138,8 @@ def __init__( self.stop_after_empty_page = stop_after_empty_page def init_request(self, request: Request) -> None: + self._has_next_page = True + self.current_value = self.initial_value if request.params is None: request.params = {} diff --git a/dlt/sources/pipeline_templates/arrow_pipeline.py b/dlt/sources/pipeline_templates/arrow_pipeline.py index 92ed0664b9..e91f6e35f2 100644 --- a/dlt/sources/pipeline_templates/arrow_pipeline.py +++ b/dlt/sources/pipeline_templates/arrow_pipeline.py @@ -24,7 +24,7 @@ def add_updated_at(item: pa.Table): return item.set_column(column_count, "updated_at", [[time.time()] * item.num_rows]) -# apply tranformer to resource +# apply transformer to resource resource.add_map(add_updated_at) diff --git a/dlt/sources/pipeline_templates/default_pipeline.py b/dlt/sources/pipeline_templates/default_pipeline.py index 9fa03f9ce5..e7cd0a5d39 100644 --- a/dlt/sources/pipeline_templates/default_pipeline.py +++ b/dlt/sources/pipeline_templates/default_pipeline.py @@ -1,51 +1,111 @@ -"""The Default Pipeline Template provides a simple starting point for your dlt pipeline""" +"""The Intro Pipeline Template contains the example from the docs intro page""" # mypy: disable-error-code="no-untyped-def,arg-type" +from typing import Optional +import pandas as pd +import sqlalchemy as sa + import dlt -from dlt.common import Decimal +from dlt.sources.helpers import requests -@dlt.resource(name="customers", primary_key="id") -def customers(): - """Load customer data from a simple python list.""" - yield [ - {"id": 1, "name": "simon", "city": "berlin"}, - {"id": 2, "name": "violet", "city": "london"}, - {"id": 3, "name": "tammo", "city": "new york"}, - ] +def load_api_data() -> None: + """Load data from the chess api, for more complex examples use our rest_api source""" + # Create a dlt pipeline that will load + # chess player data to the DuckDB destination + pipeline = dlt.pipeline( + pipeline_name="chess_pipeline", destination="duckdb", dataset_name="player_data" + ) + # Grab some player data from Chess.com API + data = [] + for player in ["magnuscarlsen", "rpragchess"]: + response = requests.get(f"https://api.chess.com/pub/player/{player}") + response.raise_for_status() + data.append(response.json()) -@dlt.resource(name="inventory", primary_key="id") -def inventory(): - """Load inventory data from a simple python list.""" - yield [ - {"id": 1, "name": "apple", "price": Decimal("1.50")}, - {"id": 2, "name": "banana", "price": Decimal("1.70")}, - {"id": 3, "name": "pear", "price": Decimal("2.50")}, - ] + # Extract, normalize, and load the data + load_info = pipeline.run(data, table_name="player") + print(load_info) # noqa: T201 -@dlt.source(name="my_fruitshop") -def source(): - """A source function groups all resources into one schema.""" - return customers(), inventory() +def load_pandas_data() -> None: + """Load data from a public csv via pandas""" + owid_disasters_csv = ( + "https://raw.githubusercontent.com/owid/owid-datasets/master/datasets/" + "Natural%20disasters%20from%201900%20to%202019%20-%20EMDAT%20(2020)/" + "Natural%20disasters%20from%201900%20to%202019%20-%20EMDAT%20(2020).csv" + ) + df = pd.read_csv(owid_disasters_csv) -def load_stuff() -> None: - # specify the pipeline name, destination and dataset name when configuring pipeline, - # otherwise the defaults will be used that are derived from the current script name - p = dlt.pipeline( - pipeline_name="fruitshop", + pipeline = dlt.pipeline( + pipeline_name="from_csv", destination="duckdb", - dataset_name="fruitshop_data", + dataset_name="mydata", ) + load_info = pipeline.run(df, table_name="natural_disasters") + + print(load_info) # noqa: T201 + + +def load_sql_data() -> None: + """Load data from a sql database with sqlalchemy, for more complex examples use our sql_database source""" + + # Use any SQL database supported by SQLAlchemy, below we use a public + # MySQL instance to get data. + # NOTE: you'll need to install pymysql with `pip install pymysql` + # NOTE: loading data from public mysql instance may take several seconds + engine = sa.create_engine("mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam") + + with engine.connect() as conn: + # Select genome table, stream data in batches of 100 elements + query = "SELECT * FROM genome LIMIT 1000" + rows = conn.execution_options(yield_per=100).exec_driver_sql(query) + + pipeline = dlt.pipeline( + pipeline_name="from_database", + destination="duckdb", + dataset_name="genome_data", + ) - load_info = p.run(source()) + # Convert the rows into dictionaries on the fly with a map function + load_info = pipeline.run(map(lambda row: dict(row._mapping), rows), table_name="genome") - # pretty print the information on data that was loaded + print(load_info) # noqa: T201 + + +@dlt.resource(write_disposition="replace") +def github_api_resource(api_secret_key: Optional[str] = dlt.secrets.value): + from dlt.sources.helpers.rest_client import paginate + from dlt.sources.helpers.rest_client.auth import BearerTokenAuth + from dlt.sources.helpers.rest_client.paginators import HeaderLinkPaginator + + url = "https://api.github.com/repos/dlt-hub/dlt/issues" + + # Github allows both authenticated and non-authenticated requests (with low rate limits) + auth = BearerTokenAuth(api_secret_key) if api_secret_key else None + for page in paginate( + url, auth=auth, paginator=HeaderLinkPaginator(), params={"state": "open", "per_page": "100"} + ): + yield page + + +@dlt.source +def github_api_source(api_secret_key: Optional[str] = dlt.secrets.value): + return github_api_resource(api_secret_key=api_secret_key) + + +def load_data_from_source(): + pipeline = dlt.pipeline( + pipeline_name="github_api_pipeline", destination="duckdb", dataset_name="github_api_data" + ) + load_info = pipeline.run(github_api_source()) print(load_info) # noqa: T201 if __name__ == "__main__": - load_stuff() + load_api_data() + load_pandas_data() + load_sql_data() diff --git a/dlt/sources/pipeline_templates/fruitshop_pipeline.py b/dlt/sources/pipeline_templates/fruitshop_pipeline.py new file mode 100644 index 0000000000..574774aa1c --- /dev/null +++ b/dlt/sources/pipeline_templates/fruitshop_pipeline.py @@ -0,0 +1,51 @@ +"""The Default Pipeline Template provides a simple starting point for your dlt pipeline""" + +# mypy: disable-error-code="no-untyped-def,arg-type" + +import dlt +from dlt.common import Decimal + + +@dlt.resource(primary_key="id") +def customers(): + """Load customer data from a simple python list.""" + yield [ + {"id": 1, "name": "simon", "city": "berlin"}, + {"id": 2, "name": "violet", "city": "london"}, + {"id": 3, "name": "tammo", "city": "new york"}, + ] + + +@dlt.resource(primary_key="id") +def inventory(): + """Load inventory data from a simple python list.""" + yield [ + {"id": 1, "name": "apple", "price": Decimal("1.50")}, + {"id": 2, "name": "banana", "price": Decimal("1.70")}, + {"id": 3, "name": "pear", "price": Decimal("2.50")}, + ] + + +@dlt.source +def fruitshop(): + """A source function groups all resources into one schema.""" + return customers(), inventory() + + +def load_shop() -> None: + # specify the pipeline name, destination and dataset name when configuring pipeline, + # otherwise the defaults will be used that are derived from the current script name + p = dlt.pipeline( + pipeline_name="fruitshop", + destination="duckdb", + dataset_name="fruitshop_data", + ) + + load_info = p.run(fruitshop()) + + # pretty print the information on data that was loaded + print(load_info) # noqa: T201 + + +if __name__ == "__main__": + load_shop() diff --git a/dlt/sources/pipeline_templates/github_api_pipeline.py b/dlt/sources/pipeline_templates/github_api_pipeline.py new file mode 100644 index 0000000000..80cac0c525 --- /dev/null +++ b/dlt/sources/pipeline_templates/github_api_pipeline.py @@ -0,0 +1,51 @@ +"""The Github API templates provides a starting point to read data from REST APIs with REST Client helper""" + +# mypy: disable-error-code="no-untyped-def,arg-type" + +from typing import Optional + +import dlt + +from dlt.sources.helpers.rest_client import paginate +from dlt.sources.helpers.rest_client.auth import BearerTokenAuth +from dlt.sources.helpers.rest_client.paginators import HeaderLinkPaginator + + +@dlt.resource(write_disposition="replace") +def github_api_resource(api_secret_key: Optional[str] = dlt.secrets.value): + url = "https://api.github.com/repos/dlt-hub/dlt/issues" + + # Github allows both authenticated and non-authenticated requests (with low rate limits) + auth = BearerTokenAuth(api_secret_key) if api_secret_key else None + for page in paginate( + url, auth=auth, paginator=HeaderLinkPaginator(), params={"state": "open", "per_page": "100"} + ): + yield page + + +@dlt.source +def github_api_source(api_secret_key: Optional[str] = dlt.secrets.value): + return github_api_resource(api_secret_key=api_secret_key) + + +def run_source() -> None: + # configure the pipeline with your destination details + pipeline = dlt.pipeline( + pipeline_name="github_api_pipeline", destination="duckdb", dataset_name="github_api_data" + ) + + # print credentials by running the resource + data = list(github_api_resource()) + + # print the data yielded from resource + print(data) # noqa: T201 + + # run the pipeline with your parameters + load_info = pipeline.run(github_api_source()) + + # pretty print the information on data that was loaded + print(load_info) # noqa: T201 + + +if __name__ == "__main__": + run_source() diff --git a/dlt/sources/pipeline_templates/intro_pipeline.py b/dlt/sources/pipeline_templates/intro_pipeline.py deleted file mode 100644 index a4de18daba..0000000000 --- a/dlt/sources/pipeline_templates/intro_pipeline.py +++ /dev/null @@ -1,82 +0,0 @@ -"""The Intro Pipeline Template contains the example from the docs intro page""" - -# mypy: disable-error-code="no-untyped-def,arg-type" - -import pandas as pd -import sqlalchemy as sa - -import dlt -from dlt.sources.helpers import requests - - -def load_api_data() -> None: - """Load data from the chess api, for more complex examples use our rest_api source""" - - # Create a dlt pipeline that will load - # chess player data to the DuckDB destination - pipeline = dlt.pipeline( - pipeline_name="chess_pipeline", destination="duckdb", dataset_name="player_data" - ) - # Grab some player data from Chess.com API - data = [] - for player in ["magnuscarlsen", "rpragchess"]: - response = requests.get(f"https://api.chess.com/pub/player/{player}") - response.raise_for_status() - data.append(response.json()) - - # Extract, normalize, and load the data - load_info = pipeline.run(data, table_name="player") - print(load_info) # noqa: T201 - - -def load_pandas_data() -> None: - """Load data from a public csv via pandas""" - - owid_disasters_csv = ( - "https://raw.githubusercontent.com/owid/owid-datasets/master/datasets/" - "Natural%20disasters%20from%201900%20to%202019%20-%20EMDAT%20(2020)/" - "Natural%20disasters%20from%201900%20to%202019%20-%20EMDAT%20(2020).csv" - ) - df = pd.read_csv(owid_disasters_csv) - data = df.to_dict(orient="records") - - pipeline = dlt.pipeline( - pipeline_name="from_csv", - destination="duckdb", - dataset_name="mydata", - ) - load_info = pipeline.run(data, table_name="natural_disasters") - - print(load_info) # noqa: T201 - - -def load_sql_data() -> None: - """Load data from a sql database with sqlalchemy, for more complex examples use our sql_database source""" - - # Use any SQL database supported by SQLAlchemy, below we use a public - # MySQL instance to get data. - # NOTE: you'll need to install pymysql with `pip install pymysql` - # NOTE: loading data from public mysql instance may take several seconds - engine = sa.create_engine("mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam") - - with engine.connect() as conn: - # Select genome table, stream data in batches of 100 elements - query = "SELECT * FROM genome LIMIT 1000" - rows = conn.execution_options(yield_per=100).exec_driver_sql(query) - - pipeline = dlt.pipeline( - pipeline_name="from_database", - destination="duckdb", - dataset_name="genome_data", - ) - - # Convert the rows into dictionaries on the fly with a map function - load_info = pipeline.run(map(lambda row: dict(row._mapping), rows), table_name="genome") - - print(load_info) # noqa: T201 - - -if __name__ == "__main__": - load_api_data() - load_pandas_data() - load_sql_data() diff --git a/dlt/sources/pipeline_templates/requests_pipeline.py b/dlt/sources/pipeline_templates/requests_pipeline.py index 19acaa1fdb..14c30ec35d 100644 --- a/dlt/sources/pipeline_templates/requests_pipeline.py +++ b/dlt/sources/pipeline_templates/requests_pipeline.py @@ -15,7 +15,7 @@ BASE_PATH = "https://api.chess.com/pub/player" -@dlt.resource(name="players", primary_key="player_id") +@dlt.resource(primary_key="player_id") def players(): """Load player profiles from the chess api.""" for player_name in ["magnuscarlsen", "rpragchess"]: @@ -37,7 +37,7 @@ def players_games(player: Any) -> Iterator[TDataItems]: @dlt.source(name="chess") -def source(): +def chess(): """A source function groups all resources into one schema.""" return players(), players_games() @@ -51,7 +51,7 @@ def load_chess_data() -> None: dataset_name="chess_data", ) - load_info = p.run(source()) + load_info = p.run(chess()) # pretty print the information on data that was loaded print(load_info) # noqa: T201 diff --git a/dlt/sources/rest_api/__init__.py b/dlt/sources/rest_api/__init__.py index b92ed6301c..1be634f2e5 100644 --- a/dlt/sources/rest_api/__init__.py +++ b/dlt/sources/rest_api/__init__.py @@ -1,6 +1,6 @@ """Generic API Source""" from copy import deepcopy -from typing import Type, Any, Dict, List, Optional, Generator, Callable, cast, Union +from typing import Any, Dict, List, Optional, Generator, Callable, cast, Union import graphlib # type: ignore[import,unused-ignore] from requests.auth import AuthBase @@ -9,10 +9,8 @@ from dlt.common import jsonpath from dlt.common.schema.schema import Schema from dlt.common.schema.typing import TSchemaContract -from dlt.common.configuration.specs import BaseConfiguration -from dlt.extract.incremental import Incremental -from dlt.extract.source import DltResource, DltSource +from dlt.extract import Incremental, DltResource, DltSource, decorators from dlt.sources.helpers.rest_client import RESTClient from dlt.sources.helpers.rest_client.paginators import BasePaginator @@ -26,6 +24,7 @@ from .typing import ( AuthConfig, ClientConfig, + EndpointResourceBase, ResolvedParam, ResolveParamConfig, Endpoint, @@ -56,6 +55,18 @@ ] +@decorators.source +def rest_api( + client: ClientConfig = dlt.config.value, + resources: List[Union[str, EndpointResource, DltResource]] = dlt.config.value, + resource_defaults: Optional[EndpointResourceBase] = None, +) -> List[DltResource]: + """Creates and configures a REST API source with default settings""" + return rest_api_resources( + {"client": client, "resources": resources, "resource_defaults": resource_defaults} + ) + + def rest_api_source( config: RESTAPIConfig, name: str = None, @@ -64,7 +75,7 @@ def rest_api_source( root_key: bool = False, schema: Schema = None, schema_contract: TSchemaContract = None, - spec: Type[BaseConfiguration] = None, + parallelized: bool = False, ) -> DltSource: """Creates and configures a REST API source for data extraction. @@ -85,8 +96,9 @@ def rest_api_source( will be loaded from file. schema_contract (TSchemaContract, optional): Schema contract settings that will be applied to this resource. - spec (Type[BaseConfiguration], optional): A specification of configuration - and secret values required by the source. + parallelized (bool, optional): If `True`, resource generators will be + extracted in parallel with other resources. Transformers that return items are also parallelized. + Non-eligible resources are ignored. Defaults to `False` which preserves resource settings. Returns: DltSource: A configured dlt source. @@ -109,18 +121,20 @@ def rest_api_source( }, }) """ - decorated = dlt.source( - rest_api_resources, - name, - section, - max_table_nesting, - root_key, - schema, - schema_contract, - spec, + # TODO: this must be removed when TypedDicts are supported by resolve_configuration + # so secrets values are bound BEFORE validation. validation will happen during the resolve process + _validate_config(config) + decorated = rest_api.with_args( + name=name, + section=section, + max_table_nesting=max_table_nesting, + root_key=root_key, + schema=schema, + schema_contract=schema_contract, + parallelized=parallelized, ) - return decorated(config) + return decorated(**config) def rest_api_resources(config: RESTAPIConfig) -> List[DltResource]: @@ -186,7 +200,7 @@ def rest_api_resources(config: RESTAPIConfig) -> List[DltResource]: _validate_config(config) client_config = config["client"] - resource_defaults = config.get("resource_defaults", {}) + resource_defaults = config.get("resource_defaults") or {} resource_list = config["resources"] ( @@ -212,7 +226,7 @@ def create_resources( client_config: ClientConfig, dependency_graph: graphlib.TopologicalSorter, endpoint_resource_map: Dict[str, Union[EndpointResource, DltResource]], - resolved_param_map: Dict[str, Optional[ResolvedParam]], + resolved_param_map: Dict[str, Optional[List[ResolvedParam]]], ) -> Dict[str, DltResource]: resources = {} @@ -229,10 +243,10 @@ def create_resources( paginator = create_paginator(endpoint_config.get("paginator")) processing_steps = endpoint_resource.pop("processing_steps", []) - resolved_param: ResolvedParam = resolved_param_map[resource_name] + resolved_params: List[ResolvedParam] = resolved_param_map[resource_name] include_from_parent: List[str] = endpoint_resource.get("include_from_parent", []) - if not resolved_param and include_from_parent: + if not resolved_params and include_from_parent: raise ValueError( f"Resource {resource_name} has include_from_parent but is not " "dependent on another resource" @@ -267,7 +281,7 @@ def process( resource.add_map(step["map"]) return resource - if resolved_param is None: + if resolved_params is None: def paginate_resource( method: HTTPMethodBasic, @@ -318,9 +332,10 @@ def paginate_resource( resources[resource_name] = process(resources[resource_name], processing_steps) else: - predecessor = resources[resolved_param.resolve_config["resource"]] + first_param = resolved_params[0] + predecessor = resources[first_param.resolve_config["resource"]] - base_params = exclude_keys(request_params, {resolved_param.param_name}) + base_params = exclude_keys(request_params, {x.param_name for x in resolved_params}) def paginate_dependent_resource( items: List[Dict[str, Any]], @@ -331,7 +346,7 @@ def paginate_dependent_resource( data_selector: Optional[jsonpath.TJsonPath], hooks: Optional[Dict[str, Any]], client: RESTClient = client, - resolved_param: ResolvedParam = resolved_param, + resolved_params: List[ResolvedParam] = resolved_params, include_from_parent: List[str] = include_from_parent, incremental_object: Optional[Incremental[Any]] = incremental_object, incremental_param: Optional[IncrementalParam] = incremental_param, @@ -349,7 +364,7 @@ def paginate_dependent_resource( for item in items: formatted_path, parent_record = process_parent_data_item( - path, item, resolved_param, include_from_parent + path, item, resolved_params, include_from_parent ) for child_page in client.paginate( @@ -395,7 +410,12 @@ def _validate_config(config: RESTAPIConfig) -> None: def _mask_secrets(auth_config: AuthConfig) -> AuthConfig: - if isinstance(auth_config, AuthBase) and not isinstance(auth_config, AuthConfigBase): + # skip AuthBase (derived from requests lib) or shorthand notation + if ( + isinstance(auth_config, AuthBase) + and not isinstance(auth_config, AuthConfigBase) + or isinstance(auth_config, str) + ): return auth_config has_sensitive_key = any(key in auth_config for key in SENSITIVE_KEYS) @@ -449,22 +469,3 @@ def _validate_param_type( raise ValueError( f"Invalid param type: {value.get('type')}. Available options: {PARAM_TYPES}" ) - - -# XXX: This is a workaround pass test_dlt_init.py -# since the source uses dlt.source as a function -def _register_source(source_func: Callable[..., DltSource]) -> None: - import inspect - from dlt.common.configuration import get_fun_spec - from dlt.common.source import _SOURCES, SourceInfo - - spec = get_fun_spec(source_func) - func_module = inspect.getmodule(source_func) - _SOURCES[source_func.__name__] = SourceInfo( - SPEC=spec, - f=source_func, - module=func_module, - ) - - -_register_source(rest_api_source) diff --git a/dlt/sources/rest_api/config_setup.py b/dlt/sources/rest_api/config_setup.py index 916715b214..0f9857b45a 100644 --- a/dlt/sources/rest_api/config_setup.py +++ b/dlt/sources/rest_api/config_setup.py @@ -14,8 +14,8 @@ ) import graphlib # type: ignore[import,unused-ignore] import string +from requests import Response -import dlt from dlt.common import logger from dlt.common.configuration import resolve_configuration from dlt.common.schema.utils import merge_columns @@ -25,7 +25,6 @@ from dlt.extract.incremental import Incremental from dlt.extract.utils import ensure_table_schema_columns -from dlt.sources.helpers.requests import Response from dlt.sources.helpers.rest_client.paginators import ( BasePaginator, SinglePagePaginator, @@ -177,14 +176,14 @@ def create_auth(auth_config: Optional[AuthConfig]) -> Optional[AuthConfigBase]: if isinstance(auth_config, dict): auth_type = auth_config.get("type", "bearer") auth_class = get_auth_class(auth_type) - auth = auth_class(**exclude_keys(auth_config, {"type"})) + auth = auth_class.from_init_value(exclude_keys(auth_config, {"type"})) - if auth: + if auth and not auth.__is_resolved__: # TODO: provide explicitly (non-default) values as explicit explicit_value=dict(auth) # this will resolve auth which is a configuration using current section context - return resolve_configuration(auth, accept_partial=True) + auth = resolve_configuration(auth, accept_partial=False) - return None + return auth def setup_incremental_object( @@ -196,7 +195,7 @@ def setup_incremental_object( if ( isinstance(param_config, dict) and param_config.get("type") == "incremental" - or isinstance(param_config, dlt.sources.incremental) + or isinstance(param_config, Incremental) ): incremental_params.append(param_name) if len(incremental_params) > 1: @@ -206,7 +205,7 @@ def setup_incremental_object( ) convert: Optional[Callable[..., Any]] for param_name, param_config in request_params.items(): - if isinstance(param_config, dlt.sources.incremental): + if isinstance(param_config, Incremental): if param_config.end_value is not None: raise ValueError( f"Only initial_value is allowed in the configuration of param: {param_name}. To" @@ -228,7 +227,7 @@ def setup_incremental_object( config = exclude_keys(param_config, {"type", "convert", "transform"}) # TODO: implement param type to bind incremental to return ( - dlt.sources.incremental(**config), + Incremental(**config), IncrementalParam(start=param_name, end=None), convert, ) @@ -238,7 +237,7 @@ def setup_incremental_object( incremental_config, {"start_param", "end_param", "convert", "transform"} ) return ( - dlt.sources.incremental(**config), + Incremental(**config), IncrementalParam( start=incremental_config["start_param"], end=incremental_config.get("end_param"), @@ -273,10 +272,10 @@ def build_resource_dependency_graph( resource_defaults: EndpointResourceBase, resource_list: List[Union[str, EndpointResource, DltResource]], ) -> Tuple[ - Any, Dict[str, Union[EndpointResource, DltResource]], Dict[str, Optional[ResolvedParam]] + Any, Dict[str, Union[EndpointResource, DltResource]], Dict[str, Optional[List[ResolvedParam]]] ]: dependency_graph = graphlib.TopologicalSorter() - resolved_param_map: Dict[str, ResolvedParam] = {} + resolved_param_map: Dict[str, Optional[List[ResolvedParam]]] = {} endpoint_resource_map = expand_and_index_resources(resource_list, resource_defaults) # create dependency graph @@ -288,20 +287,24 @@ def build_resource_dependency_graph( assert isinstance(endpoint_resource["endpoint"], dict) # connect transformers to resources via resolved params resolved_params = _find_resolved_params(endpoint_resource["endpoint"]) - if len(resolved_params) > 1: - raise ValueError( - f"Multiple resolved params for resource {resource_name}: {resolved_params}" - ) - elif len(resolved_params) == 1: - resolved_param = resolved_params[0] - predecessor = resolved_param.resolve_config["resource"] + + # set of resources in resolved params + named_resources = {rp.resolve_config["resource"] for rp in resolved_params} + + if len(named_resources) > 1: + raise ValueError(f"Multiple parent resources for {resource_name}: {resolved_params}") + elif len(named_resources) == 1: + # validate the first parameter (note the resource is the same for all params) + first_param = resolved_params[0] + predecessor = first_param.resolve_config["resource"] if predecessor not in endpoint_resource_map: raise ValueError( f"A transformer resource {resource_name} refers to non existing parent resource" - f" {predecessor} on {resolved_param}" + f" {predecessor} on {first_param}" ) + dependency_graph.add(resource_name, predecessor) - resolved_param_map[resource_name] = resolved_param + resolved_param_map[resource_name] = resolved_params else: dependency_graph.add(resource_name) resolved_param_map[resource_name] = None @@ -574,21 +577,28 @@ def remove_field(response: Response, *args, **kwargs) -> Response: def process_parent_data_item( path: str, item: Dict[str, Any], - resolved_param: ResolvedParam, + resolved_params: List[ResolvedParam], include_from_parent: List[str], ) -> Tuple[str, Dict[str, Any]]: - parent_resource_name = resolved_param.resolve_config["resource"] + parent_resource_name = resolved_params[0].resolve_config["resource"] - field_values = jsonpath.find_values(resolved_param.field_path, item) + param_values = {} - if not field_values: - field_path = resolved_param.resolve_config["field"] - raise ValueError( - f"Transformer expects a field '{field_path}' to be present in the incoming data from" - f" resource {parent_resource_name} in order to bind it to path param" - f" {resolved_param.param_name}. Available parent fields are {', '.join(item.keys())}" - ) - bound_path = path.format(**{resolved_param.param_name: field_values[0]}) + for resolved_param in resolved_params: + field_values = jsonpath.find_values(resolved_param.field_path, item) + + if not field_values: + field_path = resolved_param.resolve_config["field"] + raise ValueError( + f"Transformer expects a field '{field_path}' to be present in the incoming data" + f" from resource {parent_resource_name} in order to bind it to path param" + f" {resolved_param.param_name}. Available parent fields are" + f" {', '.join(item.keys())}" + ) + + param_values[resolved_param.param_name] = field_values[0] + + bound_path = path.format(**param_values) parent_record: Dict[str, Any] = {} if include_from_parent: diff --git a/dlt/sources/sql_database/__init__.py b/dlt/sources/sql_database/__init__.py index f7c83b4b80..1574c4aa20 100644 --- a/dlt/sources/sql_database/__init__.py +++ b/dlt/sources/sql_database/__init__.py @@ -2,20 +2,16 @@ from typing import Callable, Dict, List, Optional, Union, Iterable, Any -from dlt.common.libs.sql_alchemy import MetaData, Table, Engine - import dlt -from dlt.sources import DltResource - +from dlt.common.configuration.specs import ConnectionStringCredentials +from dlt.common.libs.sql_alchemy import MetaData, Table, Engine -from dlt.sources.credentials import ConnectionStringCredentials -from dlt.common.configuration.specs.config_section_context import ConfigSectionContext +from dlt.extract import DltResource, Incremental, decorators from .helpers import ( table_rows, engine_from_credentials, TableBackend, - SqlDatabaseTableConfiguration, SqlTableResourceConfiguration, _detect_precision_hints_deprecated, TQueryAdapter, @@ -29,7 +25,7 @@ ) -@dlt.source +@decorators.source def sql_database( credentials: Union[ConnectionStringCredentials, Engine, str] = dlt.secrets.value, schema: Optional[str] = dlt.config.value, @@ -121,13 +117,15 @@ def sql_database( ) -@dlt.resource(name=lambda args: args["table"], standalone=True, spec=SqlTableResourceConfiguration) +@decorators.resource( + name=lambda args: args["table"], standalone=True, spec=SqlTableResourceConfiguration +) def sql_table( credentials: Union[ConnectionStringCredentials, Engine, str] = dlt.secrets.value, table: str = dlt.config.value, schema: Optional[str] = dlt.config.value, metadata: Optional[MetaData] = None, - incremental: Optional[dlt.sources.incremental[Any]] = None, + incremental: Optional[Incremental[Any]] = None, chunk_size: int = 50000, backend: TableBackend = "sqlalchemy", detect_precision_hints: Optional[bool] = None, @@ -193,7 +191,7 @@ def sql_table( table_adapter_callback(table_obj) skip_nested_on_minimal = backend == "sqlalchemy" - return dlt.resource( + return decorators.resource( table_rows, name=table_obj.name, primary_key=get_primary_key(table_obj), diff --git a/dlt/sources/sql_database/arrow_helpers.py b/dlt/sources/sql_database/arrow_helpers.py index 898d8c3280..1f72205a2a 100644 --- a/dlt/sources/sql_database/arrow_helpers.py +++ b/dlt/sources/sql_database/arrow_helpers.py @@ -1,150 +1,25 @@ -from typing import Any, Sequence, Optional +from typing import Any, Sequence from dlt.common.schema.typing import TTableSchemaColumns -from dlt.common import logger, json + from dlt.common.configuration import with_config from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.json import custom_encode, map_nested_in_place - -from .schema_types import RowAny +from dlt.common.libs.pyarrow import ( + row_tuples_to_arrow as _row_tuples_to_arrow, +) @with_config -def columns_to_arrow( - columns_schema: TTableSchemaColumns, +def row_tuples_to_arrow( + rows: Sequence[Any], caps: DestinationCapabilitiesContext = None, - tz: str = "UTC", + columns: TTableSchemaColumns = None, + tz: str = None, ) -> Any: """Converts `column_schema` to arrow schema using `caps` and `tz`. `caps` are injected from the container - which is always the case if run within the pipeline. This will generate arrow schema compatible with the destination. Otherwise generic capabilities are used """ - from dlt.common.libs.pyarrow import pyarrow as pa, get_py_arrow_datatype - from dlt.common.destination.capabilities import DestinationCapabilitiesContext - - return pa.schema( - [ - pa.field( - name, - get_py_arrow_datatype( - schema_item, - caps or DestinationCapabilitiesContext.generic_capabilities(), - tz, - ), - nullable=schema_item.get("nullable", True), - ) - for name, schema_item in columns_schema.items() - if schema_item.get("data_type") is not None - ] + return _row_tuples_to_arrow( + rows, caps or DestinationCapabilitiesContext.generic_capabilities(), columns, tz ) - - -def row_tuples_to_arrow(rows: Sequence[RowAny], columns: TTableSchemaColumns, tz: str) -> Any: - """Converts the rows to an arrow table using the columns schema. - Columns missing `data_type` will be inferred from the row data. - Columns with object types not supported by arrow are excluded from the resulting table. - """ - from dlt.common.libs.pyarrow import pyarrow as pa - import numpy as np - - try: - from pandas._libs import lib - - pivoted_rows = lib.to_object_array_tuples(rows).T - except ImportError: - logger.info( - "Pandas not installed, reverting to numpy.asarray to create a table which is slower" - ) - pivoted_rows = np.asarray(rows, dtype="object", order="k").T # type: ignore[call-overload] - - columnar = { - col: dat.ravel() for col, dat in zip(columns, np.vsplit(pivoted_rows, len(columns))) - } - columnar_known_types = { - col["name"]: columnar[col["name"]] - for col in columns.values() - if col.get("data_type") is not None - } - columnar_unknown_types = { - col["name"]: columnar[col["name"]] - for col in columns.values() - if col.get("data_type") is None - } - - arrow_schema = columns_to_arrow(columns, tz=tz) - - for idx in range(0, len(arrow_schema.names)): - field = arrow_schema.field(idx) - py_type = type(rows[0][idx]) - # cast double / float ndarrays to decimals if type mismatch, looks like decimals and floats are often mixed up in dialects - if pa.types.is_decimal(field.type) and issubclass(py_type, (str, float)): - logger.warning( - f"Field {field.name} was reflected as decimal type, but rows contains" - f" {py_type.__name__}. Additional cast is required which may slow down arrow table" - " generation." - ) - float_array = pa.array(columnar_known_types[field.name], type=pa.float64()) - columnar_known_types[field.name] = float_array.cast(field.type, safe=False) - if issubclass(py_type, (dict, list)): - logger.warning( - f"Field {field.name} was reflected as JSON type and needs to be serialized back to" - " string to be placed in arrow table. This will slow data extraction down. You" - " should cast JSON field to STRING in your database system ie. by creating and" - " extracting an SQL VIEW that selects with cast." - ) - json_str_array = pa.array( - [None if s is None else json.dumps(s) for s in columnar_known_types[field.name]] - ) - columnar_known_types[field.name] = json_str_array - - # If there are unknown type columns, first create a table to infer their types - if columnar_unknown_types: - new_schema_fields = [] - for key in list(columnar_unknown_types): - arrow_col: Optional[pa.Array] = None - try: - arrow_col = pa.array(columnar_unknown_types[key]) - if pa.types.is_null(arrow_col.type): - logger.warning( - f"Column {key} contains only NULL values and data type could not be" - " inferred. This column is removed from a arrow table" - ) - continue - - except pa.ArrowInvalid as e: - # Try coercing types not supported by arrow to a json friendly format - # E.g. dataclasses -> dict, UUID -> str - try: - arrow_col = pa.array( - map_nested_in_place(custom_encode, list(columnar_unknown_types[key])) - ) - logger.warning( - f"Column {key} contains a data type which is not supported by pyarrow and" - f" got converted into {arrow_col.type}. This slows down arrow table" - " generation." - ) - except (pa.ArrowInvalid, TypeError): - logger.warning( - f"Column {key} contains a data type which is not supported by pyarrow. This" - f" column will be ignored. Error: {e}" - ) - if arrow_col is not None: - columnar_known_types[key] = arrow_col - new_schema_fields.append( - pa.field( - key, - arrow_col.type, - nullable=columns[key]["nullable"], - ) - ) - - # New schema - column_order = {name: idx for idx, name in enumerate(columns)} - arrow_schema = pa.schema( - sorted( - list(arrow_schema) + new_schema_fields, - key=lambda x: column_order[x.name], - ) - ) - - return pa.Table.from_pydict(columnar_known_types, schema=arrow_schema) diff --git a/dlt/sources/sql_database/helpers.py b/dlt/sources/sql_database/helpers.py index 1d758fe882..24b31c3802 100644 --- a/dlt/sources/sql_database/helpers.py +++ b/dlt/sources/sql_database/helpers.py @@ -14,12 +14,16 @@ import operator import dlt -from dlt.common.configuration.specs import BaseConfiguration, configspec +from dlt.common.configuration.specs import ( + BaseConfiguration, + ConnectionStringCredentials, + configspec, +) from dlt.common.exceptions import MissingDependencyException from dlt.common.schema import TTableSchemaColumns from dlt.common.typing import TDataItem, TSortOrder -from dlt.sources.credentials import ConnectionStringCredentials +from dlt.extract import Incremental from .arrow_helpers import row_tuples_to_arrow from .schema_types import ( @@ -47,7 +51,7 @@ def __init__( table: Table, columns: TTableSchemaColumns, chunk_size: int = 1000, - incremental: Optional[dlt.sources.incremental[Any]] = None, + incremental: Optional[Incremental[Any]] = None, query_adapter_callback: Optional[TQueryAdapter] = None, ) -> None: self.engine = engine @@ -146,7 +150,7 @@ def _load_rows(self, query: SelectAny, backend_kwargs: Dict[str, Any]) -> TDataI yield df elif self.backend == "pyarrow": yield row_tuples_to_arrow( - partition, self.columns, tz=backend_kwargs.get("tz", "UTC") + partition, columns=self.columns, tz=backend_kwargs.get("tz", "UTC") ) def _load_rows_connectorx( @@ -186,7 +190,7 @@ def table_rows( table: Table, chunk_size: int, backend: TableBackend, - incremental: Optional[dlt.sources.incremental[Any]] = None, + incremental: Optional[Incremental[Any]] = None, defer_table_reflect: bool = False, table_adapter_callback: Callable[[Table], None] = None, reflection_level: ReflectionLevel = "minimal", @@ -291,18 +295,12 @@ def _detect_precision_hints_deprecated(value: Optional[bool]) -> None: ) -@configspec -class SqlDatabaseTableConfiguration(BaseConfiguration): - incremental: Optional[dlt.sources.incremental] = None # type: ignore[type-arg] - included_columns: Optional[List[str]] = None - - @configspec class SqlTableResourceConfiguration(BaseConfiguration): credentials: Union[ConnectionStringCredentials, Engine, str] = None table: str = None schema: Optional[str] = None - incremental: Optional[dlt.sources.incremental] = None # type: ignore[type-arg] + incremental: Optional[Incremental] = None # type: ignore[type-arg] chunk_size: int = 50000 backend: TableBackend = "sqlalchemy" detect_precision_hints: Optional[bool] = None diff --git a/docs/examples/conftest.py b/docs/examples/conftest.py index be1a03990b..b00436fc10 100644 --- a/docs/examples/conftest.py +++ b/docs/examples/conftest.py @@ -35,8 +35,8 @@ def setup_secret_providers(request): def _initial_providers(): return [ EnvironProvider(), - SecretsTomlProvider(project_dir=secret_dir, add_global_config=False), - ConfigTomlProvider(project_dir=config_dir, add_global_config=False), + SecretsTomlProvider(settings_dir=secret_dir, add_global_config=False), + ConfigTomlProvider(settings_dir=config_dir, add_global_config=False), ] glob_ctx = ConfigProvidersContext() diff --git a/docs/examples/custom_destination_lancedb/custom_destination_lancedb.py b/docs/examples/custom_destination_lancedb/custom_destination_lancedb.py index 305c7d1f1a..aa2f284f5b 100644 --- a/docs/examples/custom_destination_lancedb/custom_destination_lancedb.py +++ b/docs/examples/custom_destination_lancedb/custom_destination_lancedb.py @@ -92,7 +92,7 @@ def spotify_shows( spotify_base_api_url = "https://api.spotify.com/v1" client = RESTClient( base_url=spotify_base_api_url, - auth=SpotifyAuth(client_id=client_id, client_secret=client_secret), # type: ignore[arg-type] + auth=SpotifyAuth(client_id=client_id, client_secret=client_secret), ) for show in fields(Shows): diff --git a/docs/website/docs/conftest.py b/docs/website/docs/conftest.py index 87ccffe53b..a4b82c46bc 100644 --- a/docs/website/docs/conftest.py +++ b/docs/website/docs/conftest.py @@ -34,8 +34,8 @@ def setup_secret_providers(request): def _initial_providers(): return [ EnvironProvider(), - SecretsTomlProvider(project_dir=secret_dir, add_global_config=False), - ConfigTomlProvider(project_dir=config_dir, add_global_config=False), + SecretsTomlProvider(settings_dir=secret_dir, add_global_config=False), + ConfigTomlProvider(settings_dir=config_dir, add_global_config=False), ] glob_ctx = ConfigProvidersContext() diff --git a/docs/website/docs/dlt-ecosystem/destinations/databricks.md b/docs/website/docs/dlt-ecosystem/destinations/databricks.md index ddbf930306..08d2f0751c 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/databricks.md +++ b/docs/website/docs/dlt-ecosystem/destinations/databricks.md @@ -141,7 +141,7 @@ The `jsonl` format has some limitations when used with Databricks: ## Staging support -Databricks supports both Amazon S3 and Azure Blob Storage as staging locations. `dlt` will upload files in `parquet` format to the staging location and will instruct Databricks to load data from there. +Databricks supports both Amazon S3, Azure Blob Storage and Google Cloud Storage as staging locations. `dlt` will upload files in `parquet` format to the staging location and will instruct Databricks to load data from there. ### Databricks and Amazon S3 @@ -187,6 +187,11 @@ pipeline = dlt.pipeline( ``` +### Databricks and Google Cloud Storage + +In order to load from Google Cloud Storage stage you must set-up the credentials via **named credential**. See below. Databricks does not allow to pass Google Credentials +explicitly in SQL Statements. + ### Use external locations and stored credentials `dlt` forwards bucket credentials to the `COPY INTO` SQL command by default. You may prefer to use [external locations or stored credentials instead](https://docs.databricks.com/en/sql/language-manual/sql-ref-external-locations.html#external-location) that are stored on the Databricks side. diff --git a/docs/website/docs/dlt-ecosystem/destinations/sqlalchemy.md b/docs/website/docs/dlt-ecosystem/destinations/sqlalchemy.md index b9014e0564..9f33c02337 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/sqlalchemy.md +++ b/docs/website/docs/dlt-ecosystem/destinations/sqlalchemy.md @@ -135,8 +135,7 @@ The following write dispositions are supported: - `append` - `replace` with `truncate-and-insert` and `insert-from-staging` replace strategies. `staging-optimized` falls back to `insert-from-staging`. - -The `merge` disposition is not supported and falls back to `append`. +- `merge` with `delete-insert` and `scd2` merge strategies. ## Data loading diff --git a/docs/website/docs/dlt-ecosystem/file-formats/csv.md b/docs/website/docs/dlt-ecosystem/file-formats/csv.md index 6b9ff68269..687ae3085c 100644 --- a/docs/website/docs/dlt-ecosystem/file-formats/csv.md +++ b/docs/website/docs/dlt-ecosystem/file-formats/csv.md @@ -1,13 +1,13 @@ --- -title: csv -description: The csv file format +title: CSV +description: The CSV file format keywords: [csv, file formats] --- import SetTheFormat from './_set_the_format.mdx'; # CSV file format -**csv** is the most basic file format for storing tabular data, where all values are strings and are separated by a delimiter (typically a comma). +**CSV** is the most basic file format for storing tabular data, where all values are strings and are separated by a delimiter (typically a comma). `dlt` uses it for specific use cases - mostly for performance and compatibility reasons. Internally, we use two implementations: @@ -16,7 +16,7 @@ Internally, we use two implementations: ## Supported destinations -The `csv` format is supported by the following destinations: **Postgres**, **Filesystem**, **Snowflake** +The CSV format is supported by the following destinations: **Postgres**, **Filesystem**, **Snowflake** ## How to configure diff --git a/docs/website/docs/dlt-ecosystem/file-formats/jsonl.md b/docs/website/docs/dlt-ecosystem/file-formats/jsonl.md index 54e5b1cbd2..f1783aa29e 100644 --- a/docs/website/docs/dlt-ecosystem/file-formats/jsonl.md +++ b/docs/website/docs/dlt-ecosystem/file-formats/jsonl.md @@ -1,11 +1,11 @@ --- -title: jsonl -description: The jsonl file format -keywords: [jsonl, file formats] +title: JSONL +description: The JSONL file format or JSON Delimited stores several JSON documents in one file. The JSON documents are separated by a new line. +keywords: [jsonl, file formats, json delimited, jsonl file format] --- import SetTheFormat from './_set_the_format.mdx'; -# jsonl - JSON delimited +# JSONL - JSON Lines - JSON Delimited JSON delimited is a file format that stores several JSON documents in one file. The JSON documents are separated by a new line. diff --git a/docs/website/docs/dlt-ecosystem/transformations/pandas.md b/docs/website/docs/dlt-ecosystem/transformations/pandas.md index 4125e4e114..cda4855268 100644 --- a/docs/website/docs/dlt-ecosystem/transformations/pandas.md +++ b/docs/website/docs/dlt-ecosystem/transformations/pandas.md @@ -22,7 +22,7 @@ with pipeline.sql_client() as client: with client.execute_query( 'SELECT "reactions__+1", "reactions__-1", reactions__laugh, reactions__hooray, reactions__rocket FROM issues' ) as table: - # calling `df` on a cursor returns the data as a data frame + # calling `df` on a cursor, returns the data as a pandas data frame reactions = table.df() counts = reactions.sum(0).sort_values(0, ascending=False) ``` diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/filesystem/advanced.md b/docs/website/docs/dlt-ecosystem/verified-sources/filesystem/advanced.md index e1eeca0ee9..a66a7b1d7f 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/filesystem/advanced.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/filesystem/advanced.md @@ -1,5 +1,5 @@ --- -title: Advanced Filesystem Usage +title: Advanced filesystem usage description: Use filesystem source as a building block keywords: [readers source and filesystem, files, filesystem, readers source, cloud storage] --- @@ -54,7 +54,7 @@ When using a nested or recursive glob pattern, `relative_path` will include the ## Create your own transformer -Although the `filesystem` resource yields the files from cloud storage or a local filesystem, you need to apply a transformer resource to retrieve the records from files. `dlt` natively supports three file types: `csv`, `parquet`, and `jsonl` (more details in [filesystem transformer resource](../filesystem/basic#2-choose-the-right-transformer-resource)). +Although the `filesystem` resource yields the files from cloud storage or a local filesystem, you need to apply a transformer resource to retrieve the records from files. dlt natively supports three file types: [CSV](../../file-formats/csv.md), [Parquet](../../file-formats/parquet.md), and [JSONL](../../file-formats/jsonl.md) (more details in [filesystem transformer resource](../filesystem/basic#2-choose-the-right-transformer-resource)). But you can easily create your own. In order to do this, you just need a function that takes as input a `FileItemDict` iterator and yields a list of records (recommended for performance) or individual records. diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/filesystem/basic.md b/docs/website/docs/dlt-ecosystem/verified-sources/filesystem/basic.md index ac4af7862f..6df10323dd 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/filesystem/basic.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/filesystem/basic.md @@ -1,14 +1,14 @@ --- title: Filesystem source description: Learn how to set up and configure -keywords: [readers source and filesystem, files, filesystem, readers source, cloud storage] +keywords: [readers source and filesystem, files, filesystem, readers source, cloud storage, object storage, local file system] --- import Header from '../_source-info-header.md';
-Filesystem source allows loading files from remote locations (AWS S3, Google Cloud Storage, Google Drive, Azure Blob Storage, SFTP server) or the local filesystem seamlessly. Filesystem source natively supports `csv`, `parquet`, and `jsonl` files and allows customization for loading any type of structured files. +Filesystem source allows loading files from remote locations (AWS S3, Google Cloud Storage, Google Drive, Azure Blob Storage, SFTP server) or the local filesystem seamlessly. Filesystem source natively supports [CSV](../../file-formats/csv.md), [Parquet](../../file-formats/parquet.md), and [JSONL](../../file-formats/jsonl.md) files and allows customization for loading any type of structured files. -To load unstructured data (`.pdf`, `.txt`, e-mail), please refer to the [unstructured data source](https://github.com/dlt-hub/verified-sources/tree/master/sources/unstructured_data). +To load unstructured data (PDF, plain text, e-mail), please refer to the [unstructured data source](https://github.com/dlt-hub/verified-sources/tree/master/sources/unstructured_data). ## How filesystem source works @@ -145,11 +145,8 @@ You don't need any credentials for the local filesystem. ### Add credentials to dlt pipeline -To provide credentials to the filesystem source, you can use [any method available](../../../general-usage/credentials/setup#available-config-providers) in `dlt`. -One of the easiest ways is to use configuration files. The `.dlt` folder in your working directory -contains two files: `config.toml` and `secrets.toml`. Sensitive information, like passwords and -access tokens, should only be put into `secrets.toml`, while any other configuration, like the path to -a bucket, can be specified in `config.toml`. +To provide credentials to the filesystem source, you can use [any method available](../../../general-usage/credentials/setup#available-config-providers) in dlt. +One of the easiest ways is to use configuration files. The `.dlt` folder in your working directory contains two files: `config.toml` and `secrets.toml`. Sensitive information, like passwords and access tokens, should only be put into `secrets.toml`, while any other configuration, like the path to a bucket, can be specified in `config.toml`. -You can also specify the credentials using environment variables. The name of the corresponding environment -variable should be slightly different from the corresponding name in the TOML file. Simply replace dots `.` with double -underscores `__`: +You can also specify the credentials using environment variables. The name of the corresponding environment variable should be slightly different from the corresponding name in the TOML file. Simply replace dots `.` with double underscores `__`: ```sh export SOURCES__FILESYSTEM__AWS_ACCESS_KEY_ID = "Please set me up!" @@ -262,16 +257,12 @@ export SOURCES__FILESYSTEM__AWS_SECRET_ACCESS_KEY = "Please set me up!" ``` :::tip -`dlt` supports more ways of authorizing with cloud storage, including identity-based -and default credentials. To learn more about adding credentials to your pipeline, please refer to the -[Configuration and secrets section](../../../general-usage/credentials/complex_types#gcp-credentials). +dlt supports more ways of authorizing with cloud storage, including identity-based and default credentials. To learn more about adding credentials to your pipeline, please refer to the [Configuration and secrets section](../../../general-usage/credentials/complex_types#gcp-credentials). ::: ## Usage -The filesystem source is quite unique since it provides you with building blocks for loading data from files. -First, it iterates over files in the storage and then processes each file to yield the records. -Usually, you need two resources: +The filesystem source is quite unique since it provides you with building blocks for loading data from files. First, it iterates over files in the storage and then processes each file to yield the records. Usually, you need two resources: 1. The `filesystem` resource enumerates files in a selected bucket using a glob pattern, returning details as `FileItem` in customizable page sizes. 2. One of the available transformer resources to process each file in a specific transforming function and yield the records. @@ -279,8 +270,7 @@ Usually, you need two resources: ### 1. Initialize a `filesystem` resource :::note -If you use just the `filesystem` resource, it will only list files in the storage based on glob parameters and yield the -files [metadata](advanced#fileitem-fields). The `filesystem` resource itself does not read or copy files. +If you use just the `filesystem` resource, it will only list files in the storage based on glob parameters and yield the files [metadata](advanced#fileitem-fields). The `filesystem` resource itself does not read or copy files. ::: All parameters of the resource can be specified directly in code: @@ -319,9 +309,8 @@ Full list of `filesystem` resource parameters: ### 2. Choose the right transformer resource -The current implementation of the filesystem source natively supports three file types: `csv`, `parquet`, and `jsonl`. -You can apply any of the above or [create your own transformer](advanced#create-your-own-transformer). To apply the selected transformer -resource, use pipe notation `|`: +The current implementation of the filesystem source natively supports three file types: CSV, Parquet, and JSONL. +You can apply any of the above or [create your own transformer](advanced#create-your-own-transformer). To apply the selected transformer resource, use pipe notation `|`: ```py from dlt.sources.filesystem import filesystem, read_csv @@ -334,17 +323,13 @@ filesystem_pipe = filesystem( #### Available transformers -- `read_csv()` - processes `csv` files using `pandas` -- `read_jsonl()` - processes `jsonl` files chunk by chunk -- `read_parquet()` - processes `parquet` files using `pyarrow` -- `read_csv_duckdb()` - this transformer processes `csv` files using DuckDB, which usually shows better performance than `pandas`. +- `read_csv()` - processes CSV files using [Pandas](https://pandas.pydata.org/) +- `read_jsonl()` - processes JSONL files chunk by chunk +- `read_parquet()` - processes Parquet files using [PyArrow](https://arrow.apache.org/docs/python/) +- `read_csv_duckdb()` - this transformer processes CSV files using DuckDB, which usually shows better performance than pandas. :::tip -We advise that you give each resource a -[specific name](../../../general-usage/resource#duplicate-and-rename-resources) -before loading with `pipeline.run`. This will ensure that data goes to a table with the name you -want and that each pipeline uses a -[separate state for incremental loading.](../../../general-usage/state#read-and-write-pipeline-state-in-a-resource) +We advise that you give each resource a [specific name](../../../general-usage/resource#duplicate-and-rename-resources) before loading with `pipeline.run`. This will ensure that data goes to a table with the name you want and that each pipeline uses a [separate state for incremental loading.](../../../general-usage/state#read-and-write-pipeline-state-in-a-resource) ::: ### 3. Create and run a pipeline @@ -406,6 +391,7 @@ print(load_info) In this example, we load only new records based on the field called `updated_at`. This method may be useful if you are not able to filter files by modification date because, for example, all files are modified each time a new record appears. + ```py import dlt from dlt.sources.filesystem import filesystem, read_csv @@ -462,6 +448,7 @@ print(load_info) :::tip You could also use `file_glob` to filter files by names. It works very well in simple cases, for example, filtering by extension: + ```py from dlt.sources.filesystem import filesystem @@ -505,16 +492,12 @@ bucket_url = '\\?\C:\a\b\c' ### If you get an empty list of files -If you are running a `dlt` pipeline with the filesystem source and get zero records, we recommend you check +If you are running a dlt pipeline with the filesystem source and get zero records, we recommend you check the configuration of `bucket_url` and `file_glob` parameters. -For example, with Azure Blob storage, people sometimes mistake the account name for the container name. Make sure -you've set up a URL as `"az:///"`. +For example, with Azure Blob Storage, people sometimes mistake the account name for the container name. Make sure you've set up a URL as `"az:///"`. -Also, please reference the [glob](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.glob) -function to configure the resource correctly. Use `**` to include recursive files. Note that the local -filesystem supports full Python [glob](https://docs.python.org/3/library/glob.html#glob.glob) functionality, -while cloud storage supports a restricted `fsspec` [version](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.glob). +Also, please reference the [glob](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.glob) function to configure the resource correctly. Use `**` to include recursive files. Note that the local filesystem supports full Python [glob](https://docs.python.org/3/library/glob.html#glob.glob) functionality, while cloud storage supports a restricted `fsspec` [version](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.glob). diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/filesystem/index.md b/docs/website/docs/dlt-ecosystem/verified-sources/filesystem/index.md index 0aaa07b0c3..5aa930c1ae 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/filesystem/index.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/filesystem/index.md @@ -1,10 +1,10 @@ --- -title: Filesystem & cloud storage -description: dlt-verified source for Filesystem & cloud storage -keywords: [readers source and filesystem, files, filesystem, readers source, cloud storage] +title: Cloud storage and filesystem +description: dlt-verified source for reading files from cloud storage and local file system +keywords: [file system, files, filesystem, readers source, cloud storage, object storage, local file system] --- -The Filesystem source allows seamless loading of files from the following locations: +The filesystem source allows seamless loading of files from the following locations: * AWS S3 * Google Cloud Storage * Google Drive @@ -12,7 +12,7 @@ The Filesystem source allows seamless loading of files from the following locati * remote filesystem (via SFTP) * local filesystem -The Filesystem source natively supports `csv`, `parquet`, and `jsonl` files and allows customization for loading any type of structured file. +The filesystem source natively supports [CSV](../../file-formats/csv.md), [Parquet](../../file-formats/parquet.md), and [JSONL](../../file-formats/jsonl.md) files and allows customization for loading any type of structured file. import DocCardList from '@theme/DocCardList'; diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/rest_api/basic.md b/docs/website/docs/dlt-ecosystem/verified-sources/rest_api/basic.md index 03214950f4..b7ce29b391 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/rest_api/basic.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/rest_api/basic.md @@ -574,7 +574,7 @@ rest_api.config_setup.register_auth("custom_auth", CustomAuth) ### Define resource relationships -When you have a resource that depends on another resource, you can define the relationship using the `resolve` configuration. With it, you link a path parameter in the child resource to a field in the parent resource's data. +When you have a resource that depends on another resource, you can define the relationship using the `resolve` configuration. This allows you to link one or more path parameters in the child resource to fields in the parent resource's data. In the GitHub example, the `issue_comments` resource depends on the `issues` resource. The `issue_number` parameter in the `issue_comments` endpoint configuration is resolved from the `number` field of the `issues` resource: @@ -638,6 +638,54 @@ The `field` value can be specified as a [JSONPath](https://github.com/h2non/json Under the hood, dlt handles this by using a [transformer resource](../../../general-usage/resource.md#process-resources-with-dlttransformer). +#### Resolving multiple path parameters from a parent resource + +When a child resource depends on multiple fields from a single parent resource, you can define multiple `resolve` parameters in the endpoint configuration. For example: + +```py +{ + "resources": [ + "groups", + { + "name": "users", + "endpoint": { + "path": "groups/{group_id}/users", + "params": { + "group_id": { + "type": "resolve", + "resource": "groups", + "field": "id", + }, + }, + }, + }, + { + "name": "user_details", + "endpoint": { + "path": "groups/{group_id}/users/{user_id}/details", + "params": { + "group_id": { + "type": "resolve", + "resource": "users", + "field": "group_id", + }, + "user_id": { + "type": "resolve", + "resource": "users", + "field": "id", + }, + }, + }, + }, + ], +} +``` + +In the configuration above: + +- The `users` resource depends on the `groups` resource, resolving the `group_id` parameter from the `id` field in `groups`. +- The `user_details` resource depends on the `users` resource, resolving both `group_id` and `user_id` parameters from fields in `users`. + #### Include fields from the parent resource You can include data from the parent resource in the child resource by using the `include_from_parent` field in the resource configuration. For example: diff --git a/docs/website/docs/dlt-ecosystem/visualizations/exploring-the-data.md b/docs/website/docs/dlt-ecosystem/visualizations/exploring-the-data.md index 65c937ef77..2d7a7642c2 100644 --- a/docs/website/docs/dlt-ecosystem/visualizations/exploring-the-data.md +++ b/docs/website/docs/dlt-ecosystem/visualizations/exploring-the-data.md @@ -65,7 +65,7 @@ with pipeline.sql_client() as client: with client.execute_query( 'SELECT "reactions__+1", "reactions__-1", reactions__laugh, reactions__hooray, reactions__rocket FROM issues' ) as table: - # calling `df` on a cursor returns the data as a DataFrame + # calling `df` on a cursor, returns the data as a pandas DataFrame reactions = table.df() counts = reactions.sum(0).sort_values(0, ascending=False) ``` diff --git a/docs/website/docs/general-usage/credentials/setup.md b/docs/website/docs/general-usage/credentials/setup.md index 9d459cc298..709cf09fe8 100644 --- a/docs/website/docs/general-usage/credentials/setup.md +++ b/docs/website/docs/general-usage/credentials/setup.md @@ -180,6 +180,14 @@ Check out the [example](#examples) of setting up credentials through environment To organize development and securely manage environment variables for credentials storage, you can use [python-dotenv](https://pypi.org/project/python-dotenv/) to automatically load variables from an `.env` file. ::: +:::tip +Environment Variables additionally looks for secret values in `/run/secrets/` to seamlessly resolve values defined as **Kubernetes/Docker secrets**. +For that purpose it uses alternative name format with lowercase, `-` (dash) as a separator and "_" converted into `-`: +In the example above: `sources--facebook-ads--access-token` will be used to search for the secrets (and other forms up until `access-token`). +Mind that only values marked as secret (with `dlt.secrets.value` or using ie. `TSecretStrValue` explicitly) are checked. Remember to name your secrets +in Kube resources/compose file properly. +::: + ## Vaults Vault integration methods vary based on the vault type. Check out our example involving [Google Cloud Secrets Manager](../../walkthroughs/add_credentials.md#retrieving-credentials-from-google-cloud-secret-manager). diff --git a/docs/website/docs/general-usage/http/rest-client.md b/docs/website/docs/general-usage/http/rest-client.md index 125604ab94..c1606b99bb 100644 --- a/docs/website/docs/general-usage/http/rest-client.md +++ b/docs/website/docs/general-usage/http/rest-client.md @@ -438,11 +438,10 @@ The available authentication methods are defined in the `dlt.sources.helpers.res - [BearerTokenAuth](#bearer-token-authentication) - [APIKeyAuth](#api-key-authentication) - [HttpBasicAuth](#http-basic-authentication) -- [OAuth2ClientCredentials](#oauth20-authorization) +- [OAuth2ClientCredentials](#oauth-20-authorization) -For specific use cases, you can [implement custom authentication](#implementing-custom-authentication) by subclassing the `AuthBase` class from the Requests library. -For specific flavors of OAuth 2.0, you can [implement custom OAuth 2.0](#oauth2-authorization) -by subclassing `OAuth2ClientCredentials`. +For specific use cases, you can [implement custom authentication](#implementing-custom-authentication) by subclassing the `AuthConfigBase` class from the `dlt.sources.helpers.rest_client.auth` module. +For specific flavors of OAuth 2.0, you can [implement custom OAuth 2.0](#oauth-20-authorization) by subclassing `OAuth2ClientCredentials`. ### Bearer token authentication @@ -565,12 +564,12 @@ response = client.get("/users") ### Implementing custom authentication -You can implement custom authentication by subclassing the `AuthBase` class and implementing the `__call__` method: +You can implement custom authentication by subclassing the `AuthConfigBase` class and implementing the `__call__` method: ```py -from requests.auth import AuthBase +from dlt.sources.helpers.rest_client.auth import AuthConfigBase -class CustomAuth(AuthBase): +class CustomAuth(AuthConfigBase): def __init__(self, token): self.token = token diff --git a/docs/website/netlify.toml b/docs/website/netlify.toml index e5d3550f5c..51cc4ee21f 100644 --- a/docs/website/netlify.toml +++ b/docs/website/netlify.toml @@ -37,4 +37,8 @@ to = "/docs/tutorial/load-data-from-an-api" [[redirects]] from = "/docs/telemetry" -to = "/docs/reference/telemetry" \ No newline at end of file +to = "/docs/reference/telemetry" + +[[redirects]] +from = "/docs/walkthroughs" +to = "/docs/intro" diff --git a/docs/website/sidebars.js b/docs/website/sidebars.js index 223d469a47..c5bbbd5f7b 100644 --- a/docs/website/sidebars.js +++ b/docs/website/sidebars.js @@ -107,7 +107,7 @@ const sidebars = { }, { type: 'category', - label: 'Filesystem & cloud storage', + label: 'Cloud storage and filesystem', description: 'AWS S3, Google Cloud Storage, Azure, SFTP, local file system', link: { type: 'doc', diff --git a/poetry.lock b/poetry.lock index 13a8bcf5db..a9ceb1a8f4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2208,6 +2208,23 @@ urllib3 = ">=1.26" alembic = ["alembic (>=1.0.11,<2.0.0)", "sqlalchemy (>=2.0.21)"] sqlalchemy = ["sqlalchemy (>=2.0.21)"] +[[package]] +name = "db-dtypes" +version = "1.3.0" +description = "Pandas Data Types for SQL systems (BigQuery, Spanner)" +optional = true +python-versions = ">=3.7" +files = [ + {file = "db_dtypes-1.3.0-py2.py3-none-any.whl", hash = "sha256:7e65c59f849ccbe6f7bc4d0253edcc212a7907662906921caba3e4aadd0bc277"}, + {file = "db_dtypes-1.3.0.tar.gz", hash = "sha256:7bcbc8858b07474dc85b77bb2f3ae488978d1336f5ea73b58c39d9118bc3e91b"}, +] + +[package.dependencies] +numpy = ">=1.16.6" +packaging = ">=17.0" +pandas = ">=0.24.2" +pyarrow = ">=3.0.0" + [[package]] name = "dbt-athena-community" version = "1.7.1" @@ -8900,6 +8917,21 @@ toml = {version = "*", markers = "python_version < \"3.11\""} tqdm = "*" typing-extensions = "*" +[[package]] +name = "sqlglot" +version = "25.24.5" +description = "An easily customizable SQL parser and transpiler" +optional = true +python-versions = ">=3.7" +files = [ + {file = "sqlglot-25.24.5-py3-none-any.whl", hash = "sha256:f8a8870d1f5cdd2e2dc5c39a5030a0c7b0a91264fb8972caead3dac8e8438873"}, + {file = "sqlglot-25.24.5.tar.gz", hash = "sha256:6d3d604034301ca3b614d6b4148646b4033317b7a93d1801e9661495eb4b4fcf"}, +] + +[package.extras] +dev = ["duckdb (>=0.6)", "maturin (>=1.4,<2.0)", "mypy", "pandas", "pandas-stubs", "pdoc", "pre-commit", "python-dateutil", "pytz", "ruff (==0.4.3)", "types-python-dateutil", "types-pytz", "typing-extensions"] +rs = ["sqlglotrs (==0.2.12)"] + [[package]] name = "sqlparse" version = "0.4.4" @@ -10174,15 +10206,15 @@ cffi = ["cffi (>=1.11)"] [extras] athena = ["botocore", "pyarrow", "pyathena", "s3fs"] az = ["adlfs"] -bigquery = ["gcsfs", "google-cloud-bigquery", "grpcio", "pyarrow"] +bigquery = ["db-dtypes", "gcsfs", "google-cloud-bigquery", "grpcio", "pyarrow"] cli = ["cron-descriptor", "pipdeptree"] clickhouse = ["adlfs", "clickhouse-connect", "clickhouse-driver", "gcsfs", "pyarrow", "s3fs"] databricks = ["databricks-sql-connector"] deltalake = ["deltalake", "pyarrow"] dremio = ["pyarrow"] duckdb = ["duckdb"] -filesystem = ["botocore", "s3fs"] -gcp = ["gcsfs", "google-cloud-bigquery", "grpcio"] +filesystem = ["botocore", "s3fs", "sqlglot"] +gcp = ["db-dtypes", "gcsfs", "google-cloud-bigquery", "grpcio"] gs = ["gcsfs"] lancedb = ["lancedb", "pyarrow", "tantivy"] motherduck = ["duckdb", "pyarrow"] diff --git a/pyproject.toml b/pyproject.toml index 46f6af4086..db9bdc7446 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "dlt" -version = "1.1.0" +version = "1.2.0" description = "dlt is an open-source python-first scalable data loading library that does not require any backend to run." authors = ["dltHub Inc. "] maintainers = [ "Marcin Rudolf ", "Adrian Brudaru ", "Anton Burnashev ", "David Scharf " ] @@ -50,6 +50,7 @@ tenacity = ">=8.0.2" jsonpath-ng = ">=1.5.3" fsspec = ">=2022.4.0" packaging = ">=21.1" +pluggy = ">=1.3.0" win-precise-time = {version = ">=1.4.2", markers="os_name == 'nt'"} graphlib-backport = {version = "*", python = "<3.9"} @@ -84,18 +85,20 @@ lancedb = { version = ">=0.8.2", optional = true, markers = "python_version >= ' tantivy = { version = ">= 0.22.0", optional = true } deltalake = { version = ">=0.19.0", optional = true } sqlalchemy = { version = ">=1.4", optional = true } -alembic = {version = "^1.13.2", optional = true} +alembic = {version = ">1.10.0", optional = true} paramiko = {version = ">=3.3.0", optional = true} +sqlglot = {version = ">=20.0.0", optional = true} +db-dtypes = { version = ">=1.2.0", optional = true } [tool.poetry.extras] gcp = ["grpcio", "google-cloud-bigquery", "db-dtypes", "gcsfs"] # bigquery is alias on gcp extras -bigquery = ["grpcio", "google-cloud-bigquery", "pyarrow", "db-dtypes", "gcsfs"] +bigquery = ["grpcio", "google-cloud-bigquery", "pyarrow", "gcsfs", "db-dtypes"] postgres = ["psycopg2-binary", "psycopg2cffi"] redshift = ["psycopg2-binary", "psycopg2cffi"] parquet = ["pyarrow"] duckdb = ["duckdb"] -filesystem = ["s3fs", "botocore"] +filesystem = ["s3fs", "botocore", "sqlglot"] s3 = ["s3fs", "botocore"] gs = ["gcsfs"] az = ["adlfs"] diff --git a/tests/cli/common/test_cli_invoke.py b/tests/cli/common/test_cli_invoke.py index 77c003a5c9..97db8ab86b 100644 --- a/tests/cli/common/test_cli_invoke.py +++ b/tests/cli/common/test_cli_invoke.py @@ -1,13 +1,10 @@ import os import shutil -from subprocess import CalledProcessError -import pytest from pytest_console_scripts import ScriptRunner from unittest.mock import patch import dlt from dlt.common.known_env import DLT_DATA_DIR -from dlt.common.configuration.paths import get_dlt_data_dir from dlt.common.runners.venv import Venv from dlt.common.utils import custom_environ, set_working_dir from dlt.common.pipeline import get_dlt_pipelines_dir @@ -63,7 +60,7 @@ def test_invoke_pipeline(script_runner: ScriptRunner) -> None: shutil.copytree("tests/cli/cases/deploy_pipeline", TEST_STORAGE_ROOT, dirs_exist_ok=True) with set_working_dir(TEST_STORAGE_ROOT): - with custom_environ({"COMPLETED_PROB": "1.0", DLT_DATA_DIR: get_dlt_data_dir()}): + with custom_environ({"COMPLETED_PROB": "1.0", DLT_DATA_DIR: dlt.current.run().data_dir}): venv = Venv.restore_current() venv.run_script("dummy_pipeline.py") # we check output test_pipeline_command else @@ -97,7 +94,7 @@ def test_invoke_pipeline(script_runner: ScriptRunner) -> None: def test_invoke_init_chess_and_template(script_runner: ScriptRunner) -> None: with set_working_dir(TEST_STORAGE_ROOT): # store dlt data in test storage (like patch_home_dir) - with custom_environ({DLT_DATA_DIR: get_dlt_data_dir()}): + with custom_environ({DLT_DATA_DIR: dlt.current.run().data_dir}): result = script_runner.run(["dlt", "init", "chess", "dummy"]) assert "Verified source chess was added to your project!" in result.stdout assert result.returncode == 0 @@ -117,7 +114,7 @@ def test_invoke_list_sources(script_runner: ScriptRunner) -> None: def test_invoke_deploy_project(script_runner: ScriptRunner) -> None: with set_working_dir(TEST_STORAGE_ROOT): # store dlt data in test storage (like patch_home_dir) - with custom_environ({DLT_DATA_DIR: get_dlt_data_dir()}): + with custom_environ({DLT_DATA_DIR: dlt.current.run().data_dir}): result = script_runner.run( ["dlt", "deploy", "debug_pipeline.py", "github-action", "--schedule", "@daily"] ) diff --git a/tests/cli/common/test_telemetry_command.py b/tests/cli/common/test_telemetry_command.py index 21f44b3e88..fc67dde5fa 100644 --- a/tests/cli/common/test_telemetry_command.py +++ b/tests/cli/common/test_telemetry_command.py @@ -6,7 +6,7 @@ from unittest.mock import patch from dlt.common.configuration.container import Container -from dlt.common.configuration.paths import DOT_DLT +from dlt.common.runtime.run_context import DOT_DLT from dlt.common.configuration.providers import ConfigTomlProvider, CONFIG_TOML from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext from dlt.common.storages import FileStorage @@ -27,56 +27,66 @@ def test_main_telemetry_command(test_storage: FileStorage) -> None: def _initial_providers(): return [ConfigTomlProvider(add_global_config=True)] + container = Container() glob_ctx = ConfigProvidersContext() glob_ctx.providers = _initial_providers() - with set_working_dir(test_storage.make_full_path("project")), Container().injectable_context( - glob_ctx - ), patch( - "dlt.common.configuration.specs.config_providers_context.ConfigProvidersContext.initial_providers", - _initial_providers, - ): - # no config files: status is ON - with io.StringIO() as buf, contextlib.redirect_stdout(buf): - telemetry_status_command() - assert "ENABLED" in buf.getvalue() - # disable telemetry - with io.StringIO() as buf, contextlib.redirect_stdout(buf): - change_telemetry_status_command(False) - # enable global flag in providers (tests have global flag disabled) - glob_ctx = ConfigProvidersContext() - glob_ctx.providers = [ConfigTomlProvider(add_global_config=True)] - with Container().injectable_context(glob_ctx): + try: + with set_working_dir(test_storage.make_full_path("project")), patch( + "dlt.common.configuration.specs.config_providers_context.ConfigProvidersContext.initial_providers", + _initial_providers, + ): + # no config files: status is ON + with io.StringIO() as buf, contextlib.redirect_stdout(buf): + telemetry_status_command() + assert "ENABLED" in buf.getvalue() + # disable telemetry + with io.StringIO() as buf, contextlib.redirect_stdout(buf): + # force the mock config.toml provider + container[ConfigProvidersContext] = glob_ctx + change_telemetry_status_command(False) + # enable global flag in providers (tests have global flag disabled) + glob_ctx = ConfigProvidersContext() + glob_ctx.providers = [ConfigTomlProvider(add_global_config=True)] + with Container().injectable_context(glob_ctx): + telemetry_status_command() + output = buf.getvalue() + assert "OFF" in output + assert "DISABLED" in output + # make sure no config.toml exists in project (it is not created if it was not already there) + project_dot = os.path.join("project", DOT_DLT) + assert not test_storage.has_folder(project_dot) + # enable telemetry + with io.StringIO() as buf, contextlib.redirect_stdout(buf): + # force the mock config.toml provider + container[ConfigProvidersContext] = glob_ctx + change_telemetry_status_command(True) + # enable global flag in providers (tests have global flag disabled) + glob_ctx = ConfigProvidersContext() + glob_ctx.providers = [ConfigTomlProvider(add_global_config=True)] + with Container().injectable_context(glob_ctx): + telemetry_status_command() + output = buf.getvalue() + assert "ON" in output + assert "ENABLED" in output + # create config toml in project dir + test_storage.create_folder(project_dot) + test_storage.save(os.path.join("project", DOT_DLT, CONFIG_TOML), "# empty") + # disable telemetry + with io.StringIO() as buf, contextlib.redirect_stdout(buf): + # force the mock config.toml provider + container[ConfigProvidersContext] = glob_ctx + # this command reload providers + change_telemetry_status_command(False) + # so the change is visible (because it is written to project config so we do not need to look into global like before) telemetry_status_command() output = buf.getvalue() assert "OFF" in output assert "DISABLED" in output - # make sure no config.toml exists in project (it is not created if it was not already there) - project_dot = os.path.join("project", DOT_DLT) - assert not test_storage.has_folder(project_dot) - # enable telemetry - with io.StringIO() as buf, contextlib.redirect_stdout(buf): - change_telemetry_status_command(True) - # enable global flag in providers (tests have global flag disabled) - glob_ctx = ConfigProvidersContext() - glob_ctx.providers = [ConfigTomlProvider(add_global_config=True)] - with Container().injectable_context(glob_ctx): - telemetry_status_command() - output = buf.getvalue() - assert "ON" in output - assert "ENABLED" in output - # create config toml in project dir - test_storage.create_folder(project_dot) - test_storage.save(os.path.join("project", DOT_DLT, CONFIG_TOML), "# empty") - # disable telemetry - with io.StringIO() as buf, contextlib.redirect_stdout(buf): - # this command reload providers - change_telemetry_status_command(False) - # so the change is visible (because it is written to project config so we do not need to look into global like before) - telemetry_status_command() - output = buf.getvalue() - assert "OFF" in output - assert "DISABLED" in output + finally: + # delete current config provider after the patched init ctx is out of scope + if ConfigProvidersContext in container: + del container[ConfigProvidersContext] def test_command_instrumentation() -> None: diff --git a/tests/cli/test_deploy_command.py b/tests/cli/test_deploy_command.py index 78a14ee914..5d9163679a 100644 --- a/tests/cli/test_deploy_command.py +++ b/tests/cli/test_deploy_command.py @@ -135,11 +135,11 @@ def test_deploy_command( test_storage.atomic_rename(".dlt/secrets.toml.ci", ".dlt/secrets.toml") # reset toml providers to (1) CWD (2) non existing dir so API_KEY is not found - for project_dir, api_key in [ + for settings_dir, api_key in [ (None, "api_key_9x3ehash"), (".", "please set me up!"), ]: - with reset_providers(project_dir=project_dir): + with reset_providers(settings_dir=settings_dir): # this time script will run venv.run_script("debug_pipeline.py") with echo.always_choose(False, always_choose_value=True): diff --git a/tests/cli/test_init_command.py b/tests/cli/test_init_command.py index f76dc2f053..35c68ecfb4 100644 --- a/tests/cli/test_init_command.py +++ b/tests/cli/test_init_command.py @@ -19,15 +19,13 @@ import dlt from dlt.common import git -from dlt.common.configuration.paths import make_dlt_settings_path from dlt.common.configuration.providers import CONFIG_TOML, SECRETS_TOML, SecretsTomlProvider from dlt.common.runners import Venv from dlt.common.storages.file_storage import FileStorage -from dlt.common.source import _SOURCES from dlt.common.utils import set_working_dir -from dlt.cli import init_command, echo +from dlt.cli import init_command, echo, utils from dlt.cli.init_command import ( SOURCES_MODULE_NAME, DEFAULT_VERIFIED_SOURCES_REPO, @@ -60,7 +58,7 @@ CORE_SOURCES = ["filesystem", "rest_api", "sql_database"] # we also hardcode all the templates here for testing -TEMPLATES = ["debug", "default", "arrow", "requests", "dataframe", "intro"] +TEMPLATES = ["debug", "default", "arrow", "requests", "dataframe", "fruitshop", "github_api"] # a few verified sources we know to exist SOME_KNOWN_VERIFIED_SOURCES = ["chess", "google_sheets", "pipedrive"] @@ -83,7 +81,7 @@ def test_init_command_pipeline_default_template(repo_dir: str, project_files: Fi init_command.init_command("some_random_name", "redshift", repo_dir) visitor = assert_init_files(project_files, "some_random_name_pipeline", "redshift") # multiple resources - assert len(visitor.known_resource_calls) > 1 + assert len(visitor.known_resource_calls) == 1 def test_default_source_file_selection() -> None: @@ -247,7 +245,7 @@ def test_custom_destination_note(repo_dir: str, project_files: FileStorage): @pytest.mark.parametrize("omit", [True, False]) # this will break if we have new core sources that are not in verified sources anymore -@pytest.mark.parametrize("source", CORE_SOURCES) +@pytest.mark.parametrize("source", set(CORE_SOURCES) - {"rest_api"}) def test_omit_core_sources( source: str, omit: bool, project_files: FileStorage, repo_dir: str ) -> None: @@ -527,17 +525,16 @@ def test_init_requirements_text(repo_dir: str, project_files: FileStorage) -> No assert "pip3 install" in _out -@pytest.mark.skip("Why is this not working??") -def test_pipeline_template_sources_in_single_file( - repo_dir: str, project_files: FileStorage -) -> None: - init_command.init_command("debug", "bigquery", repo_dir) - # _SOURCES now contains the sources from pipeline.py which simulates loading from two places - with pytest.raises(CliCommandException) as cli_ex: - init_command.init_command("arrow", "redshift", repo_dir) - assert "In init scripts you must declare all sources and resources in single file." in str( - cli_ex.value - ) +# def test_pipeline_template_sources_in_single_file( +# repo_dir: str, project_files: FileStorage +# ) -> None: +# init_command.init_command("debug", "bigquery", repo_dir) +# # SourceReference.SOURCES now contains the sources from pipeline.py which simulates loading from two places +# with pytest.raises(CliCommandException) as cli_ex: +# init_command.init_command("arrow", "redshift", repo_dir) +# assert "In init scripts you must declare all sources and resources in single file." in str( +# cli_ex.value +# ) def test_incompatible_dlt_version_warning(repo_dir: str, project_files: FileStorage) -> None: @@ -624,8 +621,8 @@ def assert_common_files( ) -> Tuple[PipelineScriptVisitor, SecretsTomlProvider]: # cwd must be project files - otherwise assert won't work assert os.getcwd() == project_files.storage_path - assert project_files.has_file(make_dlt_settings_path(SECRETS_TOML)) - assert project_files.has_file(make_dlt_settings_path(CONFIG_TOML)) + assert project_files.has_file(utils.make_dlt_settings_path(SECRETS_TOML)) + assert project_files.has_file(utils.make_dlt_settings_path(CONFIG_TOML)) assert project_files.has_file(".gitignore") assert project_files.has_file(pipeline_script) # inspect script diff --git a/tests/cli/utils.py b/tests/cli/utils.py index 998885375f..d1ac762b69 100644 --- a/tests/cli/utils.py +++ b/tests/cli/utils.py @@ -6,9 +6,10 @@ from dlt.common import git from dlt.common.pipeline import get_dlt_repos_dir from dlt.common.storages.file_storage import FileStorage -from dlt.common.source import _SOURCES from dlt.common.utils import set_working_dir, uniq_id +from dlt.sources import SourceReference + from dlt.cli import echo from dlt.cli.init_command import DEFAULT_VERIFIED_SOURCES_REPO @@ -58,14 +59,14 @@ def get_repo_dir(cloned_init_repo: FileStorage) -> str: def get_project_files(clear_all_sources: bool = True) -> FileStorage: # we only remove sources registered outside of dlt core - for name, source in _SOURCES.copy().items(): + for name, source in SourceReference.SOURCES.copy().items(): if not source.module.__name__.startswith( "dlt.sources" ) and not source.module.__name__.startswith("default_pipeline"): - _SOURCES.pop(name) + SourceReference.SOURCES.pop(name) if clear_all_sources: - _SOURCES.clear() + SourceReference.SOURCES.clear() # project dir return FileStorage(PROJECT_DIR, makedirs=True) diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index 4665386af4..a8049cd49f 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -7,6 +7,7 @@ Final, Generic, List, + Literal, Mapping, MutableMapping, NewType, @@ -25,9 +26,12 @@ from dlt.common.utils import custom_environ, get_exception_trace, get_exception_trace_chain from dlt.common.typing import ( AnyType, + CallableAny, ConfigValue, DictStrAny, + SecretSentinel, StrAny, + TSecretStrValue, TSecretValue, extract_inner_type, ) @@ -1090,7 +1094,7 @@ def test_do_not_resolve_twice(environment: Any) -> None: c = resolve.resolve_configuration(SecretConfiguration()) assert c.secret_value == "password" c2 = SecretConfiguration() - c2.secret_value = "other" # type: ignore[assignment] + c2.secret_value = "other" c2.__is_resolved__ = True assert c2.is_resolved() # will not overwrite with env @@ -1103,7 +1107,7 @@ def test_do_not_resolve_twice(environment: Any) -> None: assert c4.secret_value == "password" assert c2 is c3 is c4 # also c is resolved so - c.secret_value = "else" # type: ignore[assignment] + c.secret_value = "else" assert resolve.resolve_configuration(c).secret_value == "else" @@ -1112,7 +1116,7 @@ def test_do_not_resolve_embedded(environment: Any) -> None: c = resolve.resolve_configuration(EmbeddedSecretConfiguration()) assert c.secret.secret_value == "password" c2 = SecretConfiguration() - c2.secret_value = "other" # type: ignore[assignment] + c2.secret_value = "other" c2.__is_resolved__ = True embed_c = EmbeddedSecretConfiguration() embed_c.secret = c2 @@ -1210,13 +1214,23 @@ def test_extract_inner_hint() -> None: # extracts new types assert resolve.extract_inner_hint(TSecretValue) is AnyType # preserves new types on extract - assert resolve.extract_inner_hint(TSecretValue, preserve_new_types=True) is TSecretValue + assert resolve.extract_inner_hint(CallableAny, preserve_new_types=True) is CallableAny + # extracts and preserves annotated + assert resolve.extract_inner_hint(Optional[Annotated[int, "X"]]) is int # type: ignore[arg-type] + TAnnoInt = Annotated[int, "X"] + assert resolve.extract_inner_hint(Optional[TAnnoInt], preserve_annotated=True) is TAnnoInt # type: ignore[arg-type] + # extracts and preserves literals + TLit = Literal["a", "b"] + TAnnoLit = Annotated[TLit, "X"] + assert resolve.extract_inner_hint(TAnnoLit, preserve_literal=True) is TLit # type: ignore[arg-type] + assert resolve.extract_inner_hint(TAnnoLit, preserve_literal=False) is str # type: ignore[arg-type] def test_is_secret_hint() -> None: assert resolve.is_secret_hint(GcpServiceAccountCredentialsWithoutDefaults) is True assert resolve.is_secret_hint(Optional[GcpServiceAccountCredentialsWithoutDefaults]) is True # type: ignore[arg-type] assert resolve.is_secret_hint(TSecretValue) is True + assert resolve.is_secret_hint(TSecretStrValue) is True assert resolve.is_secret_hint(Optional[TSecretValue]) is True # type: ignore[arg-type] assert resolve.is_secret_hint(InstrumentedConfiguration) is False # do not recognize new types @@ -1232,9 +1246,8 @@ def test_is_secret_hint() -> None: def test_is_secret_hint_custom_type() -> None: - # any new type named TSecretValue is a secret - assert resolve.is_secret_hint(NewType("TSecretValue", int)) is True - assert resolve.is_secret_hint(NewType("TSecretValueX", int)) is False + # any type annotated with SecretSentinel is secret + assert resolve.is_secret_hint(Annotated[int, SecretSentinel]) is True # type: ignore[arg-type] def coerce_single_value(key: str, value: str, hint: Type[Any]) -> Any: diff --git a/tests/common/configuration/test_inject.py b/tests/common/configuration/test_inject.py index 0dc7e53357..5908c1ef4a 100644 --- a/tests/common/configuration/test_inject.py +++ b/tests/common/configuration/test_inject.py @@ -32,7 +32,14 @@ from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.reflection.spec import _get_spec_name_from_f -from dlt.common.typing import StrAny, TSecretStrValue, TSecretValue, is_newtype_type +from dlt.common.typing import ( + StrAny, + TSecretStrValue, + TSecretValue, + is_annotated, + is_newtype_type, + is_subclass, +) from tests.utils import preserve_environ from tests.common.configuration.utils import environment, toml_providers @@ -199,7 +206,7 @@ def f_custom_secret_type( f_type = spec.__dataclass_fields__[f].type assert is_secret_hint(f_type) assert cfg.get_resolvable_fields()[f] is f_type - assert is_newtype_type(f_type) + assert is_annotated(f_type) environment["_DICT"] = '{"a":1}' environment["_INT"] = "1234" diff --git a/tests/common/configuration/test_spec_union.py b/tests/common/configuration/test_spec_union.py index b1e316734d..de670b7bf5 100644 --- a/tests/common/configuration/test_spec_union.py +++ b/tests/common/configuration/test_spec_union.py @@ -9,7 +9,7 @@ from dlt.common.configuration.specs import CredentialsConfiguration, BaseConfiguration from dlt.common.configuration import configspec, resolve_configuration from dlt.common.configuration.specs.gcp_credentials import GcpServiceAccountCredentials -from dlt.common.typing import TSecretValue +from dlt.common.typing import TSecretStrValue from dlt.common.configuration.specs.connection_string_credentials import ConnectionStringCredentials from dlt.common.configuration.resolve import initialize_credentials from dlt.common.configuration.specs.exceptions import NativeValueError @@ -27,14 +27,14 @@ def auth(self): @configspec class ZenEmailCredentials(ZenCredentials): email: str = None - password: TSecretValue = None + password: TSecretStrValue = None def parse_native_representation(self, native_value: Any) -> None: assert isinstance(native_value, str) if native_value.startswith("email:"): parts = native_value.split(":") self.email = parts[-2] - self.password = parts[-1] # type: ignore[assignment] + self.password = parts[-1] else: raise NativeValueError(self.__class__, native_value, "invalid email NV") @@ -45,14 +45,14 @@ def auth(self): @configspec class ZenApiKeyCredentials(ZenCredentials): api_key: str = None - api_secret: TSecretValue = None + api_secret: TSecretStrValue = None def parse_native_representation(self, native_value: Any) -> None: assert isinstance(native_value, str) if native_value.startswith("secret:"): parts = native_value.split(":") self.api_key = parts[-2] - self.api_secret = parts[-1] # type: ignore[assignment] + self.api_secret = parts[-1] else: raise NativeValueError(self.__class__, native_value, "invalid secret NV") @@ -201,10 +201,10 @@ class GoogleAnalyticsCredentialsOAuth(GoogleAnalyticsCredentialsBase): """ client_id: str = None - client_secret: TSecretValue = None - project_id: TSecretValue = None - refresh_token: TSecretValue = None - access_token: Optional[TSecretValue] = None + client_secret: TSecretStrValue = None + project_id: TSecretStrValue = None + refresh_token: TSecretStrValue = None + access_token: Optional[TSecretStrValue] = None @dlt.source(max_table_nesting=2) diff --git a/tests/common/configuration/test_toml_provider.py b/tests/common/configuration/test_toml_provider.py index a19aea8796..ca95e46810 100644 --- a/tests/common/configuration/test_toml_provider.py +++ b/tests/common/configuration/test_toml_provider.py @@ -16,6 +16,7 @@ CONFIG_TOML, BaseDocProvider, CustomLoaderDocProvider, + SettingsTomlProvider, SecretsTomlProvider, ConfigTomlProvider, StringTomlProvider, @@ -246,7 +247,7 @@ def test_toml_get_key_as_section(toml_providers: ConfigProvidersContext) -> None def test_toml_read_exception() -> None: pipeline_root = "./tests/common/cases/configuration/.wrong.dlt" with pytest.raises(TomlProviderReadException) as py_ex: - ConfigTomlProvider(project_dir=pipeline_root) + ConfigTomlProvider(settings_dir=pipeline_root) assert py_ex.value.file_name == "config.toml" @@ -288,7 +289,7 @@ def test_toml_global_config() -> None: def test_write_value(toml_providers: ConfigProvidersContext) -> None: - provider: BaseDocProvider + provider: SettingsTomlProvider for provider in toml_providers.providers: # type: ignore[assignment] if not provider.is_writable: continue @@ -351,9 +352,10 @@ def test_write_value(toml_providers: ConfigProvidersContext) -> None: "dict_test.deep_dict.embed.inner_2", ) # write a dict over non dict - provider.set_value("deep_list", test_d1, None, "deep", "deep", "deep") + ovr_dict = {"ovr": 1, "ocr": {"ovr": 2}} + provider.set_value("deep_list", ovr_dict, None, "deep", "deep", "deep") assert provider.get_value("deep_list", TAny, None, "deep", "deep", "deep") == ( - test_d1, + ovr_dict, "deep.deep.deep.deep_list", ) # merge dicts @@ -368,7 +370,8 @@ def test_write_value(toml_providers: ConfigProvidersContext) -> None: test_m_d1_d2, "dict_test.deep_dict", ) - # print(provider.get_value("deep_dict", Any, None, "dict_test")) + # compare toml and doc repr + assert provider._config_doc == provider._config_toml.unwrap() # write configuration pool = PoolRunnerConfiguration(pool_type="none", workers=10) @@ -403,7 +406,7 @@ def test_set_spec_value(toml_providers: ConfigProvidersContext) -> None: def test_set_fragment(toml_providers: ConfigProvidersContext) -> None: - provider: BaseDocProvider + provider: SettingsTomlProvider for provider in toml_providers.providers: # type: ignore[assignment] if not isinstance(provider, BaseDocProvider): continue diff --git a/tests/common/runtime/test_run_context_data_dir.py b/tests/common/runtime/test_run_context_data_dir.py new file mode 100644 index 0000000000..f8759a2809 --- /dev/null +++ b/tests/common/runtime/test_run_context_data_dir.py @@ -0,0 +1,13 @@ +import os + +import dlt + +# import auto fixture that sets global and data dir to TEST_STORAGE +from dlt.common.runtime.run_context import DOT_DLT +from tests.utils import TEST_STORAGE_ROOT, patch_home_dir + + +def test_data_dir_test_storage() -> None: + run_context = dlt.current.run() + assert run_context.global_dir.endswith(os.path.join(TEST_STORAGE_ROOT, DOT_DLT)) + assert run_context.global_dir == run_context.data_dir diff --git a/tests/common/runtime/test_run_context_random_data_dir.py b/tests/common/runtime/test_run_context_random_data_dir.py new file mode 100644 index 0000000000..fb13f16e6f --- /dev/null +++ b/tests/common/runtime/test_run_context_random_data_dir.py @@ -0,0 +1,11 @@ +import dlt + +# import auto fixture that sets global and data dir to TEST_STORAGE + random folder +from tests.utils import TEST_STORAGE_ROOT, patch_random_home_dir + + +def test_data_dir_test_storage() -> None: + run_context = dlt.current.run() + assert TEST_STORAGE_ROOT in run_context.global_dir + assert "global_" in run_context.global_dir + assert run_context.global_dir == run_context.data_dir diff --git a/tests/common/test_typing.py b/tests/common/test_typing.py index 3a9e320040..2749e3ebb1 100644 --- a/tests/common/test_typing.py +++ b/tests/common/test_typing.py @@ -1,3 +1,4 @@ +import pytest from dataclasses import dataclass from typing import ( Any, @@ -20,6 +21,7 @@ from uuid import UUID +from dlt import TSecretValue from dlt.common.configuration.specs.base_configuration import ( BaseConfiguration, get_config_if_union_hint, @@ -27,6 +29,7 @@ from dlt.common.configuration.specs import GcpServiceAccountCredentialsWithoutDefaults from dlt.common.typing import ( StrAny, + TSecretStrValue, extract_inner_type, extract_union_types, get_all_types_of_class_in_union, @@ -270,3 +273,23 @@ def test_get_all_types_of_class_in_union() -> None: assert get_all_types_of_class_in_union( Union[BaseConfiguration, str], Incremental[float], with_superclass=True ) == [BaseConfiguration] + + +def test_secret_type() -> None: + # typing must be ok + val: TSecretValue = 1 # noqa + val_2: TSecretValue = b"ABC" # noqa + + # must evaluate to self at runtime + assert TSecretValue("a") == "a" + assert TSecretValue(b"a") == b"a" + assert TSecretValue(7) == 7 + assert isinstance(TSecretValue(7), int) + + # secret str evaluates to str + val_str: TSecretStrValue = "x" # noqa + # here we expect ignore! + val_str_err: TSecretStrValue = 1 # type: ignore[assignment] # noqa + + assert TSecretStrValue("x_str") == "x_str" + assert TSecretStrValue({}) == "{}" diff --git a/tests/common/utils.py b/tests/common/utils.py index 553f67995e..9b5e6bccce 100644 --- a/tests/common/utils.py +++ b/tests/common/utils.py @@ -8,7 +8,7 @@ import datetime # noqa: 251 from dlt.common import json -from dlt.common.typing import StrAny +from dlt.common.typing import StrAny, TSecretStrValue from dlt.common.schema import utils, Schema from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.configuration.providers import environ as environ_provider @@ -64,9 +64,7 @@ def restore_secret_storage_path() -> None: def load_secret(name: str) -> str: environ_provider.SECRET_STORAGE_PATH = "./tests/common/cases/secrets/%s" - secret, _ = environ_provider.EnvironProvider().get_value( - name, environ_provider.TSecretValue, None - ) + secret, _ = environ_provider.EnvironProvider().get_value(name, TSecretStrValue, None) if not secret: raise FileNotFoundError(environ_provider.SECRET_STORAGE_PATH % name) return secret diff --git a/tests/conftest.py b/tests/conftest.py index 6c0384ea8a..74e6388eca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,8 +24,8 @@ def initial_providers() -> List[ConfigProvider]: # do not read the global config return [ EnvironProvider(), - SecretsTomlProvider(project_dir="tests/.dlt", add_global_config=False), - ConfigTomlProvider(project_dir="tests/.dlt", add_global_config=False), + SecretsTomlProvider(settings_dir="tests/.dlt", add_global_config=False), + ConfigTomlProvider(settings_dir="tests/.dlt", add_global_config=False), ] diff --git a/tests/extract/test_decorators.py b/tests/extract/test_decorators.py index 73286678b5..92900a0329 100644 --- a/tests/extract/test_decorators.py +++ b/tests/extract/test_decorators.py @@ -12,9 +12,9 @@ from dlt.common.configuration.inject import get_fun_spec from dlt.common.configuration.resolve import inject_section from dlt.common.configuration.specs.config_section_context import ConfigSectionContext +from dlt.common.configuration.specs.pluggable_run_context import PluggableRunContext from dlt.common.exceptions import ArgumentsOverloadException, DictValidationException from dlt.common.pipeline import StateInjectableContext, TPipelineState -from dlt.common.source import _SOURCES from dlt.common.schema import Schema from dlt.common.schema.utils import new_table, new_column from dlt.common.schema.typing import TTableSchemaColumns @@ -23,9 +23,12 @@ from dlt.cli.source_detection import detect_source_configs from dlt.common.utils import custom_environ +from dlt.extract.decorators import DltSourceFactoryWrapper +from dlt.extract.source import SourceReference from dlt.extract import DltResource, DltSource from dlt.extract.exceptions import ( DynamicNameNotStandaloneResource, + ExplicitSourceNameInvalid, InvalidResourceDataTypeFunctionNotAGenerator, InvalidResourceDataTypeIsNone, InvalidResourceDataTypeMultiplePipes, @@ -39,10 +42,12 @@ SourceNotAFunction, CurrentSourceSchemaNotAvailable, InvalidParallelResourceDataType, + UnknownSourceReference, ) from dlt.extract.items import TableNameMeta from tests.common.utils import load_yml_case +from tests.utils import MockableRunContext def test_default_resource() -> None: @@ -660,6 +665,201 @@ def schema_test(): assert "table" not in s.discover_schema().tables +@dlt.source(name="shorthand", section="shorthand") +def with_shorthand_registry(data): + return dlt.resource(data, name="alpha") + + +@dlt.source +def test_decorators(): + return dlt.resource(["A", "B"], name="alpha") + + +@dlt.resource +def res_reg_with_secret(secretz: str = dlt.secrets.value): + yield [secretz] * 3 + + +def test_source_reference() -> None: + # shorthand works when name == section + ref = SourceReference.from_reference("shorthand") + assert list(ref(["A", "B"])) == ["A", "B"] + ref = SourceReference.from_reference("shorthand.shorthand") + assert list(ref(["A", "B"])) == ["A", "B"] + # same for test_decorators + ref = SourceReference.from_reference("test_decorators") + assert list(ref()) == ["A", "B"] + ref = SourceReference.from_reference("test_decorators.test_decorators") + assert list(ref()) == ["A", "B"] + + # inner sources are registered + @dlt.source + def _inner_source(): + return dlt.resource(["C", "D"], name="beta") + + ref = SourceReference.from_reference("test_decorators._inner_source") + assert list(ref()) == ["C", "D"] + + # duplicate section / name will replace registration + @dlt.source(name="_inner_source") + def _inner_source_2(): + return dlt.resource(["E", "F"], name="beta") + + ref = SourceReference.from_reference("test_decorators._inner_source") + assert list(ref()) == ["E", "F"] + + # unknown reference + with pytest.raises(UnknownSourceReference) as ref_ex: + SourceReference.from_reference("$ref") + assert ref_ex.value.ref == ["dlt.$ref.$ref"] + + @dlt.source(section="special") + def absolute_config(init: int, mark: str = dlt.config.value, secret: str = dlt.secrets.value): + # will need to bind secret + return (res_reg_with_secret, dlt.resource([init, mark, secret], name="dump")) + + ref = SourceReference.from_reference("special.absolute_config") + os.environ["SOURCES__SPECIAL__MARK"] = "ma" + os.environ["SOURCES__SPECIAL__SECRET"] = "sourse" + # resource when in source adopts source section + os.environ["SOURCES__SPECIAL__RES_REG_WITH_SECRET__SECRETZ"] = "resourse" + source = ref(init=100) + assert list(source) == ["resourse", "resourse", "resourse", 100, "ma", "sourse"] + + +def test_source_reference_with_context() -> None: + ctx = PluggableRunContext() + mock = MockableRunContext.from_context(ctx.context) + mock._name = "mock" + ctx.context = mock + + with Container().injectable_context(ctx): + # should be able to import things from dlt package + ref = SourceReference.from_reference("shorthand") + assert list(ref(["A", "B"])) == ["A", "B"] + ref = SourceReference.from_reference("shorthand.shorthand") + assert list(ref(["A", "B"])) == ["A", "B"] + # unknown reference + with pytest.raises(UnknownSourceReference) as ref_ex: + SourceReference.from_reference("$ref") + assert ref_ex.value.ref == ["mock.$ref.$ref", "dlt.$ref.$ref"] + with pytest.raises(UnknownSourceReference) as ref_ex: + SourceReference.from_reference("mock.$ref.$ref") + assert ref_ex.value.ref == ["mock.$ref.$ref"] + + # create a "shorthand" source in this context + @dlt.source(name="shorthand", section="shorthand") + def with_shorthand_registry(data): + return dlt.resource(list(reversed(data)), name="alpha") + + ref = SourceReference.from_reference("shorthand") + assert list(ref(["C", "x"])) == ["x", "C"] + ref = SourceReference.from_reference("mock.shorthand.shorthand") + assert list(ref(["C", "x"])) == ["x", "C"] + # from dlt package + ref = SourceReference.from_reference("dlt.shorthand.shorthand") + assert list(ref(["C", "x"])) == ["C", "x"] + + +def test_source_reference_from_module() -> None: + ref = SourceReference.from_reference("tests.extract.test_decorators.with_shorthand_registry") + assert list(ref(["C", "x"])) == ["C", "x"] + + # module exists but attr is not a factory + with pytest.raises(UnknownSourceReference) as ref_ex: + SourceReference.from_reference( + "tests.extract.test_decorators.test_source_reference_from_module" + ) + assert ref_ex.value.ref == ["tests.extract.test_decorators.test_source_reference_from_module"] + + # wrong module + with pytest.raises(UnknownSourceReference) as ref_ex: + SourceReference.from_reference( + "test.extract.test_decorators.test_source_reference_from_module" + ) + assert ref_ex.value.ref == ["test.extract.test_decorators.test_source_reference_from_module"] + + +def test_source_factory_with_args() -> None: + # check typing - no type ignore below! + factory = with_shorthand_registry.with_args + # do not override anything + source = factory()(data=["AXA"]) + assert list(source) == ["AXA"] + + # there are some overrides from decorator + assert with_shorthand_registry.name == "shorthand" # type: ignore + assert with_shorthand_registry.section == "shorthand" # type: ignore + + # with_args creates clones + source_f_1: DltSourceFactoryWrapper[Any, DltSource] = factory( # type: ignore + max_table_nesting=1, root_key=True + ) + source_f_2: DltSourceFactoryWrapper[Any, DltSource] = factory( # type: ignore + max_table_nesting=1, root_key=False, schema_contract="discard_value" + ) + assert source_f_1 is not source_f_2 + + # check if props are set + assert source_f_1.name == source_f_2.name == "shorthand" + assert source_f_1.section == source_f_2.section == "shorthand" + assert source_f_1.max_table_nesting == source_f_2.max_table_nesting == 1 + assert source_f_1.root_key is True + assert source_f_2.root_key is False + assert source_f_2.schema_contract == "discard_value" + + # check if props are preserved when not set + incompat_schema = Schema("incompat") + with pytest.raises(ExplicitSourceNameInvalid): + source_f_1.with_args( + section="special", schema=incompat_schema, parallelized=True, schema_contract="evolve" + ) + + compat_schema = Schema("shorthand") + compat_schema.tables["alpha"] = new_table("alpha") + source_f_3 = source_f_1.with_args( + section="special", schema=compat_schema, parallelized=True, schema_contract="evolve" + ) + assert source_f_3.name == "shorthand" + assert source_f_3.section == "special" + assert source_f_3.max_table_nesting == 1 + assert source_f_3.root_key is True + assert source_f_3.schema is compat_schema + assert source_f_3.parallelized is True + assert source_f_3.schema_contract == "evolve" + source_f_3 = source_f_3.with_args() + assert source_f_3.name == "shorthand" + assert source_f_3.section == "special" + assert source_f_3.max_table_nesting == 1 + assert source_f_3.root_key is True + assert source_f_3.schema is compat_schema + assert source_f_3.parallelized is True + assert source_f_3.schema_contract == "evolve" + + # create source + source = source_f_3(["A", "X"]) + assert source.root_key is True + assert source.schema.tables["alpha"] == compat_schema.tables["alpha"] + assert source.name == "shorthand" + assert source.section == "special" + assert source.max_table_nesting == 1 + assert source.schema_contract == "evolve" + + # when section / name are changed, config location follows + @dlt.source + def absolute_config(init: int, mark: str = dlt.config.value, secret: str = dlt.secrets.value): + # will need to bind secret + return (res_reg_with_secret, dlt.resource([init, mark, secret], name="dump")) + + absolute_config = absolute_config.with_args(name="absolute", section="special") + os.environ["SOURCES__SPECIAL__ABSOLUTE__MARK"] = "ma" + os.environ["SOURCES__SPECIAL__ABSOLUTE__SECRET"] = "sourse" + # resource when in source adopts source section + os.environ["SOURCES__SPECIAL__RES_REG_WITH_SECRET__SECRETZ"] = "resourse" + source = absolute_config(init=100) + assert list(source) == ["resourse", "resourse", "resourse", 100, "ma", "sourse"] + + @dlt.resource def standalone_resource(secret=dlt.secrets.value, config=dlt.config.value, opt: str = "A"): yield 1 @@ -701,7 +901,7 @@ def inner_standalone_resource( def inner_source(secret=dlt.secrets.value, config=dlt.config.value, opt: str = "A"): return standalone_resource - SPEC = _SOURCES[inner_source.__qualname__].SPEC + SPEC = SourceReference.find("test_decorators.inner_source").SPEC fields = SPEC.get_resolvable_fields() assert {"secret", "config", "opt"} == set(fields.keys()) @@ -717,21 +917,23 @@ def no_args(): return dlt.resource([1, 2], name="data") # there is a spec even if no arguments - SPEC = _SOURCES[no_args.__qualname__].SPEC + SPEC = SourceReference.find("dlt.test_decorators.no_args").SPEC assert SPEC - _, _, checked = detect_source_configs(_SOURCES, "", ()) - assert no_args.__qualname__ in checked - SPEC = _SOURCES[no_args.__qualname__].SPEC + # source names are used to index detected sources + _, _, checked = detect_source_configs(SourceReference.SOURCES, "", ()) + assert "no_args" in checked + + SPEC = SourceReference.find("dlt.test_decorators.not_args_r").SPEC assert SPEC - _, _, checked = detect_source_configs(_SOURCES, "", ()) - assert not_args_r.__qualname__ in checked + _, _, checked = detect_source_configs(SourceReference.SOURCES, "", ()) + assert "not_args_r" in checked @dlt.resource def not_args_r_i(): yield from [1, 2, 3] - assert not_args_r_i.__qualname__ not in _SOURCES + assert "dlt.test_decorators.not_args_r_i" not in SourceReference.SOURCES # you can call those assert list(no_args()) == [1, 2] @@ -764,6 +966,7 @@ def users(mode: str): return users s = all_users() + assert isinstance(s, TypedSource) assert list(s.users("group")) == ["group"] @@ -853,18 +1056,109 @@ def many_instances(): assert list(standalone_signature(1)) == [1, 2, 3, 4] +@pytest.mark.parametrize("res", (standalone_signature, regular_signature)) +def test_reference_registered_resource(res: DltResource) -> None: + if isinstance(res, DltResource): + ref = res(1, 2).name + # find reference + res_ref = SourceReference.find(f"test_decorators.{ref}") + assert res_ref.SPEC is res.SPEC + else: + ref = res.__name__ + # create source with single res. + factory = SourceReference.from_reference(f"test_decorators.{ref}") + # pass explicit config + source = factory(init=1, secret_end=3) + assert source.name == ref + assert source.section == "" + assert ref in source.resources + assert list(source) == [1, 2] + + # use regular config + os.environ[f"SOURCES__TEST_DECORATORS__{ref.upper()}__SECRET_END"] = "5" + source = factory(init=1) + assert list(source) == [1, 2, 3, 4] + + # use config with override + # os.environ["SOURCES__SECTION__SIGNATURE__INIT"] = "-1" + os.environ["SOURCES__SECTION__SIGNATURE__SECRET_END"] = "7" + source = factory.with_args( + name="signature", + section="section", + max_table_nesting=1, + root_key=True, + schema_contract="freeze", + parallelized=True, + )(-1) + assert list(source) == [-1, 0, 1, 2, 3, 4, 5, 6] + # use renamed name + resource = source.signature + assert resource.section == "section" + assert resource.name == "signature" + assert resource.max_table_nesting == 1 + assert resource.schema_contract == "freeze" + + +def test_inner_resource_not_registered() -> None: + # inner resources are not registered + @dlt.resource(standalone=True) + def inner_data_std(): + yield [1, 2, 3] + + with pytest.raises(UnknownSourceReference): + SourceReference.from_reference("test_decorators.inner_data_std") + + @dlt.resource() + def inner_data_reg(): + yield [1, 2, 3] + + with pytest.raises(UnknownSourceReference): + SourceReference.from_reference("test_decorators.inner_data_reg") + + @dlt.transformer(standalone=True) def standalone_transformer(item: TDataItem, init: int, secret_end: int = dlt.secrets.value): """Has fine transformer docstring""" yield from range(item + init, secret_end) +@dlt.transformer +def regular_transformer(item: TDataItem, init: int, secret_end: int = dlt.secrets.value): + yield from range(item + init, secret_end) + + @dlt.transformer(standalone=True) def standalone_transformer_returns(item: TDataItem, init: int = dlt.config.value): """Has fine transformer docstring""" return "A" * item * init +@pytest.mark.parametrize("ref", ("standalone_transformer", "regular_transformer")) +def test_reference_registered_transformer(ref: str) -> None: + factory = SourceReference.from_reference(f"test_decorators.{ref}") + bound_tx = standalone_signature(1, 3) | factory(5, 10).resources.detach() + print(bound_tx) + assert list(bound_tx) == [6, 7, 7, 8, 8, 9, 9] + + # use regular config + os.environ[f"SOURCES__TEST_DECORATORS__{ref.upper()}__SECRET_END"] = "7" + bound_tx = standalone_signature(1, 3) | factory(5).resources.detach() + assert list(bound_tx) == [6] + + # use config with override + os.environ["SOURCES__SECTION__SIGNATURE__SECRET_END"] = "8" + source = factory.with_args( + name="signature", + section="section", + max_table_nesting=1, + root_key=True, + schema_contract="freeze", + parallelized=True, + )(5) + bound_tx = standalone_signature(1, 3) | source.resources.detach() + assert list(bound_tx) == [6, 7, 7] + + @pytest.mark.parametrize("next_item_mode", ["fifo", "round_robin"]) def test_standalone_transformer(next_item_mode: str) -> None: os.environ["EXTRACT__NEXT_ITEM_MODE"] = next_item_mode @@ -1067,7 +1361,6 @@ async def source_yield_with_parens(reverse: bool = False): async def _assert_source(source_coro_f, expected_data) -> None: # test various forms of source decorator, parens, no parens, yield, return source_coro = source_coro_f() - assert inspect.iscoroutinefunction(source_coro_f) assert inspect.iscoroutine(source_coro) source = await source_coro assert "data" in source.resources diff --git a/tests/extract/test_sources.py b/tests/extract/test_sources.py index d111548db0..9bfeec1cb4 100644 --- a/tests/extract/test_sources.py +++ b/tests/extract/test_sources.py @@ -6,6 +6,7 @@ import dlt, os from dlt.common.configuration.container import Container +from dlt.common.configuration.specs import BaseConfiguration from dlt.common.exceptions import DictValidationException, PipelineStateNotAvailable from dlt.common.pipeline import StateInjectableContext, source_state from dlt.common.schema import Schema @@ -1104,6 +1105,21 @@ def multiplier(number, mul): assert bound_pipe._pipe.parent is pipe._pipe.parent +@dlt.resource(selected=False) +def number_gen_ext(max_r=3): + yield from range(1, max_r) + + +def test_clone_resource_with_rename(): + assert number_gen_ext.SPEC is not BaseConfiguration + gene_r = number_gen_ext.with_name("gene") + assert number_gen_ext.name == "number_gen_ext" + assert gene_r.name == "gene" + assert number_gen_ext.section == gene_r.section + assert gene_r.SPEC is number_gen_ext.SPEC + assert gene_r.selected == number_gen_ext.selected is False + + def test_source_multiple_iterations() -> None: def some_data(): yield [1, 2, 3] diff --git a/tests/libs/test_deltalake.py b/tests/libs/test_deltalake.py index e18fb1abd7..77bf80ea7e 100644 --- a/tests/libs/test_deltalake.py +++ b/tests/libs/test_deltalake.py @@ -51,8 +51,8 @@ def test_deltalake_storage_options() -> None: # yes credentials, yes deltalake_storage_options: no shared keys creds = AwsCredentials( aws_access_key_id="dummy_key_id", - aws_secret_access_key="dummy_acces_key", # type: ignore[arg-type] - aws_session_token="dummy_session_token", # type: ignore[arg-type] + aws_secret_access_key="dummy_acces_key", + aws_session_token="dummy_session_token", region_name="dummy_region_name", ) config.credentials = creds diff --git a/tests/load/clickhouse/test_clickhouse_configuration.py b/tests/load/clickhouse/test_clickhouse_configuration.py index 2b74922c34..ad33062f11 100644 --- a/tests/load/clickhouse/test_clickhouse_configuration.py +++ b/tests/load/clickhouse/test_clickhouse_configuration.py @@ -3,7 +3,7 @@ import pytest from dlt.common.configuration.resolve import resolve_configuration -from dlt.common.libs.sql_alchemy_shims import make_url +from dlt.common.libs.sql_alchemy_compat import make_url from dlt.common.utils import digest128 from dlt.destinations.impl.clickhouse.clickhouse import ClickHouseClient from dlt.destinations.impl.clickhouse.configuration import ( diff --git a/tests/load/filesystem/test_azure_credentials.py b/tests/load/filesystem/test_azure_credentials.py index 2353491737..64da35d9be 100644 --- a/tests/load/filesystem/test_azure_credentials.py +++ b/tests/load/filesystem/test_azure_credentials.py @@ -38,7 +38,7 @@ def az_service_principal_config() -> Optional[FilesystemConfiguration]: credentials = AzureServicePrincipalCredentialsWithoutDefaults( azure_tenant_id=dlt.config.get("tests.az_sp_tenant_id", str), azure_client_id=dlt.config.get("tests.az_sp_client_id", str), - azure_client_secret=dlt.config.get("tests.az_sp_client_secret", str), # type: ignore[arg-type] + azure_client_secret=dlt.config.get("tests.az_sp_client_secret", str), azure_storage_account_name=dlt.config.get("tests.az_sp_storage_account_name", str), ) # diff --git a/tests/load/filesystem/test_object_store_rs_credentials.py b/tests/load/filesystem/test_object_store_rs_credentials.py index 90530218d9..c69521f6ea 100644 --- a/tests/load/filesystem/test_object_store_rs_credentials.py +++ b/tests/load/filesystem/test_object_store_rs_credentials.py @@ -1,13 +1,12 @@ """Tests translation of `dlt` credentials into `object_store` Rust crate credentials.""" -from typing import Any, Dict, cast +from typing import Any, Dict import pytest from deltalake import DeltaTable from deltalake.exceptions import TableNotFoundError import dlt -from dlt.common.typing import TSecretStrValue from dlt.common.configuration import resolve_configuration from dlt.common.configuration.specs import ( AnyAzureCredentials, @@ -144,8 +143,8 @@ def test_aws_object_store_rs_credentials(driver: str) -> None: sess_creds = creds.to_session_credentials() creds = AwsCredentials( aws_access_key_id=sess_creds["aws_access_key_id"], - aws_secret_access_key=cast(TSecretStrValue, sess_creds["aws_secret_access_key"]), - aws_session_token=cast(TSecretStrValue, sess_creds["aws_session_token"]), + aws_secret_access_key=sess_creds["aws_secret_access_key"], + aws_session_token=sess_creds["aws_session_token"], region_name=fs_creds["region_name"], ) assert creds.aws_session_token is not None @@ -156,8 +155,8 @@ def test_aws_object_store_rs_credentials(driver: str) -> None: # AwsCredentialsWithoutDefaults: user-provided session token creds = AwsCredentialsWithoutDefaults( aws_access_key_id=sess_creds["aws_access_key_id"], - aws_secret_access_key=cast(TSecretStrValue, sess_creds["aws_secret_access_key"]), - aws_session_token=cast(TSecretStrValue, sess_creds["aws_session_token"]), + aws_secret_access_key=sess_creds["aws_secret_access_key"], + aws_session_token=sess_creds["aws_session_token"], region_name=fs_creds["region_name"], ) assert creds.aws_session_token is not None diff --git a/tests/load/filesystem/test_sql_client.py b/tests/load/filesystem/test_sql_client.py new file mode 100644 index 0000000000..a5344e14e1 --- /dev/null +++ b/tests/load/filesystem/test_sql_client.py @@ -0,0 +1,330 @@ +"""Test the duckdb supported sql client for special internal features""" + + +from typing import Any + +import pytest +import dlt +import os + +from dlt import Pipeline +from dlt.common.utils import uniq_id + +from tests.load.utils import ( + destinations_configs, + DestinationTestConfiguration, + GCS_BUCKET, + SFTP_BUCKET, + MEMORY_BUCKET, +) +from dlt.destinations import filesystem +from tests.utils import TEST_STORAGE_ROOT +from dlt.destinations.exceptions import DatabaseUndefinedRelation + + +def _run_dataset_checks( + pipeline: Pipeline, + destination_config: DestinationTestConfiguration, + table_format: Any = None, + alternate_access_pipeline: Pipeline = None, +) -> None: + total_records = 200 + + TEST_SECRET_NAME = "TEST_SECRET" + uniq_id() + + # only some buckets have support for persistent secrets + needs_persistent_secrets = ( + destination_config.bucket_url.startswith("s3") + or destination_config.bucket_url.startswith("az") + or destination_config.bucket_url.startswith("abfss") + ) + + unsupported_persistent_secrets = destination_config.bucket_url.startswith("gs") + + @dlt.source() + def source(): + @dlt.resource( + table_format=table_format, + write_disposition="replace", + ) + def items(): + yield from [ + { + "id": i, + "children": [{"id": i + 100}, {"id": i + 1000}], + } + for i in range(total_records) + ] + + @dlt.resource( + table_format=table_format, + write_disposition="replace", + ) + def double_items(): + yield from [ + { + "id": i, + "double_id": i * 2, + } + for i in range(total_records) + ] + + return [items, double_items] + + # run source + pipeline.run(source(), loader_file_format=destination_config.file_format) + + if alternate_access_pipeline: + pipeline.destination = alternate_access_pipeline.destination + + import duckdb + from duckdb import HTTPException, IOException, InvalidInputException + from dlt.destinations.impl.filesystem.sql_client import ( + FilesystemSqlClient, + DuckDbCredentials, + ) + + # check we can create new tables from the views + with pipeline.sql_client() as c: + c.execute_sql( + "CREATE TABLE items_joined AS (SELECT i.id, di.double_id FROM items as i JOIN" + " double_items as di ON (i.id = di.id));" + ) + with c.execute_query("SELECT * FROM items_joined ORDER BY id ASC;") as cursor: + joined_table = cursor.fetchall() + assert len(joined_table) == total_records + assert list(joined_table[0]) == [0, 0] + assert list(joined_table[5]) == [5, 10] + assert list(joined_table[10]) == [10, 20] + + # inserting values into a view should fail gracefully + with pipeline.sql_client() as c: + try: + c.execute_sql("INSERT INTO double_items VALUES (1, 2)") + except Exception as exc: + assert "double_items is not an table" in str(exc) + + # check that no automated views are created for a schema different than + # the known one + with pipeline.sql_client() as c: + c.execute_sql("CREATE SCHEMA other_schema;") + with pytest.raises(DatabaseUndefinedRelation): + with c.execute_query("SELECT * FROM other_schema.items ORDER BY id ASC;") as cursor: + pass + # correct dataset view works + with c.execute_query(f"SELECT * FROM {c.dataset_name}.items ORDER BY id ASC;") as cursor: + table = cursor.fetchall() + assert len(table) == total_records + # no dataset prefix works + with c.execute_query("SELECT * FROM items ORDER BY id ASC;") as cursor: + table = cursor.fetchall() + assert len(table) == total_records + + # + # tests with external duckdb instance + # + + duck_db_location = TEST_STORAGE_ROOT + "/" + uniq_id() + + def _external_duckdb_connection() -> duckdb.DuckDBPyConnection: + external_db = duckdb.connect(duck_db_location) + # the line below solves problems with certificate path lookup on linux, see duckdb docs + external_db.sql("SET azure_transport_option_type = 'curl';") + return external_db + + def _fs_sql_client_for_external_db( + connection: duckdb.DuckDBPyConnection, + ) -> FilesystemSqlClient: + return FilesystemSqlClient( + dataset_name="second", + fs_client=pipeline.destination_client(), # type: ignore + credentials=DuckDbCredentials(connection), + ) + + # we create a duckdb with a table an see wether we can add more views from the fs client + external_db = _external_duckdb_connection() + external_db.execute("CREATE SCHEMA first;") + external_db.execute("CREATE SCHEMA second;") + external_db.execute("CREATE TABLE first.items AS SELECT i FROM range(0, 3) t(i)") + assert len(external_db.sql("SELECT * FROM first.items").fetchall()) == 3 + + fs_sql_client = _fs_sql_client_for_external_db(external_db) + with fs_sql_client as sql_client: + sql_client.create_views_for_tables( + {"items": "referenced_items", "_dlt_loads": "_dlt_loads"} + ) + + # views exist + assert len(external_db.sql("SELECT * FROM second.referenced_items").fetchall()) == total_records + assert len(external_db.sql("SELECT * FROM first.items").fetchall()) == 3 + external_db.close() + + # in case we are not connecting to a bucket, views should still be here after connection reopen + if not needs_persistent_secrets and not unsupported_persistent_secrets: + external_db = _external_duckdb_connection() + assert ( + len(external_db.sql("SELECT * FROM second.referenced_items").fetchall()) + == total_records + ) + external_db.close() + return + + # in other cases secrets are not available and this should fail + external_db = _external_duckdb_connection() + with pytest.raises((HTTPException, IOException, InvalidInputException)): + assert ( + len(external_db.sql("SELECT * FROM second.referenced_items").fetchall()) + == total_records + ) + external_db.close() + + # gs does not support persistent secrest, so we can't do further checks + if unsupported_persistent_secrets: + return + + # create secret + external_db = _external_duckdb_connection() + fs_sql_client = _fs_sql_client_for_external_db(external_db) + with fs_sql_client as sql_client: + fs_sql_client.create_authentication(persistent=True, secret_name=TEST_SECRET_NAME) + external_db.close() + + # now this should work + external_db = _external_duckdb_connection() + assert len(external_db.sql("SELECT * FROM second.referenced_items").fetchall()) == total_records + + # NOTE: when running this on CI, there seem to be some kind of race conditions that prevent + # secrets from being removed as it does not find the file... We'll need to investigate this. + return + + # now drop the secrets again + fs_sql_client = _fs_sql_client_for_external_db(external_db) + with fs_sql_client as sql_client: + fs_sql_client.drop_authentication(TEST_SECRET_NAME) + external_db.close() + + # fails again + external_db = _external_duckdb_connection() + with pytest.raises((HTTPException, IOException, InvalidInputException)): + assert ( + len(external_db.sql("SELECT * FROM second.referenced_items").fetchall()) + == total_records + ) + external_db.close() + + +@pytest.mark.essential +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + local_filesystem_configs=True, + all_buckets_filesystem_configs=True, + bucket_exclude=[SFTP_BUCKET, MEMORY_BUCKET], + ), # TODO: make SFTP work + ids=lambda x: x.name, +) +def test_read_interfaces_filesystem(destination_config: DestinationTestConfiguration) -> None: + # we force multiple files per table, they may only hold 700 items + os.environ["DATA_WRITER__FILE_MAX_ITEMS"] = "700" + + if destination_config.file_format not in ["parquet", "jsonl"]: + pytest.skip( + f"Test only works for jsonl and parquet, given: {destination_config.file_format}" + ) + + pipeline = destination_config.setup_pipeline( + "read_pipeline", + dataset_name="read_test", + dev_mode=True, + ) + + _run_dataset_checks(pipeline, destination_config) + + # for gcs buckets we additionally test the s3 compat layer + if destination_config.bucket_url == GCS_BUCKET: + gcp_bucket = filesystem( + GCS_BUCKET.replace("gs://", "s3://"), destination_name="filesystem_s3_gcs_comp" + ) + pipeline = destination_config.setup_pipeline( + "read_pipeline", dataset_name="read_test", dev_mode=True, destination=gcp_bucket + ) + _run_dataset_checks(pipeline, destination_config) + + +@pytest.mark.essential +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + table_format_filesystem_configs=True, + with_table_format="delta", + bucket_exclude=[SFTP_BUCKET, MEMORY_BUCKET], + # NOTE: delta does not work on memory buckets + ), + ids=lambda x: x.name, +) +def test_delta_tables(destination_config: DestinationTestConfiguration) -> None: + os.environ["DATA_WRITER__FILE_MAX_ITEMS"] = "700" + + pipeline = destination_config.setup_pipeline( + "read_pipeline", + dataset_name="read_test", + ) + + # in case of gcs we use the s3 compat layer for reading + # for writing we still need to use the gc authentication, as delta_rs seems to use + # methods on the s3 interface that are not implemented by gcs + access_pipeline = pipeline + if destination_config.bucket_url == GCS_BUCKET: + gcp_bucket = filesystem( + GCS_BUCKET.replace("gs://", "s3://"), destination_name="filesystem_s3_gcs_comp" + ) + access_pipeline = destination_config.setup_pipeline( + "read_pipeline", dataset_name="read_test", destination=gcp_bucket + ) + + _run_dataset_checks( + pipeline, + destination_config, + table_format="delta", + alternate_access_pipeline=access_pipeline, + ) + + +@pytest.mark.essential +@pytest.mark.parametrize( + "destination_config", + destinations_configs(local_filesystem_configs=True), + ids=lambda x: x.name, +) +def test_evolving_filesystem(destination_config: DestinationTestConfiguration) -> None: + """test that files with unequal schemas still work together""" + + if destination_config.file_format not in ["parquet", "jsonl"]: + pytest.skip( + f"Test only works for jsonl and parquet, given: {destination_config.file_format}" + ) + + @dlt.resource(table_name="items") + def items(): + yield from [{"id": i} for i in range(20)] + + pipeline = destination_config.setup_pipeline( + "read_pipeline", + dataset_name="read_test", + dev_mode=True, + ) + + pipeline.run([items()], loader_file_format=destination_config.file_format) + + df = pipeline._dataset().items.df() + assert len(df.index) == 20 + + @dlt.resource(table_name="items") + def items2(): + yield from [{"id": i, "other_value": "Blah"} for i in range(20, 50)] + + pipeline.run([items2()], loader_file_format=destination_config.file_format) + + # check df and arrow access + assert len(pipeline._dataset().items.df().index) == 50 + assert pipeline._dataset().items.arrow().num_rows == 50 diff --git a/tests/load/pipeline/test_databricks_pipeline.py b/tests/load/pipeline/test_databricks_pipeline.py index 2225d0001c..e802cde693 100644 --- a/tests/load/pipeline/test_databricks_pipeline.py +++ b/tests/load/pipeline/test_databricks_pipeline.py @@ -2,7 +2,12 @@ import os from dlt.common.utils import uniq_id -from tests.load.utils import DestinationTestConfiguration, destinations_configs, AZ_BUCKET +from tests.load.utils import ( + GCS_BUCKET, + DestinationTestConfiguration, + destinations_configs, + AZ_BUCKET, +) from tests.pipeline.utils import assert_load_info @@ -13,7 +18,7 @@ @pytest.mark.parametrize( "destination_config", destinations_configs( - default_sql_configs=True, bucket_subset=(AZ_BUCKET), subset=("databricks",) + default_sql_configs=True, bucket_subset=(AZ_BUCKET,), subset=("databricks",) ), ids=lambda x: x.name, ) @@ -62,7 +67,7 @@ def test_databricks_external_location(destination_config: DestinationTestConfigu in pipeline.list_failed_jobs_in_package(info.loads_ids[0])[0].failed_message ) - # # should fail on non existing stored credentials + # should fail on non existing stored credentials bricks = databricks(is_staging_external_location=False, staging_credentials_name="CREDENTIAL_X") pipeline = destination_config.setup_pipeline( "test_databricks_external_location", @@ -90,3 +95,53 @@ def test_databricks_external_location(destination_config: DestinationTestConfigu assert ( "credential_x" in pipeline.list_failed_jobs_in_package(info.loads_ids[0])[0].failed_message ) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + default_sql_configs=True, bucket_subset=(AZ_BUCKET,), subset=("databricks",) + ), + ids=lambda x: x.name, +) +def test_databricks_gcs_external_location(destination_config: DestinationTestConfiguration) -> None: + # do not interfere with state + os.environ["RESTORE_FROM_DESTINATION"] = "False" + # let the package complete even with failed jobs + os.environ["RAISE_ON_FAILED_JOBS"] = "false" + + dataset_name = "test_databricks_gcs_external_location" + uniq_id() + + # swap AZ bucket for GCS_BUCKET + from dlt.destinations import databricks, filesystem + + stage = filesystem(GCS_BUCKET) + + # explicit cred handover should fail + bricks = databricks() + pipeline = destination_config.setup_pipeline( + "test_databricks_gcs_external_location", + dataset_name=dataset_name, + destination=bricks, + staging=stage, + ) + info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs) + assert info.has_failed_jobs is True + assert ( + "You need to use Databricks named credential" + in pipeline.list_failed_jobs_in_package(info.loads_ids[0])[0].failed_message + ) + + # should fail on non existing stored credentials + bricks = databricks(is_staging_external_location=False, staging_credentials_name="CREDENTIAL_X") + pipeline = destination_config.setup_pipeline( + "test_databricks_external_location", + dataset_name=dataset_name, + destination=bricks, + staging=stage, + ) + info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs) + assert info.has_failed_jobs is True + assert ( + "credential_x" in pipeline.list_failed_jobs_in_package(info.loads_ids[0])[0].failed_message + ) diff --git a/tests/load/pipeline/test_dbt_helper.py b/tests/load/pipeline/test_dbt_helper.py index d55c81e998..9e4afa6531 100644 --- a/tests/load/pipeline/test_dbt_helper.py +++ b/tests/load/pipeline/test_dbt_helper.py @@ -24,7 +24,12 @@ def dbt_venv() -> Iterator[Venv]: # context manager will delete venv at the end # yield Venv.restore_current() # NOTE: we limit the max version of dbt to allow all dbt adapters to run. ie. sqlserver does not work on 1.8 - with create_venv(tempfile.mkdtemp(), list(ACTIVE_SQL_DESTINATIONS), dbt_version="<1.8") as venv: + # TODO: pytest marking below must be fixed + dbt_configs = set( + c.values[0].destination_type # type: ignore[attr-defined] + for c in destinations_configs(default_sql_configs=True, supports_dbt=True) + ) + with create_venv(tempfile.mkdtemp(), list(dbt_configs), dbt_version="<1.8") as venv: yield venv diff --git a/tests/load/pipeline/test_restore_state.py b/tests/load/pipeline/test_restore_state.py index 050636c491..51cb392b29 100644 --- a/tests/load/pipeline/test_restore_state.py +++ b/tests/load/pipeline/test_restore_state.py @@ -674,6 +674,10 @@ def some_data(param: str) -> Any: # nevertheless this is potentially dangerous situation 🤷 assert ra_production_p.state == prod_state + # for now skip sql client tests for filesystem + if destination_config.destination_type == "filesystem": + return + # get all the states, notice version 4 twice (one from production, the other from local) try: with p.sql_client() as client: diff --git a/tests/load/pipeline/test_scd2.py b/tests/load/pipeline/test_scd2.py index 3e08b792ed..2a5b9ed296 100644 --- a/tests/load/pipeline/test_scd2.py +++ b/tests/load/pipeline/test_scd2.py @@ -52,13 +52,22 @@ def strip_timezone(ts: TAnyDateTime) -> pendulum.DateTime: def get_table( - pipeline: dlt.Pipeline, table_name: str, sort_column: str = None, include_root_id: bool = True + pipeline: dlt.Pipeline, + table_name: str, + sort_column: str = None, + include_root_id: bool = True, + ts_columns: Optional[List[str]] = None, ) -> List[Dict[str, Any]]: """Returns destination table contents as list of dictionaries.""" + ts_columns = ts_columns or [] table = [ { - k: strip_timezone(v) if isinstance(v, datetime) else v + k: ( + strip_timezone(v) + if isinstance(v, datetime) or (k in ts_columns and v is not None) + else v + ) for k, v in r.items() if not k.startswith("_dlt") or k in DEFAULT_VALIDITY_COLUMN_NAMES @@ -128,7 +137,7 @@ def r(data): # assert load results ts_1 = get_load_package_created_at(p, info) assert_load_info(info) - assert get_table(p, "dim_test", "c2__nc1") == [ + assert get_table(p, "dim_test", "c2__nc1", ts_columns=[from_, to]) == [ { from_: ts_1, to: None, @@ -153,7 +162,7 @@ def r(data): info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_2 = get_load_package_created_at(p, info) assert_load_info(info) - assert get_table(p, "dim_test", "c2__nc1") == [ + assert get_table(p, "dim_test", "c2__nc1", ts_columns=[from_, to]) == [ { from_: ts_1, to: None, @@ -178,7 +187,7 @@ def r(data): info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_3 = get_load_package_created_at(p, info) assert_load_info(info) - assert get_table(p, "dim_test", "c2__nc1") == [ + assert get_table(p, "dim_test", "c2__nc1", ts_columns=[from_, to]) == [ {from_: ts_1, to: ts_3, "nk": 2, "c1": "bar", "c2__nc1": "bar"}, {from_: ts_1, to: ts_2, "nk": 1, "c1": "foo", "c2__nc1": "foo"}, { @@ -198,7 +207,7 @@ def r(data): info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_4 = get_load_package_created_at(p, info) assert_load_info(info) - assert get_table(p, "dim_test", "c2__nc1") == [ + assert get_table(p, "dim_test", "c2__nc1", ts_columns=[from_, to]) == [ {from_: ts_1, to: ts_3, "nk": 2, "c1": "bar", "c2__nc1": "bar"}, { from_: ts_4, @@ -242,7 +251,7 @@ def r(data): info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_1 = get_load_package_created_at(p, info) assert_load_info(info) - assert get_table(p, "dim_test", "c1") == [ + assert get_table(p, "dim_test", "c1", ts_columns=[FROM, TO]) == [ {FROM: ts_1, TO: None, "nk": 2, "c1": "bar"}, {FROM: ts_1, TO: None, "nk": 1, "c1": "foo"}, ] @@ -261,7 +270,7 @@ def r(data): info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_2 = get_load_package_created_at(p, info) assert_load_info(info) - assert get_table(p, "dim_test", "c1") == [ + assert get_table(p, "dim_test", "c1", ts_columns=[FROM, TO]) == [ {FROM: ts_1, TO: None, "nk": 2, "c1": "bar"}, {FROM: ts_1, TO: ts_2, "nk": 1, "c1": "foo"}, # updated {FROM: ts_2, TO: None, "nk": 1, "c1": "foo_updated"}, # new @@ -289,7 +298,7 @@ def r(data): ts_3 = get_load_package_created_at(p, info) assert_load_info(info) assert_records_as_set( - get_table(p, "dim_test"), + get_table(p, "dim_test", ts_columns=[FROM, TO]), [ {FROM: ts_1, TO: None, "nk": 2, "c1": "bar"}, {FROM: ts_1, TO: ts_2, "nk": 1, "c1": "foo"}, @@ -315,7 +324,7 @@ def r(data): ts_4 = get_load_package_created_at(p, info) assert_load_info(info) assert_records_as_set( - get_table(p, "dim_test"), + get_table(p, "dim_test", ts_columns=[FROM, TO]), [ {FROM: ts_1, TO: ts_4, "nk": 2, "c1": "bar"}, # updated {FROM: ts_1, TO: ts_2, "nk": 1, "c1": "foo"}, @@ -336,7 +345,7 @@ def r(data): ts_5 = get_load_package_created_at(p, info) assert_load_info(info) assert_records_as_set( - get_table(p, "dim_test"), + get_table(p, "dim_test", ts_columns=[FROM, TO]), [ {FROM: ts_1, TO: ts_4, "nk": 2, "c1": "bar"}, {FROM: ts_5, TO: None, "nk": 3, "c1": "baz"}, # new @@ -502,7 +511,7 @@ def r(data): {**{FROM: ts_3, TO: None}, **r1_no_child}, {**{FROM: ts_1, TO: None}, **r2_no_child}, ] - assert_records_as_set(get_table(p, "dim_test"), expected) + assert_records_as_set(get_table(p, "dim_test", ts_columns=[FROM, TO]), expected) # assert child records expected = [ @@ -739,7 +748,10 @@ def dim_test(data): assert load_table_counts(p, "dim_test")["dim_test"] == 3 ts3 = get_load_package_created_at(p, info) # natural key 1 should now have two records (one retired, one active) - actual = [{k: v for k, v in row.items() if k in ("nk", TO)} for row in get_table(p, "dim_test")] + actual = [ + {k: v for k, v in row.items() if k in ("nk", TO)} + for row in get_table(p, "dim_test", ts_columns=[FROM, TO]) + ] expected = [{"nk": 1, TO: ts3}, {"nk": 1, TO: None}, {"nk": 2, TO: None}] assert_records_as_set(actual, expected) # type: ignore[arg-type] @@ -753,7 +765,10 @@ def dim_test(data): assert load_table_counts(p, "dim_test")["dim_test"] == 4 ts4 = get_load_package_created_at(p, info) # natural key 1 should now have three records (two retired, one active) - actual = [{k: v for k, v in row.items() if k in ("nk", TO)} for row in get_table(p, "dim_test")] + actual = [ + {k: v for k, v in row.items() if k in ("nk", TO)} + for row in get_table(p, "dim_test", ts_columns=[FROM, TO]) + ] expected = [{"nk": 1, TO: ts3}, {"nk": 1, TO: ts4}, {"nk": 1, TO: None}, {"nk": 2, TO: None}] assert_records_as_set(actual, expected) # type: ignore[arg-type] @@ -805,7 +820,7 @@ def dim_test_compound(data): # "Doe" should now have two records (one retired, one active) actual = [ {k: v for k, v in row.items() if k in ("first_name", "last_name", TO)} - for row in get_table(p, "dim_test_compound") + for row in get_table(p, "dim_test_compound", ts_columns=[FROM, TO]) ] expected = [ {"first_name": first_name, "last_name": "Doe", TO: ts3}, @@ -869,7 +884,7 @@ def dim_test(data): ts2 = get_load_package_created_at(p, info) actual = [ {k: v for k, v in row.items() if k in ("date", "name", TO)} - for row in get_table(p, "dim_test") + for row in get_table(p, "dim_test", ts_columns=[TO]) ] expected = [ {"date": "2024-01-01", "name": "a", TO: None}, diff --git a/tests/load/pipeline/test_write_disposition_changes.py b/tests/load/pipeline/test_write_disposition_changes.py index f7d915903e..fad244fa71 100644 --- a/tests/load/pipeline/test_write_disposition_changes.py +++ b/tests/load/pipeline/test_write_disposition_changes.py @@ -128,11 +128,15 @@ def source(): # schemaless destinations allow adding of root key without the pipeline failing # they do not mind adding NOT NULL columns to tables with existing data (id NOT NULL is supported at all) # doing this will result in somewhat useless behavior - destination_allows_adding_root_key = destination_config.destination_type in [ - "dremio", - "clickhouse", - "athena", - ] + destination_allows_adding_root_key = ( + destination_config.destination_type + in [ + "dremio", + "clickhouse", + "athena", + ] + or destination_config.destination_name == "sqlalchemy_mysql" + ) if destination_allows_adding_root_key and not with_root_key: pipeline.run( diff --git a/tests/load/snowflake/test_snowflake_configuration.py b/tests/load/snowflake/test_snowflake_configuration.py index f692b7ae92..21973025c7 100644 --- a/tests/load/snowflake/test_snowflake_configuration.py +++ b/tests/load/snowflake/test_snowflake_configuration.py @@ -8,7 +8,7 @@ pytest.importorskip("snowflake") -from dlt.common.libs.sql_alchemy_shims import make_url +from dlt.common.libs.sql_alchemy_compat import make_url from dlt.common.configuration.resolve import resolve_configuration from dlt.common.configuration.exceptions import ConfigurationValueError from dlt.common.utils import digest128 @@ -152,8 +152,8 @@ def test_overwrite_query_value_from_explicit() -> None: def test_to_connector_params_private_key() -> None: creds = SnowflakeCredentials() - creds.private_key = PKEY_PEM_STR # type: ignore[assignment] - creds.private_key_passphrase = PKEY_PASSPHRASE # type: ignore[assignment] + creds.private_key = PKEY_PEM_STR + creds.private_key_passphrase = PKEY_PASSPHRASE creds.username = "user1" creds.database = "db1" creds.host = "host1" @@ -177,8 +177,8 @@ def test_to_connector_params_private_key() -> None: ) creds = SnowflakeCredentials() - creds.private_key = PKEY_DER_STR # type: ignore[assignment] - creds.private_key_passphrase = PKEY_PASSPHRASE # type: ignore[assignment] + creds.private_key = PKEY_DER_STR + creds.private_key_passphrase = PKEY_PASSPHRASE creds.username = "user1" creds.database = "db1" creds.host = "host1" diff --git a/tests/load/sqlalchemy/docker-compose.yml b/tests/load/sqlalchemy/docker-compose.yml new file mode 100644 index 0000000000..29375a0f2e --- /dev/null +++ b/tests/load/sqlalchemy/docker-compose.yml @@ -0,0 +1,16 @@ +# Use root/example as user/password credentials +version: '3.1' + +services: + + db: + image: mysql:8 + restart: always + environment: + MYSQL_ROOT_PASSWORD: root + MYSQL_DATABASE: dlt_data + MYSQL_USER: loader + MYSQL_PASSWORD: loader + ports: + - 3306:3306 + # (this is just an example, not intended to be a production configuration) diff --git a/tests/load/test_read_interfaces.py b/tests/load/test_read_interfaces.py new file mode 100644 index 0000000000..e093e4d670 --- /dev/null +++ b/tests/load/test_read_interfaces.py @@ -0,0 +1,302 @@ +from typing import Any + +import pytest +import dlt +import os + +from dlt import Pipeline +from dlt.common import Decimal +from dlt.common.utils import uniq_id + +from typing import List +from functools import reduce + +from tests.load.utils import ( + destinations_configs, + DestinationTestConfiguration, + GCS_BUCKET, + SFTP_BUCKET, + MEMORY_BUCKET, +) +from dlt.destinations import filesystem +from tests.utils import TEST_STORAGE_ROOT + + +def _run_dataset_checks( + pipeline: Pipeline, + destination_config: DestinationTestConfiguration, + table_format: Any = None, + alternate_access_pipeline: Pipeline = None, +) -> None: + destination_type = pipeline.destination_client().config.destination_type + + skip_df_chunk_size_check = False + expected_columns = ["id", "decimal", "other_decimal", "_dlt_load_id", "_dlt_id"] + if destination_type == "bigquery": + chunk_size = 50 + total_records = 80 + elif destination_type == "mssql": + chunk_size = 700 + total_records = 1000 + else: + chunk_size = 2048 + total_records = 3000 + + # on filesystem one chunk is one file and not the default vector size + if destination_type == "filesystem": + skip_df_chunk_size_check = True + + # we always expect 2 chunks based on the above setup + expected_chunk_counts = [chunk_size, total_records - chunk_size] + + @dlt.source() + def source(): + @dlt.resource( + table_format=table_format, + write_disposition="replace", + columns={ + "id": {"data_type": "bigint"}, + # we add a decimal with precision to see wether the hints are preserved + "decimal": {"data_type": "decimal", "precision": 10, "scale": 3}, + "other_decimal": {"data_type": "decimal", "precision": 12, "scale": 3}, + }, + ) + def items(): + yield from [ + { + "id": i, + "children": [{"id": i + 100}, {"id": i + 1000}], + "decimal": Decimal("10.433"), + "other_decimal": Decimal("10.433"), + } + for i in range(total_records) + ] + + @dlt.resource( + table_format=table_format, + write_disposition="replace", + columns={ + "id": {"data_type": "bigint"}, + "double_id": {"data_type": "bigint"}, + }, + ) + def double_items(): + yield from [ + { + "id": i, + "double_id": i * 2, + } + for i in range(total_records) + ] + + return [items, double_items] + + # run source + s = source() + pipeline.run(s, loader_file_format=destination_config.file_format) + + if alternate_access_pipeline: + pipeline.destination = alternate_access_pipeline.destination + + # access via key + table_relationship = pipeline._dataset()["items"] + + # full frame + df = table_relationship.df() + assert len(df.index) == total_records + + # + # check dataframes + # + + # chunk + df = table_relationship.df(chunk_size=chunk_size) + if not skip_df_chunk_size_check: + assert len(df.index) == chunk_size + # lowercase results for the snowflake case + assert set(df.columns.values) == set(expected_columns) + + # iterate all dataframes + frames = list(table_relationship.iter_df(chunk_size=chunk_size)) + if not skip_df_chunk_size_check: + assert [len(df.index) for df in frames] == expected_chunk_counts + + # check all items are present + ids = reduce(lambda a, b: a + b, [f[expected_columns[0]].to_list() for f in frames]) + assert set(ids) == set(range(total_records)) + + # access via prop + table_relationship = pipeline._dataset().items + + # + # check arrow tables + # + + # full table + table = table_relationship.arrow() + assert table.num_rows == total_records + + # chunk + table = table_relationship.arrow(chunk_size=chunk_size) + assert set(table.column_names) == set(expected_columns) + assert table.num_rows == chunk_size + + # check frame amount and items counts + tables = list(table_relationship.iter_arrow(chunk_size=chunk_size)) + assert [t.num_rows for t in tables] == expected_chunk_counts + + # check all items are present + ids = reduce(lambda a, b: a + b, [t.column(expected_columns[0]).to_pylist() for t in tables]) + assert set(ids) == set(range(total_records)) + + # check fetch accessors + table_relationship = pipeline._dataset().items + + # check accessing one item + one = table_relationship.fetchone() + assert one[0] in range(total_records) + + # check fetchall + fall = table_relationship.fetchall() + assert len(fall) == total_records + assert {item[0] for item in fall} == set(range(total_records)) + + # check fetchmany + many = table_relationship.fetchmany(chunk_size) + assert len(many) == chunk_size + + # check iterfetchmany + chunks = list(table_relationship.iter_fetch(chunk_size=chunk_size)) + assert [len(chunk) for chunk in chunks] == expected_chunk_counts + ids = reduce(lambda a, b: a + b, [[item[0] for item in chunk] for chunk in chunks]) + assert set(ids) == set(range(total_records)) + + # check that hints are carried over to arrow table + expected_decimal_precision = 10 + expected_decimal_precision_2 = 12 + if destination_config.destination_type == "bigquery": + # bigquery does not allow precision configuration.. + expected_decimal_precision = 38 + expected_decimal_precision_2 = 38 + assert ( + table_relationship.arrow().schema.field("decimal").type.precision + == expected_decimal_precision + ) + assert ( + table_relationship.arrow().schema.field("other_decimal").type.precision + == expected_decimal_precision_2 + ) + + # simple check that query also works + tname = pipeline.sql_client().make_qualified_table_name("items") + query_relationship = pipeline._dataset()(f"select * from {tname} where id < 20") + + # we selected the first 20 + table = query_relationship.arrow() + assert table.num_rows == 20 + + # check join query + tdname = pipeline.sql_client().make_qualified_table_name("double_items") + query = ( + f"SELECT i.id, di.double_id FROM {tname} as i JOIN {tdname} as di ON (i.id = di.id) WHERE" + " i.id < 20 ORDER BY i.id ASC" + ) + join_relationship = pipeline._dataset()(query) + table = join_relationship.fetchall() + assert len(table) == 20 + assert list(table[0]) == [0, 0] + assert list(table[5]) == [5, 10] + assert list(table[10]) == [10, 20] + + # check loads table access + loads_table = pipeline._dataset()[pipeline.default_schema.loads_table_name] + loads_table.fetchall() + + +@pytest.mark.essential +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, +) +def test_read_interfaces_sql(destination_config: DestinationTestConfiguration) -> None: + pipeline = destination_config.setup_pipeline( + "read_pipeline", dataset_name="read_test", dev_mode=True + ) + _run_dataset_checks(pipeline, destination_config) + + +@pytest.mark.essential +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + local_filesystem_configs=True, + all_buckets_filesystem_configs=True, + bucket_exclude=[SFTP_BUCKET, MEMORY_BUCKET], + ), # TODO: make SFTP work + ids=lambda x: x.name, +) +def test_read_interfaces_filesystem(destination_config: DestinationTestConfiguration) -> None: + # we force multiple files per table, they may only hold 700 items + os.environ["DATA_WRITER__FILE_MAX_ITEMS"] = "700" + + if destination_config.file_format not in ["parquet", "jsonl"]: + pytest.skip( + f"Test only works for jsonl and parquet, given: {destination_config.file_format}" + ) + + pipeline = destination_config.setup_pipeline( + "read_pipeline", + dataset_name="read_test", + dev_mode=True, + ) + + _run_dataset_checks(pipeline, destination_config) + + # for gcs buckets we additionally test the s3 compat layer + if destination_config.bucket_url == GCS_BUCKET: + gcp_bucket = filesystem( + GCS_BUCKET.replace("gs://", "s3://"), destination_name="filesystem_s3_gcs_comp" + ) + pipeline = destination_config.setup_pipeline( + "read_pipeline", dataset_name="read_test", dev_mode=True, destination=gcp_bucket + ) + _run_dataset_checks(pipeline, destination_config) + + +@pytest.mark.essential +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + table_format_filesystem_configs=True, + with_table_format="delta", + bucket_exclude=[SFTP_BUCKET, MEMORY_BUCKET], + ), + ids=lambda x: x.name, +) +def test_delta_tables(destination_config: DestinationTestConfiguration) -> None: + os.environ["DATA_WRITER__FILE_MAX_ITEMS"] = "700" + + pipeline = destination_config.setup_pipeline( + "read_pipeline", + dataset_name="read_test", + ) + + # in case of gcs we use the s3 compat layer for reading + # for writing we still need to use the gc authentication, as delta_rs seems to use + # methods on the s3 interface that are not implemented by gcs + access_pipeline = pipeline + if destination_config.bucket_url == GCS_BUCKET: + gcp_bucket = filesystem( + GCS_BUCKET.replace("gs://", "s3://"), destination_name="filesystem_s3_gcs_comp" + ) + access_pipeline = destination_config.setup_pipeline( + "read_pipeline", dataset_name="read_test", destination=gcp_bucket + ) + + _run_dataset_checks( + pipeline, + destination_config, + table_format="delta", + alternate_access_pipeline=access_pipeline, + ) diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index 199b4b83b7..3636b3e53a 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -347,9 +347,13 @@ def test_execute_df(client: SqlJobClientBase) -> None: f"SELECT * FROM {f_q_table_name} ORDER BY col ASC" ) as curr: # be compatible with duckdb vector size - df_1 = curr.df(chunk_size=chunk_size) - df_2 = curr.df(chunk_size=chunk_size) - df_3 = curr.df(chunk_size=chunk_size) + iterator = curr.iter_df(chunk_size) + df_1 = next(iterator) + df_2 = next(iterator) + try: + df_3 = next(iterator) + except StopIteration: + df_3 = None # Force lower case df columns, snowflake has all cols uppercase for df in [df_1, df_2, df_3]: if df is not None: diff --git a/tests/load/utils.py b/tests/load/utils.py index 19601f2cf1..575938af15 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -162,7 +162,7 @@ class DestinationTestConfiguration: supports_dbt: bool = True disable_compression: bool = False dev_mode: bool = False - credentials: Optional[Union[CredentialsConfiguration, Dict[str, Any]]] = None + credentials: Optional[Union[CredentialsConfiguration, Dict[str, Any], str]] = None env_vars: Optional[Dict[str, str]] = None destination_name: Optional[str] = None @@ -215,8 +215,11 @@ def setup(self) -> None: os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" if self.credentials is not None: - for key, value in dict(self.credentials).items(): - os.environ[f"DESTINATION__CREDENTIALS__{key.upper()}"] = str(value) + if isinstance(self.credentials, str): + os.environ["DESTINATION__CREDENTIALS"] = self.credentials + else: + for key, value in dict(self.credentials).items(): + os.environ[f"DESTINATION__CREDENTIALS__{key.upper()}"] = str(value) if self.env_vars is not None: for k, v in self.env_vars.items(): @@ -331,15 +334,19 @@ def destinations_configs( destination_configs += [ DestinationTestConfiguration( destination_type="sqlalchemy", - supports_merge=False, + supports_merge=True, supports_dbt=False, destination_name="sqlalchemy_mysql", + credentials=( # Use root cause we need to create databases, + "mysql://root:root@127.0.0.1:3306/dlt_data" + ), ), DestinationTestConfiguration( destination_type="sqlalchemy", - supports_merge=False, + supports_merge=True, supports_dbt=False, destination_name="sqlalchemy_sqlite", + credentials="sqlite:///_storage/dl_data.sqlite", ), ] @@ -589,6 +596,7 @@ def destinations_configs( bucket_url=bucket, extra_info=bucket, supports_merge=False, + file_format="parquet", ) ] @@ -671,6 +679,7 @@ def destinations_configs( # add marks destination_configs = [ + # TODO: fix this, probably via pytest plugin that processes parametrize params cast( DestinationTestConfiguration, pytest.param( @@ -680,7 +689,6 @@ def destinations_configs( ) for conf in destination_configs ] - return destination_configs diff --git a/.github/weaviate-compose.yml b/tests/load/weaviate/docker-compose.yml similarity index 100% rename from .github/weaviate-compose.yml rename to tests/load/weaviate/docker-compose.yml diff --git a/tests/pipeline/test_dlt_versions.py b/tests/pipeline/test_dlt_versions.py index 98ac7a3728..c7a8832214 100644 --- a/tests/pipeline/test_dlt_versions.py +++ b/tests/pipeline/test_dlt_versions.py @@ -13,7 +13,6 @@ from dlt.common.runners import Venv from dlt.common.storages.exceptions import StorageMigrationError from dlt.common.utils import custom_environ, set_working_dir -from dlt.common.configuration.paths import get_dlt_data_dir from dlt.common.storages import FileStorage from dlt.common.schema.typing import ( LOADS_TABLE_NAME, @@ -77,7 +76,7 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: # execute in test storage with set_working_dir(TEST_STORAGE_ROOT): # store dlt data in test storage (like patch_home_dir) - with custom_environ({DLT_DATA_DIR: get_dlt_data_dir()}): + with custom_environ({DLT_DATA_DIR: dlt.current.run().data_dir}): # save database outside of pipeline dir with custom_environ( {"DESTINATION__DUCKDB__CREDENTIALS": "duckdb:///test_github_3.duckdb"} @@ -222,7 +221,7 @@ def test_filesystem_pipeline_with_dlt_update(test_storage: FileStorage) -> None: # execute in test storage with set_working_dir(TEST_STORAGE_ROOT): # store dlt data in test storage (like patch_home_dir) - with custom_environ({DLT_DATA_DIR: get_dlt_data_dir()}): + with custom_environ({DLT_DATA_DIR: dlt.current.run().data_dir}): # create virtual env with (0.4.9) where filesystem started to store state with Venv.create(tempfile.mkdtemp(), ["dlt==0.4.9"]) as venv: try: @@ -294,7 +293,7 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: # execute in test storage with set_working_dir(TEST_STORAGE_ROOT): # store dlt data in test storage (like patch_home_dir) - with custom_environ({DLT_DATA_DIR: get_dlt_data_dir()}): + with custom_environ({DLT_DATA_DIR: dlt.current.run().data_dir}): # save database outside of pipeline dir with custom_environ( {"DESTINATION__DUCKDB__CREDENTIALS": "duckdb:///test_github_3.duckdb"} @@ -369,7 +368,7 @@ def test_normalize_package_with_dlt_update(test_storage: FileStorage) -> None: # execute in test storage with set_working_dir(TEST_STORAGE_ROOT): # store dlt data in test storage (like patch_home_dir) - with custom_environ({DLT_DATA_DIR: get_dlt_data_dir()}): + with custom_environ({DLT_DATA_DIR: dlt.current.run().data_dir}): # save database outside of pipeline dir with custom_environ( {"DESTINATION__DUCKDB__CREDENTIALS": "duckdb:///test_github_3.duckdb"} @@ -404,7 +403,7 @@ def test_scd2_pipeline_update(test_storage: FileStorage) -> None: # execute in test storage with set_working_dir(TEST_STORAGE_ROOT): # store dlt data in test storage (like patch_home_dir) - with custom_environ({DLT_DATA_DIR: get_dlt_data_dir()}): + with custom_environ({DLT_DATA_DIR: dlt.current.run().data_dir}): # save database outside of pipeline dir with custom_environ( {"DESTINATION__DUCKDB__CREDENTIALS": "duckdb:///test_github_3.duckdb"} diff --git a/tests/pipeline/test_pipeline_state.py b/tests/pipeline/test_pipeline_state.py index 11c45d72cc..303d2fdb6f 100644 --- a/tests/pipeline/test_pipeline_state.py +++ b/tests/pipeline/test_pipeline_state.py @@ -11,14 +11,14 @@ ) from dlt.common.schema import Schema from dlt.common.schema.utils import pipeline_state_table -from dlt.common.source import get_current_pipe_name +from dlt.common.pipeline import get_current_pipe_name from dlt.common.storages import FileStorage from dlt.common import pipeline as state_module from dlt.common.storages.load_package import TPipelineStateDoc from dlt.common.utils import uniq_id from dlt.common.destination.reference import Destination, StateInfo - from dlt.common.validation import validate_dict + from dlt.destinations.utils import get_pipeline_state_query_columns from dlt.pipeline.exceptions import PipelineStateEngineNoUpgradePathException, PipelineStepFailed from dlt.pipeline.pipeline import Pipeline diff --git a/tests/plugins/__init__.py b/tests/plugins/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/plugins/dlt_example_plugin/Makefile b/tests/plugins/dlt_example_plugin/Makefile new file mode 100644 index 0000000000..cdf97bb120 --- /dev/null +++ b/tests/plugins/dlt_example_plugin/Makefile @@ -0,0 +1,8 @@ + +uninstall-example-plugin: + pip uninstall example_plugin -y + +install-example-plugin: uninstall-example-plugin + # this builds and installs the example plugin + poetry build + pip install dist/example_plugin-0.1.0-py3-none-any.whl \ No newline at end of file diff --git a/tests/plugins/dlt_example_plugin/README.md b/tests/plugins/dlt_example_plugin/README.md new file mode 100644 index 0000000000..d1cce015be --- /dev/null +++ b/tests/plugins/dlt_example_plugin/README.md @@ -0,0 +1,4 @@ +# Example DLT Plugin +1. Plugin name must start with dlt- to be recognized at run time +2. Export the module that registers plugin in an entry point +3. Use pluggy hookspecs thst you can find here and there in the dlt \ No newline at end of file diff --git a/tests/plugins/dlt_example_plugin/dlt_example_plugin/__init__.py b/tests/plugins/dlt_example_plugin/dlt_example_plugin/__init__.py new file mode 100644 index 0000000000..345559e701 --- /dev/null +++ b/tests/plugins/dlt_example_plugin/dlt_example_plugin/__init__.py @@ -0,0 +1,29 @@ +import os +from typing import ClassVar + +from dlt.common.configuration import plugins +from dlt.common.configuration.specs.pluggable_run_context import SupportsRunContext +from dlt.common.runtime.run_context import RunContext, DOT_DLT + +from tests.utils import TEST_STORAGE_ROOT + + +class RunContextTest(RunContext): + CONTEXT_NAME: ClassVar[str] = "dlt-test" + + @property + def run_dir(self) -> str: + return os.path.abspath("tests") + + @property + def settings_dir(self) -> str: + return os.path.join(self.run_dir, DOT_DLT) + + @property + def data_dir(self) -> str: + return os.path.abspath(TEST_STORAGE_ROOT) + + +@plugins.hookimpl(specname="plug_run_context") +def plug_run_context_impl() -> SupportsRunContext: + return RunContextTest() diff --git a/tests/plugins/dlt_example_plugin/pyproject.toml b/tests/plugins/dlt_example_plugin/pyproject.toml new file mode 100644 index 0000000000..475254e591 --- /dev/null +++ b/tests/plugins/dlt_example_plugin/pyproject.toml @@ -0,0 +1,20 @@ +[tool.poetry] +name = "dlt-example-plugin" +version = "0.1.0" +description = "" +authors = ["dave "] +readme = "README.md" +packages = [ + { include = "dlt_example_plugin" }, +] + +[tool.poetry.plugins.dlt] +dlt-example-plugin = "dlt_example_plugin" + +[tool.poetry.dependencies] +python = ">=3.8.1,<3.13" +dlt={"path"="../../../"} + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" diff --git a/tests/plugins/test_plugin_discovery.py b/tests/plugins/test_plugin_discovery.py new file mode 100644 index 0000000000..3fe18860d7 --- /dev/null +++ b/tests/plugins/test_plugin_discovery.py @@ -0,0 +1,53 @@ +from subprocess import CalledProcessError +import pytest +import os +import sys +import tempfile +import shutil +import importlib + +from dlt.common.configuration.container import Container +from dlt.common.runners import Venv +from dlt.common.configuration import plugins +from dlt.common.runtime import run_context +from tests.utils import TEST_STORAGE_ROOT + + +@pytest.fixture(scope="module", autouse=True) +def plugin_install(): + # install plugin into temp dir + temp_dir = tempfile.mkdtemp() + venv = Venv.restore_current() + try: + print( + venv.run_module( + "pip", "install", "tests/plugins/dlt_example_plugin", "--target", temp_dir + ) + ) + except CalledProcessError as c_err: + print(c_err.stdout) + print(c_err.stderr) + raise + sys.path.insert(0, temp_dir) + + # remove current plugin manager + container = Container() + if plugins.PluginContext in container: + del container[plugins.PluginContext] + + # reload metadata module + importlib.reload(importlib.metadata) + + yield + + # remove distribution search, temp package and plugin manager + sys.path.remove(temp_dir) + shutil.rmtree(temp_dir) + importlib.reload(importlib.metadata) + del container[plugins.PluginContext] + + +def test_example_plugin() -> None: + context = run_context.current() + assert context.name == "dlt-test" + assert context.data_dir == os.path.abspath(TEST_STORAGE_ROOT) diff --git a/tests/sources/helpers/rest_client/test_paginators.py b/tests/sources/helpers/rest_client/test_paginators.py index 39e3d767a0..49a6275536 100644 --- a/tests/sources/helpers/rest_client/test_paginators.py +++ b/tests/sources/helpers/rest_client/test_paginators.py @@ -380,6 +380,33 @@ def test_update_state(self): paginator.update_state(response, data=NON_EMPTY_PAGE) assert paginator.has_next_page is False + def test_init_request(self): + paginator = PageNumberPaginator(base_page=1, total_path=None) + request = Mock(Request) + request.params = {} + response = Mock(Response, json=lambda: "OK") + + assert paginator.current_value == 1 + assert paginator.has_next_page is True + paginator.init_request(request) + + paginator.update_state(response, data=NON_EMPTY_PAGE) + paginator.update_request(request) + + assert paginator.current_value == 2 + assert paginator.has_next_page is True + assert request.params["page"] == 2 + + paginator.update_state(response, data=None) + paginator.update_request(request) + + assert paginator.current_value == 2 + assert paginator.has_next_page is False + + paginator.init_request(request) + assert paginator.current_value == 1 + assert paginator.has_next_page is True + def test_update_state_with_string_total_pages(self): paginator = PageNumberPaginator(base_page=1, page=1) response = Mock(Response, json=lambda: {"total": "3"}) diff --git a/tests/sources/rest_api/configurations/source_configs.py b/tests/sources/rest_api/configurations/source_configs.py index 8e26a4183b..fb24a0ad49 100644 --- a/tests/sources/rest_api/configurations/source_configs.py +++ b/tests/sources/rest_api/configurations/source_configs.py @@ -4,6 +4,7 @@ import requests import dlt import dlt.common +from dlt.common.configuration.exceptions import ConfigFieldMissingException from dlt.common.typing import TSecretStrValue from dlt.common.exceptions import DictValidationException from dlt.common.configuration.specs import configspec @@ -11,7 +12,7 @@ import dlt.sources.helpers import dlt.sources.helpers.requests from dlt.sources.helpers.rest_client.paginators import HeaderLinkPaginator -from dlt.sources.helpers.rest_client.auth import OAuth2AuthBase +from dlt.sources.helpers.rest_client.auth import OAuth2AuthBase, APIKeyAuth from dlt.sources.helpers.rest_client.paginators import SinglePagePaginator from dlt.sources.helpers.rest_client.auth import HttpBasicAuth @@ -32,6 +33,47 @@ exception=DictValidationException, config={"resources": []}, ), + # expect missing api_key at the right config section coming from the shorthand auth notation + ConfigTest( + expected_message="SOURCES__REST_API__INVALID_CONFIG__CREDENTIALS__API_KEY", + exception=ConfigFieldMissingException, + config={ + "client": { + "base_url": "https://api.example.com", + "auth": "api_key", + }, + "resources": ["posts"], + }, + ), + # expect missing api_key at the right config section coming from the explicit auth config base + ConfigTest( + expected_message="SOURCES__REST_API__INVALID_CONFIG__CREDENTIALS__API_KEY", + exception=ConfigFieldMissingException, + config={ + "client": { + "base_url": "https://api.example.com", + "auth": APIKeyAuth(), + }, + "resources": ["posts"], + }, + ), + # expect missing api_key at the right config section coming from the dict notation + # TODO: currently this test fails on validation, api_key is necessary. validation happens + # before secrets are bound, this must be changed + ConfigTest( + expected_message=( + "For ApiKeyAuthConfig: In path ./client/auth: following required fields are missing" + " {'api_key'}" + ), + exception=DictValidationException, + config={ + "client": { + "base_url": "https://api.example.com", + "auth": {"type": "api_key", "location": "header"}, + }, + "resources": ["posts"], + }, + ), ConfigTest( expected_message="In path ./client: following fields are unexpected {'invalid_key'}", exception=DictValidationException, diff --git a/tests/sources/rest_api/configurations/test_configuration.py b/tests/sources/rest_api/configurations/test_configuration.py index 6adbfc5175..ca84479a0d 100644 --- a/tests/sources/rest_api/configurations/test_configuration.py +++ b/tests/sources/rest_api/configurations/test_configuration.py @@ -44,7 +44,7 @@ @pytest.mark.parametrize("expected_message, exception, invalid_config", INVALID_CONFIGS) def test_invalid_configurations(expected_message, exception, invalid_config): with pytest.raises(exception, match=expected_message): - rest_api_source(invalid_config) + rest_api_source(invalid_config, name="invalid_config") @pytest.mark.parametrize("valid_config", VALID_CONFIGS) diff --git a/tests/sources/rest_api/configurations/test_resolve_config.py b/tests/sources/rest_api/configurations/test_resolve_config.py index a0ca7ce890..d3d9308df1 100644 --- a/tests/sources/rest_api/configurations/test_resolve_config.py +++ b/tests/sources/rest_api/configurations/test_resolve_config.py @@ -88,32 +88,34 @@ def test_bind_path_param() -> None: def test_process_parent_data_item() -> None: - resolve_param = ResolvedParam( - "id", {"field": "obj_id", "resource": "issues", "type": "resolve"} - ) + resolve_params = [ + ResolvedParam("id", {"field": "obj_id", "resource": "issues", "type": "resolve"}) + ] bound_path, parent_record = process_parent_data_item( - "dlt-hub/dlt/issues/{id}/comments", {"obj_id": 12345}, resolve_param, None + "dlt-hub/dlt/issues/{id}/comments", {"obj_id": 12345}, resolve_params, None ) assert bound_path == "dlt-hub/dlt/issues/12345/comments" assert parent_record == {} bound_path, parent_record = process_parent_data_item( - "dlt-hub/dlt/issues/{id}/comments", {"obj_id": 12345}, resolve_param, ["obj_id"] + "dlt-hub/dlt/issues/{id}/comments", {"obj_id": 12345}, resolve_params, ["obj_id"] ) assert parent_record == {"_issues_obj_id": 12345} bound_path, parent_record = process_parent_data_item( "dlt-hub/dlt/issues/{id}/comments", {"obj_id": 12345, "obj_node": "node_1"}, - resolve_param, + resolve_params, ["obj_id", "obj_node"], ) assert parent_record == {"_issues_obj_id": 12345, "_issues_obj_node": "node_1"} # test nested data - resolve_param_nested = ResolvedParam( - "id", {"field": "some_results.obj_id", "resource": "issues", "type": "resolve"} - ) + resolve_param_nested = [ + ResolvedParam( + "id", {"field": "some_results.obj_id", "resource": "issues", "type": "resolve"} + ) + ] item = {"some_results": {"obj_id": 12345}} bound_path, parent_record = process_parent_data_item( "dlt-hub/dlt/issues/{id}/comments", item, resolve_param_nested, None @@ -123,7 +125,7 @@ def test_process_parent_data_item() -> None: # param path not found with pytest.raises(ValueError) as val_ex: bound_path, parent_record = process_parent_data_item( - "dlt-hub/dlt/issues/{id}/comments", {"_id": 12345}, resolve_param, None + "dlt-hub/dlt/issues/{id}/comments", {"_id": 12345}, resolve_params, None ) assert "Transformer expects a field 'obj_id'" in str(val_ex.value) @@ -132,11 +134,36 @@ def test_process_parent_data_item() -> None: bound_path, parent_record = process_parent_data_item( "dlt-hub/dlt/issues/{id}/comments", {"obj_id": 12345, "obj_node": "node_1"}, - resolve_param, + resolve_params, ["obj_id", "node"], ) assert "in order to include it in child records under _issues_node" in str(val_ex.value) + # Resolve multiple parameters from a single record + multi_resolve_params = [ + ResolvedParam("issue_id", {"field": "issue", "resource": "comments", "type": "resolve"}), + ResolvedParam("id", {"field": "id", "resource": "comments", "type": "resolve"}), + ] + + bound_path, parent_record = process_parent_data_item( + "dlt-hub/dlt/issues/{issue_id}/comments/{id}", + {"issue": 12345, "id": 56789}, + multi_resolve_params, + None, + ) + assert bound_path == "dlt-hub/dlt/issues/12345/comments/56789" + assert parent_record == {} + + # param path not found with multiple parameters + with pytest.raises(ValueError) as val_ex: + bound_path, parent_record = process_parent_data_item( + "dlt-hub/dlt/issues/{issue_id}/comments/{id}", + {"_issue": 12345, "id": 56789}, + multi_resolve_params, + None, + ) + assert "Transformer expects a field 'issue'" in str(val_ex.value) + def test_two_resources_can_depend_on_one_parent_resource() -> None: user_id = { @@ -173,7 +200,7 @@ def test_two_resources_can_depend_on_one_parent_resource() -> None: assert resources["user_details"]._pipe.parent.name == "users" -def test_dependent_resource_cannot_bind_multiple_parameters() -> None: +def test_dependent_resource_can_bind_multiple_parameters() -> None: config: RESTAPIConfig = { "client": { "base_url": "https://api.example.com", @@ -200,15 +227,9 @@ def test_dependent_resource_cannot_bind_multiple_parameters() -> None: }, ], } - with pytest.raises(ValueError) as e: - rest_api_resources(config) - error_part_1 = re.escape( - "Multiple resolved params for resource user_details: [ResolvedParam(param_name='user_id'" - ) - error_part_2 = re.escape("ResolvedParam(param_name='group_id'") - assert e.match(error_part_1) - assert e.match(error_part_2) + resources = rest_api_source(config).resources + assert resources["user_details"]._pipe.parent.name == "users" def test_one_resource_cannot_bind_two_parents() -> None: @@ -244,7 +265,7 @@ def test_one_resource_cannot_bind_two_parents() -> None: rest_api_resources(config) error_part_1 = re.escape( - "Multiple resolved params for resource user_details: [ResolvedParam(param_name='user_id'" + "Multiple parent resources for user_details: [ResolvedParam(param_name='user_id'" ) error_part_2 = re.escape("ResolvedParam(param_name='group_id'") assert e.match(error_part_1) diff --git a/tests/sources/rest_api/test_rest_api_pipeline_template.py b/tests/sources/rest_api/test_rest_api_pipeline_template.py index cd5cca0b10..b397984d9f 100644 --- a/tests/sources/rest_api/test_rest_api_pipeline_template.py +++ b/tests/sources/rest_api/test_rest_api_pipeline_template.py @@ -18,6 +18,6 @@ def test_all_examples(example_name: str) -> None: github_token: TSecretStrValue = dlt.secrets.get("sources.github.access_token") if not github_token: # try to get GITHUB TOKEN which is available on github actions, fallback to None if not available - github_token = os.environ.get("GITHUB_TOKEN", None) # type: ignore + github_token = os.environ.get("GITHUB_TOKEN", None) dlt.secrets["sources.rest_api_pipeline.github.access_token"] = github_token getattr(rest_api_pipeline, example_name)() diff --git a/tests/sources/rest_api/test_rest_api_source.py b/tests/sources/rest_api/test_rest_api_source.py index f6b97a7f47..153d35416f 100644 --- a/tests/sources/rest_api/test_rest_api_source.py +++ b/tests/sources/rest_api/test_rest_api_source.py @@ -1,9 +1,13 @@ import dlt import pytest + +from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext + from dlt.sources.rest_api.typing import RESTAPIConfig from dlt.sources.helpers.rest_client.paginators import SinglePagePaginator +from dlt.sources.rest_api import rest_api_source, rest_api -from dlt.sources.rest_api import rest_api_source +from tests.common.configuration.utils import environment, toml_providers from tests.utils import ALL_DESTINATIONS, assert_load_info, load_table_counts @@ -16,8 +20,32 @@ def _make_pipeline(destination_name: str): ) +def test_rest_api_config_provider(toml_providers: ConfigProvidersContext) -> None: + # mock dicts in toml provider + dlt.config["client"] = { + "base_url": "https://pokeapi.co/api/v2/", + } + dlt.config["resources"] = [ + { + "name": "pokemon_list", + "endpoint": { + "path": "pokemon", + "paginator": SinglePagePaginator(), + "data_selector": "results", + "params": { + "limit": 10, + }, + }, + } + ] + pipeline = _make_pipeline("duckdb") + load_info = pipeline.run(rest_api()) + print(load_info) + + @pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) -def test_rest_api_source(destination_name: str) -> None: +@pytest.mark.parametrize("invocation_type", ("deco", "factory")) +def test_rest_api_source(destination_name: str, invocation_type: str) -> None: config: RESTAPIConfig = { "client": { "base_url": "https://pokeapi.co/api/v2/", @@ -38,7 +66,10 @@ def test_rest_api_source(destination_name: str) -> None: "location", ], } - data = rest_api_source(config) + if invocation_type == "deco": + data = rest_api(**config) + else: + data = rest_api_source(config) pipeline = _make_pipeline(destination_name) load_info = pipeline.run(data) print(load_info) @@ -54,7 +85,8 @@ def test_rest_api_source(destination_name: str) -> None: @pytest.mark.parametrize("destination_name", ALL_DESTINATIONS) -def test_dependent_resource(destination_name: str) -> None: +@pytest.mark.parametrize("invocation_type", ("deco", "factory")) +def test_dependent_resource(destination_name: str, invocation_type: str) -> None: config: RESTAPIConfig = { "client": { "base_url": "https://pokeapi.co/api/v2/", @@ -95,7 +127,10 @@ def test_dependent_resource(destination_name: str) -> None: ], } - data = rest_api_source(config) + if invocation_type == "deco": + data = rest_api(**config) + else: + data = rest_api_source(config) pipeline = _make_pipeline(destination_name) load_info = pipeline.run(data) assert_load_info(load_info) diff --git a/tests/sources/sql_database/test_arrow_helpers.py b/tests/sources/sql_database/test_arrow_helpers.py index 8328bed89b..abd063889c 100644 --- a/tests/sources/sql_database/test_arrow_helpers.py +++ b/tests/sources/sql_database/test_arrow_helpers.py @@ -65,7 +65,7 @@ def test_row_tuples_to_arrow_unknown_types(all_unknown: bool) -> None: col.pop("data_type", None) # Call the function - result = row_tuples_to_arrow(rows, columns, tz="UTC") # type: ignore[arg-type] + result = row_tuples_to_arrow(rows, columns=columns, tz="UTC") # type: ignore # Result is arrow table containing all columns in original order with correct types assert result.num_columns == len(columns) @@ -98,7 +98,7 @@ def test_row_tuples_to_arrow_detects_range_type() -> None: (IntRange(3, 30),), ] result = row_tuples_to_arrow( - rows=rows, # type: ignore[arg-type] + rows=rows, columns={"range_col": {"name": "range_col", "nullable": False}}, tz="UTC", ) diff --git a/tests/sources/test_pipeline_templates.py b/tests/sources/test_pipeline_templates.py index 0743a21fef..a83ccff67f 100644 --- a/tests/sources/test_pipeline_templates.py +++ b/tests/sources/test_pipeline_templates.py @@ -1,61 +1,20 @@ import pytest +import importlib @pytest.mark.parametrize( - "example_name", - ("load_all_datatypes",), + "template_name,examples", + [ + ("debug_pipeline", ("load_all_datatypes",)), + ("default_pipeline", ("load_api_data", "load_sql_data", "load_pandas_data")), + ("arrow_pipeline", ("load_arrow_tables",)), + ("dataframe_pipeline", ("load_dataframe",)), + ("requests_pipeline", ("load_chess_data",)), + ("github_api_pipeline", ("run_source",)), + ("fruitshop_pipeline", ("load_shop",)), + ], ) -def test_debug_pipeline(example_name: str) -> None: - from dlt.sources.pipeline_templates import debug_pipeline - - getattr(debug_pipeline, example_name)() - - -@pytest.mark.parametrize( - "example_name", - ("load_arrow_tables",), -) -def test_arrow_pipeline(example_name: str) -> None: - from dlt.sources.pipeline_templates import arrow_pipeline - - getattr(arrow_pipeline, example_name)() - - -@pytest.mark.parametrize( - "example_name", - ("load_dataframe",), -) -def test_dataframe_pipeline(example_name: str) -> None: - from dlt.sources.pipeline_templates import dataframe_pipeline - - getattr(dataframe_pipeline, example_name)() - - -@pytest.mark.parametrize( - "example_name", - ("load_stuff",), -) -def test_default_pipeline(example_name: str) -> None: - from dlt.sources.pipeline_templates import default_pipeline - - getattr(default_pipeline, example_name)() - - -@pytest.mark.parametrize( - "example_name", - ("load_chess_data",), -) -def test_requests_pipeline(example_name: str) -> None: - from dlt.sources.pipeline_templates import requests_pipeline - - getattr(requests_pipeline, example_name)() - - -@pytest.mark.parametrize( - "example_name", - ("load_api_data", "load_sql_data", "load_pandas_data"), -) -def test_intro_pipeline(example_name: str) -> None: - from dlt.sources.pipeline_templates import intro_pipeline - - getattr(intro_pipeline, example_name)() +def test_debug_pipeline(template_name: str, examples: str) -> None: + demo_module = importlib.import_module(f"dlt.sources.pipeline_templates.{template_name}") + for example_name in examples: + getattr(demo_module, example_name)() diff --git a/tests/utils.py b/tests/utils.py index 813deea69f..876737bd6a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -12,6 +12,7 @@ from requests import Response import dlt +from dlt.common import known_env from dlt.common.configuration.container import Container from dlt.common.configuration.providers import ( DictionaryProvider, @@ -24,8 +25,13 @@ from dlt.common.configuration.specs.config_providers_context import ( ConfigProvidersContext, ) +from dlt.common.configuration.specs.pluggable_run_context import ( + PluggableRunContext, + SupportsRunContext, +) from dlt.common.pipeline import LoadInfo, PipelineContext, SupportsPipeline from dlt.common.runtime.init import init_logging +from dlt.common.runtime.run_context import DOT_DLT, RunContext from dlt.common.runtime.telemetry import start_telemetry, stop_telemetry from dlt.common.schema import Schema from dlt.common.storages import FileStorage @@ -164,19 +170,66 @@ def duckdb_pipeline_location() -> Iterator[None]: yield +class MockableRunContext(RunContext): + @property + def name(self) -> str: + return self._name + + @property + def global_dir(self) -> str: + return self._global_dir + + @property + def run_dir(self) -> str: + return os.environ.get(known_env.DLT_PROJECT_DIR, self._run_dir) + + # @property + # def settings_dir(self) -> str: + # return self._settings_dir + + @property + def data_dir(self) -> str: + return os.environ.get(known_env.DLT_DATA_DIR, self._data_dir) + + _name: str + _global_dir: str + _run_dir: str + _settings_dir: str + _data_dir: str + + @classmethod + def from_context(cls, ctx: SupportsRunContext) -> "MockableRunContext": + cls_ = cls() + cls_._name = ctx.name + cls_._global_dir = ctx.global_dir + cls_._run_dir = ctx.run_dir + cls_._settings_dir = ctx.settings_dir + cls_._data_dir = ctx.data_dir + return cls_ + + @pytest.fixture(autouse=True) def patch_home_dir() -> Iterator[None]: - with patch("dlt.common.configuration.paths._get_user_home_dir") as _get_home_dir: - _get_home_dir.return_value = os.path.abspath(TEST_STORAGE_ROOT) + ctx = PluggableRunContext() + mock = MockableRunContext.from_context(ctx.context) + mock._global_dir = mock._data_dir = os.path.join(os.path.abspath(TEST_STORAGE_ROOT), DOT_DLT) + ctx.context = mock + + with Container().injectable_context(ctx): yield @pytest.fixture(autouse=True) def patch_random_home_dir() -> Iterator[None]: - global_dir = os.path.join(TEST_STORAGE_ROOT, "global_" + uniq_id()) - os.makedirs(global_dir, exist_ok=True) - with patch("dlt.common.configuration.paths._get_user_home_dir") as _get_home_dir: - _get_home_dir.return_value = os.path.abspath(global_dir) + ctx = PluggableRunContext() + mock = MockableRunContext.from_context(ctx.context) + mock._global_dir = mock._data_dir = os.path.join( + os.path.join(TEST_STORAGE_ROOT, "global_" + uniq_id()), DOT_DLT + ) + ctx.context = mock + + os.makedirs(mock.global_dir, exist_ok=True) + with Container().injectable_context(ctx): yield @@ -391,16 +444,16 @@ def assert_query_data( @contextlib.contextmanager -def reset_providers(project_dir: str) -> Iterator[ConfigProvidersContext]: - """Context manager injecting standard set of providers where toml providers are initialized from `project_dir`""" - return _reset_providers(project_dir) +def reset_providers(settings_dir: str) -> Iterator[ConfigProvidersContext]: + """Context manager injecting standard set of providers where toml providers are initialized from `settings_dir`""" + return _reset_providers(settings_dir) -def _reset_providers(project_dir: str) -> Iterator[ConfigProvidersContext]: +def _reset_providers(settings_dir: str) -> Iterator[ConfigProvidersContext]: ctx = ConfigProvidersContext() ctx.providers.clear() ctx.add_provider(EnvironProvider()) - ctx.add_provider(SecretsTomlProvider(project_dir=project_dir)) - ctx.add_provider(ConfigTomlProvider(project_dir=project_dir)) + ctx.add_provider(SecretsTomlProvider(settings_dir=settings_dir)) + ctx.add_provider(ConfigTomlProvider(settings_dir=settings_dir)) with Container().injectable_context(ctx): yield ctx