Skip to content

Commit

Permalink
chore: more polish of injector interface
Browse files Browse the repository at this point in the history
  • Loading branch information
z3z1ma committed Jul 17, 2024
1 parent fd3383b commit 628c426
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 32 deletions.
4 changes: 2 additions & 2 deletions src/cdf/injector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from cdf.injector.errors import SetChildConfigError as SetChildConfigError
from cdf.injector.specs import Forward as Forward
from cdf.injector.specs import GlobalInput as GlobalInput
from cdf.injector.specs import Instance as Instance
from cdf.injector.specs import LocalInput as LocalInput
from cdf.injector.specs import Object as Object
from cdf.injector.specs import Prototype as Prototype
from cdf.injector.specs import PrototypeMixin as PrototypeMixin
from cdf.injector.specs import Singleton as Singleton
Expand All @@ -38,7 +38,7 @@
"Forward",
"GlobalInput",
"LocalInput",
"Object",
"Instance",
"Prototype",
"PrototypeMixin",
"Singleton",
Expand Down
10 changes: 5 additions & 5 deletions src/cdf/injector/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class ConfigSpec(injector_specs.Spec[TC]):
"local_inputs",
]

def __init__(self, cls: type[TC], **local_inputs: t.Any) -> None:
def __init__(self, cls: t.Type[TC], **local_inputs: t.Any) -> None:
super().__init__()
self.cls = cls
self.local_inputs = local_inputs
Expand Down Expand Up @@ -144,7 +144,7 @@ def _process_input(
spec: injector_specs._Input,
inputs: dict[str, t.Any],
desc: str,
) -> injector_specs._Object:
) -> injector_specs._Instance:
"""Convert Input spec to Object spec."""
try:
value = inputs[key]
Expand All @@ -159,7 +159,7 @@ def _process_input(
injector_utils.check_type(value, spec.type_, desc=desc)

# Preserve old spec id
return injector_specs._Object(value, spec_id=spec.spec_id)
return injector_specs._Instance(value, spec_id=spec.spec_id)

def _load(self, **local_inputs: t.Any) -> None:
"""Transfer class variables to instance."""
Expand Down Expand Up @@ -276,7 +276,7 @@ def __setattr__(self, key: str, value: t.Any) -> None:

# Automatically wrap input if user hasn't done so
if not isinstance(value, injector_specs.Spec):
value = injector_specs.Object(value)
value = injector_specs.Instance(value)

self._specs[key] = value

Expand Down Expand Up @@ -316,6 +316,6 @@ def get(self, config_spec: ConfigSpec) -> Config:
return t.cast(Config, config)


def get_config(config_cls: type[TC], **global_inputs: t.Any) -> TC:
def get_config(config_cls: t.Type[TC], **global_inputs: t.Any) -> TC:
"""More type-safe alternative to getting config objs."""
return config_cls().get(**global_inputs) # type: ignore[no-any-return]
4 changes: 2 additions & 2 deletions src/cdf/injector/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _process_arg_spec(
elif isinstance(arg, injector_specs._Callable):
# Anonymous prototype or singleton
result = self._materialize_callable_spec(config, arg).instantiate()
elif isinstance(arg, injector_specs._Object):
elif isinstance(arg, injector_specs._Instance):
return arg.obj
else:
for child_config in config._child_configs.values():
Expand Down Expand Up @@ -118,7 +118,7 @@ def _materialize_callable_spec(
def _get(self, config: injector_config.Config, key: str) -> t.Any:
"""Get instance represented by key in given config."""
spec = getattr(config, key)
if isinstance(spec, injector_specs._Object):
if isinstance(spec, injector_specs._Instance):
return spec.obj
elif isinstance(spec, injector_specs._Singleton):
try:
Expand Down
46 changes: 26 additions & 20 deletions src/cdf/injector/specs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""cdf.di specs.
"""cdf.injector specs.
NB: The cdf.di.{Object,Singleton,...} functions follow the same
NB: The cdf.injector.{Object,Singleton,...} functions follow the same
pattern as dataclasses.field() vs dataclasses.Field:
in order for typing to work for the user, we have dummy functions
that mimic expected typing behavior.
Expand Down Expand Up @@ -41,7 +41,7 @@ def instantiate(cls: type[T], *args: t.Any, **kwargs: t.Any) -> T:
class AttrFuture:
"""Future representing attr access on a Spec by its spec id."""

def __init__(self, root_spec_id: SpecID, attrs: list[str]) -> None:
def __init__(self, root_spec_id: SpecID, attrs: t.List[str]) -> None:
self.root_spec_id = root_spec_id
self.attrs = attrs

Expand Down Expand Up @@ -93,7 +93,7 @@ def _get_next_spec_id(cls) -> SpecID:
return result


class _Object(Spec[T]):
class _Instance(Spec[T]):
"""Represents fully-instantiated object to pass through."""

_INTERNAL_FIELDS = Spec._INTERNAL_FIELDS + ["obj"]
Expand All @@ -103,14 +103,14 @@ def __init__(self, obj: T, spec_id: SpecID | None = None) -> None:
self.obj = obj


def Object(obj: T) -> T: # noqa: N802
def Instance(obj: T) -> T: # noqa: N802
"""Spec to pass through a fully-instantiated object.
Args:
obj: Fully-instantiated object to pass through.
"""
# Cast because the return type will act like a T
return t.cast(T, _Object(obj))
return t.cast(T, _Instance(obj))


class _Input(Spec[T]):
Expand All @@ -135,6 +135,9 @@ def GlobalInput( # noqa: N802
) -> T:
"""Spec to use user input passed in at config instantiation.
This is to say, when the config is instantiated via get_config or through
the type contructor, you may pass in a value to override the default.
Args:
type_: Expected type of input, for both static and runtime check.
default: Default value if no input is provided.
Expand All @@ -154,6 +157,9 @@ def LocalInput( # noqa: N802
) -> T:
"""Spec to use user input passed in at config declaration.
This is to say, whenever the config is declared as a class field,
you may pass in a value to override the default.
Args:
type_: Expected type of input, for both static and runtime check.
default: Default value if no input is provided.
Expand Down Expand Up @@ -240,53 +246,53 @@ def Singleton( # noqa: N802
return t.cast(T, _Singleton(func_or_type, *args, **kwargs))


def SingletonTuple(*args: T) -> tuple[T]: # noqa: N802
def SingletonTuple(*args: T) -> t.Tuple[T]: # noqa: N802
"""Spec to create tuple with args and caching per config field."""
# Cast because the return type will act like a tuple of T
return t.cast("tuple[T]", _Singleton(tuple, args))
return t.cast("t.Tuple[T]", _Singleton(tuple, args))


def SingletonList(*args: T) -> list[T]: # noqa: N802
def SingletonList(*args: T) -> t.List[T]: # noqa: N802
"""Spec to create list with args and caching per config field."""
# Cast because the return type will act like a list of T
return t.cast("list[T]", _Singleton(list, args))
return t.cast("t.List[T]", _Singleton(list, args))


def SingletonDict( # noqa: N802
values: dict[t.Any, T] = MISSING_DICT, # noqa
values: t.Dict[t.Any, T] = MISSING_DICT, # noqa
/,
**kwargs: T,
) -> dict[t.Any, T]:
) -> t.Dict[t.Any, T]:
"""Spec to create dict with args and caching per config field.
Can specify either by pointing to a dict, passing in kwargs,
or unioning both.
>>> import cdf.di
>>> spec0 = cdf.di.Object(1); spec1 = cdf.di.Object(2)
>>> cdf.di.SingletonDict({"x": spec0, "y": spec1}) is not None
>>> import cdf.injector
>>> spec0 = cdf.injector.Object(1); spec1 = cdf.injector.Object(2)
>>> cdf.injector.SingletonDict({"x": spec0, "y": spec1}) is not None
True
Or, alternatively:
>>> cdf.di.SingletonDict(x=spec0, y=spec1) is not None
>>> cdf.injector.SingletonDict(x=spec0, y=spec1) is not None
True
"""
if values is MISSING_DICT:
# Cast because the return type will act like a dict of T
return t.cast("dict[t.Any, T]", _Singleton(dict, **kwargs))
return t.cast("t.Dict[t.Any, T]", _Singleton(dict, **kwargs))
else:
# Cast because the return type will act like a dict of T
return t.cast(
"dict[t.Any, T]",
"t.Dict[t.Any, T]",
_Singleton(_union_dict_and_kwargs, values, **kwargs),
)


class PrototypeMixin:
"""Helper class for Prototype to ease syntax in Config.
Equivalent to cdf.di.Prototype(cls, ...).
Equivalent to cdf.injector.Prototype(cls, ...).
"""

def __new__(
Expand All @@ -301,7 +307,7 @@ def __new__(
class SingletonMixin:
"""Helper class for Singleton to ease syntax in Config.
Equivalent to cdf.di.Singleton(cls, ...).
Equivalent to cdf.injector.Singleton(cls, ...).
"""

def __new__(
Expand Down
6 changes: 3 additions & 3 deletions src/cdf/injector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from . import errors as injector_errors

PRIMITIVE_TYPES: t.Final[tuple[type, ...]] = (
PRIMITIVE_TYPES: t.Final[t.Tuple[t.Type, ...]] = (
type(None),
bool,
int,
Expand All @@ -27,9 +27,9 @@ def check_type(
type_: Type to check against.
desc: Description for error.
>>> import pytest; import cdf.di
>>> import pytest; import cdf.injector
>>> check_type("abc", str)
>>> with pytest.raises(cdf.di.InputConfigError):
>>> with pytest.raises(cdf.injector.InputConfigError):
... check_type("abc", int)
"""
if type_ is None:
Expand Down

0 comments on commit 628c426

Please sign in to comment.