Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph Support #528

Draft
wants to merge 45 commits into
base: main
Choose a base branch
from
Draft

Graph Support #528

wants to merge 45 commits into from

Conversation

samuelcolvin
Copy link
Member

@samuelcolvin samuelcolvin commented Dec 23, 2024

TODO:

  • nodes via decorator impossible without HKT
  • infer graph name
  • tests
  • docs
  • examples

This is a work in progress, it's the result of a lot of discussion with @dmontagu.

The idea is to provide a graph/state machine library to use with PydanticAI that is as type-safe as possible in python.

NOTE: the vast majority of multi-agent examples I've seen to not need a graph or state machine, and would be more complex to write and understand if written using one. You should only use this functionality if:

  1. you understand how to use Agent as tools
  2. you've tried using standard programming techniques to link agents
  3. after that you're still sure you need a graph library and state machine

In particular this means we define edges (which nodes in a graph can breached from any given node) using type annotations, rather than some separate set_edges mechanism.

To do this we define nodes as types (that must inherit from BaseNode), to route the graph to (say) NodeB, NodeB will return an instance of NodeB which holders the input data to NodeB. Similarly to end a run, nodes should return End.

We inspect the return annotation of the run method on nodes to build the graph.

here's a minimal simple example:

Code
from __future__ import annotations as _annotations

from dataclasses import dataclass

from pydantic_ai_graph import Graph, BaseNode, End


@dataclass
class NodeA(BaseNode[None]):
    apple: int

    async def run(self, ctx) -> NodeB:
        return NodeB(self.apple / 2)


@dataclass
class NodeB(BaseNode[None]):
    banana: float

    async def run(self, ctx) -> NodeC:
        return NodeC((int(self.banana), int(self.banana) + 4))


@dataclass
class NodeC(BaseNode[None, float]):
    pair: tuple[int, int]

    async def run(self, ctx) -> NodeA | End[float]:
        v1, v2 = self.pair
        if v1 + v2 > 10:
            return End((v1 + v2) / 3)
        else:
            return NodeA(v1 + v2)


graph = Graph(nodes=(NodeA, NodeB, NodeC))
runner_node_a = graph.get_runner(NodeA)
print(runner_node_a.mermaid_code())


async def main():
    result, history = await runner_node_a.run(None, 8)
    print(result)
    print('\n'.join(map(str, history)))
    result, history = await runner_node_a.run(None, 7)
    print(result)
    print('\n'.join(map(str, history)))


if __name__ == '__main__':
    import asyncio
    asyncio.run(main())

The mermaid chart printed in the example looks like this:

graph TD
  START --> NodeA
  NodeA --> NodeB
  NodeB --> NodeC
  NodeC --> NodeA
  NodeC --> END
Loading

And the rest of the output is:

4.0
NodeA -> NodeB
NodeB -> NodeC
NodeC -> END
4.666666666666667
NodeA -> NodeB
NodeB -> NodeC
NodeC -> NodeA
NodeA -> NodeB
NodeB -> NodeC
NodeC -> END

The graph library is completely independent of LLM use cases, but can relatively easily be used with pydantic-ai's Agent, see the examples/pydantic_ai_examples/email_extract_graph.py example.

Copy link

cloudflare-workers-and-pages bot commented Dec 23, 2024

Deploying pydantic-ai with  Cloudflare Pages  Cloudflare Pages

Latest commit: 5717bd5
Status: ✅  Deploy successful!
Preview URL: https://f58aa41d.pydantic-ai.pages.dev
Branch Preview URL: https://graph.pydantic-ai.pages.dev

View logs

This comment was marked as off-topic.

@brettkromkamp
Copy link

brettkromkamp commented Jan 7, 2025

In my opinion this feature is critical for adoption of the PydanticAI framework. Any timeframe when this will land in main, @samuelcolvin?

I very much like the approach of using type annotations and returns instead of a separate set_edge mechanism (described above)... really gives me a LlamaIndex Workflows vibe as opposed to the more complex LangGraph approach.

@samuelcolvin
Copy link
Member Author

@brettkromkamp we'll do our best to get something merged and released this week.

@samuelcolvin samuelcolvin force-pushed the graph branch 2 times, most recently from c52885b to e7b3949 Compare January 7, 2025 21:41
@samuelcolvin
Copy link
Member Author

samuelcolvin commented Jan 9, 2025

I've added support for Interrupt to interrupt a run and continue it from the right place.

I've remove Interrupt and instead added next() method which provides the same functionality, but is more flexible and easier to understand.

Here's an example using it:

from __future__ import annotations as _annotations

from dataclasses import dataclass
from typing import Annotated

import logfire

