Skip to content

Commit

Permalink
Remove state generic from graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Jan 31, 2025
1 parent 6c00d1d commit 48e15fc
Show file tree
Hide file tree
Showing 14 changed files with 322 additions and 457 deletions.
11 changes: 0 additions & 11 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,6 @@ docs-insiders: .docs-insiders-install ## Build the documentation using insiders
docs-serve-insiders: .docs-insiders-install ## Build and serve the documentation using insiders packages
uv run --no-sync mkdocs serve -f mkdocs.insiders.yml

.PHONY: cf-pages-build
cf-pages-build: ## Install uv, install dependencies and build the docs, used on CloudFlare Pages
curl -LsSf https://astral.sh/uv/install.sh | sh
uv python install 3.12
uv sync --python 3.12 --frozen --group docs
uv pip install --reinstall --no-deps \
--extra-index-url https://pydantic:${PPPR_TOKEN}@pppr.pydantic.dev/simple/ \
mkdocs-material mkdocstrings-python
uv pip freeze
uv run --no-sync mkdocs build -f mkdocs.insiders.yml

.PHONY: all
all: format lint typecheck testcov ## Run code formatting, linting, static type checks, and tests with coverage report generation

Expand Down
83 changes: 44 additions & 39 deletions examples/pydantic_ai_examples/question_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from dataclasses import dataclass, field
from pathlib import Path
from typing import Annotated
from typing import Annotated, Union, cast

import logfire
from devtools import debug
Expand All @@ -22,7 +22,7 @@
# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured
logfire.configure(send_to_logfire='if-token-present')

ask_agent = Agent('openai:gpt-4o', result_type=str)
ask_agent = Agent('openai:gpt-4o')


@dataclass
Expand All @@ -33,24 +33,27 @@ class QuestionState:


@dataclass
class Ask(BaseNode[QuestionState]):
async def run(self, ctx: GraphRunContext[QuestionState]) -> Answer:
class Ask(BaseNode):
state: QuestionState

async def run(self, ctx: GraphRunContext) -> Answer:
result = await ask_agent.run(
'Ask a simple question with a single correct answer.',
message_history=ctx.state.ask_agent_messages,
message_history=self.state.ask_agent_messages,
)
ctx.state.ask_agent_messages += result.all_messages()
ctx.state.question = result.data
return Answer()
self.state.ask_agent_messages += result.all_messages()
self.state.question = result.data
return Answer(self.state)


@dataclass
class Answer(BaseNode[QuestionState]):
class Answer(BaseNode):
state: QuestionState
answer: str | None = None

async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate:
async def run(self, ctx: GraphRunContext) -> Evaluate:
assert self.answer is not None
return Evaluate(self.answer)
return Evaluate(self.state, self.answer)


@dataclass
Expand All @@ -67,61 +70,60 @@ class EvaluationResult:


@dataclass
class Evaluate(BaseNode[QuestionState]):
class Evaluate(BaseNode):
state: QuestionState
answer: str

async def run(
self,
ctx: GraphRunContext[QuestionState],
ctx: GraphRunContext,
) -> Congratulate | Reprimand:
assert ctx.state.question is not None
assert self.state.question is not None
result = await evaluate_agent.run(
format_as_xml({'question': ctx.state.question, 'answer': self.answer}),
message_history=ctx.state.evaluate_agent_messages,
format_as_xml({'question': self.state.question, 'answer': self.answer}),
message_history=self.state.evaluate_agent_messages,
)
ctx.state.evaluate_agent_messages += result.all_messages()
self.state.evaluate_agent_messages += result.all_messages()
if result.data.correct:
return Congratulate(result.data.comment)
return Congratulate(self.state, result.data.comment)
else:
return Reprimand(result.data.comment)
return Reprimand(self.state, result.data.comment)


@dataclass
class Congratulate(BaseNode[QuestionState, None, None]):
class Congratulate(BaseNode[None, None]):
state: QuestionState
comment: str

async def run(
self, ctx: GraphRunContext[QuestionState]
) -> Annotated[End, Edge(label='success')]:
async def run(self, ctx: GraphRunContext) -> Annotated[End, Edge(label='success')]:
print(f'Correct answer! {self.comment}')
return End(None)


@dataclass
class Reprimand(BaseNode[QuestionState]):
class Reprimand(BaseNode):
state: QuestionState
comment: str

async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask:
async def run(self, ctx: GraphRunContext) -> Ask:
print(f'Comment: {self.comment}')
# > Comment: Vichy is no longer the capital of France.
ctx.state.question = None
return Ask()
self.state.question = None
return Ask(self.state)


question_graph = Graph(
nodes=(Ask, Answer, Evaluate, Congratulate, Reprimand), state_type=QuestionState
)
question_graph = Graph(nodes=(Ask, Answer, Evaluate, Congratulate, Reprimand))


async def run_as_continuous():
state = QuestionState()
node = Ask()
history: list[HistoryStep[QuestionState, None]] = []
node = Ask(state)
history: list[HistoryStep[None]] = []
with logfire.span('run questions graph'):
while True:
node = await question_graph.next(node, history, state=state)
node = await question_graph.next(node, history)
if isinstance(node, End):
debug([e.data_snapshot() for e in history])
debug([e.data for e in history])
break
elif isinstance(node, Answer):
assert state.question
Expand All @@ -140,19 +142,22 @@ async def run_as_cli(answer: str | None):
if history:
last = history[-1]
assert last.kind == 'node', 'expected last step to be a node'
state = last.state
last_node = cast(
Union[Ask, Answer, Evaluate, Congratulate, Reprimand], last.node
)
state = last_node.state
assert answer is not None, 'answer is required to continue from history'
node = Answer(answer)
node = Answer(state, answer)
else:
state = QuestionState()
node = Ask()
node = Ask(state)
debug(state, node)

with logfire.span('run questions graph'):
while True:
node = await question_graph.next(node, history, state=state)
node = await question_graph.next(node, history)
if isinstance(node, End):
debug([e.data_snapshot() for e in history])
debug([e.data for e in history])
print('Finished!')
break
elif isinstance(node, Answer):
Expand Down
Loading

0 comments on commit 48e15fc

Please sign in to comment.