Skip to content

Commit

Permalink
Add post-init injection
Browse files Browse the repository at this point in the history
  • Loading branch information
Aegdesil committed Oct 24, 2023
1 parent dd0ac02 commit 4440f92
Show file tree
Hide file tree
Showing 17 changed files with 243 additions and 116 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
Opyoid follows [semver guidelines](https://semver.org) for versioning.

## Unreleased
## 2.1.0
### Features
- Added `__opyoid_post_init__` method support to allow breaking dependency loops, check the README for more details

## 2.0.0
### Breaking changes
- Remove support for Python < 3.8
Expand Down
110 changes: 10 additions & 100 deletions opyoid/bindings/self_binding/callable_to_provider_adapter.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import logging
from inspect import Parameter, signature
from typing import Any, Callable, Dict, List, Optional, Type
from typing import Callable, Optional, Type

from opyoid.bindings.instance_binding import FromInstanceProvider
from opyoid.exceptions import NoBindingFound, NonInjectableTypeError
from opyoid.injection_context import InjectionContext
from opyoid.provider import Provider
from opyoid.target import Target
from opyoid.type_checker import TypeChecker
from opyoid.utils import EMPTY, get_class_full_name, InjectedT
from opyoid.utils import InjectedT
from .from_callable_provider import FromCallableProvider
from .parameters_provider import ParametersProvider
from ...scopes import Scope


Expand All @@ -18,6 +16,9 @@ class CallableToProviderAdapter:

logger = logging.getLogger(__name__)

def __init__(self):
self._parameters_provider = ParametersProvider()

def create(
self,
type_or_function: Callable[..., InjectedT],
Expand All @@ -29,35 +30,16 @@ def create(
)
if cached_provider:
return cached_provider
if isinstance(type_or_function, type):
parameters = list(signature(type_or_function.__init__).parameters.values())[1:]
else:
parameters = signature(type_or_function).parameters.values()
positional_providers: List[Provider[Any]] = []
args_provider: Optional[Provider[List[Any]]] = None
keyword_providers: Dict[str, Provider[Any]] = {}
for parameter in parameters:
context.current_parameter = parameter
# Ignore '**kwargs'
if parameter.kind == Parameter.VAR_KEYWORD:
continue

if parameter.kind == Parameter.VAR_POSITIONAL:
# *args
args_provider = self._get_positional_parameter_provider(parameter, type_or_function, context)
continue
parameter_provider = self._get_parameter_provider(parameter, type_or_function, context)
if parameter.kind == Parameter.KEYWORD_ONLY:
# After *args
keyword_providers[parameter.name] = parameter_provider
else:
# Before *args
positional_providers.append(parameter_provider)
positional_providers, args_provider, keyword_providers = self._parameters_provider.get_parameters_provider(
type_or_function, context
)
unscoped_provider = FromCallableProvider(
type_or_function,
positional_providers,
args_provider,
keyword_providers,
context,
)
scope_context: InjectionContext[Scope] = context.get_child_context(Target(scope))
try:
Expand All @@ -69,75 +51,3 @@ def create(
provider = scope_provider.get().get_scoped_provider(unscoped_provider)
context.injection_state.provider_registry.set_provider(context.target, provider)
return provider

def _get_parameter_provider(
self, parameter: Parameter, type_or_function: Callable[..., InjectedT], context: InjectionContext[InjectedT]
) -> Provider[InjectedT]:
default_value = parameter.default if parameter.default is not Parameter.empty else EMPTY
if parameter.annotation is not Parameter.empty:
if TypeChecker.is_named(parameter.annotation):
provider: Optional[Provider[InjectedT]] = self._get_provider(
[Target(parameter.annotation.original_type, parameter.annotation.name, default_value)], context
)
else:
provider = self._get_provider(
[
Target(parameter.annotation, parameter.name, default_value),
Target(parameter.annotation, None, default_value),
],
context,
)
if provider:
return provider
if parameter.default is not Parameter.empty:
return FromInstanceProvider(parameter.default)
raise NonInjectableTypeError(
f"Could not find a binding or a default value for {parameter.name}: "
f"{get_class_full_name(parameter.annotation)} required by {type_or_function}"
)

def _get_positional_parameter_provider(
self, parameter: Parameter, type_or_function: Callable[..., InjectedT], context: InjectionContext[InjectedT]
) -> Provider[List[InjectedT]]:
if parameter.annotation is Parameter.empty:
return FromInstanceProvider([])
if TypeChecker.is_named(parameter.annotation):
provider: Optional[Provider[List[InjectedT]]] = self._get_provider(
[
Target(
List[parameter.annotation.original_type], # type: ignore[name-defined]
parameter.annotation.name,
default=[],
)
],
context,
)
else:
provider = self._get_provider(
[
Target(List[parameter.annotation], parameter.name, default=[]), # type: ignore[name-defined]
Target(List[parameter.annotation], default=[]), # type: ignore[name-defined]
],
context,
)
if provider:
return provider
self.logger.debug(
f"Could not find a binding for *{parameter.name}: {parameter.annotation} required by "
f"{type_or_function}, will inject nothing"
)
return FromInstanceProvider([])

@staticmethod
def _get_provider(
targets: List[Target[InjectedT]], parent_context: InjectionContext[Any]
) -> Optional[Provider[InjectedT]]:
for target_index, target in enumerate(targets):
context = parent_context.get_child_context(target, allow_jit_provider=target_index == len(targets) - 1)
context.current_class = parent_context.current_class
context.current_parameter = parent_context.current_parameter
try:
return context.get_provider()
except NoBindingFound:
pass
return None
31 changes: 30 additions & 1 deletion opyoid/bindings/self_binding/from_callable_provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Any, Callable, Dict, List, Optional

from opyoid.bindings.self_binding.parameters_provider import ParametersProvider
from opyoid.constants import OPYOID_POST_INIT
from opyoid.injection_context import InjectionContext
from opyoid.provider import Provider
from opyoid.utils import InjectedT

Expand All @@ -11,18 +14,44 @@ def __init__(
positional_providers: List[Provider[Any]],
args_provider: Optional[Provider[List[Any]]],
keyword_providers: Dict[str, Provider[Any]],
injection_context: InjectionContext,
) -> None:
self._injected_callable = injected_callable
self._positional_providers = positional_providers
self._args_provider = args_provider
self._keyword_providers = keyword_providers
self._injection_context = injection_context
self._parameters_provider = ParametersProvider()

def get(self) -> InjectedT:
args = [positional_provider.get() for positional_provider in self._positional_providers]
if self._args_provider:
args += self._args_provider.get()
kwargs = {arg_name: keyword_provider.get() for arg_name, keyword_provider in self._keyword_providers.items()}
return self._injected_callable(
result = self._injected_callable(
*args,
**kwargs,
)
if hasattr(self._injected_callable, OPYOID_POST_INIT):
self._injection_context.injection_state.add_post_init_callback(lambda: self._run_post_init(result))
# self._run_post_init(result)
# self._injection_context.add_post_init_callback(lambda: self._run_post_init(result))
return result

def _run_post_init(self, instance: InjectedT) -> None:
(
post_init_positional_providers,
post_init_args_provider,
post_init_keyword_providers,
) = self._parameters_provider.get_parameters_provider(
getattr(instance, OPYOID_POST_INIT),
self._injection_context,
)

post_init_args = [positional_provider.get() for positional_provider in post_init_positional_providers]
if post_init_args_provider:
post_init_args += post_init_args_provider.get()
post_init_kwargs = {
arg_name: keyword_provider.get() for arg_name, keyword_provider in post_init_keyword_providers.items()
}
getattr(instance, OPYOID_POST_INIT)(*post_init_args, **post_init_kwargs)
123 changes: 123 additions & 0 deletions opyoid/bindings/self_binding/parameters_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import inspect
import logging
import sys
from inspect import Parameter, signature
from typing import Any, Callable, Dict, List, Optional, Tuple

from opyoid.bindings.instance_binding import FromInstanceProvider
from opyoid.exceptions import NoBindingFound, NonInjectableTypeError
from opyoid.injection_context import InjectionContext
from opyoid.provider import Provider
from opyoid.target import Target
from opyoid.type_checker import TypeChecker
from opyoid.utils import EMPTY, get_class_full_name, InjectedT


class ParametersProvider:
"""Creates Providers from a callable."""

logger = logging.getLogger(__name__)

def get_parameters_provider(
self,
type_or_function: Callable[..., InjectedT],
context: InjectionContext[InjectedT],
) -> Tuple[List[Provider[Any]], Optional[Provider[List[Any]]], Dict[str, Provider[Any]]]:
if sys.version_info[:2] < (3, 9) and isinstance(type_or_function, type):
parameters = list(signature(type_or_function.__init__).parameters.values())[1:]
else:
parameters = list(signature(type_or_function).parameters.values())

positional_providers: List[Provider[Any]] = []
args_provider: Optional[Provider[List[Any]]] = None
keyword_providers: Dict[str, Provider[Any]] = {}
for parameter in parameters:
context.current_parameter = parameter
# Ignore '**kwargs'
if parameter.kind == Parameter.VAR_KEYWORD:
continue

if parameter.kind == Parameter.VAR_POSITIONAL:
# *args
args_provider = self._get_positional_parameter_provider(parameter, type_or_function, context)
continue
parameter_provider = self._get_parameter_provider(parameter, type_or_function, context)
if parameter.kind == Parameter.KEYWORD_ONLY:
# After *args
keyword_providers[parameter.name] = parameter_provider
else:
# Before *args
positional_providers.append(parameter_provider)
return positional_providers, args_provider, keyword_providers

def _get_parameter_provider(
self, parameter: Parameter, type_or_function: Callable[..., InjectedT], context: InjectionContext[InjectedT]
) -> Provider[InjectedT]:
default_value = parameter.default if parameter.default is not Parameter.empty else EMPTY
if parameter.annotation is not Parameter.empty:
if TypeChecker.is_named(parameter.annotation):
provider: Optional[Provider[InjectedT]] = self._get_provider(
[Target(parameter.annotation.original_type, parameter.annotation.name, default_value)], context
)
else:
provider = self._get_provider(
[
Target(parameter.annotation, parameter.name, default_value),
Target(parameter.annotation, None, default_value),
],
context,
)
if provider:
return provider
if parameter.default is not Parameter.empty:
return FromInstanceProvider(parameter.default)
raise NonInjectableTypeError(
f"Could not find a binding or a default value for {parameter.name}: "
f"{get_class_full_name(parameter.annotation)} required by {type_or_function}"
)

def _get_positional_parameter_provider(
self, parameter: Parameter, type_or_function: Callable[..., InjectedT], context: InjectionContext[InjectedT]
) -> Provider[List[InjectedT]]:
if parameter.annotation is Parameter.empty:
return FromInstanceProvider([])
if TypeChecker.is_named(parameter.annotation):
provider: Optional[Provider[List[InjectedT]]] = self._get_provider(
[
Target(
List[parameter.annotation.original_type], # type: ignore[name-defined]
parameter.annotation.name,
default=[],
)
],
context,
)
else:
provider = self._get_provider(
[
Target(List[parameter.annotation], parameter.name, default=[]), # type: ignore[name-defined]
Target(List[parameter.annotation], default=[]), # type: ignore[name-defined]
],
context,
)
if provider:
return provider
self.logger.debug(
f"Could not find a binding for *{parameter.name}: {parameter.annotation} required by "
f"{type_or_function}, will inject nothing"
)
return FromInstanceProvider([])

@staticmethod
def _get_provider(
targets: List[Target[InjectedT]], parent_context: InjectionContext[Any]
) -> Optional[Provider[InjectedT]]:
for target_index, target in enumerate(targets):
context = parent_context.get_child_context(target, allow_jit_provider=target_index == len(targets) - 1)
context.current_class = parent_context.current_class
context.current_parameter = parent_context.current_parameter
try:
return context.get_provider()
except NoBindingFound:
pass
return None
1 change: 1 addition & 0 deletions opyoid/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
OPYOID_POST_INIT = "__opyoid_post_init__"
2 changes: 1 addition & 1 deletion opyoid/injection_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from inspect import Parameter
from typing import Any, Generic, List, Optional, Type, TYPE_CHECKING, TypeVar
from typing import Any, Callable, Generic, List, Optional, Type, TYPE_CHECKING, TypeVar

import attr

Expand Down
9 changes: 8 additions & 1 deletion opyoid/injection_state.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, TYPE_CHECKING
from typing import Callable, Dict, List, Optional, TYPE_CHECKING

import attr

Expand All @@ -18,3 +18,10 @@ class InjectionState:
parent_state: Optional["InjectionState"] = None
provider_registry: ProviderRegistry = attr.Factory(ProviderRegistry)
state_by_module: Dict[AbstractModule, "InjectionState"] = attr.Factory(dict)
post_init_callbacks: List[Callable[[], None]] = attr.Factory(list)

def add_post_init_callback(self, callback: Callable[[], None]) -> None:
if self.parent_state:
self.parent_state.add_post_init_callback(callback)
else:
self.post_init_callbacks.append(callback)
6 changes: 5 additions & 1 deletion opyoid/injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .bindings import Binding
from .bindings.abstract_module import AbstractModule
from .bindings.root_module import RootModule
from .constants import OPYOID_POST_INIT
from .injection_context import InjectionContext
from .injection_state import InjectionState
from .injector_options import InjectorOptions
Expand Down Expand Up @@ -40,4 +41,7 @@ def __init__(

def inject(self, target_type: Union[Type[InjectedT], TypeVar, Any], *, named: Optional[str] = None) -> InjectedT:
injection_context: InjectionContext[InjectedT] = InjectionContext(Target(target_type, named), self._root_state)
return injection_context.get_provider().get()
result = injection_context.get_provider().get()
for post_init_callback in self._root_state.post_init_callbacks:
post_init_callback()
return result
4 changes: 3 additions & 1 deletion opyoid/providers/providers_factories/set_provider_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,7 @@ def create(self, context: InjectionContext[InjectedT]) -> Provider[InjectedT]:
context.target.named,
)
new_context = context.get_child_context(new_target)
return FromCallableProvider(cast(Callable[..., InjectedT], set), [new_context.get_provider()], None, {})
return FromCallableProvider(
cast(Callable[..., InjectedT], set), [new_context.get_provider()], None, {}, context
)
raise IncompatibleProviderFactory
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,7 @@ def create(self, context: InjectionContext[InjectedT]) -> Provider[InjectedT]:
context.target.named,
)
new_context = context.get_child_context(new_target)
return FromCallableProvider(cast(Callable[..., InjectedT], tuple), [new_context.get_provider()], None, {})
return FromCallableProvider(
cast(Callable[..., InjectedT], tuple), [new_context.get_provider()], None, {}, context
)
raise IncompatibleProviderFactory
Loading

0 comments on commit 4440f92

Please sign in to comment.