From 66a1ecca00db5e40a75f1b076a7108a1db14084b Mon Sep 17 00:00:00 2001 From: Dev-Khant Date: Mon, 23 Sep 2024 16:02:04 +0530 Subject: [PATCH] fixes mypy issues --- src/crewai/memory/memory.py | 12 +++++++-- .../memory/short_term/short_term_memory.py | 12 +++++++-- src/crewai/memory/storage/interface.py | 4 ++- src/crewai/memory/storage/rag_storage.py | 6 ++--- src/crewai/memory/user/user_memory.py | 26 ++++++++++++++----- 5 files changed, 46 insertions(+), 14 deletions(-) diff --git a/src/crewai/memory/memory.py b/src/crewai/memory/memory.py index 9df09d3c7f..2a10066db3 100644 --- a/src/crewai/memory/memory.py +++ b/src/crewai/memory/memory.py @@ -23,5 +23,13 @@ def save( self.storage.save(value, metadata) - def search(self, query: str) -> Dict[str, Any]: - return self.storage.search(query) + def search( + self, + query: str, + limit: int = 3, + filters: dict = {}, + score_threshold: float = 0.35, + ) -> Dict[str, Any]: + return self.storage.search( + query=query, limit=limit, filters=filters, score_threshold=score_threshold + ) diff --git a/src/crewai/memory/short_term/short_term_memory.py b/src/crewai/memory/short_term/short_term_memory.py index 02fb2d00cf..d774fdb8f0 100644 --- a/src/crewai/memory/short_term/short_term_memory.py +++ b/src/crewai/memory/short_term/short_term_memory.py @@ -42,8 +42,16 @@ def save( super().save(value=item.data, metadata=item.metadata, agent=item.agent) - def search(self, query: str, score_threshold: float = 0.35): - return self.storage.search(query=query, score_threshold=score_threshold) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters + def search( + self, + query: str, + limit: int = 3, + filters: dict = {}, + score_threshold: float = 0.35, + ): + return self.storage.search( + query=query, limit=limit, filters=filters, score_threshold=score_threshold + ) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters def reset(self) -> None: try: diff --git a/src/crewai/memory/storage/interface.py b/src/crewai/memory/storage/interface.py index 0ffc1de162..ebcda9153b 100644 --- a/src/crewai/memory/storage/interface.py +++ b/src/crewai/memory/storage/interface.py @@ -7,7 +7,9 @@ class Storage: def save(self, value: Any, metadata: Dict[str, Any]) -> None: pass - def search(self, key: str) -> Dict[str, Any]: # type: ignore + def search( + self, query: str, limit: int, filters: Dict, score_threshold: float + ) -> Dict[str, Any]: # type: ignore pass def reset(self) -> None: diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index 6af1963709..d6d31582d6 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -92,14 +92,14 @@ def search( # type: ignore # BUG?: Signature of "search" incompatible with supe self, query: str, limit: int = 3, - filter: Optional[dict] = None, + filters: Optional[dict] = None, score_threshold: float = 0.35, ) -> List[Any]: with suppress_logging(): try: results = ( - self.app.search(query, limit, where=filter) - if filter + self.app.search(query, limit, where=filters) + if filters else self.app.search(query, limit) ) except InvalidDimensionException: diff --git a/src/crewai/memory/user/user_memory.py b/src/crewai/memory/user/user_memory.py index c6aaad8134..60f4ddf06b 100644 --- a/src/crewai/memory/user/user_memory.py +++ b/src/crewai/memory/user/user_memory.py @@ -1,5 +1,6 @@ +from typing import Any, Dict, Optional + from crewai.memory.memory import Memory -from crewai.memory.user.user_memory_item import UserMemoryItem from crewai.memory.storage.mem0_storage import Mem0Storage @@ -15,9 +16,22 @@ def __init__(self, crew=None): storage = Mem0Storage(type="user", crew=crew) super().__init__(storage) - def save(self, item: UserMemoryItem) -> None: - data = f"Remember the details about the user: {item.data}" - super().save(data, item.metadata, user=item.user) + def save( + self, + value, + metadata: Optional[Dict[str, Any]] = None, + agent: Optional[str] = None, + ) -> None: + data = f"Remember the details about the user: {value}" + super().save(data, metadata) - def search(self, query: str, score_threshold: float = 0.35): - return self.storage.search(query=query, score_threshold=score_threshold) + def search( + self, + query: str, + limit: int = 3, + filters: dict = {}, + score_threshold: float = 0.35, + ): + return super().search( + query=query, limit=limit, filters=filters, score_threshold=score_threshold + )