Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memories are async classes #1026

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 17 additions & 38 deletions src/marvin/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ def configure(self, memory_key: str) -> None:
"""Configure the provider for a specific memory."""

@abc.abstractmethod
def add(self, memory_key: str, content: str) -> str:
async def add(self, memory_key: str, content: str) -> str:
"""Create a new memory and return its ID."""

@abc.abstractmethod
def delete(self, memory_key: str, memory_id: str) -> None:
async def delete(self, memory_key: str, memory_id: str) -> None:
"""Delete a memory by its ID."""

@abc.abstractmethod
def search(self, memory_key: str, query: str, n: int = 20) -> dict[str, str]:
async def search(self, memory_key: str, query: str, n: int = 20) -> dict[str, str]:
"""Search for n memories using a string query."""


Expand Down Expand Up @@ -92,14 +92,14 @@ def __post_init__(self):
# Configure provider
self.provider.configure(self.key)

def add(self, content: str) -> str:
return self.provider.add(self.key, content)
async def add(self, content: str) -> str:
return await self.provider.add(self.key, content)

def delete(self, memory_id: str) -> None:
self.provider.delete(self.key, memory_id)
async def delete(self, memory_id: str) -> None:
await self.provider.delete(self.key, memory_id)

def search(self, query: str, n: int = 20) -> dict[str, str]:
return self.provider.search(self.key, query, n)
async def search(self, query: str, n: int = 20) -> dict[str, str]:
return await self.provider.search(self.key, query, n)

def friendly_name(self) -> str:
return f"Memory: {self.key}"
Expand Down Expand Up @@ -130,46 +130,25 @@ def get_memory_provider(provider: str) -> MemoryProvider:
# --- CHROMA ---

if provider.startswith("chroma"):
try:
import chromadb # noqa: F401
except ImportError:
raise ImportError(
"To use Chroma as a memory provider, please install the `chromadb` package.",
)

import marvin.memory.providers.chroma as chroma_providers
import marvin.memory.providers.chroma as chroma_provider

if provider == "chroma-ephemeral":
return chroma_providers.ChromaEphemeralMemory()
return chroma_provider.ChromaEphemeralMemory()
if provider == "chroma-db":
return chroma_providers.ChromaPersistentMemory()
return chroma_provider.ChromaPersistentMemory()
if provider == "chroma-cloud":
return chroma_providers.ChromaCloudMemory()
return chroma_provider.ChromaCloudMemory()

# --- LanceDB ---

elif provider.startswith("lancedb"):
try:
import lancedb # noqa: F401
except ImportError:
raise ImportError(
"To use LanceDB as a memory provider, please install the `lancedb` package.",
)
import marvin.memory.providers.lance as lance_provider

import marvin.memory.providers.lance as lance_providers

return lance_providers.LanceMemory()
return lance_provider.LanceMemory()

# --- Postgres ---
elif provider.startswith("postgres"):
try:
import sqlalchemy # noqa: F401
except ImportError:
raise ImportError(
"To use Postgres as a memory provider, please install the `sqlalchemy` package.",
)

import marvin.memory.providers.postgres as postgres_providers
import marvin.memory.providers.postgres as postgres_provider

return postgres_providers.PostgresMemory()
return postgres_provider.PostgresMemory()
raise ValueError(f'Memory provider "{provider}" could not be loaded from a string.')
15 changes: 10 additions & 5 deletions src/marvin/memory/providers/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
from dataclasses import dataclass, field
from typing import Any

import chromadb

import marvin
from marvin.memory.memory import MemoryProvider

try:
import chromadb # noqa: F401
except ImportError:
raise ImportError(
"To use Chroma as a memory provider, please install the `chromadb` package."
)


@dataclass(kw_only=True)
class ChromaMemory(MemoryProvider):
Expand All @@ -32,7 +37,7 @@ def get_collection(self, memory_key: str) -> chromadb.Collection:
self.collection_name.format(key=memory_key),
)

def add(self, memory_key: str, content: str) -> str:
async def add(self, memory_key: str, content: str) -> str:
collection = self.get_collection(memory_key)
memory_id = str(uuid.uuid4())
collection.add(
Expand All @@ -42,11 +47,11 @@ def add(self, memory_key: str, content: str) -> str:
)
return memory_id

def delete(self, memory_key: str, memory_id: str) -> None:
async def delete(self, memory_key: str, memory_id: str) -> None:
collection = self.get_collection(memory_key)
collection.delete(ids=[memory_id])

def search(self, memory_key: str, query: str, n: int = 20) -> dict[str, str]:
async def search(self, memory_key: str, query: str, n: int = 20) -> dict[str, str]:
results = self.get_collection(memory_key).query(
query_texts=[query],
n_results=n,
Expand Down
8 changes: 4 additions & 4 deletions src/marvin/memory/providers/lance.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from lancedb.pydantic import LanceModel, Vector
except ImportError:
raise ImportError(
"LanceDB is not installed. Please install it with `pip install lancedb`.",
"To use LanceDB as a memory provider, please install the `lancedb` package."
)

from pydantic import Field
Expand Down Expand Up @@ -70,17 +70,17 @@ def get_table(self, memory_key: str) -> lancedb.table.Table:
except FileNotFoundError:
return db.create_table(table_name, schema=model)

def add(self, memory_key: str, content: str) -> str:
async def add(self, memory_key: str, content: str) -> str:
memory_id = str(uuid.uuid4())
table = self.get_table(memory_key)
table.add([{"id": memory_id, "text": content}])
return memory_id

def delete(self, memory_key: str, memory_id: str) -> None:
async def delete(self, memory_key: str, memory_id: str) -> None:
table = self.get_table(memory_key)
table.delete(f'id = "{memory_id}"')

def search(self, memory_key: str, query: str, n: int = 20) -> dict[str, str]:
async def search(self, memory_key: str, query: str, n: int = 20) -> dict[str, str]:
table = self.get_table(memory_key)
results = table.search(query).limit(n).to_pydantic(self.get_model())
return {r.id: r.text for r in results}
Loading
Loading