Skip to content

Commit

Permalink
Add post-init injection
Browse files Browse the repository at this point in the history
  • Loading branch information
Aegdesil committed Nov 3, 2023
1 parent 00c6178 commit 6a47317
Show file tree
Hide file tree
Showing 21 changed files with 373 additions and 168 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
Opyoid follows [semver guidelines](https://semver.org) for versioning.

## Unreleased
## 2.1.0
### Features
- Added `__opyoid_post_init__` method support to allow breaking dependency loops, check the README for more details

## 2.0.0
### Breaking changes
- Remove support for Python < 3.8
Expand Down
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,38 @@ assert my_instance is injected_instance
```


### Environment Variables

You can use environment variables to easily override bindings in your application.

#### Supported types

Supported types are `str`, `int`, `float`, and `bool`.

Environment variables are only used when loading ClassBindings, ProviderBindings or SelfBindings, not InstanceBindings

If the corresponding environment variable exists, it will override the existing default value and bindings for the
parameter.

#### Environment variable name

The environment variable should be named `UPPER_CLASS_NAME_UPPER_PARAMETER_NAME`
In this example, the environment variable to set is `MY_CLASS_MY_PARAMETER`:
```python
@dataclass
class MyClass:
my_parameter: int
```

#### Value conversion

For types other than str, an automatic conversion is made:
- ints and floats are converted using int() and float()
- for booleans, authorized values are:
- "0", "false" and "False", will be converted to `False`
- "1", "true" and "True", will be converted to `True`


### Binding scopes
When binding a class, you can choose the scope in which it will be instantiated.
This will only have an effect when binding classes, not instances.
Expand Down
30 changes: 0 additions & 30 deletions docs/environment_variables.md

This file was deleted.

36 changes: 36 additions & 0 deletions docs/post_init.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Post init

You can sometimes end up having a dependency loop: A requires B and B requires A.
Most of the time this is due to a code structuring issue and needs refactoring, but it can also be the only way to solve
a problem.

The `__opyoid_post_init__` method helps to break this dependency loop by adding additional attributes to the instances
after they are instantiated.

```python
from opyoid import Injector, SelfBinding
from typing import Optional

class ClassA:
def __init__(self, my_arg: "ClassB"):
self.my_arg = my_arg

class ClassB:
def __init__(
self,
) -> None:
self.my_arg: Optional[ClassA] = None

def __opyoid_post_init__(self, my_arg: ClassA) -> None:
self.my_arg = my_arg

injector = Injector(bindings=[SelfBinding(ClassA), SelfBinding(ClassB)])

instance_a = injector.inject(ClassA)
instance_b = injector.inject(ClassB)

assert isinstance(instance_a, ClassA)
assert isinstance(instance_b, ClassB)
assert instance_a.my_arg is instance_b
assert instance_b.my_arg is instance_a
```
110 changes: 10 additions & 100 deletions opyoid/bindings/self_binding/callable_to_provider_adapter.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import logging
from inspect import Parameter, signature
from typing import Any, Callable, Dict, List, Optional, Type
from typing import Callable, Optional, Type

from opyoid.bindings.instance_binding import FromInstanceProvider
from opyoid.exceptions import NoBindingFound, NonInjectableTypeError
from opyoid.injection_context import InjectionContext
from opyoid.provider import Provider
from opyoid.target import Target
from opyoid.type_checker import TypeChecker
from opyoid.utils import EMPTY, get_class_full_name, InjectedT
from opyoid.utils import InjectedT
from .from_callable_provider import FromCallableProvider
from .parameters_provider import ParametersProvider
from ...scopes import Scope


Expand All @@ -18,6 +16,9 @@ class CallableToProviderAdapter:

logger = logging.getLogger(__name__)

def __init__(self) -> None:
self._parameters_provider = ParametersProvider()

def create(
self,
type_or_function: Callable[..., InjectedT],
Expand All @@ -29,35 +30,16 @@ def create(
)
if cached_provider:
return cached_provider
if isinstance(type_or_function, type):
parameters = list(signature(type_or_function.__init__).parameters.values())[1:]
else:
parameters = signature(type_or_function).parameters.values()
positional_providers: List[Provider[Any]] = []
args_provider: Optional[Provider[List[Any]]] = None
keyword_providers: Dict[str, Provider[Any]] = {}
for parameter in parameters:
context.current_parameter = parameter
# Ignore '**kwargs'
if parameter.kind == Parameter.VAR_KEYWORD:
continue

if parameter.kind == Parameter.VAR_POSITIONAL:
# *args
args_provider = self._get_positional_parameter_provider(parameter, type_or_function, context)
continue
parameter_provider = self._get_parameter_provider(parameter, type_or_function, context)
if parameter.kind == Parameter.KEYWORD_ONLY:
# After *args
keyword_providers[parameter.name] = parameter_provider
else:
# Before *args
positional_providers.append(parameter_provider)
positional_providers, args_provider, keyword_providers = self._parameters_provider.get_parameters_provider(
type_or_function, context
)
unscoped_provider = FromCallableProvider(
type_or_function,
positional_providers,
args_provider,
keyword_providers,
context,
)
scope_context: InjectionContext[Scope] = context.get_child_context(Target(scope))
try:
Expand All @@ -69,75 +51,3 @@ def create(
provider = scope_provider.get().get_scoped_provider(unscoped_provider)
context.injection_state.provider_registry.set_provider(context.target, provider)
return provider

def _get_parameter_provider(
self, parameter: Parameter, type_or_function: Callable[..., InjectedT], context: InjectionContext[InjectedT]
) -> Provider[InjectedT]:
default_value = parameter.default if parameter.default is not Parameter.empty else EMPTY
if parameter.annotation is not Parameter.empty:
if TypeChecker.is_named(parameter.annotation):
provider: Optional[Provider[InjectedT]] = self._get_provider(
[Target(parameter.annotation.original_type, parameter.annotation.name, default_value)], context
)
else:
provider = self._get_provider(
[
Target(parameter.annotation, parameter.name, default_value),
Target(parameter.annotation, None, default_value),
],
context,
)
if provider:
return provider
if parameter.default is not Parameter.empty:
return FromInstanceProvider(parameter.default)
raise NonInjectableTypeError(
f"Could not find a binding or a default value for {parameter.name}: "
f"{get_class_full_name(parameter.annotation)} required by {type_or_function}"
)

def _get_positional_parameter_provider(
self, parameter: Parameter, type_or_function: Callable[..., InjectedT], context: InjectionContext[InjectedT]
) -> Provider[List[InjectedT]]:
if parameter.annotation is Parameter.empty:
return FromInstanceProvider([])
if TypeChecker.is_named(parameter.annotation):
provider: Optional[Provider[List[InjectedT]]] = self._get_provider(
[
Target(
List[parameter.annotation.original_type], # type: ignore[name-defined]
parameter.annotation.name,
default=[],
)
],
context,
)
else:
provider = self._get_provider(
[
Target(List[parameter.annotation], parameter.name, default=[]), # type: ignore[name-defined]
Target(List[parameter.annotation], default=[]), # type: ignore[name-defined]
],
context,
)
if provider:
return provider
self.logger.debug(
f"Could not find a binding for *{parameter.name}: {parameter.annotation} required by "
f"{type_or_function}, will inject nothing"
)
return FromInstanceProvider([])

@staticmethod
def _get_provider(
targets: List[Target[InjectedT]], parent_context: InjectionContext[Any]
) -> Optional[Provider[InjectedT]]:
for target_index, target in enumerate(targets):
context = parent_context.get_child_context(target, allow_jit_provider=target_index == len(targets) - 1)
context.current_class = parent_context.current_class
context.current_parameter = parent_context.current_parameter
try:
return context.get_provider()
except NoBindingFound:
pass
return None
29 changes: 28 additions & 1 deletion opyoid/bindings/self_binding/from_callable_provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Any, Callable, Dict, List, Optional

from opyoid.bindings.self_binding.parameters_provider import ParametersProvider
from opyoid.constants import OPYOID_POST_INIT
from opyoid.injection_context import InjectionContext
from opyoid.provider import Provider
from opyoid.utils import InjectedT

Expand All @@ -11,18 +14,42 @@ def __init__(
positional_providers: List[Provider[Any]],
args_provider: Optional[Provider[List[Any]]],
keyword_providers: Dict[str, Provider[Any]],
injection_context: InjectionContext[InjectedT],
) -> None:
self._injected_callable = injected_callable
self._positional_providers = positional_providers
self._args_provider = args_provider
self._keyword_providers = keyword_providers
self._injection_context = injection_context
self._parameters_provider = ParametersProvider()

def get(self) -> InjectedT:
args = [positional_provider.get() for positional_provider in self._positional_providers]
if self._args_provider:
args += self._args_provider.get()
kwargs = {arg_name: keyword_provider.get() for arg_name, keyword_provider in self._keyword_providers.items()}
return self._injected_callable(
result = self._injected_callable(
*args,
**kwargs,
)
if hasattr(self._injected_callable, OPYOID_POST_INIT):
self._injection_context.injection_state.add_post_init_callback(lambda: self._run_post_init(result))
return result

def _run_post_init(self, instance: InjectedT) -> None:
(
post_init_positional_providers,
post_init_args_provider,
post_init_keyword_providers,
) = self._parameters_provider.get_parameters_provider(
getattr(instance, OPYOID_POST_INIT),
self._injection_context,
)

post_init_args = [positional_provider.get() for positional_provider in post_init_positional_providers]
if post_init_args_provider:
post_init_args += post_init_args_provider.get()
post_init_kwargs = {
arg_name: keyword_provider.get() for arg_name, keyword_provider in post_init_keyword_providers.items()
}
getattr(instance, OPYOID_POST_INIT)(*post_init_args, **post_init_kwargs)
Loading

0 comments on commit 6a47317

Please sign in to comment.