Skip to content

Commit

Permalink
Fixed possible bug.
Browse files Browse the repository at this point in the history
  • Loading branch information
s3rius committed Sep 29, 2024
1 parent 5618873 commit 3f3e913
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 3 deletions.
9 changes: 6 additions & 3 deletions taskiq_dependencies/ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ class BaseResolveContext:
def __init__(
self,
graph: "DependencyGraph",
main_graph: "DependencyGraph",
initial_cache: Optional[Dict[Any, Any]] = None,
exception_propagation: bool = True,
) -> None:
self.graph = graph
# Main graph that contains all the subgraphs.
self.main_graph = main_graph
self.opened_dependencies: List[Any] = []
self.sub_contexts: "List[Any]" = []
self.initial_cache = initial_cache or {}
Expand Down Expand Up @@ -91,7 +94,7 @@ def traverse_deps( # noqa: C901
if subdep.dependency == ParamInfo:
kwargs[subdep.param_name] = ParamInfo(
dep.param_name,
self.graph,
self.main_graph,
dep.signature,
)
continue
Expand Down Expand Up @@ -201,7 +204,7 @@ def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> Any:
:return: dict with resolved kwargs.
"""
if getattr(executed_func, "dep_graph", False):
ctx = SyncResolveContext(executed_func, initial_cache)
ctx = SyncResolveContext(executed_func, self.main_graph, initial_cache)
self.sub_contexts.append(ctx)
sub_result = ctx.resolve_kwargs()
elif inspect.isgenerator(executed_func):
Expand Down Expand Up @@ -329,7 +332,7 @@ async def resolver(
:return: dict with resolved kwargs.
"""
if getattr(executed_func, "dep_graph", False):
ctx = AsyncResolveContext(executed_func, initial_cache) # type: ignore
ctx = AsyncResolveContext(executed_func, self.main_graph, initial_cache) # type: ignore
self.sub_contexts.append(ctx)
sub_result = await ctx.resolve_kwargs()
elif inspect.isgenerator(executed_func):
Expand Down
2 changes: 2 additions & 0 deletions taskiq_dependencies/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def async_ctx(
if replaced_deps:
graph = DependencyGraph(self.target, replaced_deps)
return AsyncResolveContext(
graph,
graph,
initial_cache,
exception_propagation,
Expand All @@ -90,6 +91,7 @@ def sync_ctx(
if replaced_deps:
graph = DependencyGraph(self.target, replaced_deps)
return SyncResolveContext(
graph,
graph,
initial_cache,
exception_propagation,
Expand Down
24 changes: 24 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,3 +867,27 @@ def target(acm: TestACM = Depends(get_test_acm)) -> None:
kwargs = await ctx.resolve_kwargs()
assert kwargs["acm"] == test_acm
assert not test_acm.opened


def test_param_info_subgraph() -> None:
"""
Test subgraphs for ParamInfo.
Test that correct graph is stored in ParamInfo
even if evaluated from subgraphs.
"""

def inner_dep(info: ParamInfo = Depends()) -> ParamInfo:
return info

def target(info: ParamInfo = Depends(inner_dep, use_cache=False)) -> None:
return None

graph = DependencyGraph(target=target)
with graph.sync_ctx() as g:
kwargs = g.resolve_kwargs()

info: ParamInfo = kwargs["info"]
assert info.name == ""
assert info.definition is None
assert info.graph == graph

0 comments on commit 3f3e913

Please sign in to comment.