Skip to content

Commit

Permalink
Make graph.run(...) return an instance of GraphRun
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Feb 3, 2025
1 parent 93260fe commit 39a6009
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 108 deletions.
3 changes: 2 additions & 1 deletion pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,13 @@ async def main():
)

# Actually run
end_result, _ = await graph.run(
graph_run = await graph.run(
start_node,
state=state,
deps=graph_deps,
infer_name=False,
)
end_result = graph_run.result

# Build final run result
# We don't do any advanced checking if the data is actually from a final result or not
Expand Down
4 changes: 2 additions & 2 deletions pydantic_graph/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ class Increment(BaseNode):


fives_graph = Graph(nodes=[DivisibleBy5, Increment])
result, history = fives_graph.run_sync(DivisibleBy5(4))
print(result)
graph_run = fives_graph.run_sync(DivisibleBy5(4))
print(graph_run.result)
#> 5
# the full history is quite verbose (see below), so we'll just print the summary
print([item.data_snapshot() for item in history])
Expand Down
3 changes: 1 addition & 2 deletions pydantic_graph/pydantic_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from .exceptions import GraphRuntimeError, GraphSetupError
from .graph import Graph, GraphRun, GraphRunner
from .graph import Graph, GraphRun
from .nodes import BaseNode, Edge, End, GraphRunContext
from .state import EndStep, HistoryStep, NodeStep

__all__ = (
'Graph',
'GraphRun',
'GraphRunner',
'BaseNode',
'End',
'GraphRunContext',
Expand Down
130 changes: 51 additions & 79 deletions pydantic_graph/pydantic_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)


__all__ = ('Graph', 'GraphRun', 'GraphRunner')
__all__ = ('Graph', 'GraphRun')

_logfire = logfire_api.Logfire(otel_scope='pydantic-graph')

Expand Down Expand Up @@ -133,7 +133,7 @@ def run(
state: StateT = None,
deps: DepsT = None,
infer_name: bool = True,
) -> GraphRunner[StateT, DepsT, T]:
) -> GraphRun[StateT, DepsT, T]:
"""Run the graph from a starting node until it ends.
Args:
Expand Down Expand Up @@ -170,7 +170,7 @@ async def main():
if infer_name and self.name is None:
self._infer_name(inspect.currentframe())

return GraphRunner[StateT, DepsT, T](
return GraphRun[StateT, DepsT, T](
self, start_node, history=[], state=state, deps=deps, auto_instrument=self._auto_instrument
)

Expand All @@ -181,7 +181,7 @@ def run_sync(
state: StateT = None,
deps: DepsT = None,
infer_name: bool = True,
) -> tuple[T, list[HistoryStep[StateT, T]]]:
) -> GraphRun[StateT, DepsT, T]:
"""Run the graph synchronously.
This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`.
Expand Down Expand Up @@ -499,11 +499,10 @@ def _infer_name(self, function_frame: types.FrameType | None) -> None:
return


class GraphRunner(Generic[StateT, DepsT, RunEndT]):
"""An object that can be awaited to perform a graph run.
class GraphRun(Generic[StateT, DepsT, RunEndT]):
"""A stateful run of a graph.
This object can also be used as a contextmanager to get a handle to a specific graph run,
allowing you to iterate over nodes, and possibly perform modifications to the nodes as they are run.
After being entered, can be used like an async generator to listen to / modify nodes as the run is executed.
"""

def __init__(
Expand All @@ -517,84 +516,25 @@ def __init__(
auto_instrument: bool,
):
self.graph = graph
self.first_node = first_node
self.history = history
self.state = state
self.deps = deps

self._run: GraphRun[StateT, DepsT, RunEndT] | None = None

self._auto_instrument = auto_instrument
self._span: LogfireSpan | None = None

@property
def run(self) -> GraphRun[StateT, DepsT, RunEndT]:
if self._run is None:
raise exceptions.GraphRuntimeError('GraphRunner has not been awaited yet.')
return self._run

def __await__(self) -> Generator[Any, Any, tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]]:
"""Run the graph until it ends, and return the final result."""

async def _run() -> tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]:
async with self as run:
self._run = run
async for _next_node in run:
pass

return run.final_result, run.history

return _run().__await__()

async def __aenter__(self) -> GraphRun[StateT, DepsT, RunEndT]:
if self._run is not None:
raise exceptions.GraphRuntimeError('A GraphRunner can only start a GraphRun once.')

if self._auto_instrument:
self._span = logfire_api.span('run graph {graph.name}', graph=self.graph)
self._span.__enter__()

self._run = run = GraphRun(self.graph, self.first_node, history=self.history, state=self.state, deps=self.deps)
return run

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._span is not None:
self._span.__exit__(exc_type, exc_val, exc_tb)
self._span = None # make it more obvious if you try to use it after exiting


class GraphRun(Generic[StateT, DepsT, RunEndT]):
"""A stateful run of a graph.
Can be used like an async generator to listen to / modify nodes as the run is executed.
"""

def __init__(
self,
graph: Graph[StateT, DepsT, RunEndT],
next_node: BaseNode[StateT, DepsT, RunEndT],
*,
history: list[HistoryStep[StateT, RunEndT]],
state: StateT,
deps: DepsT,
):
self.graph = graph
self.next_node = next_node
self.history = history
self.state = state
self.deps = deps

self._final_result: End[RunEndT] | None = None
self._next_node = first_node
self._started: bool = False
self._result: End[RunEndT] | None = None
self._span: LogfireSpan | None = None

@property
def is_ended(self):
return self._final_result is not None
return self._result is not None

@property
def final_result(self) -> RunEndT:
if self._final_result is None:
def result(self) -> RunEndT:
if self._result is None:
raise exceptions.GraphRuntimeError('GraphRun has not ended yet.')
return self._final_result.data
return self._result.data

async def next(
self: GraphRun[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T]
Expand All @@ -607,16 +547,48 @@ async def next(
next_node = await self.graph.next(node, history, state=state, deps=deps, infer_name=False)

if isinstance(next_node, End):
self._final_result = next_node
self._result = next_node
else:
self.next_node = next_node
self._next_node = next_node
return next_node

def __await__(self) -> Generator[Any, Any, typing_extensions.Self]:
"""Run the graph until it ends, and return the final result."""

async def _run() -> typing_extensions.Self:
with self:
async for _next_node in self:
pass

return self

return _run().__await__()

def __enter__(self) -> typing_extensions.Self:
if self._started:
raise exceptions.GraphRuntimeError('A GraphRun can only be started once.')

if self._auto_instrument:
self._span = logfire_api.span('run graph {graph.name}', graph=self.graph)
self._span.__enter__()

self._started = True
return self

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._span is not None:
self._span.__exit__(exc_type, exc_val, exc_tb)
self._span = None # make it more obvious if you try to use it after exiting

def __aiter__(self) -> AsyncIterator[BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]]:
return self

async def __anext__(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]:
"""Use the last returned node as the input to `Graph.next`."""
if self._final_result:
if self._result:
raise StopAsyncIteration
return await self.next(self.next_node)
if not self._started:
raise exceptions.GraphRuntimeError(
'You must enter the GraphRun as a contextmanager before you can iterate over it.'
)
return await self.next(self._next_node)
20 changes: 10 additions & 10 deletions tests/graph/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: #
assert my_graph.name is None
assert my_graph._get_state_type() is type(None)
assert my_graph._get_run_end_type() is int
result, history = await my_graph.run(Float2String(3.14))
graph_run = await my_graph.run(Float2String(3.14))
# len('3.14') * 2 == 8
assert result == 8
assert graph_run.result == 8
assert my_graph.name == 'my_graph'
assert history == snapshot(
assert graph_run.history == snapshot(
[
NodeStep(
state=None,
Expand All @@ -84,10 +84,10 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: #
EndStep(result=End(data=8), ts=IsNow(tz=timezone.utc)),
]
)
result, history = await my_graph.run(Float2String(3.14159))
graph_run = await my_graph.run(Float2String(3.14159))
# len('3.14159') == 7, 21 * 2 == 42
assert result == 42
assert history == snapshot(
assert graph_run.result == 42
assert graph_run.history == snapshot(
[
NodeStep(
state=None,
Expand Down Expand Up @@ -122,7 +122,7 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: #
EndStep(result=End(data=42), ts=IsNow(tz=timezone.utc)),
]
)
assert [e.data_snapshot() for e in history] == snapshot(
assert [e.data_snapshot() for e in graph_run.history] == snapshot(
[
Float2String(input_data=3.14159),
String2Length(input_data='3.14159'),
Expand Down Expand Up @@ -320,10 +320,10 @@ async def run(self, ctx: GraphRunContext[None, Deps]) -> End[int]:
return End(123)

g = Graph(nodes=(Foo, Bar))
result, history = await g.run(Foo(), deps=Deps(1, 2))
graph_run = await g.run(Foo(), deps=Deps(1, 2))

assert result == 123
assert history == snapshot(
assert graph_run.result == 123
assert graph_run.history == snapshot(
[
NodeStep(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()),
NodeStep(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()),
Expand Down
10 changes: 5 additions & 5 deletions tests/graph/test_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ async def run(self, ctx: GraphRunContext[MyState]) -> End[int]:
],
)
async def test_dump_load_history(graph: Graph[MyState, None, int]):
result, history = await graph.run(Foo(), state=MyState(1, ''))
assert result == snapshot(4)
assert history == snapshot(
graph_run = await graph.run(Foo(), state=MyState(1, ''))
assert graph_run.result == snapshot(4)
assert graph_run.history == snapshot(
[
NodeStep(state=MyState(x=2, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()),
NodeStep(state=MyState(x=2, y='y'), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()),
EndStep(result=End(4), ts=IsNow(tz=timezone.utc)),
]
)
history_json = graph.dump_history(history)
history_json = graph.dump_history(graph_run.history)
assert json.loads(history_json) == snapshot(
[
{
Expand All @@ -76,7 +76,7 @@ async def test_dump_load_history(graph: Graph[MyState, None, int]):
]
)
history_loaded = graph.load_history(history_json)
assert history == history_loaded
assert graph_run.history == history_loaded

custom_history = [
{
Expand Down
6 changes: 3 additions & 3 deletions tests/graph/test_mermaid.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ async def run(self, ctx: GraphRunContext) -> Annotated[End[None], Edge(label='eg


async def test_run_graph():
result, history = await graph1.run(Foo())
assert result is None
assert history == snapshot(
graph_run = await graph1.run(Foo())
assert graph_run.result is None
assert graph_run.history == snapshot(
[
NodeStep(
state=None,
Expand Down
6 changes: 3 additions & 3 deletions tests/graph/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ async def run(self, ctx: GraphRunContext[MyState]) -> End[str]:
assert graph._get_state_type() is MyState
assert graph._get_run_end_type() is str
state = MyState(1, '')
result, history = await graph.run(Foo(), state=state)
assert result == snapshot('x=2 y=y')
assert history == snapshot(
graph_run = await graph.run(Foo(), state=state)
assert graph_run.result == snapshot('x=2 y=y')
assert graph_run.history == snapshot(
[
NodeStep(
state=MyState(x=2, y=''),
Expand Down
6 changes: 3 additions & 3 deletions tests/typed_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,6 @@ def run_g5() -> None:
g5.run_sync(A()) # pyright: ignore[reportArgumentType]
g5.run_sync(A(), state=MyState(x=1)) # pyright: ignore[reportArgumentType]
g5.run_sync(A(), deps=MyDeps(y='y')) # pyright: ignore[reportArgumentType]
answer, history = g5.run_sync(A(), state=MyState(x=1), deps=MyDeps(y='y'))
assert_type(answer, int)
assert_type(history, list[HistoryStep[MyState, int]])
graph_run = g5.run_sync(A(), state=MyState(x=1), deps=MyDeps(y='y'))
assert_type(graph_run.result, int)
assert_type(graph_run.history, list[HistoryStep[MyState, int]])

0 comments on commit 39a6009

Please sign in to comment.