from pydantic_ai import Agent
from pydantic_ai.messages import ModelMessage
from pydantic_ai.format_as_xml import format_as_xml
from pydantic_graph import Graph, BaseNode, End, GraphContext, AbstractState, Edge

logfire.configure()
ask_agent = Agent('openai:gpt-4o', result_type=str)


@dataclass
class QuestionState(AbstractState):
    ask_agent_messages: list[ModelMessage] | None = None

    def serialize(self) -> bytes | None:
        raise NotImplementedError('TODO')


@dataclass
class Ask(BaseNode[QuestionState]):
    """Generate a question to ask the user.

    Uses the GPT-4o model to generate a question.
    """
    async def run(self, ctx: GraphContext[QuestionState]) -> Annotated[Answer, Edge(label='ask the question')]:
        result = await ask_agent.run(
            'Ask a simple question with a single correct answer.', message_history=ctx.state.ask_agent_messages
        )
        if ctx.state.ask_agent_messages is None:
            ctx.state.ask_agent_messages = []
        ctx.state.ask_agent_messages += result.all_messages()
        return Answer(result.data)


@dataclass
class Answer(BaseNode[QuestionState]):
    question: str
    answer: str | None = None

    async def run(self, ctx: GraphContext[QuestionState]) -> Annotated[Evaluate, Edge(label='answer the question')]:
        assert self.answer is not None
        return Evaluate(self.question, self.answer)


@dataclass
class EvaluationResult:
    correct: bool
    comment: str


evaluate_agent = Agent(
    'openai:gpt-4o',
    result_type=EvaluationResult,
    system_prompt='Given a question and answer, evaluate if the answer is correct.',
    result_tool_name='evaluation',
)


@dataclass
class Evaluate(BaseNode[QuestionState]):
    question: str
    answer: str

    async def run(self, ctx: GraphContext[QuestionState]) -> Congratulate | Castigate:
        result = await evaluate_agent.run(format_as_xml({'question': self.question, 'answer': self.answer}))
        if result.data.correct:
            return Congratulate(result.data.comment)
        else:
            return Castigate(result.data.comment)


@dataclass
class Congratulate(BaseNode[QuestionState, None]):
    comment: str

    async def run(self, ctx: GraphContext[QuestionState]) -> End:
        print(f'Correct answer! {self.comment}')
        return End(None)


@dataclass
class Castigate(BaseNode[QuestionState]):
    comment: str

    async def run(self, ctx: GraphContext[QuestionState]) -> Ask:
        print(f'Comment: {self.comment}')
        return Ask()


graph = Graph(nodes=(Ask, Answer, Evaluate, Congratulate, Castigate))


@dataclass
class QuestionState(AbstractState):
    ask_agent_messages: list[ModelMessage] | None = None

    def serialize(self) -> bytes | None:
        raise NotImplementedError('TODO')


print(graph.mermaid_code(start_node=Ask))
graph.mermaid_save('questions_graph.svg', start_node=Ask)


async def main():
    node = Ask()
    state = QuestionState()
    history = []
    with logfire.span('run questions graph'):
        while True:
            node = await graph.next(state, node, history)
            if isinstance(node, End):
                print('\n'.join(e.summary() for e in history))
                break
            elif isinstance(node, Answer):
                node.answer = input(f'{node.question} ')
            # otherwise just continue


if __name__ == '__main__':
    import asyncio

    asyncio.run(main())

Which has the following graph:

stateDiagram-v2
  [*] --> Ask
  Ask --> Answer: ask the question
  note right of Ask
    Generate a question to ask the user.
    Uses the GPT-4o model to generate a question.
  end note
  Answer --> Evaluate: answer the question
  Evaluate --> Congratulate
  Evaluate --> Castigate
  Congratulate --> [*]
  Castigate --> Ask
Loading

You'll see that the Answer node is square, not rounded to identify a point where the graph may restart after interruption. this is no longer included, I don't think that matters.

@brettkromkamp
Copy link

Just wondering if the new interrupt mechanism can be used for HITL-purposes? Or, is it more for retrying steps in case of failures. It could also be a general mechanism for all kinds of purposes. I'll take a closer look... definitely exciting to see how this feature is developing, though.

@dmontagu
Copy link
Contributor

dmontagu commented Jan 9, 2025

Just wondering if the new interrupt mechanism can be used for HITL-purposes? Or, is it more for retrying steps in case of failures. It could also be a general mechanism for all kinds of purposes. I'll take a closer look... definitely exciting to see how this feature is developing, though.

It is definitely explicitly and primarily intended for facilitating HITL; if it's useful for other purposes then of course that's great but most of the discussion we've been having about the feature has been oriented around how to use it for human feedback.

