Skip to content

Commit

Permalink
Merge pull request #141 from dlt-hub/rfix/source-configs-improvements
Browse files Browse the repository at this point in the history
source decorator and config injection improvements
  • Loading branch information
rudolfix authored Feb 20, 2023
2 parents 24dfbd7 + 15c358e commit 6336096
Show file tree
Hide file tree
Showing 57 changed files with 1,079 additions and 312 deletions.
3 changes: 2 additions & 1 deletion dlt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
from dlt.common.configuration.specs import CredentialsConfiguration as _CredentialsConfiguration
from dlt.extract.decorators import source, resource, transformer, defer
from dlt.extract.source import with_table_name
from dlt.pipeline import pipeline as _pipeline, run, attach, Pipeline, dbt
from dlt.pipeline import pipeline as _pipeline, run, attach, Pipeline, dbt, current as _current
from dlt.pipeline.state import state

pipeline = _pipeline
current = _current

TSecretValue = _TSecretValue
"When typing source/resource function arguments it indicates that a given argument is a secret and should be taken from dlt.secrets."
Expand Down
1 change: 0 additions & 1 deletion dlt/cli/init_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def _detect_required_configs(visitor: PipelineScriptVisitor) -> Tuple[Dict[str,

if val_store is not None:
# we are sure that all resources come from single file so we can put them in single section
# sections = () if len(known_imported_sources) == 1 else ("sources", source_name)
val_store[source_name + ":" + field_name] = WritableConfigValue(field_name, field_type, ())

return required_secrets, required_config
Expand Down
1 change: 1 addition & 0 deletions dlt/common/configuration/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
DOT_DLT = ".dlt"

from .specs.base_configuration import configspec, is_valid_hint, is_secret_hint # noqa: F401
from .specs import known_sections # noqa: F401
from .resolve import resolve_configuration, inject_section # noqa: F401
from .inject import with_config, last_config, get_fun_spec # noqa: F401
from .utils import make_dot_dlt_path # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion dlt/common/configuration/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, field_name: str, field_value: Any, hint: type) -> None:
self.field_name = field_name
self.field_value = field_value
self.hint = hint
super().__init__('configured value for field %s cannot be coerced into type %s' % (str(hint), field_name))
super().__init__('Configured value for field %s cannot be coerced into type %s' % (field_name, str(hint)))


# class ConfigIntegrityException(ConfigurationException):
Expand Down
60 changes: 49 additions & 11 deletions dlt/common/configuration/inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,70 @@
from dlt.common.typing import DictStrAny, StrAny, TFun, AnyFun
from dlt.common.configuration.resolve import resolve_configuration, inject_section
from dlt.common.configuration.specs.base_configuration import BaseConfiguration
from dlt.common.configuration.specs.config_namespace_context import ConfigSectionContext
from dlt.common.configuration.specs.config_section_context import ConfigSectionContext
from dlt.common.reflection.spec import spec_from_signature


_LAST_DLT_CONFIG = "_dlt_config"
_ORIGINAL_ARGS = "_dlt_orig_args"
TConfiguration = TypeVar("TConfiguration", bound=BaseConfiguration)
# keep a registry of all the decorated functions
_FUNC_SPECS: Dict[int, Type[BaseConfiguration]] = {}

TConfiguration = TypeVar("TConfiguration", bound=BaseConfiguration)


def get_fun_spec(f: AnyFun) -> Type[BaseConfiguration]:
return _FUNC_SPECS.get(id(f))


@overload
def with_config(func: TFun, /, spec: Type[BaseConfiguration] = None, auto_section: bool = False, only_kw: bool = False, sections: Tuple[str, ...] = ()) -> TFun:
def with_config(
func: TFun,
/,
spec: Type[BaseConfiguration] = None,
sections: Tuple[str, ...] = (),
sections_merge_style: ConfigSectionContext.TMergeFunc = ConfigSectionContext.prefer_incoming,
auto_pipeline_section: bool = False,
only_kw: bool = False
) -> TFun:
...


@overload
def with_config(func: None = ..., /, spec: Type[BaseConfiguration] = None, auto_section: bool = False, only_kw: bool = False, sections: Tuple[str, ...] = ()) -> Callable[[TFun], TFun]:
def with_config(
func: None = ...,
/,
spec: Type[BaseConfiguration] = None,
sections: Tuple[str, ...] = (),
sections_merge_style: ConfigSectionContext.TMergeFunc = ConfigSectionContext.prefer_incoming,
auto_pipeline_section: bool = False,
only_kw: bool = False
) -> Callable[[TFun], TFun]:
...


def with_config(func: Optional[AnyFun] = None, /, spec: Type[BaseConfiguration] = None, auto_section: bool = False, only_kw: bool = False, sections: Tuple[str, ...] = ()) -> Callable[[TFun], TFun]:

def with_config(
func: Optional[AnyFun] = None,
/,
spec: Type[BaseConfiguration] = None,
sections: Tuple[str, ...] = (),
sections_merge_style: ConfigSectionContext.TMergeFunc = ConfigSectionContext.prefer_incoming,
auto_pipeline_section: bool = False,
only_kw: bool = False
) -> Callable[[TFun], TFun]:
"""Injects values into decorated function arguments following the specification in `spec` or by deriving one from function's signature.
Args:
func (Optional[AnyFun], optional): A function with arguments to be injected. Defaults to None.
spec (Type[BaseConfiguration], optional): A specification of injectable arguments. Defaults to None.
sections (Tuple[str, ...], optional): A set of config sections in which to look for arguments values. Defaults to ().
prefer_existing_sections: (bool, optional): When joining existing section context, the existing context will be preferred to the one in `sections`. Default: False
auto_pipeline_section (bool, optional): If True, a top level pipeline section will be added if `pipeline_name` argument is present . Defaults to False.
only_kw (bool, optional): If True and `spec` is not provided, one is synthesized from keyword only arguments ignoring any others. Defaults to False.
Returns:
Callable[[TFun], TFun]: A decorated function
"""
section_f: Callable[[StrAny], str] = None
# section may be a function from function arguments to section
if callable(sections):
Expand All @@ -45,7 +81,7 @@ def decorator(f: TFun) -> TFun:
kwargs_arg = next((p for p in sig.parameters.values() if p.kind == Parameter.VAR_KEYWORD), None)
spec_arg: Parameter = None
pipeline_name_arg: Parameter = None
section_context = ConfigSectionContext()
section_context = ConfigSectionContext(sections=sections, merge_style=sections_merge_style)

if spec is None:
SPEC = spec_from_signature(f, sig, only_kw)
Expand All @@ -59,7 +95,7 @@ def decorator(f: TFun) -> TFun:
if p.annotation is SPEC:
# if any argument has type SPEC then us it to take initial value
spec_arg = p
if p.name == "pipeline_name" and auto_section:
if p.name == "pipeline_name" and auto_pipeline_section:
# if argument has name pipeline_name and auto_section is used, use it to generate section context
pipeline_name_arg = p
pipeline_name_arg_default = None if p.default == Parameter.empty else p.default
Expand All @@ -77,18 +113,20 @@ def _wrap(*args: Any, **kwargs: Any) -> Any:
# if section derivation function was provided then call it
nonlocal sections
if section_f:
sections = (section_f(bound_args.arguments), )
section_context.sections = (section_f(bound_args.arguments), )
# sections may be a string
if isinstance(sections, str):
sections = (sections,)
section_context.sections = (sections,)

# if one of arguments is spec the use it as initial value
if spec_arg:
config = bound_args.arguments.get(spec_arg.name, None)
# resolve SPEC, also provide section_context with pipeline_name
if pipeline_name_arg:
section_context.pipeline_name = bound_args.arguments.get(pipeline_name_arg.name, pipeline_name_arg_default)
with inject_section(section_context):
config = resolve_configuration(config or SPEC(), sections=sections, explicit_value=bound_args.arguments)
# print(f"RESOLVE CONF in inject: {f.__name__}: {section_context.sections} vs {sections}")
config = resolve_configuration(config or SPEC(), explicit_value=bound_args.arguments)
resolved_params = dict(config)
bound_args.apply_defaults()
# overwrite or add resolved params
Expand Down
39 changes: 33 additions & 6 deletions dlt/common/configuration/resolve.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import itertools
from collections.abc import Mapping as C_Mapping
from typing import Any, Dict, ContextManager, List, Optional, Sequence, Tuple, Type, TypeVar

from dlt.common.configuration.providers.provider import ConfigProvider
from dlt.common.typing import AnyType, StrAny, TSecretValue, is_final_type, is_optional_type
from dlt.common.typing import AnyType, StrAny, TSecretValue, get_all_types_of_class_in_union, is_final_type, is_optional_type, is_union

from dlt.common.configuration.specs.base_configuration import BaseConfiguration, CredentialsConfiguration, is_secret_hint, extract_inner_hint, is_context_inner_hint, is_base_configuration_inner_hint
from dlt.common.configuration.specs.config_namespace_context import ConfigSectionContext
from dlt.common.configuration.specs.config_section_context import ConfigSectionContext
from dlt.common.configuration.container import Container
from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContext
from dlt.common.configuration.utils import log_traces, deserialize_value
Expand Down Expand Up @@ -37,7 +38,10 @@ def inject_section(section_context: ConfigSectionContext, merge_existing: bool =
Args:
section_context (ConfigSectionContext): Instance providing a pipeline name and section context
merge_existing (bool, optional): Gets `pipeline_name` and `sections` from existing context if they are not provided in `section` argument. Defaults to True.
merge_existing (bool, optional): Merges existing section context with `section_context` in the arguments by executing `merge_style` function on `section_context`. Defaults to True.
Default Merge Style:
Gets `pipeline_name` and `sections` from existing context if they are not provided in `section_context` argument.
Yields:
Iterator[ConfigSectionContext]: Context manager with current section context
Expand All @@ -46,8 +50,7 @@ def inject_section(section_context: ConfigSectionContext, merge_existing: bool =
existing_context = container[ConfigSectionContext]

if merge_existing:
section_context.pipeline_name = section_context.pipeline_name or existing_context.pipeline_name
section_context.sections = section_context.sections or existing_context.sections
section_context.merge(existing_context)

return container.injectable_context(section_context)

Expand Down Expand Up @@ -118,6 +121,7 @@ def _resolve_config_fields(
for key, hint in fields.items():
# get default and explicit values
default_value = getattr(config, key, None)
traces: List[LookupTrace] = []

if explicit_values:
explicit_value = explicit_values.get(key)
Expand All @@ -128,7 +132,30 @@ def _resolve_config_fields(
else:
explicit_value = None

current_value, traces = _resolve_config_field(key, hint, default_value, explicit_value, config, config.__section__, explicit_sections, embedded_sections, accept_partial)
# if hint is union of configurations, any of them must be resolved
specs_in_union: List[Type[BaseConfiguration]] = []
current_value = None
if is_union(hint):
# print(f"HINT UNION?: {key}:{hint}")
specs_in_union = get_all_types_of_class_in_union(hint, BaseConfiguration)
if len(specs_in_union) > 1:
for idx, alt_spec in enumerate(specs_in_union):
# return first resolved config from an union
try:
current_value, traces = _resolve_config_field(key, alt_spec, default_value, explicit_value, config, config.__section__, explicit_sections, embedded_sections, accept_partial)
print(current_value)
break
except ConfigFieldMissingException as cfm_ex:
# add traces from unresolved union spec
# TODO: we should group traces per hint - currently user will see all options tried without the key info
traces.extend(list(itertools.chain(*cfm_ex.traces.values())))
except InvalidNativeValue:
# if none of specs in union parsed
if idx == len(specs_in_union) - 1:
raise
else:
current_value, traces = _resolve_config_field(key, hint, default_value, explicit_value, config, config.__section__, explicit_sections, embedded_sections, accept_partial)

# check if hint optional
is_optional = is_optional_type(hint)
# collect unresolved fields
Expand Down
2 changes: 1 addition & 1 deletion dlt/common/configuration/specs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
from .pool_runner_configuration import PoolRunnerConfiguration, TPoolType # noqa: F401
from .gcp_client_credentials import GcpClientCredentials, GcpClientCredentialsWithDefault # noqa: F401
from .postgres_credentials import PostgresCredentials, RedshiftCredentials, ConnectionStringCredentials # noqa: F401
from .config_namespace_context import ConfigSectionContext # noqa: F401
from .config_section_context import ConfigSectionContext # noqa: F401
15 changes: 11 additions & 4 deletions dlt/common/configuration/specs/base_configuration.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import inspect
import contextlib
import dataclasses
from types import FunctionType
from typing import Callable, List, Optional, Union, Any, Dict, Iterator, MutableMapping, Type, TYPE_CHECKING, get_args, get_origin, overload, ClassVar

if TYPE_CHECKING:
TDtcField = dataclasses.Field[Any]
else:
TDtcField = dataclasses.Field

from dlt.common.typing import TAnyClass, TSecretValue, extract_inner_type, is_optional_type
from dlt.common.typing import TAnyClass, TSecretValue, extract_inner_type, is_optional_type, is_union
from dlt.common.schema.utils import py_type_to_sc_type
from dlt.common.configuration.exceptions import ConfigFieldMissingTypeHintException, ConfigFieldTypeHintNotSupported

Expand All @@ -31,7 +32,7 @@ def is_credentials_inner_hint(inner_hint: Type[Any]) -> bool:


def get_config_if_union_hint(hint: Type[Any]) -> Type[Any]:
if get_origin(hint) is Union:
if is_union(hint):
return next((t for t in get_args(hint) if is_base_configuration_inner_hint(t)), None)
return None

Expand Down Expand Up @@ -82,7 +83,13 @@ def configspec(cls: None = ..., /, *, init: bool = False) -> Callable[[Type[TAny


def configspec(cls: Optional[Type[Any]] = None, /, *, init: bool = False) -> Union[Type[TAnyClass], Callable[[Type[TAnyClass]], Type[TAnyClass]]]:
"""Converts (via derivation) any decorated class to a Python dataclass that may be used as a spec to resolve configurations
In comparison the Python dataclass, a spec implements full dictionary interface for its attributes, allows instance creation from ie. strings
or other types (parsing, deserialization) and control over configuration resolution process. See `BaseConfiguration` and CredentialsConfiguration` for
more information.
"""
def wrap(cls: Type[TAnyClass]) -> Type[TAnyClass]:
is_context = issubclass(cls, _F_ContainerInjectableContext)
# if type does not derive from BaseConfiguration then derive it
Expand All @@ -98,7 +105,7 @@ def wrap(cls: Type[TAnyClass]) -> Type[TAnyClass]:
# get all attributes without corresponding annotations
for att_name, att_value in cls.__dict__.items():
# skip callables, dunder names, class variables and some special names
if not callable(att_value) and not att_name.startswith(("__", "_abc_impl")):
if not callable(att_value) and not att_name.startswith(("__", "_abc_impl")) and not isinstance(att_value, (staticmethod, classmethod)):
if att_name not in cls.__annotations__:
raise ConfigFieldMissingTypeHintException(att_name, cls)
hint = cls.__annotations__[att_name]
Expand All @@ -121,7 +128,7 @@ class BaseConfiguration(MutableMapping[str, Any]):
__is_resolved__: bool = dataclasses.field(default = False, init=False, repr=False)
"""True when all config fields were resolved and have a specified value type"""
__section__: str = dataclasses.field(default = None, init=False, repr=False)
"""Section used by config providers when searching for keys"""
"""Obligatory section used by config providers when searching for keys, always present in the search path"""
__exception__: Exception = dataclasses.field(default = None, init=False, repr=False)
"""Holds the exception that prevented the full resolution"""
__config_gen_annotations__: ClassVar[List[str]] = None
Expand Down
14 changes: 0 additions & 14 deletions dlt/common/configuration/specs/config_namespace_context.py

This file was deleted.

57 changes: 57 additions & 0 deletions dlt/common/configuration/specs/config_section_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Callable, List, Optional, Tuple, TYPE_CHECKING
from dlt.common.configuration.specs import known_sections

from dlt.common.configuration.specs.base_configuration import ContainerInjectableContext, configspec

@configspec(init=True)
class ConfigSectionContext(ContainerInjectableContext):

TMergeFunc = Callable[["ConfigSectionContext", "ConfigSectionContext"], None]

pipeline_name: Optional[str]
sections: Tuple[str, ...] = ()
merge_style: TMergeFunc = None


def merge(self, existing: "ConfigSectionContext") -> None:
"""Merges existing context into incoming using a merge style function"""
merge_style_f = self.merge_style or self.prefer_incoming
merge_style_f(self, existing)

def source_name(self) -> str:
"""Gets name of a source from `sections`"""
if self.sections and len(self.sections) == 3 and self.sections[0] == known_sections.SOURCES:
return self.sections[-1]
raise ValueError(self.sections)

def source_section(self) -> str:
"""Gets section of a source from `sections`"""
if self.sections and len(self.sections) == 3 and self.sections[0] == known_sections.SOURCES:
return self.sections[1]
raise ValueError(self.sections)

@staticmethod
def prefer_incoming(incoming: "ConfigSectionContext", existing: "ConfigSectionContext") -> None:
incoming.pipeline_name = incoming.pipeline_name or existing.pipeline_name
incoming.sections = incoming.sections or existing.sections

@staticmethod
def prefer_existing(incoming: "ConfigSectionContext", existing: "ConfigSectionContext") -> None:
"""Prefer existing section context when merging this context before injecting"""
incoming.pipeline_name = existing.pipeline_name or incoming.pipeline_name
incoming.sections = existing.sections or incoming.sections

@staticmethod
def resource_merge_style(incoming: "ConfigSectionContext", existing: "ConfigSectionContext") -> None:
"""If top level section is same and there are 3 sections it replaces second element (source module) from existing and keeps the 3rd element (name)"""
incoming.pipeline_name = incoming.pipeline_name or existing.pipeline_name
if len(incoming.sections) == 3 == len(existing.sections) and incoming.sections[0] == existing.sections[0]:
incoming.sections = (incoming.sections[0], existing.sections[1], incoming.sections[2])
else:
incoming.sections = incoming.sections or existing.sections


if TYPE_CHECKING:
# provide __init__ signature when type checking
def __init__(self, pipeline_name:str = None, sections: Tuple[str, ...] = (), merge_style: TMergeFunc = None) -> None:
...
Loading

0 comments on commit 6336096

Please sign in to comment.