From 265f8df23da9d861bc9e1b6337df81e2705553f1 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Tue, 21 Jan 2025 10:20:59 -0500 Subject: [PATCH] Memories are async by default --- src/marvin/memory/memory.py | 55 ++-- src/marvin/memory/providers/chroma.py | 15 +- src/marvin/memory/providers/lance.py | 8 +- src/marvin/memory/providers/postgres.py | 324 +++++++++++++----------- src/marvin/utilities/tools.py | 19 +- tests/basic/utilities/test_tools.py | 90 +++++++ 6 files changed, 320 insertions(+), 191 deletions(-) create mode 100644 tests/basic/utilities/test_tools.py diff --git a/src/marvin/memory/memory.py b/src/marvin/memory/memory.py index 9337cdd2e..bb437473b 100644 --- a/src/marvin/memory/memory.py +++ b/src/marvin/memory/memory.py @@ -23,15 +23,15 @@ def configure(self, memory_key: str) -> None: """Configure the provider for a specific memory.""" @abc.abstractmethod - def add(self, memory_key: str, content: str) -> str: + async def add(self, memory_key: str, content: str) -> str: """Create a new memory and return its ID.""" @abc.abstractmethod - def delete(self, memory_key: str, memory_id: str) -> None: + async def delete(self, memory_key: str, memory_id: str) -> None: """Delete a memory by its ID.""" @abc.abstractmethod - def search(self, memory_key: str, query: str, n: int = 20) -> dict[str, str]: + async def search(self, memory_key: str, query: str, n: int = 20) -> dict[str, str]: """Search for n memories using a string query.""" @@ -92,14 +92,14 @@ def __post_init__(self): # Configure provider self.provider.configure(self.key) - def add(self, content: str) -> str: - return self.provider.add(self.key, content) + async def add(self, content: str) -> str: + return await self.provider.add(self.key, content) - def delete(self, memory_id: str) -> None: - self.provider.delete(self.key, memory_id) + async def delete(self, memory_id: str) -> None: + await self.provider.delete(self.key, memory_id) - def search(self, query: str, n: int = 20) -> dict[str, str]: - return self.provider.search(self.key, query, n) + async def search(self, query: str, n: int = 20) -> dict[str, str]: + return await self.provider.search(self.key, query, n) def friendly_name(self) -> str: return f"Memory: {self.key}" @@ -130,46 +130,25 @@ def get_memory_provider(provider: str) -> MemoryProvider: # --- CHROMA --- if provider.startswith("chroma"): - try: - import chromadb # noqa: F401 - except ImportError: - raise ImportError( - "To use Chroma as a memory provider, please install the `chromadb` package.", - ) - - import marvin.memory.providers.chroma as chroma_providers + import marvin.memory.providers.chroma as chroma_provider if provider == "chroma-ephemeral": - return chroma_providers.ChromaEphemeralMemory() + return chroma_provider.ChromaEphemeralMemory() if provider == "chroma-db": - return chroma_providers.ChromaPersistentMemory() + return chroma_provider.ChromaPersistentMemory() if provider == "chroma-cloud": - return chroma_providers.ChromaCloudMemory() + return chroma_provider.ChromaCloudMemory() # --- LanceDB --- elif provider.startswith("lancedb"): - try: - import lancedb # noqa: F401 - except ImportError: - raise ImportError( - "To use LanceDB as a memory provider, please install the `lancedb` package.", - ) + import marvin.memory.providers.lance as lance_provider - import marvin.memory.providers.lance as lance_providers - - return lance_providers.LanceMemory() + return lance_provider.LanceMemory() # --- Postgres --- elif provider.startswith("postgres"): - try: - import sqlalchemy # noqa: F401 - except ImportError: - raise ImportError( - "To use Postgres as a memory provider, please install the `sqlalchemy` package.", - ) - - import marvin.memory.providers.postgres as postgres_providers + import marvin.memory.providers.postgres as postgres_provider - return postgres_providers.PostgresMemory() + return postgres_provider.PostgresMemory() raise ValueError(f'Memory provider "{provider}" could not be loaded from a string.') diff --git a/src/marvin/memory/providers/chroma.py b/src/marvin/memory/providers/chroma.py index ccae29e62..eaf4e3df8 100644 --- a/src/marvin/memory/providers/chroma.py +++ b/src/marvin/memory/providers/chroma.py @@ -2,11 +2,16 @@ from dataclasses import dataclass, field from typing import Any -import chromadb - import marvin from marvin.memory.memory import MemoryProvider +try: + import chromadb # noqa: F401 +except ImportError: + raise ImportError( + "To use Chroma as a memory provider, please install the `chromadb` package." + ) + @dataclass(kw_only=True) class ChromaMemory(MemoryProvider): @@ -32,7 +37,7 @@ def get_collection(self, memory_key: str) -> chromadb.Collection: self.collection_name.format(key=memory_key), ) - def add(self, memory_key: str, content: str) -> str: + async def add(self, memory_key: str, content: str) -> str: collection = self.get_collection(memory_key) memory_id = str(uuid.uuid4()) collection.add( @@ -42,11 +47,11 @@ def add(self, memory_key: str, content: str) -> str: ) return memory_id - def delete(self, memory_key: str, memory_id: str) -> None: + async def delete(self, memory_key: str, memory_id: str) -> None: collection = self.get_collection(memory_key) collection.delete(ids=[memory_id]) - def search(self, memory_key: str, query: str, n: int = 20) -> dict[str, str]: + async def search(self, memory_key: str, query: str, n: int = 20) -> dict[str, str]: results = self.get_collection(memory_key).query( query_texts=[query], n_results=n, diff --git a/src/marvin/memory/providers/lance.py b/src/marvin/memory/providers/lance.py index 7c5778546..a537f9b1c 100644 --- a/src/marvin/memory/providers/lance.py +++ b/src/marvin/memory/providers/lance.py @@ -10,7 +10,7 @@ from lancedb.pydantic import LanceModel, Vector except ImportError: raise ImportError( - "LanceDB is not installed. Please install it with `pip install lancedb`.", + "To use LanceDB as a memory provider, please install the `lancedb` package." ) from pydantic import Field @@ -70,17 +70,17 @@ def get_table(self, memory_key: str) -> lancedb.table.Table: except FileNotFoundError: return db.create_table(table_name, schema=model) - def add(self, memory_key: str, content: str) -> str: + async def add(self, memory_key: str, content: str) -> str: memory_id = str(uuid.uuid4()) table = self.get_table(memory_key) table.add([{"id": memory_id, "text": content}]) return memory_id - def delete(self, memory_key: str, memory_id: str) -> None: + async def delete(self, memory_key: str, memory_id: str) -> None: table = self.get_table(memory_key) table.delete(f'id = "{memory_id}"') - def search(self, memory_key: str, query: str, n: int = 20) -> dict[str, str]: + async def search(self, memory_key: str, query: str, n: int = 20) -> dict[str, str]: table = self.get_table(memory_key) results = table.search(query).limit(n).to_pydantic(self.get_model()) return {r.id: r.text for r in results} diff --git a/src/marvin/memory/providers/postgres.py b/src/marvin/memory/providers/postgres.py index 6f688c1d1..a7730a5f1 100644 --- a/src/marvin/memory/providers/postgres.py +++ b/src/marvin/memory/providers/postgres.py @@ -1,23 +1,40 @@ import uuid -from collections.abc import Callable -from dataclasses import dataclass, field +from typing import Callable, Dict, Optional +# async pg +import anyio import sqlalchemy -from pgvector.sqlalchemy import Vector +from pydantic import Field from sqlalchemy import Column, String, select, text from sqlalchemy.exc import ProgrammingError -from sqlalchemy.orm import Session, declarative_base, sessionmaker -from sqlalchemy_utils import create_database, database_exists +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.orm import declarative_base from marvin.memory.memory import MemoryProvider +try: + from sqlalchemy_utils import create_database, database_exists +except ImportError: + raise ImportError( + "To use Postgres as a memory provider, please install the `sqlalchemy_utils` package." + ) +try: + from pgvector.sqlalchemy import Vector +except ImportError: + raise ImportError( + "To use Postgres as a memory provider, please install the `pgvector` package." + ) try: # For embeddings, we can use langchain_openai or any other library: from langchain_openai import OpenAIEmbeddings except ImportError: raise ImportError( - "To use an embedding function similar to LanceDB's default, " - "please install lancedb with: pip install lancedb", + "To use Langchain OpenAI as an embedding function, please install the `langchain-openai` package." ) # SQLAlchemy base class for declarative models @@ -25,7 +42,8 @@ class SQLMemoryTable(Base): - """A simple declarative model that represents a memory record. + """ + A simple declarative model that represents a memory record. We'll dynamically set the __tablename__ at runtime. """ @@ -37,144 +55,137 @@ class SQLMemoryTable(Base): # vector = Column(Vector(dim=1536)) # Adjust dimension to match your embedding model -@dataclass(kw_only=True) class PostgresMemory(MemoryProvider): - """A Marvin MemoryProvider that stores text + embeddings in PostgreSQL - using SQLAlchemy and pg_vector. Each Memory module gets its own table. + """ + An async MemoryProvider storing text + embeddings in PostgreSQL + using SQLAlchemy + pg_vector, but with full async support. """ - database_url: str = field( - default="postgresql://user:password@localhost:5432/your_database", - metadata={ - "description": "SQLAlchemy-compatible database URL to a Postgres instance with pgvector.", - }, + database_url: str = Field( + default="postgresql+asyncpg://user:password@localhost:5432/your_database", + description="Async Postgres URL with the asyncpg driver, e.g. " + "'postgresql+asyncpg://user:pass@host:5432/dbname'.", ) - table_name: str = field( - default="memory_{key}", - metadata={ - "description": """ - Name of the table to store this memory partition. "{key}" will be replaced - by the memory's key attribute. - """, - }, + + table_name: str = Field( + "memory_{key}", + description=""" + Name of the table for this memory partition. "{key}" gets replaced by the memory key. + """, ) - embedding_dimension: int = field( + embedding_dimension: int = Field( default=1536, - metadata={ - "description": "Dimension of the embedding vectors. Match your model's output.", - }, + description="Dimension of the embedding vectors. Must match your model output size.", ) - embedding_fn: Callable = field( - default_factory=lambda: OpenAIEmbeddings( - model="text-embedding-ada-002", - ), - metadata={"description": "A function that turns a string into a vector."}, + embedding_fn: Callable = Field( + default_factory=lambda: OpenAIEmbeddings(model="text-embedding-ada-002"), + description="Function that turns a string into a numeric vector.", ) - # Connection pool settings - pool_size: int = field( - default=5, - metadata={"description": "Number of connections to keep open in the pool."}, + # -- Pool / Engine settings (SQLAlchemy will do the pooling) + pool_size: int = Field( + 5, description="Number of permanent connections in the async pool." ) - - max_overflow: int = field( - default=10, - metadata={ - "description": "Number of connections to allow that can overflow the pool.", - }, + max_overflow: int = Field( + 10, description="Number of 'overflow' connections if the pool is full." ) - - pool_timeout: int = field( - default=30, - metadata={ - "description": "Number of seconds to wait before giving up on getting a connection.", - }, + pool_timeout: int = Field( + 30, description="Seconds to wait for a connection before raising an error." ) - - pool_recycle: int = field( - default=1800, - metadata={ - "description": "Number of seconds a connection can be idle before being recycled.", - }, + pool_recycle: int = Field( + 1800, + description="Recycle connections after N seconds to avoid stale connections.", ) - - pool_pre_ping: bool = field( - default=True, - metadata={"description": "Check the connection health upon checkout."}, + pool_pre_ping: bool = Field( + True, description="Check connection health before using from the pool." ) - # Internal: keep a cached Session maker - _SessionLocal: sessionmaker | None = None - - # This dict will map "table_name" -> "model class" - _table_class_cache: dict[str, Base] = {} - - def configure(self, memory_key: str) -> None: - """Configure a SQLAlchemy session w/connection pooling and ensure the table for this - memory partition is created if it does not already exist. - """ - engine = sqlalchemy.create_engine( - self.database_url, - pool_size=self.pool_size, - max_overflow=self.max_overflow, - pool_timeout=self.pool_timeout, - pool_recycle=self.pool_recycle, - pool_pre_ping=self.pool_pre_ping, - ) - - # 2) If DB doesn't exist, create it! - if not database_exists(engine.url): - create_database(engine.url) - - with engine.connect() as conn: - conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) - conn.commit() + # We'll store an async engine + session factory: + _engine: Optional[AsyncEngine] = None + _SessionLocal: Optional[async_sessionmaker[AsyncSession]] = None - self._SessionLocal = sessionmaker(bind=engine) + # Cache for dynamically generated table classes + _table_class_cache: Dict[str, Base] = {} - # Dynamically create a specialized table model for this memory_key - table_name = self.table_name.format(key=memory_key) + _configured: bool = False - # 1) Check if table already in metadata - if table_name not in Base.metadata.tables: - # 2) Create the dynamic class + table - memory_model = type( - f"SQLMemoryTable_{memory_key}", - (SQLMemoryTable,), - { - "__tablename__": table_name, - "vector": Column(Vector(dim=self.embedding_dimension)), - }, + async def configure(self, memory_key: str) -> None: + """ + 1) Create an async engine. + 2) Optionally create the DB if it doesn't exist (requires sync workaround). + 3) Install pgvector extension. + 4) Generate the memory table if missing. + 5) Initialize the async sessionmaker. + """ + if self._configured: + return + # 1) Create an async engine. Use the asyncpg dialect. + # The pool settings are configured in 'create_async_engine' with 'pool_size', etc. + else: + self._engine = create_async_engine( + self.database_url, + pool_size=self.pool_size, + max_overflow=self.max_overflow, + pool_timeout=self.pool_timeout, + pool_recycle=self.pool_recycle, + pool_pre_ping=self.pool_pre_ping, ) - try: - Base.metadata.create_all(engine, tables=[memory_model.__table__]) - # Store it in the cache - self._table_class_cache[table_name] = memory_model - except ProgrammingError as e: - raise RuntimeError(f"Failed to create table {table_name}: {e}") - - def _get_session(self) -> Session: - if not self._SessionLocal: - raise RuntimeError( - "Session is not initialized. Make sure to call configure() first.", + exists = await anyio.to_thread.run_sync(database_exists, self.database_url) + if not exists: + await anyio.to_thread.run_sync(create_database, self.database_url) + + # 3) Run migrations / create extension in an async context: + async with self._engine.begin() as conn: + # Create the pgvector extension if not exists + await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) + # We'll create the table for the memory_key specifically + # (1) Build the dynamic table class + table_name = self.table_name.format(key=memory_key) + if table_name not in Base.metadata.tables: + memory_model = type( + f"SQLMemoryTable_{memory_key}", + (SQLMemoryTable,), + { + "__tablename__": table_name, + "vector": Column(Vector(dim=self.embedding_dimension)), + }, + ) + self._table_class_cache[table_name] = memory_model + + # (2) Actually create it (async): + def _sync_create(connection): + """Helper function to run table creation in sync context.""" + Base.metadata.create_all( + connection, tables=[memory_model.__table__] + ) + + try: + await conn.run_sync(_sync_create) + except ProgrammingError as e: + raise RuntimeError( + f"Failed to create table '{table_name}': {e}" + ) + + # 4) Now that the DB and table are ready, create a session factory + self._SessionLocal = async_sessionmaker( + self._engine, + expire_on_commit=False, ) - return self._SessionLocal() + + self._configured = True def _get_table(self, memory_key: str) -> Base: - """Return a dynamically generated declarative model class - mapped to the memory_{key} table. Each memory partition - has a separate table. + """ + Return or create the dynamic model class for 'memory_{key}' table. """ table_name = self.table_name.format(key=memory_key) - - # Return the cached class if already built if table_name in self._table_class_cache: return self._table_class_cache[table_name] - # If for some reason it's not there, create it now (or raise error): + # If not found, define it at runtime (we won't auto-create it here though) memory_model = type( f"SQLMemoryTable_{memory_key}", (SQLMemoryTable,), @@ -186,46 +197,77 @@ def _get_table(self, memory_key: str) -> Base: self._table_class_cache[table_name] = memory_model return memory_model - def add(self, memory_key: str, content: str) -> str: - """Insert a new memory record into the Postgres table, - generating an embedding and storing it in a vector column. - Returns the memory's ID (uuid). + async def add(self, memory_key: str, content: str) -> str: + """ + Insert a new record with an embedding vector. + Returns the inserted record's UUID. """ + # lazy config + if not self._configured: + await self.configure(memory_key) + + if not self._SessionLocal: + raise RuntimeError("Call 'configure(...)' before using this provider.") + memory_id = str(uuid.uuid4()) model_cls = self._get_table(memory_key) - - # Generate an embedding for the content embedding = self.embedding_fn.embed_query(content) - with self._get_session() as session: - record = model_cls(id=memory_id, text=content, vector=embedding) + async with self._SessionLocal() as session: + record = model_cls( + id=memory_id, + text=content, + vector=embedding, + ) session.add(record) - session.commit() + await session.commit() return memory_id - def delete(self, memory_key: str, memory_id: str) -> None: - """Delete a memory record by its UUID.""" + async def delete(self, memory_key: str, memory_id: str) -> None: + """ + Delete a record by UUID. + """ + # lazy config + if not self._configured: + await self.configure(memory_key) + + if not self._SessionLocal: + raise RuntimeError("Not configured. Call 'configure(...)' first.") + model_cls = self._get_table(memory_key) - with self._get_session() as session: - session.query(model_cls).filter(model_cls.id == memory_id).delete() - session.commit() + async with self._SessionLocal() as session: + await session.execute( + sqlalchemy.delete(model_cls).where(model_cls.id == memory_id) + ) + await session.commit() - def search(self, memory_key: str, query: str, n: int = 20) -> dict[str, str]: - """Uses pgvector's approximate nearest neighbor search with the `<->` operator to find - the top N matching records for the embedded query. Returns a dict of {id: text}. + async def search(self, memory_key: str, query: str, n: int = 20) -> Dict[str, str]: + """ + Async nearest-neighbor search via pgvector <-> operator or .l2_distance(), + returning up to N results as {id: text}. """ + + # lazy config + if not self._configured: + await self.configure(memory_key) + + if not self._SessionLocal: + raise RuntimeError("Not configured. Call 'configure(...)' first.") + model_cls = self._get_table(memory_key) - # Generate embedding for the query - query_embedding = self.embedding_fn.embed_query(query) + embedding = self.embedding_fn.embed_query(query) embedding_col = model_cls.vector - with self._get_session() as session: - results = session.execute( + async with self._SessionLocal() as session: + # Example using l2_distance: + results = await session.execute( select(model_cls.id, model_cls.text) - .order_by(embedding_col.l2_distance(query_embedding)) - .limit(n), - ).all() + .order_by(embedding_col.l2_distance(embedding)) + .limit(n) + ) + rows = results.all() - return {row.id: row.text for row in results} + # Convert list of Row objects -> dict + return {row.id: row.text for row in rows} diff --git a/src/marvin/utilities/tools.py b/src/marvin/utilities/tools.py index 49eddd85e..3a2907bd1 100644 --- a/src/marvin/utilities/tools.py +++ b/src/marvin/utilities/tools.py @@ -1,3 +1,4 @@ +import inspect from collections.abc import Callable from dataclasses import dataclass from functools import wraps @@ -57,12 +58,24 @@ def add_stuff(x): return x + 1 new_fn = update_fn(add_stuff, name='add_stuff_123', description='Adds stuff') + # Works with async functions too: + @update_fn('async_hello') + async def my_async_fn(x): + return x + """ def apply(func: Callable[..., T], new_name: str) -> Callable[..., T]: - @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> T: - return func(*args, **kwargs) + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> T: + return await func(*args, **kwargs) + else: + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> T: + return func(*args, **kwargs) wrapper.__name__ = new_name if description is not None: diff --git a/tests/basic/utilities/test_tools.py b/tests/basic/utilities/test_tools.py new file mode 100644 index 000000000..610fb3094 --- /dev/null +++ b/tests/basic/utilities/test_tools.py @@ -0,0 +1,90 @@ +import pytest + +from marvin.utilities.tools import update_fn + + +def test_update_fn_sync_decorator_positional(): + """Test update_fn as decorator with positional argument""" + + @update_fn("new_name") + def my_fn(x: int) -> int: + return x + 1 + + assert my_fn.__name__ == "new_name" + assert my_fn(1) == 2 + + +def test_update_fn_sync_decorator_keyword(): + """Test update_fn as decorator with keyword arguments""" + + @update_fn(name="another_name", description="adds stuff") + def another_fn(x: int) -> int: + return x + 2 + + assert another_fn.__name__ == "another_name" + assert another_fn.__doc__ == "adds stuff" + assert another_fn(1) == 3 + + +def test_update_fn_sync_direct(): + """Test update_fn called directly on a function""" + + def third_fn(x: int) -> int: + return x + 3 + + renamed = update_fn(third_fn, name="third_name") + assert renamed.__name__ == "third_name" + assert renamed(1) == 4 + + +async def test_update_fn_async_decorator_positional(): + """Test update_fn as decorator with positional argument on async function""" + + @update_fn("async_name") + async def my_async_fn(x: int) -> int: + return x + 1 + + assert my_async_fn.__name__ == "async_name" + assert await my_async_fn(1) == 2 + + +async def test_update_fn_async_decorator_keyword(): + """Test update_fn as decorator with keyword arguments on async function""" + + @update_fn(name="another_async", description="adds stuff async") + async def another_async_fn(x: int) -> int: + return x + 2 + + assert another_async_fn.__name__ == "another_async" + assert another_async_fn.__doc__ == "adds stuff async" + assert await another_async_fn(1) == 3 + + +async def test_update_fn_async_direct(): + """Test update_fn called directly on async function""" + + async def third_async_fn(x: int) -> int: + return x + 3 + + renamed = update_fn(third_async_fn, name="third_async") + assert renamed.__name__ == "third_async" + assert await renamed(1) == 4 + + +def test_update_fn_validation_missing_name_direct(): + """Test update_fn validation when name is missing in direct call""" + with pytest.raises( + ValueError, match="name must be provided when used as a function" + ): + update_fn(lambda x: x) + + +def test_update_fn_validation_missing_name_decorator(): + """Test update_fn validation when name is missing in decorator""" + with pytest.raises( + ValueError, match="name must be provided either as argument or keyword" + ): + + @update_fn() + def my_fn(x): + return x