@@ -25,15 +25,19 @@ def serialize(self) -> bytes | None:
"""Serialize the state object."""
raise NotImplementedError
Copy link
Contributor

@dmontagu dmontagu Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can eliminate this AbstractState type by moving the serialization and/or copying logic to be kwargs of the graph, and if not provided, use copy.deepcopy (or noop if None as you've done) for copying, and pydantic_core.to_json for serialization. That would let you use a typical basemodel/dataclass/typeddict as state with minimal boilerplate.

(Because the graph is aware of the state type, we can still use type hints on the kwargs like serializer: Callable[[StateT], bytes] to get the same type safety you'd get from a method.)

@samuelcolvin
Copy link
Member Author

Just wondering if the new interrupt mechanism can be used for HITL-purposes? Or, is it more for retrying steps in case of failures. It could also be a general mechanism for all kinds of purposes. I'll take a closer look... definitely exciting to see how this feature is developing, though.

@brettkromkamp I've removed Interrupt and replaced with what we think is a better API, I've updated my example above.

@ME-Msc
Copy link
Contributor

ME-Msc commented Jan 10, 2025

Hi, team. I have some questions about graph support.

  1. It seems that a node's multiple out-edges refer to different transition conditions. Will it be possible to see the condition on the mermaid graph?
  2. Is there any support for running the subsequent nodes in parallel (mentioned in Functionality to Define Multi-Agent Graphs and Workflows #529 by @izzyacademy ) ?

Copy link
Contributor

@hyperlint-ai hyperlint-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The style guide flagged several spelling errors that seemed like false positives. We skipped posting inline suggestions for the following words:

  • Pydantic

@samuelcolvin
Copy link
Member Author

  1. It seems that a node's multiple out-edges refer to different transition conditions. Will it be possible to see the condition on the mermaid graph?

Hi @ME-Msc, I'm not exactly sure what you mean here?

I'm going to provide a way to label an edge, but you won't be able to "see the logic" that leads to an edge being followed, as that's just procedural python code.

  1. Is there any support for running the subsequent nodes in parallel (mentioned in Functionality to Define Multi-Agent Graphs and Workflows #529 by @izzyacademy ) ?

Not yet, we might add it in future.

@izzyacademy
Copy link

@samuelcolvin I think I have an idea of what he is asking. I had similar thoughts earlier.

It seems @ME-Msc is looking for a mechanism to annotate (within the docstring) the pydantic_graph.BaseNode.run() method with a small note/comment to indicate what condition causes this node to route to the next node returned by this node. It looks like we could parse the docstring for a special tag or something from BaseNode.run() to get a list of conditions and then inject this into the mermaid code generated so that it shows up in the graph image generated.

@ME-Msc I think the goal of the project is to avoid fancy syntax that does not give you visibility into how the parallel nodes/tasks are run. What I would recommend is for you to dedicate a node that can aggregate all the parallel tasks and then spin up async tasks in that node using regular python code that you have 100% visibility and control so that you can see the exceptions, cancellations etc without having to get stressed out when things deviate from the happy path. I hope this helps.

I am working on an example for this because I think many users will have the similar questions/needs based on how they are using other frameworks with custom syntax for routing to parallel nodes in graph transitions

Copy link

@izzyacademy izzyacademy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the hard work on implementing this. A lot of folks in the community have been waiting for this.

pydantic_graph/pydantic_graph/graph.py Outdated Show resolved Hide resolved
@ME-Msc
Copy link
Contributor

ME-Msc commented Jan 11, 2025

@samuelcolvin Thanks for your reply. I'm sure the new example code has solved my first question. What I want is just the annotation of the edge.

async def run(self, ctx: GraphContext[QuestionState]) -> Annotated[Evaluate, Edge(label='answer the question')]:
   assert self.answer is not None
   return Evaluate(self.question, self.answer)

@ME-Msc
Copy link
Contributor

ME-Msc commented Jan 11, 2025

@izzyacademy Thank you for your explanation. As for the second question, I am afraid that aggregating all the parallel tasks in one node only by regular Python code (asyncio, multi-threads, etc.) cannot fix it elegantly. There are 3 questions I would like to suggest you consider in the example.

  1. What if the parallel tasks are executed by different pydanticAI-agents?
  2. How should you manage it if each parallel task has a different number of steps?
  3. How do you handle the situation where one step of a parallel task needs to jump to another node(or another step of another node) outside of this node?

I'm looking forward to your example very much!

@samuelcolvin
Copy link
Member Author

@izzyacademy Thank you for your explanation. As for the second question, I am afraid that aggregating all the parallel tasks in one node only by regular Python code (asyncio, multi-threads, etc.) cannot fix it elegantly. There are 3 questions I would like to suggest you consider in the example.

We might add support for running multiple nodes at once in future, but I think what we have now is already pretty powerful

  1. What if the parallel tasks are executed by different pydanticAI-agents?

That should be fine, you can run multiple agents simultaneously just as easily as running the same agent simultaneously.

  1. How should you manage it if each parallel task has a different number of steps?

You can use all the existing tools in the python toolbox to run multiple tasks in parallel

  1. How do you handle the situation where one step of a parallel task needs to jump to another node(or another step of another node) outside of this node?

That won't work yet.

@izzyacademy
Copy link

@ME-Msc Thanks for your follow up. @samuelcolvin thanks for responding.

In my personal opinion, it is always best to keep things simple in the design. It makes my life easier as an architect/engineer in production.

While a lot of frameworks may try to give you cool syntax to run nodes in your graph concurrently, in the end what the end up doing is using thread pools to run the cpu-bound tasks or asyncio to run the i/o-bound tasks. The only downside is that you do not have much visibility into what is happening or how that is implemented and in production, that is generally not a good idea.

If you also take a close look at the scenarios, they all gather at some point for every node in the parallel execution to complete or fail before moving on and I don't see hops or jumps in a well designed use case.

I believe that the current mechanism of implementing state transitions with pydantic-graph still allows you to run multiple nodes in parallel even if you have multiple agents running concurrently or only included if certain conditions are true.

So my approach will be to design you graph transitions such that if you need to run multiple "nodes" in parallel (fixed list or variable list) you can compile a dynamic/static list of Coroutines and then gather for them to complete or fail with asycnio and then you decide what to do next from there for the (gathering point node). This also allows you to cancel any task you no longer need with full control and visibility.

If they are cpu-bound you can use threads or other mechanisms to execute these "nodes" within the (gathering point node) and then proceed from there.

My examples will illustrate these.

I have travel agent scenarios where my agent/s will need to (after my airline or train reservation is booked to a distant city) make inquiry about availability for car rental, hotels, spa and city tours concurrently.

This variable list of tasks vary depending on the customer profile/preferences.

The customer may only need one or all of the searches.

All these searches must be done in parallel and then when we find out their results (pass or fail) before we can then proceed to charging the customer payment method for the total price and then sending out a confirmation message.

My point, at this time, is to keep things simple, avoid complex/fancy syntax and have full visibility/control into the inner workings of your code and flow so that you don't have to shave your head while maintaining the app in production :)

I hope this helps.

Thanks for the questions.

I am very excited about this addition to the project and it will definitely make a lot of lives easier when it lands in main and is released to the community.

Thanks Pydantic Team for an amazing effort in the design and implementation of this new capability.

Copy link
Contributor

@hyperlint-ai hyperlint-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The style guide flagged several spelling errors that seemed like false positives. We skipped posting inline suggestions for the following words:

  • FSMs

Copy link
Contributor

@hyperlint-ai hyperlint-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 1 total issue(s) found.

docs/graph.md Outdated
10. If the balance is enough to purchase the product, adjust the balance to reflect the purchase and return [`End`][pydantic_graph.nodes.End] to end the graph. We're not using the run return type, so we call `End` with `None`.
11. If the balance is insufficient, to go `InsertCoin` to prompt the user to insert more coins.
12. If the product is invalid, go to `SelectProduct` to prompt the user to select a product again.
13. The graph is created by passing a list of nodes to [`Graph`][pydantic_graph.graph.Graph]. Order of nodes is not important, but will alter how [diagramss](#mermaid-diagrams) are displayed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
13. The graph is created by passing a list of nodes to [`Graph`][pydantic_graph.graph.Graph]. Order of nodes is not important, but will alter how [diagramss](#mermaid-diagrams) are displayed.
13. The graph is created by passing a list of nodes to [`Graph`][pydantic_graph.graph.Graph]. Order of nodes is not important, but will alter how [diagrams](#mermaid-diagrams) are displayed.

Issues:

  • Style Guide - (Spelling-error) Did you really mean 'diagramss'?

Fix Explanation:

Corrected the spelling error from 'diagramss' to 'diagrams'.

Copy link
Contributor

@hyperlint-ai hyperlint-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 2 total issue(s) found.

The style guide flagged several spelling errors that seemed like false positives. We skipped posting inline suggestions for the following words:

  • [Aa]sync
  • [Dd]ataclass

docs/graph.md Outdated Show resolved Hide resolved
docs/graph.md Outdated Show resolved Hide resolved
Copy link
Contributor

@hyperlint-ai hyperlint-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 1 total issue(s) found.

docs/graph.md Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants