diff --git a/docs/agents.md b/docs/agents.md index 19da9f3a2..d65548da5 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -327,7 +327,7 @@ Running `mypy` on this will give the following output: ```bash ➤ uv run mypy type_mistakes.py -type_mistakes.py:18: error: Argument 1 to "system_prompt" of "Agent" has incompatible type "Callable[[RunContext[str]], str]"; expected "Callable[[RunContext[User]], str]" [arg-type] +type_mistakes.py:18: error: Argument 1 to "system_prompt" of "Agent" has incompatible type "Callable[[RunContext[str]], str | None]"; expected "Callable[[RunContext[User]], str | None]" [arg-type] type_mistakes.py:28: error: Argument 1 to "foobar" has incompatible type "bool"; expected "bytes" [arg-type] Found 2 errors in 1 file (checked 1 source file) ``` @@ -344,6 +344,7 @@ Generally, system prompts fall into two categories: 2. **Dynamic system prompts**: These depend in some way on context that isn't known until runtime, and should be defined via functions decorated with [`@agent.system_prompt`][pydantic_ai.Agent.system_prompt]. You can add both to a single agent; they're appended in the order they're defined at runtime. +If a dynamic system prompt function returns `None`, or any empty value, its prompt part won't be added to the messages. Here's an example using both types of system prompts: diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 60a5b3f97..3afbaa14d 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -172,15 +172,18 @@ async def _reevaluate_dynamic_prompts( # Look up the runner by its ref if runner := self.system_prompt_dynamic_functions.get(part.dynamic_ref): updated_part_content = await runner.run(run_context) - msg.parts[i] = _messages.SystemPromptPart( - updated_part_content, dynamic_ref=part.dynamic_ref - ) + if updated_part_content: + msg.parts[i] = _messages.SystemPromptPart( + updated_part_content, dynamic_ref=part.dynamic_ref + ) async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.ModelRequestPart]: """Build the initial messages for the conversation.""" messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self.system_prompts] for sys_prompt_runner in self.system_prompt_functions: prompt = await sys_prompt_runner.run(run_context) + if not prompt: + continue if sys_prompt_runner.dynamic: messages.append(_messages.SystemPromptPart(prompt, dynamic_ref=sys_prompt_runner.function.__qualname__)) else: diff --git a/pydantic_ai_slim/pydantic_ai/_system_prompt.py b/pydantic_ai_slim/pydantic_ai/_system_prompt.py index aca308001..2520aff30 100644 --- a/pydantic_ai_slim/pydantic_ai/_system_prompt.py +++ b/pydantic_ai_slim/pydantic_ai/_system_prompt.py @@ -20,15 +20,15 @@ def __post_init__(self): self._takes_ctx = len(inspect.signature(self.function).parameters) > 0 self._is_async = inspect.iscoroutinefunction(self.function) - async def run(self, run_context: RunContext[AgentDepsT]) -> str: + async def run(self, run_context: RunContext[AgentDepsT]) -> str | None: if self._takes_ctx: args = (run_context,) else: args = () if self._is_async: - function = cast(Callable[[Any], Awaitable[str]], self.function) + function = cast(Callable[[Any], Awaitable['str | None']], self.function) return await function(*args) else: - function = cast(Callable[[Any], str], self.function) + function = cast(Callable[[Any], 'str | None'], self.function) return await _utils.run_in_executor(function, *args) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 3501833d2..5c5e98700 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -645,19 +645,27 @@ def override( @overload def system_prompt( - self, func: Callable[[RunContext[AgentDepsT]], str], / - ) -> Callable[[RunContext[AgentDepsT]], str]: ... + self, + func: Callable[[RunContext[AgentDepsT]], str | None], + /, + ) -> Callable[[RunContext[AgentDepsT]], str | None]: ... @overload def system_prompt( - self, func: Callable[[RunContext[AgentDepsT]], Awaitable[str]], / - ) -> Callable[[RunContext[AgentDepsT]], Awaitable[str]]: ... + self, + func: Callable[[RunContext[AgentDepsT]], Awaitable[str | None]], + /, + ) -> Callable[[RunContext[AgentDepsT]], Awaitable[str | None]]: ... @overload - def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ... + def system_prompt(self, func: Callable[[], str | None], /) -> Callable[[], str | None]: ... @overload - def system_prompt(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ... + def system_prompt( + self, + func: Callable[[], Awaitable[str | None]], + /, + ) -> Callable[[], Awaitable[str | None]]: ... @overload def system_prompt( diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 0b9a3660e..945c7dcf2 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -72,10 +72,10 @@ def replace_with( """Retrieval function param spec.""" SystemPromptFunc = Union[ - Callable[[RunContext[AgentDepsT]], str], - Callable[[RunContext[AgentDepsT]], Awaitable[str]], - Callable[[], str], - Callable[[], Awaitable[str]], + Callable[[RunContext[AgentDepsT]], Union[str, None]], + Callable[[RunContext[AgentDepsT]], Awaitable[Union[str, None]]], + Callable[[], Union[str, None]], + Callable[[], Awaitable[Union[str, None]]], ] """A function that may or maybe not take `RunContext` as an argument, and may or may not be async. diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index f30a5c14b..d0ec1d08b 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -3,6 +3,7 @@ from collections.abc import AsyncIterator from dataclasses import asdict from datetime import timezone +from typing import Union import pydantic_core import pytest @@ -312,10 +313,20 @@ def quz(x) -> str: # pyright: ignore[reportUnknownParameterType,reportMissingPa @agent_all.system_prompt -def spam() -> str: +def spam() -> Union[str, None]: return 'foobar' +@agent_all.system_prompt +def empty1() -> Union[str, None]: + return None + + +@agent_all.system_prompt +def empty2() -> Union[str, None]: + return '' + + def test_register_all(): async def f(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: return ModelResponse( diff --git a/tests/test_agent.py b/tests/test_agent.py index 1a1091959..2ba219945 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1330,9 +1330,13 @@ def test_dynamic_false_no_reevaluate(): dynamic_value = 'A' @agent.system_prompt - async def func() -> str: + async def func() -> Union[str, None]: return dynamic_value + @agent.system_prompt + async def empty_func() -> Union[str, None]: + return None + res = agent.run_sync('Hello') assert res.all_messages() == snapshot( @@ -1405,6 +1409,10 @@ def test_dynamic_true_reevaluate_system_prompt(): async def func(): return dynamic_value + @agent.system_prompt(dynamic=True) + async def empty_func(): + return None + res = agent.run_sync('Hello') assert res.all_messages() == snapshot( diff --git a/tests/typed_agent.py b/tests/typed_agent.py index fdf9f1a25..fdf158195 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -23,7 +23,7 @@ class MyDeps: @typed_agent.system_prompt -async def system_prompt_ok1(ctx: RunContext[MyDeps]) -> str: +async def system_prompt_ok1(ctx: RunContext[MyDeps]) -> Union[str, None]: return f'{ctx.deps}' @@ -32,9 +32,15 @@ def system_prompt_ok2() -> str: return 'foobar' +@typed_agent.system_prompt +def system_prompt_ok3() -> Union[str, None]: + return None + + # we have overloads for every possible signature of system_prompt, so the type of decorated functions is correct -assert_type(system_prompt_ok1, Callable[[RunContext[MyDeps]], Awaitable[str]]) -assert_type(system_prompt_ok2, Callable[[], str]) +assert_type(system_prompt_ok1, Callable[[RunContext[MyDeps]], Awaitable[Union[str, None]]]) +assert_type(system_prompt_ok2, Callable[[], Union[str, None]]) +assert_type(system_prompt_ok3, Callable[[], Union[str, None]]) @contextmanager