Skip to content

Commit

Permalink
chore: small touch up to default lifecycle derivation
Browse files Browse the repository at this point in the history
  • Loading branch information
z3z1ma committed Sep 1, 2024
1 parent 3d0c8bd commit 5d77103
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 27 deletions.
6 changes: 4 additions & 2 deletions src/cdf/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,11 @@ def invoke(func_or_cls: t.Callable, *args: t.Any, **kwargs: t.Any) -> t.Any:
return workspace.invoke(func_or_cls, *args, **kwargs)


def get_default_callable_lifecycle() -> t.Optional["Lifecycle"]:
def get_default_callable_lifecycle() -> "Lifecycle":
"""Get the default lifecycle for callables when otherwise unspecified."""
return _DEFAULT_CALLABLE_LIFECYCLE.get()
from cdf.core.injector import Lifecycle

return _DEFAULT_CALLABLE_LIFECYCLE.get() or Lifecycle.SINGLETON


def set_default_callable_lifecycle(lifecycle: t.Optional["Lifecycle"]) -> Token:
Expand Down
33 changes: 12 additions & 21 deletions src/cdf/core/injector/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing_extensions import ParamSpec, Self

import cdf.core.configuration as conf
from cdf.core.context import get_default_callable_lifecycle
from cdf.core.injector.errors import DependencyCycleError, DependencyMutationError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -66,6 +67,13 @@ def is_deferred(self) -> bool:
def __str__(self) -> str:
return self.name.lower()

@classmethod
def default_for(cls, obj: t.Any) -> "Lifecycle":
"""Get the default lifecycle."""
if callable(obj):
return get_default_callable_lifecycle()
return cls.INSTANCE


class TypedKey(t.NamedTuple):
"""A key which is a tuple of a name and a type."""
Expand Down Expand Up @@ -263,17 +271,10 @@ def _apply_spec(self) -> Self:
@classmethod
def _ensure_lifecycle(cls, data: t.Any) -> t.Any:
"""Ensure a valid lifecycle is set for the dependency."""
from cdf.core.context import get_default_callable_lifecycle

if isinstance(data, dict):
factory = data["factory"]
default_callable_lc = (
get_default_callable_lifecycle() or Lifecycle.SINGLETON
)
lc = data.get(
"lifecycle",
default_callable_lc if callable(factory) else Lifecycle.INSTANCE,
)
lc = data.get("lifecycle", Lifecycle.default_for(factory))
if isinstance(lc, str):
lc = Lifecycle[lc.upper()]
if not isinstance(lc, Lifecycle):
Expand Down Expand Up @@ -360,14 +361,9 @@ def wrap(cls, obj: t.Any, *args: t.Any, **kwargs: t.Any) -> Self:
A new Dependency object with the object as the factory.
"""
if callable(obj):
from cdf.core.context import get_default_callable_lifecycle

if args or kwargs:
obj = partial(obj, *args, **kwargs)
default_callable_lc = (
get_default_callable_lifecycle() or Lifecycle.SINGLETON
)
return cls(factory=obj, lifecycle=default_callable_lc)
return cls(factory=obj, lifecycle=get_default_callable_lifecycle())
return cls(factory=obj, lifecycle=Lifecycle.INSTANCE)

def map_value(self, func: t.Callable[[T], T]) -> Self:
Expand Down Expand Up @@ -541,12 +537,7 @@ def add(

# Assume singleton lifecycle if the value is callable unless set in context
if lifecycle is None:
from cdf.core.context import get_default_callable_lifecycle

default_callable_lc = (
get_default_callable_lifecycle() or Lifecycle.SINGLETON
)
lifecycle = default_callable_lc if callable(value) else Lifecycle.INSTANCE
lifecycle = Lifecycle.default_for(value)

# If the value is callable and has initialization args, bind them early so
# we don't need to schlepp them around
Expand Down Expand Up @@ -762,7 +753,7 @@ def __len__(self) -> int:
return len(self.dependencies)

def __repr__(self) -> str:
return f"<DependencyRegistry {self.dependencies.keys()}>"
return f"DependencyRegistry(<{list(self.dependencies.keys())}>)"

def __str__(self) -> str:
return repr(self)
Expand Down
8 changes: 4 additions & 4 deletions src/cdf/core/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,16 +145,16 @@ def operations(self) -> t.Dict[str, cmp.Operation]:
@t.overload
def get_sqlmesh_context(
self,
gateway: t.Optional[str] = ...,
must_exist: t.Literal[False] = False,
gateway: t.Optional[str],
must_exist: t.Literal[False],
**kwargs: t.Any,
) -> t.Optional["sqlmesh.Context"]: ...

@t.overload
def get_sqlmesh_context(
self,
gateway: t.Optional[str] = ...,
must_exist: t.Literal[True] = True,
gateway: t.Optional[str],
must_exist: t.Literal[True],
**kwargs: t.Any,
) -> "sqlmesh.Context": ...

Expand Down

0 comments on commit 5d77103

Please sign in to comment.