Skip to content

Commit

Permalink
Merge pull request #14 from conversence/feature/avoid_ensure_metadata
Browse files Browse the repository at this point in the history
Optimizations and abstractions
  • Loading branch information
maparent authored Sep 23, 2023
2 parents f696edf + 793a1ab commit 1242d28
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 141 deletions.
2 changes: 1 addition & 1 deletion agentmemory/check_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def _download(url: str, fname: Path, chunk_size: int = 1024) -> None:
size = file.write(data)
bar.update(size)

default_model_path = str(Path.home() / ".cache" / "onnx_models")
default_model_path = Path.home() / ".cache" / "onnx_models"

def check_model(model_name = "all-MiniLM-L6-v2", model_path = default_model_path) -> str:
DOWNLOAD_PATH = Path(model_path) / model_name
Expand Down
73 changes: 72 additions & 1 deletion agentmemory/chroma_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,78 @@

import chromadb

from .client import CollectionMemory, AgentMemory

class ChromaCollectionMemory(CollectionMemory):
def __init__(self, collection, metadata=None) -> None:
self.collection = collection

def count(self):
return self.collection.count()

def add(self, ids, documents=None, metadatas=None, embeddings=None):
return self.collection.add(ids, documents, metadatas, embeddings)

def get(
self,
ids=None,
where=None,
limit=None,
offset=None,
where_document=None,
include=["metadatas", "documents"],
):
return self.collection.get(ids, where, limit, offset, where_document, include)

def peek(self, limit=10):
return self.collection.peek(limit)

def query(
self,
query_embeddings=None,
query_texts=None,
n_results=10,
where=None,
where_document=None,
include=["metadatas", "documents", "distances"],
):
return self.collection.query(query_embeddings, query_texts, n_results, where, where_document, include)

def update(self, ids, documents=None, metadatas=None, embeddings=None):
return self.collection.update(ids, embeddings, metadatas, documents)

def upsert(self, ids, documents=None, metadatas=None, embeddings=None):
# if no id is provided, generate one based on count of documents in collection
if any(id is None for id in ids):
origin = self.count()
# pad the id with zeros to make it 16 digits long
ids = [str(id_).zfill(16) for id_ in range(origin, origin+len(documents))]

return self.collection.upsert(ids, embeddings, metadatas, documents)

def delete(self, ids=None, where=None, where_document=None):
return self.collection.delete(ids, where, where_document)


class ChromaMemory(AgentMemory):
def __init__(self, path) -> None:
self.chroma = chromadb.PersistentClient(path=path)

def get_or_create_collection(self, category, metadata=None) -> CollectionMemory:
memory = self.chroma.get_or_create_collection(category)
return ChromaCollectionMemory(memory, metadata)

def get_collection(self, category) -> CollectionMemory:
memory = self.chroma.get_collection(category)
return ChromaCollectionMemory(memory)

def delete_collection(self, category):
self.chroma.delete_collection(category)

def list_collections(self):
return self.chroma.list_collections()


def create_client():
STORAGE_PATH = os.environ.get("STORAGE_PATH", "./memory")
return chromadb.PersistentClient(path=STORAGE_PATH)
return ChromaMemory(path=STORAGE_PATH)
2 changes: 1 addition & 1 deletion agentmemory/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class AgentCollection():

class AgentMemory(ABC):
@abstractmethod
def get_or_create_collection(self, category) -> CollectionMemory:
def get_or_create_collection(self, category, metadata=None) -> CollectionMemory:
raise NotImplementedError()

@abstractmethod
Expand Down
8 changes: 1 addition & 7 deletions agentmemory/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,6 @@ def create_memory(category, text, metadata={}, embedding=None, id=None):
metadata["created_at"] = datetime.datetime.now().timestamp()
metadata["updated_at"] = datetime.datetime.now().timestamp()

# if no id is provided, generate one based on count of documents in collection
if id is None:
id = str(memories.count())
# pad the id with zeros to make it 16 digits long
id = id.zfill(16)

# for each field in metadata...
# if the field is a boolean, convert it to a string
for key, value in metadata.items():
Expand All @@ -52,7 +46,7 @@ def create_memory(category, text, metadata={}, embedding=None, id=None):

# insert the document into the collection
memories.upsert(
ids=[str(id)],
ids=[id],
documents=[text],
metadatas=[metadata],
embeddings=[embedding] if embedding is not None else None,
Expand Down
Loading

0 comments on commit 1242d28

Please sign in to comment.