Skip to content

Commit

Permalink
Add GraphRun object
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Feb 1, 2025
1 parent 1443b49 commit 6ce755e
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 19 deletions.
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ async def main():
result_schema: _result.ResultSchema[RunResultDataT] | None = self._prepare_result_schema(result_type)

# Build the graph
graph = _agent_graph.build_agent_graph(self.name, self._deps_type, result_type or self.result_type)
graph = self._build_graph(result_type)

# Build the initial state
state = _agent_graph.GraphAgentState(
Expand Down
89 changes: 71 additions & 18 deletions pydantic_graph/pydantic_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import inspect
import types
from collections.abc import Sequence
from collections.abc import AsyncGenerator, Sequence
from contextlib import ExitStack
from dataclasses import dataclass, field
from functools import cached_property
Expand Down Expand Up @@ -170,7 +170,7 @@ async def main():
if infer_name and self.name is None:
self._infer_name(inspect.currentframe())

history: list[HistoryStep[StateT, T]] = []
graph_run = GraphRun[StateT, DepsT, T](self, state=state, deps=deps)
with ExitStack() as stack:
run_span: logfire_api.LogfireSpan | None = None
if self._auto_instrument:
Expand All @@ -181,22 +181,14 @@ async def main():
start=start_node,
)
)
while True:
next_node = await self.next(start_node, history, state=state, deps=deps, infer_name=False)
if isinstance(next_node, End):
history.append(EndStep(result=next_node))
if run_span is not None:
run_span.set_attribute('history', history)
return next_node.data, history
elif isinstance(next_node, BaseNode):
start_node = next_node
else:
if TYPE_CHECKING:
typing_extensions.assert_never(next_node)
else:
raise exceptions.GraphRuntimeError(
f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.'
)
next_node = start_node
while True:
next_node = await graph_run.next(next_node)
if isinstance(next_node, End):
history = graph_run.history
if run_span is not None:
run_span.set_attribute('history', history)
return next_node.data, history

def run_sync(
self: Graph[StateT, DepsT, T],
Expand Down Expand Up @@ -510,3 +502,64 @@ def _infer_name(self, function_frame: types.FrameType | None) -> None:
if item is self:
self.name = name
return


class GraphRun(Generic[StateT, DepsT, RunEndT]):
def __init__(
self,
graph: Graph[StateT, DepsT, RunEndT],
*,
state: StateT = None,
deps: DepsT = None,
):
self.graph = graph
self.state = state
self.deps = deps

self.history: list[HistoryStep[StateT, RunEndT]] = []
self.final_result: End[RunEndT] | None = None

self._agen: (
AsyncGenerator[BaseNode[StateT, DepsT, RunEndT] | End[RunEndT], BaseNode[StateT, DepsT, RunEndT]] | None
) = None

async def next(
self: GraphRun[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T]
) -> BaseNode[StateT, DepsT, Any] | End[T]:
agen = await self._get_primed_agen()
return await agen.asend(node)

async def _get_primed_agen(
self: GraphRun[StateT, DepsT, T],
) -> AsyncGenerator[BaseNode[StateT, DepsT, T] | End[T], BaseNode[StateT, DepsT, T]]:
graph = self.graph
state = self.state
deps = self.deps
history = self.history

if self._agen is None:

async def _agen() -> AsyncGenerator[BaseNode[StateT, DepsT, T] | End[T], BaseNode[StateT, DepsT, T]]:
next_node = yield # pyright: ignore[reportReturnType] # we prime the generator immediately below
while True:
next_node = await graph.next(next_node, history, state=state, deps=deps, infer_name=False)
if isinstance(next_node, End):
history.append(EndStep(result=next_node))
self.final_result = next_node
yield next_node
return
elif isinstance(next_node, BaseNode):
next_node = yield next_node # Give user a chance to modify the next node
else:
if TYPE_CHECKING:
typing_extensions.assert_never(next_node)
else:
raise exceptions.GraphRuntimeError(
f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.'
)

agen = _agen()
await agen.__anext__() # prime the generator

self._agen = agen
return self._agen

0 comments on commit 6ce755e

Please sign in to comment.