diff --git a/src/cdf/core/component/base.py b/src/cdf/core/component/base.py index 3412ecc..1f8ffe6 100644 --- a/src/cdf/core/component/base.py +++ b/src/cdf/core/component/base.py @@ -208,11 +208,11 @@ def _parse_metadata(cls, data: t.Any) -> t.Any: @pydantic.field_validator("main", mode="before") @classmethod - def _ensure_dependency(cls, value: t.Any) -> t.Any: + def _ensure_dependency(cls, value: t.Any, info: pydantic.ValidationInfo) -> t.Any: """Ensure the main function is a dependency.""" value = _unwrap_entrypoint(value) if isinstance(value, (dict, injector.Dependency)): - parsed_dep = injector.Dependency.model_validate(value) + parsed_dep = injector.Dependency.model_validate(value, context=info.context) else: parsed_dep = injector.Dependency.wrap(value) # NOTE: We do this extra round-trip to bypass the unecessary Generic type check in pydantic diff --git a/src/cdf/core/configuration.py b/src/cdf/core/configuration.py index 8896fc9..386a7f0 100644 --- a/src/cdf/core/configuration.py +++ b/src/cdf/core/configuration.py @@ -59,6 +59,7 @@ def bar(key: str, _cdf_resolve={"key": "api.key"}) -> None: import string import typing as t from collections import ChainMap +from contextlib import suppress from pathlib import Path import pydantic @@ -480,6 +481,7 @@ def resolve_defaults(self, func_or_cls: t.Callable[P, T]) -> t.Callable[..., T]: return func_or_cls sig = inspect.signature(func_or_cls) + is_resolved_sentinel = "__config_resolved__" resolver_hint = getattr( inspect.unwrap(func_or_cls), @@ -487,16 +489,24 @@ def resolve_defaults(self, func_or_cls: t.Callable[P, T]) -> t.Callable[..., T]: self._parse_hint_from_params(func_or_cls, sig), ) + if any(hasattr(f, is_resolved_sentinel) for f in _iter_wrapped(func_or_cls)): + return func_or_cls + @functools.wraps(func_or_cls) def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: bound_args = sig.bind_partial(*args, **kwargs) bound_args.apply_defaults() + + # Apply converters to string literal arguments for arg_name, arg_value in bound_args.arguments.items(): - # The simplest case: a string argument if isinstance(arg_value, str): - bound_args.arguments[arg_name] = self.apply_converters( - arg_value, self - ) + with suppress(Exception): + bound_args.arguments[arg_name] = self.apply_converters( + arg_value, + self, + ) + + # Resolve configuration values for name, param in sig.parameters.items(): value = _MISSING if not self.is_resolvable(param): @@ -522,6 +532,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: return func_or_cls(*bound_args.args, **bound_args.kwargs) + setattr(wrapper, is_resolved_sentinel, True) return wrapper def is_resolvable(self, param: inspect.Parameter) -> bool: @@ -552,3 +563,10 @@ def __get_pydantic_core_schema__( keys_schema=pydantic_core.core_schema.str_schema(), values_schema=pydantic_core.core_schema.any_schema(), ) + + +def _iter_wrapped(f: t.Callable): + yield f + f_w = inspect.unwrap(f) + if f_w is not f: + yield from _iter_wrapped(f_w) diff --git a/src/cdf/core/injector/registry.py b/src/cdf/core/injector/registry.py index 737c088..8801e61 100644 --- a/src/cdf/core/injector/registry.py +++ b/src/cdf/core/injector/registry.py @@ -679,6 +679,10 @@ def wire(self, func_or_cls: t.Callable[P, T]) -> t.Callable[..., T]: return func_or_cls sig = inspect.signature(func_or_cls) + is_resolved_sentinel = "__deps_resolved__" + + if any(hasattr(f, is_resolved_sentinel) for f in _iter_wrapped(func_or_cls)): + return func_or_cls @wraps(func_or_cls) def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: @@ -701,6 +705,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: bound_args.arguments[name] = dep return func_or_cls(*bound_args.args, **bound_args.kwargs) + setattr(wrapper, is_resolved_sentinel, True) return wrapper def __call__( @@ -786,5 +791,12 @@ def __get_pydantic_core_schema__( ) +def _iter_wrapped(f: t.Callable): + yield f + f_w = inspect.unwrap(f) + if f_w is not f: + yield from _iter_wrapped(f_w) + + GLOBAL_REGISTRY = DependencyRegistry() """A global dependency registry.""" diff --git a/src/cdf/core/workspace.py b/src/cdf/core/workspace.py index 2cd4b32..2f65fef 100644 --- a/src/cdf/core/workspace.py +++ b/src/cdf/core/workspace.py @@ -251,15 +251,17 @@ def run_pipeline( # Run the pipeline start = time.time() - click.echo((info := pipeline_definition()) or "No load info returned.") + jobs = pipeline_definition() click.echo( f"Pipeline process finished in {time.time() - start:.2f} seconds.", err=True, ) # Check for failed jobs - if info and info.has_failed_jobs: - ctx.fail("Pipeline failed.") + for job in jobs: + if job.has_failed_jobs: + ctx.fail("Pipeline failed.") + ctx.exit(0) @cli.command("run-publisher") @@ -389,7 +391,7 @@ def test_pipeline( def run(): print("Running pipeline") - load_info = pipeline.run(source_a()) + load = pipeline.run(source_a()) print("Pipeline finished") with pipeline.sql_client() as client: print("Querying DuckDB in " + cdf_environment) @@ -398,7 +400,7 @@ def run(): "SELECT * FROM some_pipeline_dataset.test_resource" ) ) - return load_info + return load return pipeline, run