diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 0156bd77b..c57598949 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -30,7 +30,7 @@ permissions: jobs: run_tests: - name: ${{ matrix.test-type }} w/ python ${{ matrix.python-version }} on ${{ matrix.os }} + name: ${{ matrix.test-type }} w/ python ${{ matrix.python-version }} | pydantic ${{ matrix.pydantic_version }} on ${{ matrix.os }} timeout-minutes: 15 strategy: matrix: @@ -38,12 +38,26 @@ jobs: python-version: ['3.9', '3.10', '3.11'] test-type: ['not llm'] llm_model: ['openai/gpt-3.5-turbo'] + pydantic_version: ['>=2.4.2'] include: - python-version: '3.9' os: 'ubuntu-latest' test-type: 'llm' llm_model: 'openai/gpt-3.5-turbo' + pydantic_version: '>=2.4.2' + + - python-version: '3.9' + os: 'ubuntu-latest' + test-type: 'llm' + llm_model: 'openai/gpt-3.5-turbo' + pydantic_version: '<2' + + - python-version: '3.9' + os: 'ubuntu-latest' + test-type: 'not llm' + llm_model: 'openai/gpt-3.5-turbo' + pydantic_version: '<2' runs-on: ${{ matrix.os }} @@ -60,6 +74,11 @@ jobs: cache: "pip" - name: Install Marvin run: pip install ".[tests]" + + - name: Install pydantic + run: pip install "pydantic${{ matrix.pydantic_version }}" + + - name: Run ${{ matrix.test-type }} tests (${{ matrix.llm_model }}) run: pytest -vv -m "${{ matrix.test-type }}" env: diff --git a/pyproject.toml b/pyproject.toml index f5405dbb6..3354339a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dev = [ "mkdocstrings[python]~=0.22", "pdbpp~=0.10", "pre-commit>=2.21,<4.0", + "pydantic", "ruff", ] tests = [ @@ -48,8 +49,8 @@ tests = [ "pytest-rerunfailures>=10,<13", "pytest-sugar~=0.9", "pytest~=7.3.1", - "pydantic-settings>=2.0.0", ] + framework = [ "aiosqlite>=0.19.0", "alembic>=1.11.1", diff --git a/src/marvin/_compat.py b/src/marvin/_compat.py index c09e9ba7c..98066d76b 100644 --- a/src/marvin/_compat.py +++ b/src/marvin/_compat.py @@ -10,36 +10,45 @@ get_origin, ) -from pydantic import BaseModel, create_model from pydantic.version import VERSION as PYDANTIC_VERSION -_ModelT = TypeVar("_ModelT", bound="BaseModel") - PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") if PYDANTIC_V2: from pydantic.v1 import ( BaseSettings, - Field, + PrivateAttr, SecretStr, + ValidationError, validate_arguments, ) SettingsConfigDict = BaseSettings.Config - from pydantic import field_validator # noqa # type: ignore + from pydantic import ( + BaseModel, + Field, + create_model, + field_validator, + ) else: from pydantic import ( # noqa # type: ignore BaseSettings, + BaseModel, + create_model, Field, SecretStr, validate_arguments, validator as field_validator, + ValidationError, + PrivateAttr, ) SettingsConfigDict = BaseSettings.Config +_ModelT = TypeVar("_ModelT", bound=BaseModel) + def model_dump(model: _ModelT, **kwargs: Any) -> dict[str, Any]: if PYDANTIC_V2 and hasattr(model, "model_dump"): diff --git a/src/marvin/components/ai_application.py b/src/marvin/components/ai_application.py index 4cdbedebc..bb495e959 100644 --- a/src/marvin/components/ai_application.py +++ b/src/marvin/components/ai_application.py @@ -3,10 +3,9 @@ from typing import Any, Callable, Optional, Union from jsonpatch import JsonPatch -from pydantic import BaseModel, Field, validator import marvin -from marvin._compat import PYDANTIC_V2, model_dump +from marvin._compat import PYDANTIC_V2, BaseModel, Field, field_validator, model_dump from marvin.core.ChatCompletion.providers.openai import get_context_size from marvin.openai import ChatCompletion from marvin.prompts import library as prompt_library @@ -218,17 +217,19 @@ class AIApplication(LoggerMixin, MarvinBaseModel): state_enabled: bool = True plan_enabled: bool = True - @validator("description") + @field_validator("description") def validate_description(cls, v): return inspect.cleandoc(v) - @validator("additional_prompts") + @field_validator("additional_prompts") def validate_additional_prompts(cls, v): if v is None: v = [] return v - @validator("tools", pre=True, always=True) + @field_validator( + "tools", **(dict(pre=True, always=True) if not PYDANTIC_V2 else {}) + ) def validate_tools(cls, v): if v is None: v = [] @@ -245,7 +246,7 @@ def validate_tools(cls, v): raise ValueError(f"Tool {tool} is not a `Tool` or callable.") return tools - @validator("name", always=True) + @field_validator("name") def validate_name(cls, v): if v is None: v = cls.__name__ @@ -351,8 +352,8 @@ class JSONPatchModel( op: str path: str - value: Union[str, float, int, bool, list, dict] = None - from_: str = Field(None, alias="from") + value: Union[str, float, int, bool, list, dict, None] = None + from_: Optional[str] = Field(None, alias="from") class UpdateState(Tool): diff --git a/src/marvin/components/ai_classifier.py b/src/marvin/components/ai_classifier.py index 98e0a69d2..88260b7a4 100644 --- a/src/marvin/components/ai_classifier.py +++ b/src/marvin/components/ai_classifier.py @@ -4,9 +4,9 @@ from functools import partial from typing import Any, Callable, Literal, Optional, TypeVar -from pydantic import BaseModel, Field from typing_extensions import ParamSpec, Self +from marvin._compat import BaseModel, Field from marvin.core.ChatCompletion import ChatCompletion from marvin.core.ChatCompletion.abstract import AbstractChatCompletion from marvin.prompts import Prompt, prompt_fn diff --git a/src/marvin/components/ai_function.py b/src/marvin/components/ai_function.py index dd1287a21..1a2aca2ff 100644 --- a/src/marvin/components/ai_function.py +++ b/src/marvin/components/ai_function.py @@ -3,9 +3,9 @@ from functools import partial from typing import Any, Awaitable, Callable, Generic, Optional, TypeVar, Union -from pydantic import BaseModel, Field from typing_extensions import ParamSpec, Self +from marvin._compat import BaseModel, Field from marvin.core.ChatCompletion import ChatCompletion from marvin.core.ChatCompletion.abstract import AbstractChatCompletion from marvin.prompts import Prompt, prompt_fn @@ -42,7 +42,7 @@ def prompt_wrapper(*args: P.args, **kwargs: P.kwargs) -> None: # type: ignore # {{'def' + ''.join(inspect.getsource(func).split('def')[1:])}} The user will provide function inputs (if any) and you must respond with - the most likely result. + the most likely result, which must be valid, double-quoted JSON. User: The function was called with the following inputs: {% set sig = inspect.signature(func) %} @@ -118,20 +118,29 @@ def call( *args: P.args, **kwargs: P.kwargs, ) -> Any: - return getattr( - self.as_chat_completion(*args, **kwargs).create().to_model(), - self.response_model_field_name or "output", - ) + model_instance = self.as_chat_completion(*args, **kwargs).create().to_model() + response_model_field_name = self.response_model_field_name or "output" + + if not (output := getattr(model_instance, response_model_field_name, None)): + return model_instance + + return output async def acall( self, *args: P.args, **kwargs: P.kwargs, ) -> Any: - return getattr( - (await self.as_chat_completion(*args, **kwargs).acreate()).to_model(), - self.response_model_field_name or "output", - ) + model_instance = ( + await self.as_chat_completion(*args, **kwargs).acreate() + ).to_model() + + response_model_field_name = self.response_model_field_name or "output" + + if not (output := getattr(model_instance, response_model_field_name, None)): + return model_instance + + return output def map(self, *map_args: list[Any], **map_kwargs: list[Any]): """ diff --git a/src/marvin/components/ai_model.py b/src/marvin/components/ai_model.py index 322e2758a..230172c00 100644 --- a/src/marvin/components/ai_model.py +++ b/src/marvin/components/ai_model.py @@ -3,9 +3,9 @@ from functools import partial from typing import Any, Callable, Optional, TypeVar -from pydantic import BaseModel from typing_extensions import ParamSpec, Self +from marvin._compat import BaseModel from marvin.core.ChatCompletion import ChatCompletion from marvin.core.ChatCompletion.abstract import AbstractChatCompletion from marvin.prompts import Prompt, prompt_fn diff --git a/src/marvin/components/ai_model_factory.py b/src/marvin/components/ai_model_factory.py index 6c39359f9..33a6112a3 100644 --- a/src/marvin/components/ai_model_factory.py +++ b/src/marvin/components/ai_model_factory.py @@ -1,7 +1,6 @@ from typing import Optional -from pydantic import BaseModel - +from marvin._compat import BaseModel from marvin.components.ai_model import ai_model diff --git a/src/marvin/components/library/ai_models.py b/src/marvin/components/library/ai_models.py index 0b88ce2e8..cc149cab4 100644 --- a/src/marvin/components/library/ai_models.py +++ b/src/marvin/components/library/ai_models.py @@ -2,11 +2,10 @@ from typing import Optional import httpx -from pydantic import BaseModel from typing_extensions import Self from marvin import ai_model -from marvin._compat import Field, SecretStr, field_validator +from marvin._compat import BaseModel, Field, SecretStr, field_validator from marvin.settings import MarvinBaseSettings diff --git a/src/marvin/core/ChatCompletion/__init__.py b/src/marvin/core/ChatCompletion/__init__.py index b2e08a291..068c171ce 100644 --- a/src/marvin/core/ChatCompletion/__init__.py +++ b/src/marvin/core/ChatCompletion/__init__.py @@ -1,6 +1,7 @@ from typing import Optional, Any, TypeVar -from pydantic import BaseModel from .abstract import AbstractChatCompletion + +from marvin._compat import BaseModel from marvin.settings import settings T = TypeVar( diff --git a/src/marvin/core/ChatCompletion/abstract.py b/src/marvin/core/ChatCompletion/abstract.py index 2177e47c4..a01c50d22 100644 --- a/src/marvin/core/ChatCompletion/abstract.py +++ b/src/marvin/core/ChatCompletion/abstract.py @@ -1,9 +1,8 @@ from abc import ABC, abstractmethod from typing import Any, Generic, Optional, TypeVar -from marvin._compat import model_copy, model_dump +from marvin._compat import BaseModel, Field, model_copy, model_dump from marvin.utilities.messages import Message -from pydantic import BaseModel, Field from typing_extensions import Self from .handlers import Request, Response, Turn diff --git a/src/marvin/core/ChatCompletion/handlers.py b/src/marvin/core/ChatCompletion/handlers.py index 5a8e73015..4b8ced16a 100644 --- a/src/marvin/core/ChatCompletion/handlers.py +++ b/src/marvin/core/ChatCompletion/handlers.py @@ -12,11 +12,10 @@ overload, ) -from marvin._compat import cast_to_json, model_dump +from marvin._compat import BaseModel, Field, cast_to_json, model_dump from marvin.utilities.async_utils import run_sync from marvin.utilities.logging import get_logger from marvin.utilities.messages import Message, Role -from pydantic import BaseModel, Field from typing_extensions import ParamSpec from .utils import parse_raw @@ -84,6 +83,9 @@ class Choice(BaseModel): index: int finish_reason: str + class Config: + arbitrary_types_allowed = True + class Usage(BaseModel): prompt_tokens: int diff --git a/src/marvin/core/ChatCompletion/providers/openai.py b/src/marvin/core/ChatCompletion/providers/openai.py index e9f244f76..6675c3ee0 100644 --- a/src/marvin/core/ChatCompletion/providers/openai.py +++ b/src/marvin/core/ChatCompletion/providers/openai.py @@ -1,14 +1,13 @@ import inspect from typing import Any, AsyncGenerator, Callable, Optional, TypeVar, Union -from marvin._compat import cast_to_json, model_dump +from marvin._compat import BaseModel, cast_to_json, model_dump from marvin.settings import settings from marvin.types import Function from marvin.utilities.async_utils import create_task from marvin.utilities.messages import Message from marvin.utilities.streaming import StreamHandler from openai.openai_object import OpenAIObject -from pydantic import BaseModel from ..abstract import AbstractChatCompletion from ..handlers import Request, Response, Usage diff --git a/src/marvin/openai/ChatCompletion/__init__.py b/src/marvin/openai/ChatCompletion/__init__.py index c39421a62..bd41d1135 100644 --- a/src/marvin/openai/ChatCompletion/__init__.py +++ b/src/marvin/openai/ChatCompletion/__init__.py @@ -1,13 +1,10 @@ -from pydantic import BaseModel, Field, validator, Extra, BaseSettings, root_validator from pydantic.main import ModelMetaclass -from typing import Any, Callable, List, Optional, Type, Union, Literal -from marvin import settings -from marvin.types import Function +from typing import Any, Callable, Optional from operator import itemgetter -from marvin.utilities.module_loading import import_string -import warnings -import copy + +from marvin import settings +from marvin._compat import BaseModel, Extra, Field from marvin.types.request import Request as BaseRequest from marvin.engine import ChatCompletionBase diff --git a/src/marvin/openai/Function/Registry/__init__.py b/src/marvin/openai/Function/Registry/__init__.py index b5f76533b..2f74ea661 100644 --- a/src/marvin/openai/Function/Registry/__init__.py +++ b/src/marvin/openai/Function/Registry/__init__.py @@ -1,8 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Set, Type, Union -from fastapi.routing import APIRouter -from pydantic import BaseModel, validate_arguments -from marvin.utilities.types import function_to_model -from marvin.utilities.messages import Message +from typing import Any from marvin.openai.Function import openai_fn from openai.openai_object import OpenAIObject from marvin.functions import FunctionRegistry diff --git a/src/marvin/settings.py b/src/marvin/settings.py index abeb05904..42372eeed 100644 --- a/src/marvin/settings.py +++ b/src/marvin/settings.py @@ -5,7 +5,6 @@ from ._compat import ( BaseSettings, - Field, SecretStr, model_dump, ) @@ -155,22 +154,19 @@ class Settings(MarvinBaseSettings): azure_openai: AzureOpenAI = AzureOpenAI() # SLACK - slack_api_token: Optional[SecretStr] = Field( - default=None, - description="The Slack API token to use for the Slack client", - ) + slack_api_token: Optional[SecretStr] = None # TOOLS # chroma - chroma_server_host: Optional[str] = Field(default=None) - chroma_server_http_port: Optional[int] = Field(default=None) + chroma_server_host: Optional[str] = None + chroma_server_http_port: Optional[int] = None # github - github_token: Optional[SecretStr] = Field(default=None) + github_token: Optional[SecretStr] = None # wolfram - wolfram_app_id: Optional[SecretStr] = Field(default=None) + wolfram_app_id: Optional[SecretStr] = None def get_defaults(self, provider: Optional[str] = None) -> dict[str, Any]: response: dict[str, Any] = {} diff --git a/src/marvin/types/function.py b/src/marvin/types/function.py index d898bbefe..a9da6cd25 100644 --- a/src/marvin/types/function.py +++ b/src/marvin/types/function.py @@ -3,9 +3,7 @@ import re from typing import Callable, Optional, Type -from pydantic import BaseModel - -from marvin._compat import validate_arguments +from marvin._compat import BaseModel, validate_arguments extraneous_fields = [ "args", diff --git a/src/marvin/utilities/messages.py b/src/marvin/utilities/messages.py index 2807df5ab..ef70824cc 100644 --- a/src/marvin/utilities/messages.py +++ b/src/marvin/utilities/messages.py @@ -5,10 +5,9 @@ from typing import Any, Optional from zoneinfo import ZoneInfo -from pydantic import BaseModel, Field from typing_extensions import Self -from marvin._compat import field_validator +from marvin._compat import BaseModel, Field, field_validator from marvin.utilities.strings import split_text_by_tokens from marvin.utilities.types import MarvinBaseModel @@ -34,6 +33,10 @@ class FunctionCall(BaseModel): arguments: str +def utcnow(): + return datetime.now(ZoneInfo("UTC")) + + class Message(MarvinBaseModel): role: Role content: Optional[str] = Field(default=None, description="The message content") @@ -49,7 +52,7 @@ class Message(MarvinBaseModel): id: uuid.UUID = Field(default_factory=uuid.uuid4, exclude=True) data: Optional[dict[str, Any]] = Field(default_factory=dict, exclude=True) timestamp: datetime = Field( - default_factory=lambda: datetime.now(ZoneInfo("UTC")), + default_factory=utcnow, exclude=True, ) diff --git a/src/marvin/utilities/types.py b/src/marvin/utilities/types.py index d5262b713..dbd7f1456 100644 --- a/src/marvin/utilities/types.py +++ b/src/marvin/utilities/types.py @@ -3,9 +3,7 @@ from types import GenericAlias from typing import Any, Callable, _SpecialForm -import pydantic -from pydantic import BaseModel, PrivateAttr - +from marvin._compat import BaseModel, PrivateAttr, create_model from marvin.utilities.logging import get_logger @@ -54,7 +52,7 @@ def function_to_model( # Create Pydantic model try: - Model = pydantic.create_model(name or function.__name__, **fields) + Model = create_model(name or function.__name__, **fields) except RuntimeError as exc: if "see `arbitrary_types_allowed` " in str(exc): raise ValueError( @@ -85,7 +83,7 @@ def safe_issubclass(type_, classes): def type_to_schema(type_, set_root_type: bool = True) -> dict: - if safe_issubclass(type_, pydantic.BaseModel): + if safe_issubclass(type_, BaseModel): schema = type_.schema() # if the docstring was updated at runtime, make it the description if type_.__doc__ and type_.__doc__ != schema.get("description"): @@ -94,13 +92,13 @@ def type_to_schema(type_, set_root_type: bool = True) -> dict: elif set_root_type: - class Model(pydantic.BaseModel): + class Model(BaseModel): __root__: type_ return Model.schema() else: - class Model(pydantic.BaseModel): + class Model(BaseModel): data: type_ return Model.schema() diff --git a/tests/test_components/test_ai_functions.py b/tests/test_components/test_ai_functions.py index 7019dca50..552bfab5c 100644 --- a/tests/test_components/test_ai_functions.py +++ b/tests/test_components/test_ai_functions.py @@ -1,7 +1,8 @@ import inspect import pytest -from marvin.components.ai_function import ai_fn +from marvin import ai_fn +from pydantic import BaseModel from tests.utils.mark import pytest_mark_class @@ -44,6 +45,19 @@ def list_fruit(n: int) -> list[str]: result = list_fruit(3) assert len(result) == 3 + def test_basemodel_response(self): + class Fruit(BaseModel): + name: str + color: str + + @ai_fn + def get_fruit(description: str) -> Fruit: + """Returns a fruit with the provided description""" + + fruit = get_fruit("loved by monkeys") + assert fruit.name.lower() == "banana" + assert fruit.color.lower() == "yellow" + @pytest_mark_class("llm") class TestAIFunctionsMap: