Skip to content

Commit

Permalink
extracted FunctionSchema to public module tools
Browse files Browse the repository at this point in the history
  • Loading branch information
jonchun committed Feb 13, 2025
1 parent 17df4fc commit 0a3a97d
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 27 deletions.
32 changes: 10 additions & 22 deletions pydantic_ai_slim/pydantic_ai/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations as _annotations

from inspect import Parameter, signature
from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast, get_origin
from typing import TYPE_CHECKING, Any, Callable, cast, get_origin

from pydantic import ConfigDict
from pydantic._internal import _decorators, _generate_schema, _typing_extra
Expand All @@ -20,24 +20,12 @@
from ._utils import check_object_json_schema, is_model_like

if TYPE_CHECKING:
from .tools import DocstringFormat, ObjectJsonSchema
from .tools import DocstringFormat, FunctionSchema


__all__ = ('function_schema',)


class FunctionSchema(TypedDict):
"""Internal information about a function schema."""

description: str
validator: SchemaValidator
json_schema: ObjectJsonSchema
# if not None, the function takes a single by that name (besides potentially `info`)
single_arg_name: str | None
positional_fields: list[str]
var_positional_field: str | None


def function_schema( # noqa: C901
function: Callable[..., Any],
takes_ctx: bool,
Expand Down Expand Up @@ -161,14 +149,14 @@ def function_schema( # noqa: C901
# and set it on the tool
description = json_schema.pop('description', None)

return FunctionSchema(
description=description,
validator=schema_validator,
json_schema=check_object_json_schema(json_schema),
single_arg_name=single_arg_name,
positional_fields=positional_fields,
var_positional_field=var_positional_field,
)
return {
'description': description,
'validator': schema_validator,
'json_schema': check_object_json_schema(json_schema),
'single_arg_name': single_arg_name,
'positional_fields': positional_fields,
'var_positional_field': var_positional_field,
}


def takes_ctx(function: Callable[..., Any]) -> bool:
Expand Down
15 changes: 14 additions & 1 deletion pydantic_ai_slim/pydantic_ai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import inspect
from collections.abc import Awaitable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypedDict, Union, cast

from pydantic import ValidationError
from pydantic_core import SchemaValidator
Expand All @@ -19,6 +19,7 @@
__all__ = (
'AgentDepsT',
'DocstringFormat',
'FunctionSchema',
'RunContext',
'SystemPromptFunc',
'ToolFuncContext',
Expand All @@ -35,6 +36,18 @@
"""Type variable for agent dependencies."""


class FunctionSchema(TypedDict):
"""Internal information about a function schema."""

description: str
validator: SchemaValidator
json_schema: ObjectJsonSchema
# if not None, the function takes a single by that name (besides potentially `info`)
single_arg_name: str | None
positional_fields: list[str]
var_positional_field: str | None


@dataclasses.dataclass
class RunContext(Generic[AgentDepsT]):
"""Information about the current call."""
Expand Down
8 changes: 4 additions & 4 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,10 @@ def plain_tool(x: int) -> int:

def test_init_tool_with_function_schema():
def x_tool(x: int) -> None:
pass
raise NotImplementedError

def y_tool(y: str) -> None:
pass
raise NotImplementedError

y_fs = _pydantic.function_schema(
y_tool, takes_ctx=False, docstring_format='auto', require_parameter_descriptions=False
Expand All @@ -370,10 +370,10 @@ def y_tool(y: str) -> None:

def test_init_tool_ctx_with_function_schema():
def x_tool(ctx: RunContext[int], x: int) -> None:
pass
raise NotImplementedError

def y_tool(ctx: RunContext[int], y: str) -> None:
pass
raise NotImplementedError

y_fs = _pydantic.function_schema(
y_tool, takes_ctx=True, docstring_format='auto', require_parameter_descriptions=False
Expand Down

0 comments on commit 0a3a97d

Please sign in to comment.