Skip to content

Commit

Permalink
chore: attempt to resolve conf or deps once and make converter applic…
Browse files Browse the repository at this point in the history
…ation safe
  • Loading branch information
z3z1ma committed Aug 19, 2024
1 parent 94171d7 commit cb259a6
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 11 deletions.
4 changes: 2 additions & 2 deletions src/cdf/core/component/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,11 @@ def _parse_metadata(cls, data: t.Any) -> t.Any:

@pydantic.field_validator("main", mode="before")
@classmethod
def _ensure_dependency(cls, value: t.Any) -> t.Any:
def _ensure_dependency(cls, value: t.Any, info: pydantic.ValidationInfo) -> t.Any:
"""Ensure the main function is a dependency."""
value = _unwrap_entrypoint(value)
if isinstance(value, (dict, injector.Dependency)):
parsed_dep = injector.Dependency.model_validate(value)
parsed_dep = injector.Dependency.model_validate(value, context=info.context)
else:
parsed_dep = injector.Dependency.wrap(value)
# NOTE: We do this extra round-trip to bypass the unecessary Generic type check in pydantic
Expand Down
26 changes: 22 additions & 4 deletions src/cdf/core/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def bar(key: str, _cdf_resolve={"key": "api.key"}) -> None:
import string
import typing as t
from collections import ChainMap
from contextlib import suppress
from pathlib import Path

import pydantic
Expand Down Expand Up @@ -480,23 +481,32 @@ def resolve_defaults(self, func_or_cls: t.Callable[P, T]) -> t.Callable[..., T]:
return func_or_cls

sig = inspect.signature(func_or_cls)
is_resolved_sentinel = "__config_resolved__"

resolver_hint = getattr(
inspect.unwrap(func_or_cls),
RESOLVER_HINT,
self._parse_hint_from_params(func_or_cls, sig),
)

if any(hasattr(f, is_resolved_sentinel) for f in _iter_wrapped(func_or_cls)):
return func_or_cls

@functools.wraps(func_or_cls)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
bound_args = sig.bind_partial(*args, **kwargs)
bound_args.apply_defaults()

# Apply converters to string literal arguments
for arg_name, arg_value in bound_args.arguments.items():
# The simplest case: a string argument
if isinstance(arg_value, str):
bound_args.arguments[arg_name] = self.apply_converters(
arg_value, self
)
with suppress(Exception):
bound_args.arguments[arg_name] = self.apply_converters(
arg_value,
self,
)

# Resolve configuration values
for name, param in sig.parameters.items():
value = _MISSING
if not self.is_resolvable(param):
Expand All @@ -522,6 +532,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:

return func_or_cls(*bound_args.args, **bound_args.kwargs)

setattr(wrapper, is_resolved_sentinel, True)
return wrapper

def is_resolvable(self, param: inspect.Parameter) -> bool:
Expand Down Expand Up @@ -552,3 +563,10 @@ def __get_pydantic_core_schema__(
keys_schema=pydantic_core.core_schema.str_schema(),
values_schema=pydantic_core.core_schema.any_schema(),
)


def _iter_wrapped(f: t.Callable):
yield f
f_w = inspect.unwrap(f)
if f_w is not f:
yield from _iter_wrapped(f_w)
12 changes: 12 additions & 0 deletions src/cdf/core/injector/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,10 @@ def wire(self, func_or_cls: t.Callable[P, T]) -> t.Callable[..., T]:
return func_or_cls

sig = inspect.signature(func_or_cls)
is_resolved_sentinel = "__deps_resolved__"

if any(hasattr(f, is_resolved_sentinel) for f in _iter_wrapped(func_or_cls)):
return func_or_cls

@wraps(func_or_cls)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
Expand All @@ -701,6 +705,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
bound_args.arguments[name] = dep
return func_or_cls(*bound_args.args, **bound_args.kwargs)

setattr(wrapper, is_resolved_sentinel, True)
return wrapper

def __call__(
Expand Down Expand Up @@ -786,5 +791,12 @@ def __get_pydantic_core_schema__(
)


def _iter_wrapped(f: t.Callable):
yield f
f_w = inspect.unwrap(f)
if f_w is not f:
yield from _iter_wrapped(f_w)


GLOBAL_REGISTRY = DependencyRegistry()
"""A global dependency registry."""
12 changes: 7 additions & 5 deletions src/cdf/core/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,17 @@ def run_pipeline(

# Run the pipeline
start = time.time()
click.echo((info := pipeline_definition()) or "No load info returned.")
jobs = pipeline_definition()
click.echo(
f"Pipeline process finished in {time.time() - start:.2f} seconds.",
err=True,
)

# Check for failed jobs
if info and info.has_failed_jobs:
ctx.fail("Pipeline failed.")
for job in jobs:
if job.has_failed_jobs:
ctx.fail("Pipeline failed.")

ctx.exit(0)

@cli.command("run-publisher")
Expand Down Expand Up @@ -389,7 +391,7 @@ def test_pipeline(

def run():
print("Running pipeline")
load_info = pipeline.run(source_a())
load = pipeline.run(source_a())
print("Pipeline finished")
with pipeline.sql_client() as client:
print("Querying DuckDB in " + cdf_environment)
Expand All @@ -398,7 +400,7 @@ def run():
"SELECT * FROM some_pipeline_dataset.test_resource"
)
)
return load_info
return load

return pipeline, run

Expand Down

0 comments on commit cb259a6

Please sign in to comment.