diff --git a/src/cdf/core/component/base.py b/src/cdf/core/component/base.py index d7b9a47..fb9a40a 100644 --- a/src/cdf/core/component/base.py +++ b/src/cdf/core/component/base.py @@ -195,14 +195,18 @@ def __call__(self) -> T: @classmethod def _parse_func(cls, data: t.Any) -> t.Any: """Parse node metadata.""" - if inspect.isfunction(data): + if inspect.isfunction(data) or isinstance(data, injector.Dependency): data = {"main": data} if isinstance(data, dict): dep = data["main"] if isinstance(dep, dict): func = dep["factory"] + if dep.get("alias"): + data.setdefault("name", dep["alias"]) elif isinstance(dep, injector.Dependency): func = dep.factory + if dep.alias: + data.setdefault("name", dep.alias) else: func = dep return {**_parse_metadata_from_callable(func), **data} diff --git a/src/cdf/core/injector/registry.py b/src/cdf/core/injector/registry.py index 8801e61..429c8b2 100644 --- a/src/cdf/core/injector/registry.py +++ b/src/cdf/core/injector/registry.py @@ -241,6 +241,8 @@ class Dependency(pydantic.BaseModel, t.Generic[T]): conf_spec: t.Optional[t.Union[t.Tuple[str, ...], t.Dict[str, str]]] = None """A hint for configuration values.""" + alias: t.Optional[str] = None + """Used as an alternative to inferring the name from the factory.""" _instance: t.Optional[T] = None """The instance of the dependency once resolved.""" @@ -450,7 +452,17 @@ def try_infer_type(self) -> t.Optional[t.Type[T]]: if self._is_resolved: return _unwrap_type(type(self._instance)) - def generate_key(self, name: DependencyKey) -> t.Union[str, TypedKey]: + def try_infer_name(self) -> t.Optional[str]: + """Infer the name of the dependency from the factory.""" + if self.alias: + return self.alias + if inspect.isfunction(self.factory): + return self.factory.__name__ + if inspect.isclass(self.factory): + return self.factory.__name__ + return getattr(self.factory, "name", None) + + def generate_key(self, name: t.Optional[DependencyKey]) -> t.Union[str, TypedKey]: """Generate a typed key for the dependency. Args: @@ -459,6 +471,12 @@ def generate_key(self, name: DependencyKey) -> t.Union[str, TypedKey]: Returns: A typed key if the type can be inferred, else the name. """ + if not name: + name = self.try_infer_name() + if not name: + raise ValueError( + "Cannot infer name for dependency and no name or alias provided" + ) if isinstance(name, TypedKey): return name elif isinstance(name, tuple): @@ -548,7 +566,10 @@ def add( add_instance = partialmethod(add, lifecycle=Lifecycle.INSTANCE) def add_from_dependency( - self, key: DependencyKey, dependency: Dependency, override: bool = False + self, + dependency: Dependency, + key: t.Optional[DependencyKey] = None, + override: bool = False, ) -> None: """Add a Dependency object to the container. diff --git a/src/cdf/core/workspace.py b/src/cdf/core/workspace.py index 60425ea..ea83968 100644 --- a/src/cdf/core/workspace.py +++ b/src/cdf/core/workspace.py @@ -82,27 +82,27 @@ def _setup(self) -> Self: self.conf_resolver.import_source(source) self.conf_resolver.set_environment(self.environment) self.container.add_from_dependency( - "cdf_workspace", injector.Dependency.instance(self), + key="cdf_workspace", override=True, ) self.container.add_from_dependency( - "cdf_environment", injector.Dependency.instance(self.environment), + key="cdf_environment", override=True, ) self.container.add_from_dependency( - "cdf_config", injector.Dependency.instance(self.conf_resolver), + key="cdf_config", override=True, ) self.container.add_from_dependency( - "cdf_transform", injector.Dependency.singleton(self.get_sqlmesh_context_or_raise), + key="cdf_transform", override=True, ) for service in self.services.values(): - self.container.add_from_dependency(service.name, service.main) + self.container.add_from_dependency(service.main, key=service.name) self.activate() return self @@ -396,9 +396,8 @@ def run(): return pipeline, run, [] - # Switch statement on environment - # to scaffold a FF provider, which is hereforward dictated by the user - # instead of implicit? + def ff_provider(): + return 1 # Define a workspace datateam = Workspace( @@ -441,6 +440,7 @@ def run(): ), owner="RevOps", ), + injector.Dependency[int](factory=ff_provider, alias="ff_main"), ], pipeline_definitions=[ cmp.DataPipeline( diff --git a/tests/core/test_workspace.py b/tests/core/test_workspace.py index 24fbe75..949f704 100644 --- a/tests/core/test_workspace.py +++ b/tests/core/test_workspace.py @@ -61,7 +61,7 @@ def c(secret_number: int, sfdc: str) -> int: return secret_number * 10 # Imperatively add dependencies or config if needed - datateam.container.add_from_dependency("c", injector.Dependency.prototype(c)) + datateam.container.add_from_dependency(injector.Dependency.prototype(c)) datateam.conf_resolver.import_source({"a.b.c": 10}) def source_a(a: int, prod_bigquery: str):