Skip to content

Commit

Permalink
actually ignore extra
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz committed Jan 18, 2025
1 parent 6d7966e commit f1235a1
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/marvin/fns/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def cast_async(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
context: dict[str, Any] | None = None,
) -> T:
"""Asynchronously transforms input data into a specific type using a language model.
Expand Down
6 changes: 3 additions & 3 deletions src/marvin/fns/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async def classify_async(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
context: dict[str, Any] | None = None,
) -> T: ...


Expand All @@ -41,7 +41,7 @@ async def classify_async(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
context: dict[str, Any] | None = None,
) -> list[T]: ...


Expand All @@ -52,7 +52,7 @@ async def classify_async(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
context: dict[str, Any] | None = None,
) -> T | list[T]:
"""Asynchronously classifies input data into one or more predefined labels using a language model.
Expand Down
2 changes: 1 addition & 1 deletion src/marvin/fns/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def extract(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
context: dict[str, Any] | None = None,
) -> list[T]:
"""Extracts entities of a specific type from the provided data.
Expand Down
6 changes: 3 additions & 3 deletions src/marvin/fns/generate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypeVar, cast
from typing import Any, TypeVar, cast

from pydantic import conlist

Expand Down Expand Up @@ -34,7 +34,7 @@ async def generate_async(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
context: dict[str, Any] | None = None,
) -> list[T]:
"""Generates examples of a specific type or matching a description asynchronously.
Expand Down Expand Up @@ -82,7 +82,7 @@ def generate(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
context: dict[str, Any] | None = None,
) -> list[T]:
"""Generates examples of a specific type or matching a description.
Expand Down
36 changes: 16 additions & 20 deletions src/marvin/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path
from typing import Literal

from pydantic import Field, field_validator, model_validator
from pydantic import Field, ValidationInfo, field_validator, model_validator
from pydantic_ai.models import KnownModelName
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing_extensions import Self
Expand All @@ -21,7 +21,7 @@ class Settings(BaseSettings):
case_sensitive=False,
env_file=".env",
env_file_encoding="utf-8",
extra="forbid",
extra="ignore",
validate_assignment=True,
)

Expand All @@ -45,31 +45,27 @@ def validate_home_path(cls, v: Path) -> Path:
description="Path to the database file. Defaults to `home_path / 'marvin.db'`.",
)

@model_validator(mode="after")
def validate_database_url(self) -> Self:
@field_validator("database_url")
@classmethod
def validate_database_url(cls, v: str | None, info: ValidationInfo) -> str:
"""Set and validate the database path."""
home_path = info.data.get("home_path")

# Set default if not provided
if self.database_url is None:
self.__dict__["database_url"] = str(self.home_path / "marvin.db")
return self
if v is None:
if not home_path:
raise ValueError("home_path must be set before database_url")
return str(home_path / "marvin.db")

# Handle in-memory database
if self.database_url == ":memory:":
return self

# Convert to Path for validation
path = Path(self.database_url)
if v == ":memory:":
return v

# Expand user and resolve to absolute path
path = path.expanduser().resolve()

# Ensure parent directory exists
# Convert to Path for validation and ensure parent directory exists
path = Path(v).expanduser().resolve()
path.parent.mkdir(parents=True, exist_ok=True)

# Store result as string
self.__dict__["database_url"] = str(path)

return self
return str(path)

# ------------ Logging settings ------------

Expand Down
35 changes: 35 additions & 0 deletions tests/settings/test_settings_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os

import pytest

from marvin.settings import Settings


@pytest.fixture
def current_user() -> str:
user = os.getenv("USER")
assert user is not None, "USER environment variable must be set"
return user


def test_database_url_default(current_user: str):
settings = Settings()
assert settings.database_url == f"/Users/{current_user}/.marvin/marvin.db"


@pytest.mark.parametrize(
"env_var_value, expected_database_url",
[
(":memory:", ":memory:"),
("~/.marvin/test.db", "/Users/{user}/.marvin/test.db"),
],
)
def test_database_url_set_from_env_var(
monkeypatch: pytest.MonkeyPatch,
env_var_value: str,
expected_database_url: str,
current_user: str,
):
monkeypatch.setenv("MARVIN_DATABASE_URL", env_var_value)
settings = Settings()
assert settings.database_url == expected_database_url.format(user=current_user)

0 comments on commit f1235a1

Please sign in to comment.