From 9b854b83128d436b44c0196774a41dc1aa2a1234 Mon Sep 17 00:00:00 2001 From: heyitsaamir Date: Tue, 17 Dec 2024 15:47:01 -0800 Subject: [PATCH 01/13] Change memory_id to use strings --- packages/memory_module/core/memory_core.py | 2 +- packages/memory_module/core/memory_module.py | 2 +- .../interfaces/base_memory_core.py | 2 +- .../interfaces/base_memory_module.py | 2 +- .../interfaces/base_memory_storage.py | 4 ++-- packages/memory_module/interfaces/types.py | 4 ++-- .../memory_module/storage/in_memory_storage.py | 16 ++++++++-------- .../11_modify_memories_id_to_text.sql | 17 +++++++++++++++++ .../storage/sqlite_memory_storage.py | 16 +++++++++------- src/bot.py | 2 +- tests/memory_module/test_in_memory_storage.py | 0 tests/memory_module/test_memory_storage.py | 4 ++-- 12 files changed, 45 insertions(+), 26 deletions(-) create mode 100644 packages/memory_module/storage/migrations/11_modify_memories_id_to_text.sql create mode 100644 tests/memory_module/test_in_memory_storage.py diff --git a/packages/memory_module/core/memory_core.py b/packages/memory_module/core/memory_core.py index 64059f67..9f9912c1 100644 --- a/packages/memory_module/core/memory_core.py +++ b/packages/memory_module/core/memory_core.py @@ -274,5 +274,5 @@ async def retrieve_chat_history( async def get_memories(self, memory_ids: List[str]) -> List[Memory]: return await self.storage.get_memories(memory_ids) - async def get_messages(self, memory_ids: List[int]) -> Dict[int, List[Message]]: + async def get_messages(self, memory_ids: List[str]) -> Dict[str, List[Message]]: return await self.storage.get_messages(memory_ids) diff --git a/packages/memory_module/core/memory_module.py b/packages/memory_module/core/memory_module.py index 085d995a..1c743f39 100644 --- a/packages/memory_module/core/memory_module.py +++ b/packages/memory_module/core/memory_module.py @@ -46,7 +46,7 @@ async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Opt async def get_memories(self, memory_ids: List[str]) -> List[Memory]: return await self.memory_core.get_memories(memory_ids) - async def get_messages(self, memory_ids: List[int]) -> Dict[int, List[Message]]: + async def get_messages(self, memory_ids: List[str]) -> Dict[str, List[Message]]: return await self.memory_core.get_messages(memory_ids) async def update_memory(self, memory_id: str, updated_memory: str) -> None: diff --git a/packages/memory_module/interfaces/base_memory_core.py b/packages/memory_module/interfaces/base_memory_core.py index 5d4e336d..101a6098 100644 --- a/packages/memory_module/interfaces/base_memory_core.py +++ b/packages/memory_module/interfaces/base_memory_core.py @@ -38,7 +38,7 @@ async def remove_memories(self, user_id: str) -> None: pass @abstractmethod - async def get_messages(self, memory_ids: List[int]) -> Dict[int, List[Message]]: + async def get_messages(self, memory_ids: List[str]) -> Dict[str, List[Message]]: """Get messages based on memory ids.""" pass diff --git a/packages/memory_module/interfaces/base_memory_module.py b/packages/memory_module/interfaces/base_memory_module.py index e2905581..a6b7416c 100644 --- a/packages/memory_module/interfaces/base_memory_module.py +++ b/packages/memory_module/interfaces/base_memory_module.py @@ -30,6 +30,6 @@ async def get_memories(self, memory_ids: List[str]) -> List[Memory]: pass @abstractmethod - async def get_messages(self, memory_ids: List[int]) -> Dict[int, List[Message]]: + async def get_messages(self, memory_ids: List[str]) -> Dict[str, List[Message]]: """Get messages based on memory ids.""" pass diff --git a/packages/memory_module/interfaces/base_memory_storage.py b/packages/memory_module/interfaces/base_memory_storage.py index fa6d49ec..47587c5f 100644 --- a/packages/memory_module/interfaces/base_memory_storage.py +++ b/packages/memory_module/interfaces/base_memory_storage.py @@ -15,7 +15,7 @@ async def store_memory( memory: BaseMemoryInput, *, embedding_vectors: List[List[float]], - ) -> int | None: + ) -> str | None: """Store a memory in the storage system. Args: @@ -65,7 +65,7 @@ async def get_memories(self, memory_ids: List[str]) -> List[Memory]: pass @abstractmethod - async def get_messages(self, memory_ids: List[int]) -> Dict[int, List[Message]]: + async def get_messages(self, memory_ids: List[str]) -> Dict[str, List[Message]]: """Get messages based on memory ids.""" pass diff --git a/packages/memory_module/interfaces/types.py b/packages/memory_module/interfaces/types.py index 137f50fa..cb5b8072 100644 --- a/packages/memory_module/interfaces/types.py +++ b/packages/memory_module/interfaces/types.py @@ -27,7 +27,7 @@ class Message(BaseModel): class MemoryAttribution(BaseModel): - memory_id: int + memory_id: str message_id: str @@ -51,7 +51,7 @@ class BaseMemoryInput(BaseModel): class Memory(BaseMemoryInput): """Represents a processed memory.""" - id: int + id: str class EmbedText(BaseModel): diff --git a/packages/memory_module/storage/in_memory_storage.py b/packages/memory_module/storage/in_memory_storage.py index eeffac01..a8d5fd0a 100644 --- a/packages/memory_module/storage/in_memory_storage.py +++ b/packages/memory_module/storage/in_memory_storage.py @@ -26,11 +26,11 @@ async def store_memory( memory: BaseMemoryInput, *, embedding_vectors: List[List[float]], - ) -> int | None: + ) -> str | None: memory_id = str(len(self.storage["memories"]) + 1) self.storage["memories"][memory_id] = memory self.storage["embeddings"][memory_id] = embedding_vectors - return int(memory_id) + return memory_id async def update_memory(self, memory_id: str, updated_memory: str, *, embedding_vectors: List[List[float]]) -> None: if memory_id in self.storage["memories"]: @@ -75,10 +75,10 @@ async def get_memories(self, memory_ids: List[str]) -> List[Memory]: if memory_id in self.storage["memories"] ] - async def get_messages(self, memory_ids: List[int]) -> Dict[int, List[Message]]: - messages_dict: Dict[int, List[Message]] = {} + async def get_messages(self, memory_ids: List[str]) -> Dict[str, List[Message]]: + messages_dict: Dict[str, List[Message]] = {} for memory_id in memory_ids: - str_id = str(memory_id) + str_id = memory_id if str_id in self.storage["memories"]: memory = self.storage["memories"][str_id] if hasattr(memory, "message_attributions"): @@ -101,11 +101,11 @@ async def clear_memories(self, user_id: str) -> None: ] # remove all memories for user for memory_id in memory_ids_for_user: - self.storage["embeddings"].pop(str(memory_id), None) - self.storage["memories"].pop(str(memory_id), None) + self.storage["embeddings"].pop(memory_id, None) + self.storage["memories"].pop(memory_id, None) async def get_memory(self, memory_id: int) -> Optional[Memory]: - return self.storage["memories"].get(str(memory_id)) + return self.storage["memories"].get(memory_id) async def get_all_memories(self, limit: Optional[int] = None) -> List[Memory]: return [value for key, value in self.storage["memories"].items()][:limit] diff --git a/packages/memory_module/storage/migrations/11_modify_memories_id_to_text.sql b/packages/memory_module/storage/migrations/11_modify_memories_id_to_text.sql new file mode 100644 index 00000000..294dd78e --- /dev/null +++ b/packages/memory_module/storage/migrations/11_modify_memories_id_to_text.sql @@ -0,0 +1,17 @@ +-- Convert memories.id to TEXT +CREATE TABLE memories_new ( + id TEXT PRIMARY KEY, + content TEXT NOT NULL, + created_at TIMESTAMP NOT NULL, + user_id TEXT, + memory_type TEXT NOT NULL DEFAULT 'semantic' +); + +-- Copy data from old table to new table, converting id to TEXT +INSERT INTO memories_new SELECT CAST(id AS TEXT), content, created_at, user_id, memory_type FROM memories; + +-- Drop the old table +DROP TABLE memories; + +-- Rename the new table to the original name +ALTER TABLE memories_new RENAME TO memories; \ No newline at end of file diff --git a/packages/memory_module/storage/sqlite_memory_storage.py b/packages/memory_module/storage/sqlite_memory_storage.py index a2af71ed..6b386062 100644 --- a/packages/memory_module/storage/sqlite_memory_storage.py +++ b/packages/memory_module/storage/sqlite_memory_storage.py @@ -1,4 +1,5 @@ import logging +import uuid from pathlib import Path from typing import Dict, List, Optional @@ -25,19 +26,22 @@ def __init__(self, db_path: Optional[str | Path] = None): self.db_path = db_path or DEFAULT_DB_PATH self.storage = SQLiteStorage(self.db_path) - async def store_memory(self, memory: BaseMemoryInput, *, embedding_vectors: List[List[float]]) -> int | None: + async def store_memory(self, memory: BaseMemoryInput, *, embedding_vectors: List[List[float]]) -> str: """Store a memory and its message attributions.""" serialized_embeddings = [ sqlite_vec.serialize_float32(embedding_vector) for embedding_vector in embedding_vectors ] + memory_id = str(uuid.uuid4()) + async with self.storage.transaction() as cursor: # Store the memory await cursor.execute( """INSERT INTO memories - (content, created_at, user_id, memory_type) - VALUES (?, ?, ?, ?)""", + (id, content, created_at, user_id, memory_type) + VALUES (?, ?, ?, ?, ?)""", ( + memory_id, memory.content, memory.created_at, memory.user_id, @@ -45,8 +49,6 @@ async def store_memory(self, memory: BaseMemoryInput, *, embedding_vectors: List ), ) - memory_id = cursor.lastrowid - # Store message attributions if memory.message_attributions: await cursor.executemany( @@ -344,7 +346,7 @@ async def get_memories(self, memory_ids: List[str]) -> List[Memory]: return [Memory(**memory_data) for memory_data in memories_dict.values()] - async def get_messages(self, memory_ids: List[int]) -> Dict[int, List[Message]]: + async def get_messages(self, memory_ids: List[str]) -> Dict[str, List[Message]]: """Get messages based on memory ids.""" query = """ SELECT ma.memory_id, m.* @@ -357,7 +359,7 @@ async def get_messages(self, memory_ids: List[int]) -> Dict[int, List[Message]]: messages_dict = {} for row in rows: - memory_id = int(row["memory_id"]) + memory_id = row["memory_id"] if memory_id not in messages_dict: messages_dict[memory_id] = [] diff --git a/src/bot.py b/src/bot.py index fe549149..83dd151a 100644 --- a/src/bot.py +++ b/src/bot.py @@ -159,7 +159,7 @@ async def confirm_memorized_fields(fields_to_confirm: ConfirmMemorizedFields, co # group memories by field name user_details_with_memories: List[tuple[UserDetail, Memory | None]] = [] for user_detail in fields_to_confirm.fields_to_confirm: - memories_for_user_detail = [memory for memory in memories if str(memory.id) in user_detail.memory_ids] + memories_for_user_detail = [memory for memory in memories if memory.id in user_detail.memory_ids] # just take the first one into account for citation (for now) user_details_with_memories.append( (user_detail, memories_for_user_detail[0] if memories_for_user_detail else None) diff --git a/tests/memory_module/test_in_memory_storage.py b/tests/memory_module/test_in_memory_storage.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/memory_module/test_memory_storage.py b/tests/memory_module/test_memory_storage.py index 999747a4..1a061e10 100644 --- a/tests/memory_module/test_memory_storage.py +++ b/tests/memory_module/test_memory_storage.py @@ -73,7 +73,7 @@ async def test_update_memory(memory_storage, sample_memory_input, sample_embeddi # Update memory updated_content = "Updated memory content" - await memory_storage.update_memory(str(memory_id), updated_content, embedding_vectors=sample_embedding) + await memory_storage.update_memory(memory_id, updated_content, embedding_vectors=sample_embedding) # Verify update updated_memory = await memory_storage.get_memory(memory_id) @@ -173,7 +173,7 @@ async def test_get_memories_by_ids(memory_storage, sample_memory_input, sample_e memory_id = await memory_storage.store_memory(sample_memory_input, embedding_vectors=sample_embedding) # Retrieve by ID - memories = await memory_storage.get_memories([str(memory_id)]) + memories = await memory_storage.get_memories([memory_id]) assert len(memories) == 1 assert memories[0].content == sample_memory_input.content From 80e2b93e7453bfb2f7085f464ebca1c452b51b99 Mon Sep 17 00:00:00 2001 From: heyitsaamir Date: Tue, 17 Dec 2024 16:18:32 -0800 Subject: [PATCH 02/13] Fix tests --- tests/memory_module/test_memory_module.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/memory_module/test_memory_module.py b/tests/memory_module/test_memory_module.py index a0459d6d..5dd160a6 100644 --- a/tests/memory_module/test_memory_module.py +++ b/tests/memory_module/test_memory_module.py @@ -148,7 +148,7 @@ async def test_simple_conversation(memory_module): result = await memory_module.retrieve_memories("apple pie", "", 1) assert len(result) == 1 - assert result[0].id == 2 + assert result[0].id == next(memory.id for memory in stored_memories if "apple pie" in memory.content) @pytest.mark.asyncio @@ -239,8 +239,9 @@ async def test_update_memory(memory_module): stored_memories = await memory_module.memory_core.storage.get_all_memories() assert len(stored_memories) >= 1 - await memory_module.update_memory(1, "The user like San Diego city") - updated_message = await memory_module.memory_core.storage.get_memory(1) + memory_id = next(memory.id for memory in stored_memories if "Seattle" in memory.content) + await memory_module.update_memory(memory_id, "The user like San Diego city") + updated_message = await memory_module.memory_core.storage.get_memory(memory_id) assert "San Diego" in updated_message.content From 9e892ef9041eab28dc141f0e8deada94cdaa1e9a Mon Sep 17 00:00:00 2001 From: heyitsaamir Date: Tue, 17 Dec 2024 16:19:26 -0800 Subject: [PATCH 03/13] Add type to message --- packages/evals/benchmark_memory_module.py | 8 +++--- packages/memory_module/core/memory_core.py | 2 +- packages/memory_module/interfaces/types.py | 4 +-- .../migrations/12_change_message_type.sql | 25 +++++++++++++++++++ .../storage/sqlite_memory_storage.py | 4 +-- .../storage/sqlite_message_buffer_storage.py | 2 +- src/bot.py | 10 ++++---- tests/memory_module/test_memory_module.py | 8 ++++-- tests/memory_module/test_memory_storage.py | 6 ++--- tests/memory_module/utils.py | 1 + 10 files changed, 50 insertions(+), 20 deletions(-) create mode 100644 packages/memory_module/storage/migrations/12_change_message_type.sql diff --git a/packages/evals/benchmark_memory_module.py b/packages/evals/benchmark_memory_module.py index 462d731b..9f88b49b 100644 --- a/packages/evals/benchmark_memory_module.py +++ b/packages/evals/benchmark_memory_module.py @@ -52,16 +52,16 @@ def create_message(**kwargs): return Message( id=str(uuid.uuid4()), content=kwargs["content"], - is_assistant_message=kwargs["is_assistant_message"], - # author_id="user" if not kwargs["is_assistant_message"] else None, + type=kwargs["type"], + # author_id="user" if kwargs["type"] == "user" else None, author_id="user", created_at=datetime.now(), conversation_ref="conversation_ref", ) for message in messages: - is_assistant_message = message["role"] == "assistant" - msg = create_message(content=message["content"], is_assistant_message=is_assistant_message) + type = "assistant" if message["role"] == "assistant" else "user" + msg = create_message(content=message["content"], type=type) await memory_module.add_message(msg) diff --git a/packages/memory_module/core/memory_core.py b/packages/memory_module/core/memory_core.py index 9f9912c1..a5349f73 100644 --- a/packages/memory_module/core/memory_core.py +++ b/packages/memory_module/core/memory_core.py @@ -198,7 +198,7 @@ async def _extract_semantic_fact_from_messages( logger.info("Extracting semantic facts from messages") messages_str = "" for idx, message in enumerate(messages): - if not message.is_assistant_message: + if message.type == "user": messages_str += f"{idx}. User: {message.content}\n" else: messages_str += f"{idx}. Assistant: {message.content}\n" diff --git a/packages/memory_module/interfaces/types.py b/packages/memory_module/interfaces/types.py index cb5b8072..250e964e 100644 --- a/packages/memory_module/interfaces/types.py +++ b/packages/memory_module/interfaces/types.py @@ -1,7 +1,7 @@ from datetime import datetime from decimal import Decimal from enum import Enum -from typing import List, Optional +from typing import List, Literal, Optional from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -22,7 +22,7 @@ class Message(BaseModel): author_id: Optional[str] conversation_ref: str created_at: datetime - is_assistant_message: bool = False + type: Literal["user", "assistant", "internal"] | None = None deep_link: Optional[str] = None diff --git a/packages/memory_module/storage/migrations/12_change_message_type.sql b/packages/memory_module/storage/migrations/12_change_message_type.sql new file mode 100644 index 00000000..0ef9dd83 --- /dev/null +++ b/packages/memory_module/storage/migrations/12_change_message_type.sql @@ -0,0 +1,25 @@ +-- Create new table with desired schema +CREATE TABLE messages_new ( + id TEXT NOT NULL, + content TEXT NOT NULL, + author_id TEXT NOT NULL, + conversation_ref TEXT NOT NULL, + created_at TIMESTAMP NOT NULL, + deep_link TEXT, + type TEXT NOT NULL +); + +-- Copy data from old table to new table +INSERT INTO messages_new (id, content, author_id, conversation_ref, created_at, deep_link, type) +SELECT id, content, author_id, conversation_ref, created_at, deep_link, + CASE + WHEN is_assistant_message = 1 THEN 'assistant' + ELSE 'user' + END +FROM messages; + +-- Drop old table +DROP TABLE messages; + +-- Rename new table to original name +ALTER TABLE messages_new RENAME TO messages; \ No newline at end of file diff --git a/packages/memory_module/storage/sqlite_memory_storage.py b/packages/memory_module/storage/sqlite_memory_storage.py index 6b386062..3590eedf 100644 --- a/packages/memory_module/storage/sqlite_memory_storage.py +++ b/packages/memory_module/storage/sqlite_memory_storage.py @@ -278,7 +278,7 @@ async def store_short_term_memory(self, message: Message) -> None: author_id, conversation_ref, created_at, - is_assistant_message, + type, deep_link ) VALUES (?, ?, ?, ?, ?, ?, ?)""", ( @@ -287,7 +287,7 @@ async def store_short_term_memory(self, message: Message) -> None: message.author_id, message.conversation_ref, message.created_at, - message.is_assistant_message, + message.type, message.deep_link, ), ) diff --git a/packages/memory_module/storage/sqlite_message_buffer_storage.py b/packages/memory_module/storage/sqlite_message_buffer_storage.py index bce5699c..0c71432e 100644 --- a/packages/memory_module/storage/sqlite_message_buffer_storage.py +++ b/packages/memory_module/storage/sqlite_message_buffer_storage.py @@ -50,7 +50,7 @@ async def get_buffered_messages(self, conversation_ref: str) -> List[Message]: m.author_id, m.conversation_ref, m.created_at, - m.is_assistant_message, + m.type, m.deep_link FROM buffered_messages b JOIN messages m ON b.message_id = m.id diff --git a/src/bot.py b/src/bot.py index 83dd151a..bf904231 100644 --- a/src/bot.py +++ b/src/bot.py @@ -67,7 +67,7 @@ class TaskConfig(BaseModel): task_name="troubleshoot_device_issue", required_fields=["OS", "Device Type", "Year"] ), "troubleshoot_connectivity_issue": TaskConfig( - task_name="troubleshoot_connectivity_issue", required_fields=["OS", "Device Type", "Year"] + task_name="troubleshoot_connectivity_issue", required_fields=["OS", "Device Type", "Router Location"] ), "troubleshoot_access_issue": TaskConfig( task_name="troubleshoot_access_issue", required_fields=["OS", "Device Type", "Year"] @@ -288,7 +288,7 @@ def build_deep_link(context: TurnContext, message_id: str): async def add_message( context: TurnContext, content: str, - is_assistant_message: bool, + message_type: Literal["user", "assistant", "internal"] = "user", created_at: datetime.datetime | None = None, override_message_id: bool = False, ): @@ -314,7 +314,7 @@ async def add_message( author_id=user_aad_object_id, conversation_ref=conversation_ref_dict.conversation.id, created_at=created_at or datetime.datetime.now(datetime.timezone.utc), - is_assistant_message=is_assistant_message, + type=type, deep_link=build_deep_link(context, context.activity.id), ) ) @@ -344,7 +344,7 @@ async def on_message(context: TurnContext, state: TurnState): Note: Step 2 - Gather necessary information for the selected task. To gather missing fields for the task: Step 2a: Use the "get_memorized_fields" function to check if any required fields are already known. - Step 2b (If necessary): Use the "confirm_memorized_fields" function to confirm the fields if they are already known. You should only call this if the user has alraedy previously provided this information. + Step 2b (If necessary): Use the "confirm_memorized_fields" function to confirm the fields if they are already known. Step 2c (If necessary): For each missing field, prompt the user to provide the required information. Note: Step 3 - Execute the task. @@ -373,7 +373,7 @@ async def on_message(context: TurnContext, state: TurnState): }, *[ { - "role": "assistant" if message.is_assistant_message else "user", + "role": "user" if message.type == "user" else "assistant", "content": message.content, } for message in messages diff --git a/tests/memory_module/test_memory_module.py b/tests/memory_module/test_memory_module.py index 5dd160a6..0e29e887 100644 --- a/tests/memory_module/test_memory_module.py +++ b/tests/memory_module/test_memory_module.py @@ -126,6 +126,7 @@ async def test_simple_conversation(memory_module): author_id="user-123", conversation_ref=conversation_id, created_at=datetime.now(), + type="user", ), Message( id=str(uuid4()), @@ -133,6 +134,7 @@ async def test_simple_conversation(memory_module): author_id="user-123", conversation_ref=conversation_id, created_at=datetime.now(), + type="user", ), ] @@ -206,7 +208,7 @@ async def mock_extract_episodic(*args, **kwargs): author_id="user-123", conversation_ref=conversation_id, created_at=datetime.now(), - role="user", + type="user", ) for i in range(3) ] @@ -229,6 +231,7 @@ async def test_update_memory(memory_module): author_id="user-123", conversation_ref=conversation_id, created_at=datetime.now(), + type="user", ), ] @@ -256,6 +259,7 @@ async def test_remove_memory(memory_module): author_id="user-123", conversation_ref=conversation_id, created_at=datetime.now(), + type="user", ), ] @@ -282,7 +286,7 @@ async def test_short_term_memory(memory_module): author_id="user-123", conversation_ref=conversation_id, created_at=datetime.now(), - role="user", + type="user", ) for i in range(3) ] diff --git a/tests/memory_module/test_memory_storage.py b/tests/memory_module/test_memory_storage.py index 1a061e10..8e8d663f 100644 --- a/tests/memory_module/test_memory_storage.py +++ b/tests/memory_module/test_memory_storage.py @@ -41,7 +41,7 @@ def sample_message(): author_id="user1", conversation_ref="conv1", created_at=datetime.now(), - is_assistant_message=False, + type="user", ) @@ -188,7 +188,7 @@ async def test_get_messages(memory_storage): author_id="user1", conversation_ref="conv1", created_at=datetime.now(), - is_assistant_message=False, + type="user", deep_link="link1", ), Message( @@ -197,7 +197,7 @@ async def test_get_messages(memory_storage): author_id="user1", conversation_ref="conv1", created_at=datetime.now(), - is_assistant_message=True, + type="assistant", deep_link="link2", ), ] diff --git a/tests/memory_module/utils.py b/tests/memory_module/utils.py index 188973ff..72606b61 100644 --- a/tests/memory_module/utils.py +++ b/tests/memory_module/utils.py @@ -19,6 +19,7 @@ def create_test_message(content: str): conversation_ref="123", created_at=datetime.now(), content=content, + type="user", ) From 0aceaeff65ac9ae51bfe779fde84159a29d4b489 Mon Sep 17 00:00:00 2001 From: Aamir <48929123+heyitsaamir@users.noreply.github.com> Date: Tue, 17 Dec 2024 19:39:29 -0800 Subject: [PATCH 04/13] Update src/bot.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/bot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bot.py b/src/bot.py index bf904231..581d3572 100644 --- a/src/bot.py +++ b/src/bot.py @@ -314,7 +314,7 @@ async def add_message( author_id=user_aad_object_id, conversation_ref=conversation_ref_dict.conversation.id, created_at=created_at or datetime.datetime.now(datetime.timezone.utc), - type=type, + type=message_type, deep_link=build_deep_link(context, context.activity.id), ) ) From e4c6e1294f314e38872f0726f03be658767dc34c Mon Sep 17 00:00:00 2001 From: heyitsaamir Date: Wed, 18 Dec 2024 12:07:57 -0800 Subject: [PATCH 05/13] Split message types --- packages/memory_module/__init__.py | 14 +++ packages/memory_module/core/memory_core.py | 12 +- packages/memory_module/core/memory_module.py | 16 ++- .../interfaces/base_memory_core.py | 4 +- .../interfaces/base_memory_module.py | 4 +- .../interfaces/base_memory_storage.py | 11 +- packages/memory_module/interfaces/types.py | 49 ++++++-- .../storage/in_memory_storage.py | 80 +++++++++++-- .../storage/migrations_manager.py | 1 + .../storage/sqlite_memory_storage.py | 67 +++++++---- .../storage/sqlite_message_buffer_storage.py | 11 +- packages/memory_module/storage/utils.py | 16 +++ src/bot.py | 107 +++++++++++++----- tests/memory_module/test_memory_storage.py | 13 +-- 14 files changed, 306 insertions(+), 99 deletions(-) create mode 100644 packages/memory_module/storage/utils.py diff --git a/packages/memory_module/__init__.py b/packages/memory_module/__init__.py index ff6b4f3d..f7eea423 100644 --- a/packages/memory_module/__init__.py +++ b/packages/memory_module/__init__.py @@ -1,9 +1,16 @@ from memory_module.config import LLMConfig, MemoryModuleConfig from memory_module.core.memory_module import MemoryModule from memory_module.interfaces.types import ( + AssistantMessage, + AssistantMessageInput, + InternalMessage, + InternalMessageInput, Memory, Message, + MessageInput, ShortTermMemoryRetrievalConfig, + UserMessage, + UserMessageInput, ) __all__ = [ @@ -11,6 +18,13 @@ "MemoryModuleConfig", "LLMConfig", "Memory", + "InternalMessage", + "InternalMessageInput", + "UserMessageInput", + "UserMessage", "Message", + "MessageInput", + "AssistantMessage", + "AssistantMessageInput", "ShortTermMemoryRetrievalConfig", ] diff --git a/packages/memory_module/core/memory_core.py b/packages/memory_module/core/memory_core.py index a5349f73..c38d0c20 100644 --- a/packages/memory_module/core/memory_core.py +++ b/packages/memory_module/core/memory_core.py @@ -14,6 +14,7 @@ Memory, MemoryType, Message, + MessageInput, ShortTermMemoryRetrievalConfig, ) from memory_module.services.llm_service import LLMService @@ -94,7 +95,7 @@ def __init__( storage: Optional storage implementation for memory persistence """ self.lm = llm_service - self.storage = storage or ( + self.storage: BaseMemoryStorage = storage or ( SQLiteMemoryStorage(db_path=config.db_path) if config.db_path is not None else InMemoryStorage() ) @@ -200,8 +201,11 @@ async def _extract_semantic_fact_from_messages( for idx, message in enumerate(messages): if message.type == "user": messages_str += f"{idx}. User: {message.content}\n" - else: + elif message.type == "assistant": messages_str += f"{idx}. Assistant: {message.content}\n" + else: + # we explicitly ignore internal messages + continue system_message = f"""You are a semantic memory management agent. Your goal is to extract meaningful, facts and preferences from user messages. Focus on recognizing general patterns and interests that will remain relevant over time, even if the user is mentioning short-term plans or events. @@ -262,8 +266,8 @@ async def _extract_episodic_memory_from_messages(self, messages: List[Message]) return await self.lm.completion(messages=messages, response_model=EpisodicMemoryExtraction) - async def add_short_term_memory(self, message: Message) -> None: - await self.storage.store_short_term_memory(message) + async def add_short_term_memory(self, message: MessageInput) -> Message: + return await self.storage.store_short_term_memory(message) async def retrieve_chat_history( self, conversation_ref: str, config: ShortTermMemoryRetrievalConfig diff --git a/packages/memory_module/core/memory_module.py b/packages/memory_module/core/memory_module.py index 1c743f39..97bd993f 100644 --- a/packages/memory_module/core/memory_module.py +++ b/packages/memory_module/core/memory_module.py @@ -6,7 +6,7 @@ from memory_module.interfaces.base_memory_core import BaseMemoryCore from memory_module.interfaces.base_memory_module import BaseMemoryModule from memory_module.interfaces.base_message_queue import BaseMessageQueue -from memory_module.interfaces.types import Memory, Message, ShortTermMemoryRetrievalConfig +from memory_module.interfaces.types import Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig from memory_module.services.llm_service import LLMService @@ -31,13 +31,17 @@ def __init__( self.config = config self.llm_service = llm_service or LLMService(config=config.llm) - self.memory_core = memory_core or MemoryCore(config=config, llm_service=self.llm_service) - self.message_queue = message_queue or MessageQueue(config=config, memory_core=self.memory_core) + self.memory_core: BaseMemoryCore = memory_core or MemoryCore(config=config, llm_service=self.llm_service) + self.message_queue: BaseMessageQueue = message_queue or MessageQueue( + config=config, memory_core=self.memory_core + ) - async def add_message(self, message: Message) -> None: + async def add_message(self, message: MessageInput) -> Message: """Add a message to be processed into memory.""" - await self.memory_core.add_short_term_memory(message) - await self.message_queue.enqueue(message) + message_res = await self.memory_core.add_short_term_memory(message) + await self.message_queue.enqueue(message_res) + + return message_res async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Optional[int]) -> List[Memory]: """Retrieve relevant memories based on a query.""" diff --git a/packages/memory_module/interfaces/base_memory_core.py b/packages/memory_module/interfaces/base_memory_core.py index 101a6098..3b743cde 100644 --- a/packages/memory_module/interfaces/base_memory_core.py +++ b/packages/memory_module/interfaces/base_memory_core.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional -from memory_module.interfaces.types import Memory, Message, ShortTermMemoryRetrievalConfig +from memory_module.interfaces.types import Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig class BaseMemoryCore(ABC): @@ -43,7 +43,7 @@ async def get_messages(self, memory_ids: List[str]) -> Dict[str, List[Message]]: pass @abstractmethod - async def add_short_term_memory(self, message: Message) -> None: + async def add_short_term_memory(self, message: MessageInput) -> Message: """Add a short-term memory entry.""" pass diff --git a/packages/memory_module/interfaces/base_memory_module.py b/packages/memory_module/interfaces/base_memory_module.py index a6b7416c..fdde341e 100644 --- a/packages/memory_module/interfaces/base_memory_module.py +++ b/packages/memory_module/interfaces/base_memory_module.py @@ -1,14 +1,14 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional -from memory_module.interfaces.types import Memory, Message, ShortTermMemoryRetrievalConfig +from memory_module.interfaces.types import Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig class BaseMemoryModule(ABC): """Base class for the memory module interface.""" @abstractmethod - async def add_message(self, message: Message) -> None: + async def add_message(self, message: MessageInput) -> Message: """Add a message to be processed into memory.""" pass diff --git a/packages/memory_module/interfaces/base_memory_storage.py b/packages/memory_module/interfaces/base_memory_storage.py index 47587c5f..a22a9360 100644 --- a/packages/memory_module/interfaces/base_memory_storage.py +++ b/packages/memory_module/interfaces/base_memory_storage.py @@ -1,7 +1,14 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional -from memory_module.interfaces.types import BaseMemoryInput, EmbedText, Memory, Message, ShortTermMemoryRetrievalConfig +from memory_module.interfaces.types import ( + BaseMemoryInput, + EmbedText, + Memory, + Message, + MessageInput, + ShortTermMemoryRetrievalConfig, +) class BaseMemoryStorage(ABC): @@ -30,7 +37,7 @@ async def update_memory(self, memory_id: str, updated_memory: str, *, embedding_ pass @abstractmethod - async def store_short_term_memory(self, message: Message) -> None: + async def store_short_term_memory(self, message: MessageInput) -> Message: """Store a short-term memory entry. Args: diff --git a/packages/memory_module/interfaces/types.py b/packages/memory_module/interfaces/types.py index 250e964e..b1a3acf6 100644 --- a/packages/memory_module/interfaces/types.py +++ b/packages/memory_module/interfaces/types.py @@ -1,7 +1,8 @@ +from abc import ABC from datetime import datetime from decimal import Decimal from enum import Enum -from typing import List, Literal, Optional +from typing import ClassVar, List, Optional from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -12,18 +13,52 @@ class User(BaseModel): id: str -class Message(BaseModel): - """Represents a message in a conversation.""" +class BaseMessageInput(ABC, BaseModel): + content: str + author_id: str + conversation_ref: str + +class InternalMessageInput(BaseMessageInput): + model_config = ConfigDict(from_attributes=True) + type: ClassVar = "internal" + created_at: Optional[datetime] = None + + +class InternalMessage(InternalMessageInput): model_config = ConfigDict(from_attributes=True) id: str - content: str - author_id: Optional[str] - conversation_ref: str + created_at: datetime # type: ignore Ignoring because this will exist in the concrete class + + +class UserMessageInput(BaseMessageInput): + model_config = ConfigDict(from_attributes=True) + id: str + type: ClassVar = "user" + deep_link: Optional[str] = None created_at: datetime - type: Literal["user", "assistant", "internal"] | None = None + + +class UserMessage(UserMessageInput): + model_config = ConfigDict(from_attributes=True) + + +class AssistantMessageInput(BaseMessageInput): + model_config = ConfigDict(from_attributes=True) + id: str + type: ClassVar = "assistant" deep_link: Optional[str] = None + created_at: Optional[datetime] = None + + +class AssistantMessage(AssistantMessageInput): + model_config = ConfigDict(from_attributes=True) + created_at: datetime # type: ignore Ignoring because this will exist in the concrete class + + +type MessageInput = InternalMessageInput | UserMessageInput | AssistantMessageInput +type Message = InternalMessage | UserMessage | AssistantMessage class MemoryAttribution(BaseModel): diff --git a/packages/memory_module/storage/in_memory_storage.py b/packages/memory_module/storage/in_memory_storage.py index a8d5fd0a..b0e80a9b 100644 --- a/packages/memory_module/storage/in_memory_storage.py +++ b/packages/memory_module/storage/in_memory_storage.py @@ -1,6 +1,7 @@ import datetime +import uuid from collections import defaultdict -from typing import Dict, List, Optional +from typing import Dict, List, Optional, TypedDict import numpy as np from memory_module.interfaces.base_memory_storage import BaseMemoryStorage @@ -9,12 +10,32 @@ ) from memory_module.interfaces.base_scheduled_events_service import Event from memory_module.interfaces.base_scheduled_events_storage import BaseScheduledEventsStorage -from memory_module.interfaces.types import BaseMemoryInput, EmbedText, Memory, Message, ShortTermMemoryRetrievalConfig +from memory_module.interfaces.types import ( + AssistantMessage, + AssistantMessageInput, + BaseMemoryInput, + EmbedText, + InternalMessage, + InternalMessageInput, + Memory, + Message, + MessageInput, + ShortTermMemoryRetrievalConfig, + UserMessage, + UserMessageInput, +) + + +class InMemoryInternalStore(TypedDict): + memories: Dict[str, Memory] + embeddings: Dict[str, List[List[float]]] + buffered_messages: Dict[str, List[Message]] + scheduled_events: Dict[str, Event] class InMemoryStorage(BaseMemoryStorage, BaseMessageBufferStorage, BaseScheduledEventsStorage): def __init__(self): - self.storage: Dict = { + self.storage: InMemoryInternalStore = { "embeddings": {}, "buffered_messages": defaultdict(list), "scheduled_events": {}, @@ -28,7 +49,8 @@ async def store_memory( embedding_vectors: List[List[float]], ) -> str | None: memory_id = str(len(self.storage["memories"]) + 1) - self.storage["memories"][memory_id] = memory + memory_obj = Memory(**memory.model_dump(), id=memory_id) + self.storage["memories"][memory_id] = memory_obj self.storage["embeddings"][memory_id] = embedding_vectors return memory_id @@ -37,8 +59,48 @@ async def update_memory(self, memory_id: str, updated_memory: str, *, embedding_ self.storage["memories"][memory_id].content = updated_memory self.storage["embeddings"][memory_id] = embedding_vectors - async def store_short_term_memory(self, message: Message) -> None: - await self.store_buffered_message(message) + async def store_short_term_memory(self, message: MessageInput) -> Message: + if isinstance(message, InternalMessageInput): + id = str(uuid.uuid4()) + else: + id = message.id + + created_at = message.created_at or datetime.datetime.now() + + if isinstance(message, InternalMessageInput): + deep_link = None + else: + deep_link = message.deep_link + + if isinstance(message, UserMessageInput): + message_obj = UserMessage( + id=id, + content=message.content, + created_at=created_at, + conversation_ref=message.conversation_ref, + deep_link=deep_link, + author_id=message.author_id, + ) + elif isinstance(message, AssistantMessageInput): + message_obj = AssistantMessage( + id=id, + content=message.content, + created_at=created_at, + conversation_ref=message.conversation_ref, + deep_link=deep_link, + author_id=message.author_id, + ) + else: + message_obj = InternalMessage( + id=id, + content=message.content, + created_at=created_at, + conversation_ref=message.conversation_ref, + author_id=message.author_id, + ) + + await self.store_buffered_message(message_obj) + return message_obj async def retrieve_memories( self, embedText: EmbedText, user_id: Optional[str], limit: Optional[int] = None @@ -66,7 +128,7 @@ async def retrieve_memories( ) sorted_memories.sort(key=lambda x: x["distance"], reverse=True) - return [Memory(id=item["id"], **item["memory"].__dict__) for item in sorted_memories[:limit]] + return [Memory(**item["memory"].__dict__) for item in sorted_memories[:limit]] async def get_memories(self, memory_ids: List[str]) -> List[Memory]: return [ @@ -81,7 +143,7 @@ async def get_messages(self, memory_ids: List[str]) -> Dict[str, List[Message]]: str_id = memory_id if str_id in self.storage["memories"]: memory = self.storage["memories"][str_id] - if hasattr(memory, "message_attributions"): + if memory.message_attributions: messages = [] for msg_id in memory.message_attributions: # Search through buffered messages to find matching message @@ -104,7 +166,7 @@ async def clear_memories(self, user_id: str) -> None: self.storage["embeddings"].pop(memory_id, None) self.storage["memories"].pop(memory_id, None) - async def get_memory(self, memory_id: int) -> Optional[Memory]: + async def get_memory(self, memory_id: str) -> Optional[Memory]: return self.storage["memories"].get(memory_id) async def get_all_memories(self, limit: Optional[int] = None) -> List[Memory]: diff --git a/packages/memory_module/storage/migrations_manager.py b/packages/memory_module/storage/migrations_manager.py index 602e9866..5fbb7f8b 100644 --- a/packages/memory_module/storage/migrations_manager.py +++ b/packages/memory_module/storage/migrations_manager.py @@ -7,6 +7,7 @@ import sqlite_vec logger = logging.getLogger(__name__) +print(sqlite3.sqlite_version) class MigrationManager: diff --git a/packages/memory_module/storage/sqlite_memory_storage.py b/packages/memory_module/storage/sqlite_memory_storage.py index 3590eedf..52a6bf37 100644 --- a/packages/memory_module/storage/sqlite_memory_storage.py +++ b/packages/memory_module/storage/sqlite_memory_storage.py @@ -1,3 +1,4 @@ +import datetime import logging import uuid from pathlib import Path @@ -8,11 +9,14 @@ from memory_module.interfaces.types import ( BaseMemoryInput, EmbedText, + InternalMessageInput, Memory, Message, + MessageInput, ShortTermMemoryRetrievalConfig, ) from memory_module.storage.sqlite_storage import SQLiteStorage +from memory_module.storage.utils import build_message_from_dict logger = logging.getLogger(__name__) @@ -268,29 +272,44 @@ async def get_all_memories(self, limit: Optional[int] = None) -> List[Memory]: return [Memory(**memory_data) for memory_data in memories_dict.values()] - async def store_short_term_memory(self, message: Message) -> None: + async def store_short_term_memory(self, message: MessageInput) -> Message: """Store a short-term memory entry.""" - async with self.storage.transaction() as cursor: - await cursor.execute( - """INSERT INTO messages ( - id, - content, - author_id, - conversation_ref, - created_at, - type, - deep_link - ) VALUES (?, ?, ?, ?, ?, ?, ?)""", - ( - message.id, - message.content, - message.author_id, - message.conversation_ref, - message.created_at, - message.type, - message.deep_link, - ), - ) + if isinstance(message, InternalMessageInput): + id = str(uuid.uuid4()) + else: + id = message.id + + created_at = message.created_at or datetime.datetime.now() + + if isinstance(message, InternalMessageInput): + deep_link = None + else: + deep_link = message.deep_link + await self.storage.execute( + """INSERT INTO messages ( + id, + content, + author_id, + conversation_ref, + created_at, + type, + deep_link + ) VALUES (?, ?, ?, ?, ?, ?, ?)""", + ( + id, + message.content, + message.author_id, + message.conversation_ref, + created_at, + message.type, + deep_link, + ), + ) + + row = await self.storage.fetch_one("SELECT * FROM messages WHERE id = ?", (id,)) + if not row: + raise ValueError(f"Message with id {id} not found in storage") + return build_message_from_dict(row) async def retrieve_chat_history( self, conversation_ref: str, config: ShortTermMemoryRetrievalConfig @@ -309,7 +328,7 @@ async def retrieve_chat_history( rows = await self.storage.fetch_all(query, params) - return [Message(**row) for row in rows][::-1] + return [build_message_from_dict(row) for row in rows][::-1] async def get_memories(self, memory_ids: List[str]) -> List[Memory]: query = """ @@ -363,6 +382,6 @@ async def get_messages(self, memory_ids: List[str]) -> Dict[str, List[Message]]: if memory_id not in messages_dict: messages_dict[memory_id] = [] - messages_dict[memory_id].append(Message(**row)) + messages_dict[memory_id].append(build_message_from_dict(row)) return messages_dict diff --git a/packages/memory_module/storage/sqlite_message_buffer_storage.py b/packages/memory_module/storage/sqlite_message_buffer_storage.py index 0c71432e..edd5b91e 100644 --- a/packages/memory_module/storage/sqlite_message_buffer_storage.py +++ b/packages/memory_module/storage/sqlite_message_buffer_storage.py @@ -7,6 +7,7 @@ ) from memory_module.interfaces.types import Message from memory_module.storage.sqlite_storage import SQLiteStorage +from memory_module.storage.utils import build_message_from_dict logger = logging.getLogger(__name__) @@ -45,20 +46,14 @@ async def get_buffered_messages(self, conversation_ref: str) -> List[Message]: """Retrieve all buffered messages for a conversation.""" query = """ SELECT - m.id, - m.content, - m.author_id, - m.conversation_ref, - m.created_at, - m.type, - m.deep_link + m.* FROM buffered_messages b JOIN messages m ON b.message_id = m.id WHERE b.conversation_ref = ? ORDER BY b.created_at ASC """ results = await self.storage.fetch_all(query, (conversation_ref,)) - return [Message(**row) for row in results] + return [build_message_from_dict(row) for row in results] async def clear_buffered_messages(self, conversation_ref: str) -> None: """Remove all buffered messages for a conversation.""" diff --git a/packages/memory_module/storage/utils.py b/packages/memory_module/storage/utils.py new file mode 100644 index 00000000..b0ee5fb0 --- /dev/null +++ b/packages/memory_module/storage/utils.py @@ -0,0 +1,16 @@ +from typing import Dict + +from memory_module.interfaces.types import AssistantMessage, InternalMessage, Message, UserMessage + + +def build_message_from_dict(row: Dict) -> Message: + """Build a message object from a dictionary which contains the message data.""" + + if row["type"] == "internal": + return InternalMessage(**row) + elif row["type"] == "user": + return UserMessage(**row) + elif row["type"] == "assistant": + return AssistantMessage(**row) + else: + raise ValueError(f"Invalid message type: {row['type']}") diff --git a/src/bot.py b/src/bot.py index 581d3572..57263cb5 100644 --- a/src/bot.py +++ b/src/bot.py @@ -3,7 +3,6 @@ import os import sys import traceback -import uuid from typing import List, Literal sys.path.append(os.path.join(os.path.dirname(__file__), "../packages")) @@ -11,7 +10,15 @@ from botbuilder.core import CardFactory, MemoryStorage, TurnContext from botbuilder.schema import Activity from litellm import acompletion -from memory_module import LLMConfig, Memory, MemoryModule, MemoryModuleConfig, Message +from memory_module import ( + AssistantMessageInput, + InternalMessageInput, + LLMConfig, + Memory, + MemoryModule, + MemoryModuleConfig, + UserMessageInput, +) from pydantic import BaseModel, Field from teams import Application, ApplicationOptions, TeamsAdapter from teams.ai.citations import AIEntity, Appearance, ClientCitation @@ -207,11 +214,13 @@ async def confirm_memorized_fields(fields_to_confirm: ConfirmMemorizedFields, co return json.dumps(fields_to_confirm.model_dump()) -async def send_string_message(context: TurnContext, message: str): +async def send_string_message(context: TurnContext, message: str) -> str | None: activity = Activity( type="message", text=message, entities=[AIEntity(additional_type=["AIGeneratedContent"], citation=[])] ) - await context.send_activity(activity) + res = await context.send_activity(activity) + if res: + return res.id async def execute_task(task_name: ExecuteTask, context: TurnContext) -> str: @@ -285,51 +294,97 @@ def build_deep_link(context: TurnContext, message_id: str): return f"https://teams.microsoft.com/l/message/{deeplink_conversation_id}/{message_id}?context=%7B%22contextType%22%3A%22chat%22%7D" -async def add_message( - context: TurnContext, - content: str, - message_type: Literal["user", "assistant", "internal"] = "user", - created_at: datetime.datetime | None = None, - override_message_id: bool = False, -): +async def add_user_message(context: TurnContext): conversation_ref_dict = TurnContext.get_conversation_reference(context.activity) + content = context.activity.text + if not content: + print("content is not text, so ignoring...") + return False if conversation_ref_dict is None: print("conversation_ref_dict is None") return False if conversation_ref_dict.user is None: print("conversation_ref_dict.user is None") return False + if conversation_ref_dict.conversation is None: + print("conversation_ref_dict.conversation is None") + return False + user_aad_object_id = conversation_ref_dict.user.aad_object_id + message_id = context.activity.id + await memory_module.add_message( + UserMessageInput( + id=message_id, + content=context.activity.text, + author_id=user_aad_object_id, + conversation_ref=conversation_ref_dict.conversation.id, + created_at=context.activity.timestamp if context.activity.timestamp else datetime.datetime.now(), + deep_link=build_deep_link(context, context.activity.id), + ) + ) + return True + + +async def add_agent_message(context: TurnContext, message_id: str, content: str): + conversation_ref_dict = TurnContext.get_conversation_reference(context.activity) + if not content: + print("content is not text, so ignoring...") + return False + if conversation_ref_dict is None: + print("conversation_ref_dict is None") + return False if conversation_ref_dict.bot is None: print("conversation_ref_dict.bot is None") return False if conversation_ref_dict.conversation is None: print("conversation_ref_dict.conversation is None") return False - user_aad_object_id = conversation_ref_dict.user.aad_object_id - message_id = str(uuid.uuid4()) if override_message_id else context.activity.id await memory_module.add_message( - Message( + AssistantMessageInput( id=message_id, content=content, - author_id=user_aad_object_id, + author_id=conversation_ref_dict.bot.id, + conversation_ref=conversation_ref_dict.conversation.id, + deep_link=build_deep_link(context, message_id), + ) + ) + return True + + +async def add_internal_message(context: TurnContext, content: str): + conversation_ref_dict = TurnContext.get_conversation_reference(context.activity) + if not content: + print("content is not text, so ignoring...") + return False + if conversation_ref_dict is None: + print("conversation_ref_dict is None") + return False + if conversation_ref_dict.bot is None: + print("conversation_ref_dict.bot is None") + return False + if conversation_ref_dict.conversation is None: + print("conversation_ref_dict.conversation is None") + return False + await memory_module.add_message( + InternalMessageInput( + content=content, + author_id=conversation_ref_dict.bot.id, conversation_ref=conversation_ref_dict.conversation.id, - created_at=created_at or datetime.datetime.now(datetime.timezone.utc), - type=message_type, - deep_link=build_deep_link(context, context.activity.id), ) ) + return True @bot_app.conversation_update("membersAdded") async def on_members_added(context: TurnContext, state: TurnState): - await send_string_message(context, "Hello! I'm your IT Support Assistant. How can I assist you today?") - await add_message(context, "Hello! I'm your IT Support Assistant. How can I assist you today?", True) + result = await send_string_message(context, "Hello! I'm your IT Support Assistant. How can I assist you today?") + if result: + await add_agent_message(context, result, "Hello! I'm your IT Support Assistant. How can I assist you today?") @bot_app.activity("message") async def on_message(context: TurnContext, state: TurnState): conversation_ref_dict = TurnContext.get_conversation_reference(context.activity) - await add_message(context, context.activity.text, False) + await add_user_message(context) system_prompt = """ You are an IT Chat Bot that helps users troubleshoot tasks @@ -394,8 +449,9 @@ async def on_message(context: TurnContext, state: TurnState): message = response.choices[0].message if message.tool_calls is None and message.content is not None: - await add_message(context, message.content, True, datetime.datetime.now(datetime.timezone.utc), True) - await send_string_message(context, message.content) + agent_message_id = await send_string_message(context, message.content) + if agent_message_id: + await add_agent_message(context, agent_message_id, message.content) break elif message.tool_calls is None and message.content is None: print("No tool calls and no content") @@ -431,7 +487,7 @@ async def on_message(context: TurnContext, state: TurnState): } ) llm_messages.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(res)}) - await add_message( + await add_internal_message( context, json.dumps( { @@ -439,9 +495,6 @@ async def on_message(context: TurnContext, state: TurnState): "result": res, } ), - True, - datetime.datetime.now(datetime.timezone.utc), - True, ) else: break diff --git a/tests/memory_module/test_memory_storage.py b/tests/memory_module/test_memory_storage.py index 8e8d663f..06fd0562 100644 --- a/tests/memory_module/test_memory_storage.py +++ b/tests/memory_module/test_memory_storage.py @@ -3,11 +3,12 @@ import pytest from memory_module.interfaces.types import ( + AssistantMessageInput, BaseMemoryInput, EmbedText, MemoryType, - Message, ShortTermMemoryRetrievalConfig, + UserMessageInput, ) from memory_module.storage.in_memory_storage import InMemoryStorage from memory_module.storage.sqlite_memory_storage import SQLiteMemoryStorage @@ -35,13 +36,12 @@ def sample_memory_input(): @pytest.fixture def sample_message(): - return Message( + return UserMessageInput( id="msg1", content="Test message", author_id="user1", conversation_ref="conv1", created_at=datetime.now(), - type="user", ) @@ -182,22 +182,19 @@ async def test_get_memories_by_ids(memory_storage, sample_memory_input, sample_e async def test_get_messages(memory_storage): # Test data test_messages = [ - Message( + UserMessageInput( id="msg1", content="Test message 1", author_id="user1", conversation_ref="conv1", created_at=datetime.now(), - type="user", deep_link="link1", ), - Message( + AssistantMessageInput( id="msg2", content="Test message 2", author_id="user1", conversation_ref="conv1", - created_at=datetime.now(), - type="assistant", deep_link="link2", ), ] From 65a466a76f54b60736bc5e3a90e8a937a333ccf9 Mon Sep 17 00:00:00 2001 From: heyitsaamir Date: Wed, 18 Dec 2024 12:12:06 -0800 Subject: [PATCH 06/13] Revert print statement --- packages/memory_module/storage/migrations_manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/memory_module/storage/migrations_manager.py b/packages/memory_module/storage/migrations_manager.py index 5fbb7f8b..602e9866 100644 --- a/packages/memory_module/storage/migrations_manager.py +++ b/packages/memory_module/storage/migrations_manager.py @@ -7,7 +7,6 @@ import sqlite_vec logger = logging.getLogger(__name__) -print(sqlite3.sqlite_version) class MigrationManager: From d346d6608cb8cfc0f138296c6a872a24e1efb3e6 Mon Sep 17 00:00:00 2001 From: heyitsaamir Date: Wed, 18 Dec 2024 12:14:40 -0800 Subject: [PATCH 07/13] Fix tests --- tests/memory_module/test_memory_module.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/tests/memory_module/test_memory_module.py b/tests/memory_module/test_memory_module.py index 0e29e887..12ae6963 100644 --- a/tests/memory_module/test_memory_module.py +++ b/tests/memory_module/test_memory_module.py @@ -18,7 +18,10 @@ SemanticMemoryExtraction, ) from memory_module.core.memory_module import MemoryModule -from memory_module.interfaces.types import Message, ShortTermMemoryRetrievalConfig +from memory_module.interfaces.types import ( + ShortTermMemoryRetrievalConfig, + UserMessageInput, +) from tests.memory_module.utils import build_llm_config @@ -120,21 +123,19 @@ async def test_simple_conversation(memory_module): """Test a simple conversation about pie.""" conversation_id = str(uuid4()) messages = [ - Message( + UserMessageInput( id=str(uuid4()), content="I love pie!", author_id="user-123", conversation_ref=conversation_id, created_at=datetime.now(), - type="user", ), - Message( + UserMessageInput( id=str(uuid4()), content="Apple pie is the best!", author_id="user-123", conversation_ref=conversation_id, created_at=datetime.now(), - type="user", ), ] @@ -202,13 +203,12 @@ async def mock_extract_episodic(*args, **kwargs): conversation_id = str(uuid4()) messages = [ - Message( + UserMessageInput( id=str(uuid4()), content=f"Message {i} about pie", author_id="user-123", conversation_ref=conversation_id, created_at=datetime.now(), - type="user", ) for i in range(3) ] @@ -225,13 +225,12 @@ async def test_update_memory(memory_module): """Test memory update""" conversation_id = str(uuid4()) messages = [ - Message( + UserMessageInput( id=str(uuid4()), content="Seattle is my favorite city!", author_id="user-123", conversation_ref=conversation_id, created_at=datetime.now(), - type="user", ), ] @@ -253,13 +252,12 @@ async def test_remove_memory(memory_module): """Test a simple conversation removal based on user id.""" conversation_id = str(uuid4()) messages = [ - Message( + UserMessageInput( id=str(uuid4()), content="I like pho a lot!", author_id="user-123", conversation_ref=conversation_id, created_at=datetime.now(), - type="user", ), ] @@ -280,13 +278,12 @@ async def test_short_term_memory(memory_module): """Test that messages are stored in short-term memory.""" conversation_id = str(uuid4()) messages = [ - Message( + UserMessageInput( id=str(uuid4()), content=f"Test message {i}", author_id="user-123", conversation_ref=conversation_id, created_at=datetime.now(), - type="user", ) for i in range(3) ] From c529bf1cba06d4e20947f54a2ccf21552d180f1c Mon Sep 17 00:00:00 2001 From: Aamir <48929123+heyitsaamir@users.noreply.github.com> Date: Wed, 18 Dec 2024 12:17:36 -0800 Subject: [PATCH 08/13] Update packages/memory_module/storage/utils.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- packages/memory_module/storage/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/memory_module/storage/utils.py b/packages/memory_module/storage/utils.py index b0ee5fb0..77daafa1 100644 --- a/packages/memory_module/storage/utils.py +++ b/packages/memory_module/storage/utils.py @@ -13,4 +13,4 @@ def build_message_from_dict(row: Dict) -> Message: elif row["type"] == "assistant": return AssistantMessage(**row) else: - raise ValueError(f"Invalid message type: {row['type']}") + raise ValueError(f"Invalid message type: {row['type']}. Expected one of: 'internal', 'user', 'assistant'") From 6f12a9e4276ffbab3efd186455e22ac00b34e5b2 Mon Sep 17 00:00:00 2001 From: heyitsaamir Date: Wed, 18 Dec 2024 13:08:09 -0800 Subject: [PATCH 09/13] Fix test --- tests/memory_module/test_memory_module.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/memory_module/test_memory_module.py b/tests/memory_module/test_memory_module.py index 12ae6963..c82dde99 100644 --- a/tests/memory_module/test_memory_module.py +++ b/tests/memory_module/test_memory_module.py @@ -297,7 +297,6 @@ async def test_short_term_memory(memory_module): conversation_id, ShortTermMemoryRetrievalConfig(last_minutes=1) ) assert len(chat_history_messages) == 3 - assert all(msg in chat_history_messages for msg in messages) # Verify messages are in reverse order reversed_messages = messages[::-1] From b964acd5c4c2a78f629ab8c399a2868d5bffce2a Mon Sep 17 00:00:00 2001 From: heyitsaamir Date: Wed, 18 Dec 2024 15:26:23 -0800 Subject: [PATCH 10/13] Address feedback --- packages/memory_module/interfaces/types.py | 26 ++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/packages/memory_module/interfaces/types.py b/packages/memory_module/interfaces/types.py index b1a3acf6..338658b2 100644 --- a/packages/memory_module/interfaces/types.py +++ b/packages/memory_module/interfaces/types.py @@ -20,12 +20,22 @@ class BaseMessageInput(ABC, BaseModel): class InternalMessageInput(BaseMessageInput): + """ + Input parameter for an internal message. Used when creating a new message. + """ + model_config = ConfigDict(from_attributes=True) type: ClassVar = "internal" created_at: Optional[datetime] = None class InternalMessage(InternalMessageInput): + """ + Represents a message that is not meant to be shown to the user. + Useful for keeping agentic transcript state. + These are not used as part of memory extraction + """ + model_config = ConfigDict(from_attributes=True) id: str @@ -33,6 +43,10 @@ class InternalMessage(InternalMessageInput): class UserMessageInput(BaseMessageInput): + """ + Input parameter for a user message. Used when creating a new message. + """ + model_config = ConfigDict(from_attributes=True) id: str type: ClassVar = "user" @@ -41,10 +55,18 @@ class UserMessageInput(BaseMessageInput): class UserMessage(UserMessageInput): + """ + Represents a message that was sent by the user. + """ + model_config = ConfigDict(from_attributes=True) class AssistantMessageInput(BaseMessageInput): + """ + Input parameter for an assistant message. Used when creating a new message. + """ + model_config = ConfigDict(from_attributes=True) id: str type: ClassVar = "assistant" @@ -53,6 +75,10 @@ class AssistantMessageInput(BaseMessageInput): class AssistantMessage(AssistantMessageInput): + """ + Represents a message that was sent by the assistant. + """ + model_config = ConfigDict(from_attributes=True) created_at: datetime # type: ignore Ignoring because this will exist in the concrete class From f6d455e40e0b1eda06fbf281d47c35eb3185773d Mon Sep 17 00:00:00 2001 From: Aamir <48929123+heyitsaamir@users.noreply.github.com> Date: Wed, 18 Dec 2024 15:57:19 -0800 Subject: [PATCH 11/13] Add memory module middleware (#67) - Adds a middlware for memory_module. Now, whenever a message from a user comes in, or message is sent from the agent out back to the user, it's automatically captured by the supplied memory_module. Discovered a number of bugs in botframework-python. Added them: https://github.com/microsoft/botbuilder-python/issues/2197 https://github.com/microsoft/botbuilder-python/issues/2198 --- packages/memory_module/__init__.py | 2 + packages/memory_module/pyproject.toml | 1 + .../storage/sqlite_memory_storage.py | 4 +- .../storage/sqlite_message_buffer_storage.py | 2 +- .../utils/teams_bot_middlware.py | 114 ++++++++++++++++++ src/bot.py | 84 +------------ uv.lock | 2 + 7 files changed, 127 insertions(+), 82 deletions(-) create mode 100644 packages/memory_module/utils/teams_bot_middlware.py diff --git a/packages/memory_module/__init__.py b/packages/memory_module/__init__.py index f7eea423..a336ceee 100644 --- a/packages/memory_module/__init__.py +++ b/packages/memory_module/__init__.py @@ -12,6 +12,7 @@ UserMessage, UserMessageInput, ) +from memory_module.utils.teams_bot_middlware import MemoryMiddleware __all__ = [ "MemoryModule", @@ -27,4 +28,5 @@ "AssistantMessage", "AssistantMessageInput", "ShortTermMemoryRetrievalConfig", + "MemoryMiddleware", ] diff --git a/packages/memory_module/pyproject.toml b/packages/memory_module/pyproject.toml index 67a7d23a..74aa491a 100644 --- a/packages/memory_module/pyproject.toml +++ b/packages/memory_module/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "numpy", "sqlite-vec>=0.1.6", "litellm==1.54.1", + "botbuilder>=0.0.1", ] [tool.uv] diff --git a/packages/memory_module/storage/sqlite_memory_storage.py b/packages/memory_module/storage/sqlite_memory_storage.py index 52a6bf37..fa089a55 100644 --- a/packages/memory_module/storage/sqlite_memory_storage.py +++ b/packages/memory_module/storage/sqlite_memory_storage.py @@ -47,7 +47,7 @@ async def store_memory(self, memory: BaseMemoryInput, *, embedding_vectors: List ( memory_id, memory.content, - memory.created_at, + memory.created_at.isoformat(), memory.user_id, memory.memory_type.value, ), @@ -300,7 +300,7 @@ async def store_short_term_memory(self, message: MessageInput) -> Message: message.content, message.author_id, message.conversation_ref, - created_at, + created_at.isoformat(), message.type, deep_link, ), diff --git a/packages/memory_module/storage/sqlite_message_buffer_storage.py b/packages/memory_module/storage/sqlite_message_buffer_storage.py index edd5b91e..1d46fc7e 100644 --- a/packages/memory_module/storage/sqlite_message_buffer_storage.py +++ b/packages/memory_module/storage/sqlite_message_buffer_storage.py @@ -38,7 +38,7 @@ async def store_buffered_message(self, message: Message) -> None: ( message.id, message.conversation_ref, - message.created_at, + message.created_at.isoformat(), ), ) diff --git a/packages/memory_module/utils/teams_bot_middlware.py b/packages/memory_module/utils/teams_bot_middlware.py new file mode 100644 index 00000000..da88ac91 --- /dev/null +++ b/packages/memory_module/utils/teams_bot_middlware.py @@ -0,0 +1,114 @@ +import datetime +from asyncio import gather +from typing import Awaitable, Callable, List + +from botbuilder.core import TurnContext +from botbuilder.core.middleware_set import Middleware +from botbuilder.schema import Activity, ResourceResponse +from memory_module.interfaces.base_memory_module import BaseMemoryModule +from memory_module.interfaces.types import ( + AssistantMessageInput, + UserMessageInput, +) + + +def build_deep_link(context: TurnContext, message_id: str): + conversation_ref = TurnContext.get_conversation_reference(context.activity) + if conversation_ref.conversation and conversation_ref.conversation.is_group: + deeplink_conversation_id = conversation_ref.conversation.id + elif conversation_ref.user and conversation_ref.bot: + user_aad_object_id = conversation_ref.user.aad_object_id + bot_id = conversation_ref.bot.id.replace("28:", "") + deeplink_conversation_id = f"19:{user_aad_object_id}_{bot_id}@unq.gbl.spaces" + else: + return None + return f"https://teams.microsoft.com/l/message/{deeplink_conversation_id}/{message_id}?context=%7B%22contextType%22%3A%22chat%22%7D" + + +class MemoryMiddleware(Middleware): + def __init__(self, memory_module: BaseMemoryModule): + self.memory_module = memory_module + + async def add_user_message(self, context: TurnContext): + conversation_ref_dict = TurnContext.get_conversation_reference(context.activity) + content = context.activity.text + if not content: + print("content is not text, so ignoring...") + return False + if conversation_ref_dict is None: + print("conversation_ref_dict is None") + return False + if conversation_ref_dict.user is None: + print("conversation_ref_dict.user is None") + return False + if conversation_ref_dict.conversation is None: + print("conversation_ref_dict.conversation is None") + return False + user_aad_object_id = conversation_ref_dict.user.aad_object_id + message_id = context.activity.id + await self.memory_module.add_message( + UserMessageInput( + id=message_id, + content=context.activity.text, + author_id=user_aad_object_id, + conversation_ref=conversation_ref_dict.conversation.id, + created_at=context.activity.timestamp if context.activity.timestamp else datetime.datetime.now(), + deep_link=build_deep_link(context, context.activity.id), + ) + ) + return True + + async def add_agent_message( + self, context: TurnContext, activities: List[Activity], responses: List[ResourceResponse] + ): + conversation_ref_dict = TurnContext.get_conversation_reference(context.activity) + if conversation_ref_dict is None: + print("conversation_ref_dict is None") + return False + if conversation_ref_dict.bot is None: + print("conversation_ref_dict.bot is None") + return False + if conversation_ref_dict.conversation is None: + print("conversation_ref_dict.conversation is None") + return False + + tasks = [] + for activity, response in zip(activities, responses, strict=False): + if activity.text: + tasks.append( + self.memory_module.add_message( + AssistantMessageInput( + id=response.id, + content=activity.text, + author_id=conversation_ref_dict.bot.id, + conversation_ref=conversation_ref_dict.conversation.id, + deep_link=build_deep_link(context, response.id), + ) + ) + ) + + if tasks: + await gather(*tasks) + return True + + async def on_turn(self, context: TurnContext, logic: Callable[[], Awaitable]): # type: ignore Bug in botbuilder-python https://github.com/microsoft/botbuilder-python/issues/2198 + # Handle incoming message + await self.add_user_message(context) + + # Store the original send_activities method + original_send_activities = context.send_activities + + # Create a wrapped version that captures the activities + # We need to do this because bot-framework has a bug with how + # _on_send_activities middleware is implemented + # https://github.com/microsoft/botbuilder-python/issues/2197 + async def wrapped_send_activities(activities: List[Activity]): + responses = await original_send_activities(activities) + await self.add_agent_message(context, activities, responses) + return responses + + # Replace the send_activities method + context.send_activities = wrapped_send_activities + + # Run the bot's logic + await logic() diff --git a/src/bot.py b/src/bot.py index 57263cb5..ce7a5b68 100644 --- a/src/bot.py +++ b/src/bot.py @@ -1,4 +1,3 @@ -import datetime import json import os import sys @@ -11,13 +10,12 @@ from botbuilder.schema import Activity from litellm import acompletion from memory_module import ( - AssistantMessageInput, InternalMessageInput, LLMConfig, Memory, + MemoryMiddleware, MemoryModule, MemoryModuleConfig, - UserMessageInput, ) from pydantic import BaseModel, Field from teams import Application, ApplicationOptions, TeamsAdapter @@ -63,6 +61,8 @@ ) ) +bot_app.adapter.use(MemoryMiddleware(memory_module)) + class TaskConfig(BaseModel): task_name: str @@ -281,75 +281,6 @@ def get_available_functions(): ] -def build_deep_link(context: TurnContext, message_id: str): - conversation_ref = TurnContext.get_conversation_reference(context.activity) - if conversation_ref.conversation and conversation_ref.conversation.is_group: - deeplink_conversation_id = conversation_ref.conversation.id - elif conversation_ref.user and conversation_ref.bot: - user_aad_object_id = conversation_ref.user.aad_object_id - bot_id = conversation_ref.bot.id.replace("28:", "") - deeplink_conversation_id = f"19:{user_aad_object_id}_{bot_id}@unq.gbl.spaces" - else: - return None - return f"https://teams.microsoft.com/l/message/{deeplink_conversation_id}/{message_id}?context=%7B%22contextType%22%3A%22chat%22%7D" - - -async def add_user_message(context: TurnContext): - conversation_ref_dict = TurnContext.get_conversation_reference(context.activity) - content = context.activity.text - if not content: - print("content is not text, so ignoring...") - return False - if conversation_ref_dict is None: - print("conversation_ref_dict is None") - return False - if conversation_ref_dict.user is None: - print("conversation_ref_dict.user is None") - return False - if conversation_ref_dict.conversation is None: - print("conversation_ref_dict.conversation is None") - return False - user_aad_object_id = conversation_ref_dict.user.aad_object_id - message_id = context.activity.id - await memory_module.add_message( - UserMessageInput( - id=message_id, - content=context.activity.text, - author_id=user_aad_object_id, - conversation_ref=conversation_ref_dict.conversation.id, - created_at=context.activity.timestamp if context.activity.timestamp else datetime.datetime.now(), - deep_link=build_deep_link(context, context.activity.id), - ) - ) - return True - - -async def add_agent_message(context: TurnContext, message_id: str, content: str): - conversation_ref_dict = TurnContext.get_conversation_reference(context.activity) - if not content: - print("content is not text, so ignoring...") - return False - if conversation_ref_dict is None: - print("conversation_ref_dict is None") - return False - if conversation_ref_dict.bot is None: - print("conversation_ref_dict.bot is None") - return False - if conversation_ref_dict.conversation is None: - print("conversation_ref_dict.conversation is None") - return False - await memory_module.add_message( - AssistantMessageInput( - id=message_id, - content=content, - author_id=conversation_ref_dict.bot.id, - conversation_ref=conversation_ref_dict.conversation.id, - deep_link=build_deep_link(context, message_id), - ) - ) - return True - - async def add_internal_message(context: TurnContext, content: str): conversation_ref_dict = TurnContext.get_conversation_reference(context.activity) if not content: @@ -376,15 +307,12 @@ async def add_internal_message(context: TurnContext, content: str): @bot_app.conversation_update("membersAdded") async def on_members_added(context: TurnContext, state: TurnState): - result = await send_string_message(context, "Hello! I'm your IT Support Assistant. How can I assist you today?") - if result: - await add_agent_message(context, result, "Hello! I'm your IT Support Assistant. How can I assist you today?") + await send_string_message(context, "Hello! I'm your IT Support Assistant. How can I assist you today?") @bot_app.activity("message") async def on_message(context: TurnContext, state: TurnState): conversation_ref_dict = TurnContext.get_conversation_reference(context.activity) - await add_user_message(context) system_prompt = """ You are an IT Chat Bot that helps users troubleshoot tasks @@ -449,9 +377,7 @@ async def on_message(context: TurnContext, state: TurnState): message = response.choices[0].message if message.tool_calls is None and message.content is not None: - agent_message_id = await send_string_message(context, message.content) - if agent_message_id: - await add_agent_message(context, agent_message_id, message.content) + await send_string_message(context, message.content) break elif message.tool_calls is None and message.content is None: print("No tool calls and no content") diff --git a/uv.lock b/uv.lock index c6795b01..4c547a21 100644 --- a/uv.lock +++ b/uv.lock @@ -1298,6 +1298,7 @@ version = "0.0.0" source = { virtual = "packages/memory_module" } dependencies = [ { name = "aiosqlite" }, + { name = "botbuilder" }, { name = "instructor" }, { name = "litellm" }, { name = "numpy" }, @@ -1309,6 +1310,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "aiosqlite", specifier = ">=0.20.0" }, + { name = "botbuilder", specifier = ">=0.0.1" }, { name = "instructor", specifier = ">=1.6.4" }, { name = "litellm", specifier = "==1.54.1" }, { name = "numpy" }, From 48483212664edb420d91f4ce4d7c41df1d231ed7 Mon Sep 17 00:00:00 2001 From: Aamir <48929123+heyitsaamir@users.noreply.github.com> Date: Wed, 18 Dec 2024 15:59:29 -0800 Subject: [PATCH 12/13] Revert "Add memory module middleware" (#68) Reverts microsoft/teams-memory-agents-py#67 --- packages/memory_module/__init__.py | 2 - packages/memory_module/pyproject.toml | 1 - .../storage/sqlite_memory_storage.py | 4 +- .../storage/sqlite_message_buffer_storage.py | 2 +- .../utils/teams_bot_middlware.py | 114 ------------------ src/bot.py | 84 ++++++++++++- uv.lock | 2 - 7 files changed, 82 insertions(+), 127 deletions(-) delete mode 100644 packages/memory_module/utils/teams_bot_middlware.py diff --git a/packages/memory_module/__init__.py b/packages/memory_module/__init__.py index a336ceee..f7eea423 100644 --- a/packages/memory_module/__init__.py +++ b/packages/memory_module/__init__.py @@ -12,7 +12,6 @@ UserMessage, UserMessageInput, ) -from memory_module.utils.teams_bot_middlware import MemoryMiddleware __all__ = [ "MemoryModule", @@ -28,5 +27,4 @@ "AssistantMessage", "AssistantMessageInput", "ShortTermMemoryRetrievalConfig", - "MemoryMiddleware", ] diff --git a/packages/memory_module/pyproject.toml b/packages/memory_module/pyproject.toml index 74aa491a..67a7d23a 100644 --- a/packages/memory_module/pyproject.toml +++ b/packages/memory_module/pyproject.toml @@ -12,7 +12,6 @@ dependencies = [ "numpy", "sqlite-vec>=0.1.6", "litellm==1.54.1", - "botbuilder>=0.0.1", ] [tool.uv] diff --git a/packages/memory_module/storage/sqlite_memory_storage.py b/packages/memory_module/storage/sqlite_memory_storage.py index fa089a55..52a6bf37 100644 --- a/packages/memory_module/storage/sqlite_memory_storage.py +++ b/packages/memory_module/storage/sqlite_memory_storage.py @@ -47,7 +47,7 @@ async def store_memory(self, memory: BaseMemoryInput, *, embedding_vectors: List ( memory_id, memory.content, - memory.created_at.isoformat(), + memory.created_at, memory.user_id, memory.memory_type.value, ), @@ -300,7 +300,7 @@ async def store_short_term_memory(self, message: MessageInput) -> Message: message.content, message.author_id, message.conversation_ref, - created_at.isoformat(), + created_at, message.type, deep_link, ), diff --git a/packages/memory_module/storage/sqlite_message_buffer_storage.py b/packages/memory_module/storage/sqlite_message_buffer_storage.py index 1d46fc7e..edd5b91e 100644 --- a/packages/memory_module/storage/sqlite_message_buffer_storage.py +++ b/packages/memory_module/storage/sqlite_message_buffer_storage.py @@ -38,7 +38,7 @@ async def store_buffered_message(self, message: Message) -> None: ( message.id, message.conversation_ref, - message.created_at.isoformat(), + message.created_at, ), ) diff --git a/packages/memory_module/utils/teams_bot_middlware.py b/packages/memory_module/utils/teams_bot_middlware.py deleted file mode 100644 index da88ac91..00000000 --- a/packages/memory_module/utils/teams_bot_middlware.py +++ /dev/null @@ -1,114 +0,0 @@ -import datetime -from asyncio import gather -from typing import Awaitable, Callable, List - -from botbuilder.core import TurnContext -from botbuilder.core.middleware_set import Middleware -from botbuilder.schema import Activity, ResourceResponse -from memory_module.interfaces.base_memory_module import BaseMemoryModule -from memory_module.interfaces.types import ( - AssistantMessageInput, - UserMessageInput, -) - - -def build_deep_link(context: TurnContext, message_id: str): - conversation_ref = TurnContext.get_conversation_reference(context.activity) - if conversation_ref.conversation and conversation_ref.conversation.is_group: - deeplink_conversation_id = conversation_ref.conversation.id - elif conversation_ref.user and conversation_ref.bot: - user_aad_object_id = conversation_ref.user.aad_object_id - bot_id = conversation_ref.bot.id.replace("28:", "") - deeplink_conversation_id = f"19:{user_aad_object_id}_{bot_id}@unq.gbl.spaces" - else: - return None - return f"https://teams.microsoft.com/l/message/{deeplink_conversation_id}/{message_id}?context=%7B%22contextType%22%3A%22chat%22%7D" - - -class MemoryMiddleware(Middleware): - def __init__(self, memory_module: BaseMemoryModule): - self.memory_module = memory_module - - async def add_user_message(self, context: TurnContext): - conversation_ref_dict = TurnContext.get_conversation_reference(context.activity) - content = context.activity.text - if not content: - print("content is not text, so ignoring...") - return False - if conversation_ref_dict is None: - print("conversation_ref_dict is None") - return False - if conversation_ref_dict.user is None: - print("conversation_ref_dict.user is None") - return False - if conversation_ref_dict.conversation is None: - print("conversation_ref_dict.conversation is None") - return False - user_aad_object_id = conversation_ref_dict.user.aad_object_id - message_id = context.activity.id - await self.memory_module.add_message( - UserMessageInput( - id=message_id, - content=context.activity.text, - author_id=user_aad_object_id, - conversation_ref=conversation_ref_dict.conversation.id, - created_at=context.activity.timestamp if context.activity.timestamp else datetime.datetime.now(), - deep_link=build_deep_link(context, context.activity.id), - ) - ) - return True - - async def add_agent_message( - self, context: TurnContext, activities: List[Activity], responses: List[ResourceResponse] - ): - conversation_ref_dict = TurnContext.get_conversation_reference(context.activity) - if conversation_ref_dict is None: - print("conversation_ref_dict is None") - return False - if conversation_ref_dict.bot is None: - print("conversation_ref_dict.bot is None") - return False - if conversation_ref_dict.conversation is None: - print("conversation_ref_dict.conversation is None") - return False - - tasks = [] - for activity, response in zip(activities, responses, strict=False): - if activity.text: - tasks.append( - self.memory_module.add_message( - AssistantMessageInput( - id=response.id, - content=activity.text, - author_id=conversation_ref_dict.bot.id, - conversation_ref=conversation_ref_dict.conversation.id, - deep_link=build_deep_link(context, response.id), - ) - ) - ) - - if tasks: - await gather(*tasks) - return True - - async def on_turn(self, context: TurnContext, logic: Callable[[], Awaitable]): # type: ignore Bug in botbuilder-python https://github.com/microsoft/botbuilder-python/issues/2198 - # Handle incoming message - await self.add_user_message(context) - - # Store the original send_activities method - original_send_activities = context.send_activities - - # Create a wrapped version that captures the activities - # We need to do this because bot-framework has a bug with how - # _on_send_activities middleware is implemented - # https://github.com/microsoft/botbuilder-python/issues/2197 - async def wrapped_send_activities(activities: List[Activity]): - responses = await original_send_activities(activities) - await self.add_agent_message(context, activities, responses) - return responses - - # Replace the send_activities method - context.send_activities = wrapped_send_activities - - # Run the bot's logic - await logic() diff --git a/src/bot.py b/src/bot.py index ce7a5b68..57263cb5 100644 --- a/src/bot.py +++ b/src/bot.py @@ -1,3 +1,4 @@ +import datetime import json import os import sys @@ -10,12 +11,13 @@ from botbuilder.schema import Activity from litellm import acompletion from memory_module import ( + AssistantMessageInput, InternalMessageInput, LLMConfig, Memory, - MemoryMiddleware, MemoryModule, MemoryModuleConfig, + UserMessageInput, ) from pydantic import BaseModel, Field from teams import Application, ApplicationOptions, TeamsAdapter @@ -61,8 +63,6 @@ ) ) -bot_app.adapter.use(MemoryMiddleware(memory_module)) - class TaskConfig(BaseModel): task_name: str @@ -281,6 +281,75 @@ def get_available_functions(): ] +def build_deep_link(context: TurnContext, message_id: str): + conversation_ref = TurnContext.get_conversation_reference(context.activity) + if conversation_ref.conversation and conversation_ref.conversation.is_group: + deeplink_conversation_id = conversation_ref.conversation.id + elif conversation_ref.user and conversation_ref.bot: + user_aad_object_id = conversation_ref.user.aad_object_id + bot_id = conversation_ref.bot.id.replace("28:", "") + deeplink_conversation_id = f"19:{user_aad_object_id}_{bot_id}@unq.gbl.spaces" + else: + return None + return f"https://teams.microsoft.com/l/message/{deeplink_conversation_id}/{message_id}?context=%7B%22contextType%22%3A%22chat%22%7D" + + +async def add_user_message(context: TurnContext): + conversation_ref_dict = TurnContext.get_conversation_reference(context.activity) + content = context.activity.text + if not content: + print("content is not text, so ignoring...") + return False + if conversation_ref_dict is None: + print("conversation_ref_dict is None") + return False + if conversation_ref_dict.user is None: + print("conversation_ref_dict.user is None") + return False + if conversation_ref_dict.conversation is None: + print("conversation_ref_dict.conversation is None") + return False + user_aad_object_id = conversation_ref_dict.user.aad_object_id + message_id = context.activity.id + await memory_module.add_message( + UserMessageInput( + id=message_id, + content=context.activity.text, + author_id=user_aad_object_id, + conversation_ref=conversation_ref_dict.conversation.id, + created_at=context.activity.timestamp if context.activity.timestamp else datetime.datetime.now(), + deep_link=build_deep_link(context, context.activity.id), + ) + ) + return True + + +async def add_agent_message(context: TurnContext, message_id: str, content: str): + conversation_ref_dict = TurnContext.get_conversation_reference(context.activity) + if not content: + print("content is not text, so ignoring...") + return False + if conversation_ref_dict is None: + print("conversation_ref_dict is None") + return False + if conversation_ref_dict.bot is None: + print("conversation_ref_dict.bot is None") + return False + if conversation_ref_dict.conversation is None: + print("conversation_ref_dict.conversation is None") + return False + await memory_module.add_message( + AssistantMessageInput( + id=message_id, + content=content, + author_id=conversation_ref_dict.bot.id, + conversation_ref=conversation_ref_dict.conversation.id, + deep_link=build_deep_link(context, message_id), + ) + ) + return True + + async def add_internal_message(context: TurnContext, content: str): conversation_ref_dict = TurnContext.get_conversation_reference(context.activity) if not content: @@ -307,12 +376,15 @@ async def add_internal_message(context: TurnContext, content: str): @bot_app.conversation_update("membersAdded") async def on_members_added(context: TurnContext, state: TurnState): - await send_string_message(context, "Hello! I'm your IT Support Assistant. How can I assist you today?") + result = await send_string_message(context, "Hello! I'm your IT Support Assistant. How can I assist you today?") + if result: + await add_agent_message(context, result, "Hello! I'm your IT Support Assistant. How can I assist you today?") @bot_app.activity("message") async def on_message(context: TurnContext, state: TurnState): conversation_ref_dict = TurnContext.get_conversation_reference(context.activity) + await add_user_message(context) system_prompt = """ You are an IT Chat Bot that helps users troubleshoot tasks @@ -377,7 +449,9 @@ async def on_message(context: TurnContext, state: TurnState): message = response.choices[0].message if message.tool_calls is None and message.content is not None: - await send_string_message(context, message.content) + agent_message_id = await send_string_message(context, message.content) + if agent_message_id: + await add_agent_message(context, agent_message_id, message.content) break elif message.tool_calls is None and message.content is None: print("No tool calls and no content") diff --git a/uv.lock b/uv.lock index 4c547a21..c6795b01 100644 --- a/uv.lock +++ b/uv.lock @@ -1298,7 +1298,6 @@ version = "0.0.0" source = { virtual = "packages/memory_module" } dependencies = [ { name = "aiosqlite" }, - { name = "botbuilder" }, { name = "instructor" }, { name = "litellm" }, { name = "numpy" }, @@ -1310,7 +1309,6 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "aiosqlite", specifier = ">=0.20.0" }, - { name = "botbuilder", specifier = ">=0.0.1" }, { name = "instructor", specifier = ">=1.6.4" }, { name = "litellm", specifier = "==1.54.1" }, { name = "numpy" }, From 23dcbf44bf7e7ed0cc2ba1462b7b1719fdd896e1 Mon Sep 17 00:00:00 2001 From: Aamir <48929123+heyitsaamir@users.noreply.github.com> Date: Wed, 18 Dec 2024 16:00:01 -0800 Subject: [PATCH 13/13] Revert "Revert "Add memory module middleware" (#68)" This reverts commit 48483212664edb420d91f4ce4d7c41df1d231ed7. --- packages/memory_module/__init__.py | 2 + packages/memory_module/pyproject.toml | 1 + .../storage/sqlite_memory_storage.py | 4 +- .../storage/sqlite_message_buffer_storage.py | 2 +- .../utils/teams_bot_middlware.py | 114 ++++++++++++++++++ src/bot.py | 84 +------------ uv.lock | 2 + 7 files changed, 127 insertions(+), 82 deletions(-) create mode 100644 packages/memory_module/utils/teams_bot_middlware.py diff --git a/packages/memory_module/__init__.py b/packages/memory_module/__init__.py index f7eea423..a336ceee 100644 --- a/packages/memory_module/__init__.py +++ b/packages/memory_module/__init__.py @@ -12,6 +12,7 @@ UserMessage, UserMessageInput, ) +from memory_module.utils.teams_bot_middlware import MemoryMiddleware __all__ = [ "MemoryModule", @@ -27,4 +28,5 @@ "AssistantMessage", "AssistantMessageInput", "ShortTermMemoryRetrievalConfig", + "MemoryMiddleware", ] diff --git a/packages/memory_module/pyproject.toml b/packages/memory_module/pyproject.toml index 67a7d23a..74aa491a 100644 --- a/packages/memory_module/pyproject.toml +++ b/packages/memory_module/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "numpy", "sqlite-vec>=0.1.6", "litellm==1.54.1", + "botbuilder>=0.0.1", ] [tool.uv] diff --git a/packages/memory_module/storage/sqlite_memory_storage.py b/packages/memory_module/storage/sqlite_memory_storage.py index 52a6bf37..fa089a55 100644 --- a/packages/memory_module/storage/sqlite_memory_storage.py +++ b/packages/memory_module/storage/sqlite_memory_storage.py @@ -47,7 +47,7 @@ async def store_memory(self, memory: BaseMemoryInput, *, embedding_vectors: List ( memory_id, memory.content, - memory.created_at, + memory.created_at.isoformat(), memory.user_id, memory.memory_type.value, ), @@ -300,7 +300,7 @@ async def store_short_term_memory(self, message: MessageInput) -> Message: message.content, message.author_id, message.conversation_ref, - created_at, + created_at.isoformat(), message.type, deep_link, ), diff --git a/packages/memory_module/storage/sqlite_message_buffer_storage.py b/packages/memory_module/storage/sqlite_message_buffer_storage.py index edd5b91e..1d46fc7e 100644 --- a/packages/memory_module/storage/sqlite_message_buffer_storage.py +++ b/packages/memory_module/storage/sqlite_message_buffer_storage.py @@ -38,7 +38,7 @@ async def store_buffered_message(self, message: Message) -> None: ( message.id, message.conversation_ref, - message.created_at, + message.created_at.isoformat(), ), ) diff --git a/packages/memory_module/utils/teams_bot_middlware.py b/packages/memory_module/utils/teams_bot_middlware.py new file mode 100644 index 00000000..da88ac91 --- /dev/null +++ b/packages/memory_module/utils/teams_bot_middlware.py @@ -0,0 +1,114 @@ +import datetime +from asyncio import gather +from typing import Awaitable, Callable, List + +from botbuilder.core import TurnContext +from botbuilder.core.middleware_set import Middleware +from botbuilder.schema import Activity, ResourceResponse +from memory_module.interfaces.base_memory_module import BaseMemoryModule +from memory_module.interfaces.types import ( + AssistantMessageInput, + UserMessageInput, +) + + +def build_deep_link(context: TurnContext, message_id: str): + conversation_ref = TurnContext.get_conversation_reference(context.activity) + if conversation_ref.conversation and conversation_ref.conversation.is_group: + deeplink_conversation_id = conversation_ref.conversation.id + elif conversation_ref.user and conversation_ref.bot: + user_aad_object_id = conversation_ref.user.aad_object_id + bot_id = conversation_ref.bot.id.replace("28:", "") + deeplink_conversation_id = f"19:{user_aad_object_id}_{bot_id}@unq.gbl.spaces" + else: + return None + return f"https://teams.microsoft.com/l/message/{deeplink_conversation_id}/{message_id}?context=%7B%22contextType%22%3A%22chat%22%7D" + + +class MemoryMiddleware(Middleware): + def __init__(self, memory_module: BaseMemoryModule): + self.memory_module = memory_module + + async def add_user_message(self, context: TurnContext): + conversation_ref_dict = TurnContext.get_conversation_reference(context.activity) + content = context.activity.text + if not content: + print("content is not text, so ignoring...") + return False + if conversation_ref_dict is None: + print("conversation_ref_dict is None") + return False + if conversation_ref_dict.user is None: + print("conversation_ref_dict.user is None") + return False + if conversation_ref_dict.conversation is None: + print("conversation_ref_dict.conversation is None") + return False + user_aad_object_id = conversation_ref_dict.user.aad_object_id + message_id = context.activity.id + await self.memory_module.add_message( + UserMessageInput( + id=message_id, + content=context.activity.text, + author_id=user_aad_object_id, + conversation_ref=conversation_ref_dict.conversation.id, + created_at=context.activity.timestamp if context.activity.timestamp else datetime.datetime.now(), + deep_link=build_deep_link(context, context.activity.id), + ) + ) + return True + + async def add_agent_message( + self, context: TurnContext, activities: List[Activity], responses: List[ResourceResponse] + ): + conversation_ref_dict = TurnContext.get_conversation_reference(context.activity) + if conversation_ref_dict is None: + print("conversation_ref_dict is None") + return False + if conversation_ref_dict.bot is None: + print("conversation_ref_dict.bot is None") + return False + if conversation_ref_dict.conversation is None: + print("conversation_ref_dict.conversation is None") + return False + + tasks = [] + for activity, response in zip(activities, responses, strict=False): + if activity.text: + tasks.append( + self.memory_module.add_message( + AssistantMessageInput( + id=response.id, + content=activity.text, + author_id=conversation_ref_dict.bot.id, + conversation_ref=conversation_ref_dict.conversation.id, + deep_link=build_deep_link(context, response.id), + ) + ) + ) + + if tasks: + await gather(*tasks) + return True + + async def on_turn(self, context: TurnContext, logic: Callable[[], Awaitable]): # type: ignore Bug in botbuilder-python https://github.com/microsoft/botbuilder-python/issues/2198 + # Handle incoming message + await self.add_user_message(context) + + # Store the original send_activities method + original_send_activities = context.send_activities + + # Create a wrapped version that captures the activities + # We need to do this because bot-framework has a bug with how + # _on_send_activities middleware is implemented + # https://github.com/microsoft/botbuilder-python/issues/2197 + async def wrapped_send_activities(activities: List[Activity]): + responses = await original_send_activities(activities) + await self.add_agent_message(context, activities, responses) + return responses + + # Replace the send_activities method + context.send_activities = wrapped_send_activities + + # Run the bot's logic + await logic() diff --git a/src/bot.py b/src/bot.py index 57263cb5..ce7a5b68 100644 --- a/src/bot.py +++ b/src/bot.py @@ -1,4 +1,3 @@ -import datetime import json import os import sys @@ -11,13 +10,12 @@ from botbuilder.schema import Activity from litellm import acompletion from memory_module import ( - AssistantMessageInput, InternalMessageInput, LLMConfig, Memory, + MemoryMiddleware, MemoryModule, MemoryModuleConfig, - UserMessageInput, ) from pydantic import BaseModel, Field from teams import Application, ApplicationOptions, TeamsAdapter @@ -63,6 +61,8 @@ ) ) +bot_app.adapter.use(MemoryMiddleware(memory_module)) + class TaskConfig(BaseModel): task_name: str @@ -281,75 +281,6 @@ def get_available_functions(): ] -def build_deep_link(context: TurnContext, message_id: str): - conversation_ref = TurnContext.get_conversation_reference(context.activity) - if conversation_ref.conversation and conversation_ref.conversation.is_group: - deeplink_conversation_id = conversation_ref.conversation.id - elif conversation_ref.user and conversation_ref.bot: - user_aad_object_id = conversation_ref.user.aad_object_id - bot_id = conversation_ref.bot.id.replace("28:", "") - deeplink_conversation_id = f"19:{user_aad_object_id}_{bot_id}@unq.gbl.spaces" - else: - return None - return f"https://teams.microsoft.com/l/message/{deeplink_conversation_id}/{message_id}?context=%7B%22contextType%22%3A%22chat%22%7D" - - -async def add_user_message(context: TurnContext): - conversation_ref_dict = TurnContext.get_conversation_reference(context.activity) - content = context.activity.text - if not content: - print("content is not text, so ignoring...") - return False - if conversation_ref_dict is None: - print("conversation_ref_dict is None") - return False - if conversation_ref_dict.user is None: - print("conversation_ref_dict.user is None") - return False - if conversation_ref_dict.conversation is None: - print("conversation_ref_dict.conversation is None") - return False - user_aad_object_id = conversation_ref_dict.user.aad_object_id - message_id = context.activity.id - await memory_module.add_message( - UserMessageInput( - id=message_id, - content=context.activity.text, - author_id=user_aad_object_id, - conversation_ref=conversation_ref_dict.conversation.id, - created_at=context.activity.timestamp if context.activity.timestamp else datetime.datetime.now(), - deep_link=build_deep_link(context, context.activity.id), - ) - ) - return True - - -async def add_agent_message(context: TurnContext, message_id: str, content: str): - conversation_ref_dict = TurnContext.get_conversation_reference(context.activity) - if not content: - print("content is not text, so ignoring...") - return False - if conversation_ref_dict is None: - print("conversation_ref_dict is None") - return False - if conversation_ref_dict.bot is None: - print("conversation_ref_dict.bot is None") - return False - if conversation_ref_dict.conversation is None: - print("conversation_ref_dict.conversation is None") - return False - await memory_module.add_message( - AssistantMessageInput( - id=message_id, - content=content, - author_id=conversation_ref_dict.bot.id, - conversation_ref=conversation_ref_dict.conversation.id, - deep_link=build_deep_link(context, message_id), - ) - ) - return True - - async def add_internal_message(context: TurnContext, content: str): conversation_ref_dict = TurnContext.get_conversation_reference(context.activity) if not content: @@ -376,15 +307,12 @@ async def add_internal_message(context: TurnContext, content: str): @bot_app.conversation_update("membersAdded") async def on_members_added(context: TurnContext, state: TurnState): - result = await send_string_message(context, "Hello! I'm your IT Support Assistant. How can I assist you today?") - if result: - await add_agent_message(context, result, "Hello! I'm your IT Support Assistant. How can I assist you today?") + await send_string_message(context, "Hello! I'm your IT Support Assistant. How can I assist you today?") @bot_app.activity("message") async def on_message(context: TurnContext, state: TurnState): conversation_ref_dict = TurnContext.get_conversation_reference(context.activity) - await add_user_message(context) system_prompt = """ You are an IT Chat Bot that helps users troubleshoot tasks @@ -449,9 +377,7 @@ async def on_message(context: TurnContext, state: TurnState): message = response.choices[0].message if message.tool_calls is None and message.content is not None: - agent_message_id = await send_string_message(context, message.content) - if agent_message_id: - await add_agent_message(context, agent_message_id, message.content) + await send_string_message(context, message.content) break elif message.tool_calls is None and message.content is None: print("No tool calls and no content") diff --git a/uv.lock b/uv.lock index c6795b01..4c547a21 100644 --- a/uv.lock +++ b/uv.lock @@ -1298,6 +1298,7 @@ version = "0.0.0" source = { virtual = "packages/memory_module" } dependencies = [ { name = "aiosqlite" }, + { name = "botbuilder" }, { name = "instructor" }, { name = "litellm" }, { name = "numpy" }, @@ -1309,6 +1310,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "aiosqlite", specifier = ">=0.20.0" }, + { name = "botbuilder", specifier = ">=0.0.1" }, { name = "instructor", specifier = ">=1.6.4" }, { name = "litellm", specifier = "==1.54.1" }, { name = "numpy" },