diff --git a/kink/container.py b/kink/container.py index 4de77a0..74816ec 100644 --- a/kink/container.py +++ b/kink/container.py @@ -1,13 +1,15 @@ from types import LambdaType -from typing import Any, Dict, Type, Union, Callable, List -from kink.typing_support import is_optional, unpack_optional +from typing import Any, Dict, Type, Union, Callable, List, overload, TypeVar from kink.errors.service_error import ServiceError - +from kink.typing_support import is_optional, unpack_optional _MISSING_SERVICE = object() +T = TypeVar("T") + + class Container: def __init__(self): self._memoized_services: Dict[Union[str, Type], Any] = {} @@ -29,7 +31,15 @@ def add_alias(self, name: Union[str, Type], target: Union[str, Type]): self._aliases[name] = [] self._aliases[name].append(target) - def __getitem__(self, key: Union[str, Type]) -> Any: + @overload + def __getitem__(self, key: str) -> Any: + ... + + @overload + def __getitem__(self, key: Type[T]) -> T: + ... + + def __getitem__(self, key): if key in self._factories: return self._factories[key](self) @@ -42,10 +52,10 @@ def __getitem__(self, key: Union[str, Type]) -> Any: return service if key in self._aliases: - unaliased_key = self._aliases[key][0] # By default return first aliased service + unaliased_key = self._aliases[key][0] # By default return first aliased service if unaliased_key in self._factories: return self._factories[unaliased_key](self) - service = self._get(unaliased_key) + service = self._get(unaliased_key) if service is not _MISSING_SERVICE: return service @@ -85,7 +95,6 @@ def __contains__(self, key) -> bool: if is_optional(key): return unpack_optional(key) in self - return False def _has_alias_list_for(self, key: Union[str, Type]) -> bool: