Skip to content

Commit

Permalink
feat: associate name with dependency and bubble up to component in te…
Browse files Browse the repository at this point in the history
…rse definitions
  • Loading branch information
z3z1ma committed Aug 19, 2024
1 parent e42086e commit c1df14b
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 12 deletions.
6 changes: 5 additions & 1 deletion src/cdf/core/component/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,18 @@ def __call__(self) -> T:
@classmethod
def _parse_func(cls, data: t.Any) -> t.Any:
"""Parse node metadata."""
if inspect.isfunction(data):
if inspect.isfunction(data) or isinstance(data, injector.Dependency):
data = {"main": data}
if isinstance(data, dict):
dep = data["main"]
if isinstance(dep, dict):
func = dep["factory"]
if dep.get("alias"):
data.setdefault("name", dep["alias"])
elif isinstance(dep, injector.Dependency):
func = dep.factory
if dep.alias:
data.setdefault("name", dep.alias)
else:
func = dep
return {**_parse_metadata_from_callable(func), **data}
Expand Down
25 changes: 23 additions & 2 deletions src/cdf/core/injector/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ class Dependency(pydantic.BaseModel, t.Generic[T]):

conf_spec: t.Optional[t.Union[t.Tuple[str, ...], t.Dict[str, str]]] = None
"""A hint for configuration values."""
alias: t.Optional[str] = None
"""Used as an alternative to inferring the name from the factory."""

_instance: t.Optional[T] = None
"""The instance of the dependency once resolved."""
Expand Down Expand Up @@ -450,7 +452,17 @@ def try_infer_type(self) -> t.Optional[t.Type[T]]:
if self._is_resolved:
return _unwrap_type(type(self._instance))

def generate_key(self, name: DependencyKey) -> t.Union[str, TypedKey]:
def try_infer_name(self) -> t.Optional[str]:
"""Infer the name of the dependency from the factory."""
if self.alias:
return self.alias
if inspect.isfunction(self.factory):
return self.factory.__name__
if inspect.isclass(self.factory):
return self.factory.__name__
return getattr(self.factory, "name", None)

def generate_key(self, name: t.Optional[DependencyKey]) -> t.Union[str, TypedKey]:
"""Generate a typed key for the dependency.
Args:
Expand All @@ -459,6 +471,12 @@ def generate_key(self, name: DependencyKey) -> t.Union[str, TypedKey]:
Returns:
A typed key if the type can be inferred, else the name.
"""
if not name:
name = self.try_infer_name()
if not name:
raise ValueError(
"Cannot infer name for dependency and no name or alias provided"
)
if isinstance(name, TypedKey):
return name
elif isinstance(name, tuple):
Expand Down Expand Up @@ -548,7 +566,10 @@ def add(
add_instance = partialmethod(add, lifecycle=Lifecycle.INSTANCE)

def add_from_dependency(
self, key: DependencyKey, dependency: Dependency, override: bool = False
self,
dependency: Dependency,
key: t.Optional[DependencyKey] = None,
override: bool = False,
) -> None:
"""Add a Dependency object to the container.
Expand Down
16 changes: 8 additions & 8 deletions src/cdf/core/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,27 +82,27 @@ def _setup(self) -> Self:
self.conf_resolver.import_source(source)
self.conf_resolver.set_environment(self.environment)
self.container.add_from_dependency(
"cdf_workspace",
injector.Dependency.instance(self),
key="cdf_workspace",
override=True,
)
self.container.add_from_dependency(
"cdf_environment",
injector.Dependency.instance(self.environment),
key="cdf_environment",
override=True,
)
self.container.add_from_dependency(
"cdf_config",
injector.Dependency.instance(self.conf_resolver),
key="cdf_config",
override=True,
)
self.container.add_from_dependency(
"cdf_transform",
injector.Dependency.singleton(self.get_sqlmesh_context_or_raise),
key="cdf_transform",
override=True,
)
for service in self.services.values():
self.container.add_from_dependency(service.name, service.main)
self.container.add_from_dependency(service.main, key=service.name)
self.activate()
return self

Expand Down Expand Up @@ -396,9 +396,8 @@ def run():

return pipeline, run, []

# Switch statement on environment
# to scaffold a FF provider, which is hereforward dictated by the user
# instead of implicit?
def ff_provider():
return 1

# Define a workspace
datateam = Workspace(
Expand Down Expand Up @@ -441,6 +440,7 @@ def run():
),
owner="RevOps",
),
injector.Dependency[int](factory=ff_provider, alias="ff_main"),
],
pipeline_definitions=[
cmp.DataPipeline(
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def c(secret_number: int, sfdc: str) -> int:
return secret_number * 10

# Imperatively add dependencies or config if needed
datateam.container.add_from_dependency("c", injector.Dependency.prototype(c))
datateam.container.add_from_dependency(injector.Dependency.prototype(c))
datateam.conf_resolver.import_source({"a.b.c": 10})

def source_a(a: int, prod_bigquery: str):
Expand Down

0 comments on commit c1df14b

Please sign in to comment.