Skip to content

Commit 3f3e913

Browse files
committed
Fixed possible bug.
1 parent 5618873 commit 3f3e913

File tree

3 files changed

+32
-3
lines changed

3 files changed

+32
-3
lines changed

taskiq_dependencies/ctx.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@ class BaseResolveContext:
2020
def __init__(
2121
self,
2222
graph: "DependencyGraph",
23+
main_graph: "DependencyGraph",
2324
initial_cache: Optional[Dict[Any, Any]] = None,
2425
exception_propagation: bool = True,
2526
) -> None:
2627
self.graph = graph
28+
# Main graph that contains all the subgraphs.
29+
self.main_graph = main_graph
2730
self.opened_dependencies: List[Any] = []
2831
self.sub_contexts: "List[Any]" = []
2932
self.initial_cache = initial_cache or {}
@@ -91,7 +94,7 @@ def traverse_deps( # noqa: C901
9194
if subdep.dependency == ParamInfo:
9295
kwargs[subdep.param_name] = ParamInfo(
9396
dep.param_name,
94-
self.graph,
97+
self.main_graph,
9598
dep.signature,
9699
)
97100
continue
@@ -201,7 +204,7 @@ def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> Any:
201204
:return: dict with resolved kwargs.
202205
"""
203206
if getattr(executed_func, "dep_graph", False):
204-
ctx = SyncResolveContext(executed_func, initial_cache)
207+
ctx = SyncResolveContext(executed_func, self.main_graph, initial_cache)
205208
self.sub_contexts.append(ctx)
206209
sub_result = ctx.resolve_kwargs()
207210
elif inspect.isgenerator(executed_func):
@@ -329,7 +332,7 @@ async def resolver(
329332
:return: dict with resolved kwargs.
330333
"""
331334
if getattr(executed_func, "dep_graph", False):
332-
ctx = AsyncResolveContext(executed_func, initial_cache) # type: ignore
335+
ctx = AsyncResolveContext(executed_func, self.main_graph, initial_cache) # type: ignore
333336
self.sub_contexts.append(ctx)
334337
sub_result = await ctx.resolve_kwargs()
335338
elif inspect.isgenerator(executed_func):

taskiq_dependencies/graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def async_ctx(
6464
if replaced_deps:
6565
graph = DependencyGraph(self.target, replaced_deps)
6666
return AsyncResolveContext(
67+
graph,
6768
graph,
6869
initial_cache,
6970
exception_propagation,
@@ -90,6 +91,7 @@ def sync_ctx(
9091
if replaced_deps:
9192
graph = DependencyGraph(self.target, replaced_deps)
9293
return SyncResolveContext(
94+
graph,
9395
graph,
9496
initial_cache,
9597
exception_propagation,

tests/test_graph.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,3 +867,27 @@ def target(acm: TestACM = Depends(get_test_acm)) -> None:
867867
kwargs = await ctx.resolve_kwargs()
868868
assert kwargs["acm"] == test_acm
869869
assert not test_acm.opened
870+
871+
872+
def test_param_info_subgraph() -> None:
873+
"""
874+
Test subgraphs for ParamInfo.
875+
876+
Test that correct graph is stored in ParamInfo
877+
even if evaluated from subgraphs.
878+
"""
879+
880+
def inner_dep(info: ParamInfo = Depends()) -> ParamInfo:
881+
return info
882+
883+
def target(info: ParamInfo = Depends(inner_dep, use_cache=False)) -> None:
884+
return None
885+
886+
graph = DependencyGraph(target=target)
887+
with graph.sync_ctx() as g:
888+
kwargs = g.resolve_kwargs()
889+
890+
info: ParamInfo = kwargs["info"]
891+
assert info.name == ""
892+
assert info.definition is None
893+
assert info.graph == graph

0 commit comments

Comments
 (0)