diff --git a/src/marvin/fns/cast.py b/src/marvin/fns/cast.py index 931b95441..27c80015b 100644 --- a/src/marvin/fns/cast.py +++ b/src/marvin/fns/cast.py @@ -33,7 +33,7 @@ async def cast_async( instructions: str | None = None, agent: Agent | None = None, thread: Thread | str | None = None, - context: dict | None = None, + context: dict[str, Any] | None = None, ) -> T: """Asynchronously transforms input data into a specific type using a language model. diff --git a/src/marvin/fns/classify.py b/src/marvin/fns/classify.py index e5ca1529a..587c8f1ab 100644 --- a/src/marvin/fns/classify.py +++ b/src/marvin/fns/classify.py @@ -28,7 +28,7 @@ async def classify_async( instructions: str | None = None, agent: Agent | None = None, thread: Thread | str | None = None, - context: dict | None = None, + context: dict[str, Any] | None = None, ) -> T: ... @@ -41,7 +41,7 @@ async def classify_async( instructions: str | None = None, agent: Agent | None = None, thread: Thread | str | None = None, - context: dict | None = None, + context: dict[str, Any] | None = None, ) -> list[T]: ... @@ -52,7 +52,7 @@ async def classify_async( instructions: str | None = None, agent: Agent | None = None, thread: Thread | str | None = None, - context: dict | None = None, + context: dict[str, Any] | None = None, ) -> T | list[T]: """Asynchronously classifies input data into one or more predefined labels using a language model. diff --git a/src/marvin/fns/extract.py b/src/marvin/fns/extract.py index a993a54b9..f91a4898a 100644 --- a/src/marvin/fns/extract.py +++ b/src/marvin/fns/extract.py @@ -80,7 +80,7 @@ def extract( instructions: str | None = None, agent: Agent | None = None, thread: Thread | str | None = None, - context: dict | None = None, + context: dict[str, Any] | None = None, ) -> list[T]: """Extracts entities of a specific type from the provided data. diff --git a/src/marvin/fns/generate.py b/src/marvin/fns/generate.py index 83171c3e8..435c0aee1 100644 --- a/src/marvin/fns/generate.py +++ b/src/marvin/fns/generate.py @@ -1,4 +1,4 @@ -from typing import TypeVar, cast +from typing import Any, TypeVar, cast from pydantic import conlist @@ -34,7 +34,7 @@ async def generate_async( instructions: str | None = None, agent: Agent | None = None, thread: Thread | str | None = None, - context: dict | None = None, + context: dict[str, Any] | None = None, ) -> list[T]: """Generates examples of a specific type or matching a description asynchronously. @@ -82,7 +82,7 @@ def generate( instructions: str | None = None, agent: Agent | None = None, thread: Thread | str | None = None, - context: dict | None = None, + context: dict[str, Any] | None = None, ) -> list[T]: """Generates examples of a specific type or matching a description. diff --git a/src/marvin/settings.py b/src/marvin/settings.py index 9674b478a..259b79b18 100644 --- a/src/marvin/settings.py +++ b/src/marvin/settings.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import Literal -from pydantic import Field, field_validator, model_validator +from pydantic import Field, ValidationInfo, field_validator, model_validator from pydantic_ai.models import KnownModelName from pydantic_settings import BaseSettings, SettingsConfigDict from typing_extensions import Self @@ -21,7 +21,7 @@ class Settings(BaseSettings): case_sensitive=False, env_file=".env", env_file_encoding="utf-8", - extra="forbid", + extra="ignore", validate_assignment=True, ) @@ -45,31 +45,27 @@ def validate_home_path(cls, v: Path) -> Path: description="Path to the database file. Defaults to `home_path / 'marvin.db'`.", ) - @model_validator(mode="after") - def validate_database_url(self) -> Self: + @field_validator("database_url") + @classmethod + def validate_database_url(cls, v: str | None, info: ValidationInfo) -> str: """Set and validate the database path.""" + home_path = info.data.get("home_path") + # Set default if not provided - if self.database_url is None: - self.__dict__["database_url"] = str(self.home_path / "marvin.db") - return self + if v is None: + if not home_path: + raise ValueError("home_path must be set before database_url") + return str(home_path / "marvin.db") # Handle in-memory database - if self.database_url == ":memory:": - return self - - # Convert to Path for validation - path = Path(self.database_url) + if v == ":memory:": + return v - # Expand user and resolve to absolute path - path = path.expanduser().resolve() - - # Ensure parent directory exists + # Convert to Path for validation and ensure parent directory exists + path = Path(v).expanduser().resolve() path.parent.mkdir(parents=True, exist_ok=True) - # Store result as string - self.__dict__["database_url"] = str(path) - - return self + return str(path) # ------------ Logging settings ------------ diff --git a/tests/settings/test_settings_object.py b/tests/settings/test_settings_object.py new file mode 100644 index 000000000..22ada7e36 --- /dev/null +++ b/tests/settings/test_settings_object.py @@ -0,0 +1,35 @@ +import os + +import pytest + +from marvin.settings import Settings + + +@pytest.fixture +def current_user() -> str: + user = os.getenv("USER") + assert user is not None, "USER environment variable must be set" + return user + + +def test_database_url_default(current_user: str): + settings = Settings() + assert settings.database_url == f"/Users/{current_user}/.marvin/marvin.db" + + +@pytest.mark.parametrize( + "env_var_value, expected_database_url", + [ + (":memory:", ":memory:"), + ("~/.marvin/test.db", "/Users/{user}/.marvin/test.db"), + ], +) +def test_database_url_set_from_env_var( + monkeypatch: pytest.MonkeyPatch, + env_var_value: str, + expected_database_url: str, + current_user: str, +): + monkeypatch.setenv("MARVIN_DATABASE_URL", env_var_value) + settings = Settings() + assert settings.database_url == expected_database_url.format(user=current_user)