Skip to content

Commit 8ac8981

Browse files
heyitsaamirCopilot
andauthored
Split messages into different types (#66)
- Originally messages were just one type - "Messages". But "Messages" can be used for different purposes (or not used for other purposes). - We don't want to take internal messages into consideration when extracting memories. We only want those to be public messages. - For that, we have split messages into User, Assistant and Internal. The Message type is now a union of these types. - The input for creating these was also created to simplify the public facing APIs (not in this PR). --------- Co-authored-by: Copilot <[email protected]>
1 parent c17359f commit 8ac8981

14 files changed

+341
-113
lines changed

packages/memory_module/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,30 @@
11
from memory_module.config import LLMConfig, MemoryModuleConfig
22
from memory_module.core.memory_module import MemoryModule
33
from memory_module.interfaces.types import (
4+
AssistantMessage,
5+
AssistantMessageInput,
6+
InternalMessage,
7+
InternalMessageInput,
48
Memory,
59
Message,
10+
MessageInput,
611
ShortTermMemoryRetrievalConfig,
12+
UserMessage,
13+
UserMessageInput,
714
)
815

916
__all__ = [
1017
"MemoryModule",
1118
"MemoryModuleConfig",
1219
"LLMConfig",
1320
"Memory",
21+
"InternalMessage",
22+
"InternalMessageInput",
23+
"UserMessageInput",
24+
"UserMessage",
1425
"Message",
26+
"MessageInput",
27+
"AssistantMessage",
28+
"AssistantMessageInput",
1529
"ShortTermMemoryRetrievalConfig",
1630
]

packages/memory_module/core/memory_core.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Memory,
1515
MemoryType,
1616
Message,
17+
MessageInput,
1718
ShortTermMemoryRetrievalConfig,
1819
)
1920
from memory_module.services.llm_service import LLMService
@@ -94,7 +95,7 @@ def __init__(
9495
storage: Optional storage implementation for memory persistence
9596
"""
9697
self.lm = llm_service
97-
self.storage = storage or (
98+
self.storage: BaseMemoryStorage = storage or (
9899
SQLiteMemoryStorage(db_path=config.db_path) if config.db_path is not None else InMemoryStorage()
99100
)
100101

@@ -200,8 +201,11 @@ async def _extract_semantic_fact_from_messages(
200201
for idx, message in enumerate(messages):
201202
if message.type == "user":
202203
messages_str += f"{idx}. User: {message.content}\n"
203-
else:
204+
elif message.type == "assistant":
204205
messages_str += f"{idx}. Assistant: {message.content}\n"
206+
else:
207+
# we explicitly ignore internal messages
208+
continue
205209

206210
system_message = f"""You are a semantic memory management agent. Your goal is to extract meaningful, facts and preferences from user messages. Focus on recognizing general patterns and interests
207211
that will remain relevant over time, even if the user is mentioning short-term plans or events.
@@ -262,8 +266,8 @@ async def _extract_episodic_memory_from_messages(self, messages: List[Message])
262266

263267
return await self.lm.completion(messages=messages, response_model=EpisodicMemoryExtraction)
264268

265-
async def add_short_term_memory(self, message: Message) -> None:
266-
await self.storage.store_short_term_memory(message)
269+
async def add_short_term_memory(self, message: MessageInput) -> Message:
270+
return await self.storage.store_short_term_memory(message)
267271

268272
async def retrieve_chat_history(
269273
self, conversation_ref: str, config: ShortTermMemoryRetrievalConfig

packages/memory_module/core/memory_module.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from memory_module.interfaces.base_memory_core import BaseMemoryCore
77
from memory_module.interfaces.base_memory_module import BaseMemoryModule
88
from memory_module.interfaces.base_message_queue import BaseMessageQueue
9-
from memory_module.interfaces.types import Memory, Message, ShortTermMemoryRetrievalConfig
9+
from memory_module.interfaces.types import Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig
1010
from memory_module.services.llm_service import LLMService
1111

1212

@@ -31,13 +31,17 @@ def __init__(
3131
self.config = config
3232

3333
self.llm_service = llm_service or LLMService(config=config.llm)
34-
self.memory_core = memory_core or MemoryCore(config=config, llm_service=self.llm_service)
35-
self.message_queue = message_queue or MessageQueue(config=config, memory_core=self.memory_core)
34+
self.memory_core: BaseMemoryCore = memory_core or MemoryCore(config=config, llm_service=self.llm_service)
35+
self.message_queue: BaseMessageQueue = message_queue or MessageQueue(
36+
config=config, memory_core=self.memory_core
37+
)
3638

37-
async def add_message(self, message: Message) -> None:
39+
async def add_message(self, message: MessageInput) -> Message:
3840
"""Add a message to be processed into memory."""
39-
await self.memory_core.add_short_term_memory(message)
40-
await self.message_queue.enqueue(message)
41+
message_res = await self.memory_core.add_short_term_memory(message)
42+
await self.message_queue.enqueue(message_res)
43+
44+
return message_res
4145

4246
async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Optional[int]) -> List[Memory]:
4347
"""Retrieve relevant memories based on a query."""

packages/memory_module/interfaces/base_memory_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import ABC, abstractmethod
22
from typing import Dict, List, Optional
33

4-
from memory_module.interfaces.types import Memory, Message, ShortTermMemoryRetrievalConfig
4+
from memory_module.interfaces.types import Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig
55

66

77
class BaseMemoryCore(ABC):
@@ -43,7 +43,7 @@ async def get_messages(self, memory_ids: List[str]) -> Dict[str, List[Message]]:
4343
pass
4444

4545
@abstractmethod
46-
async def add_short_term_memory(self, message: Message) -> None:
46+
async def add_short_term_memory(self, message: MessageInput) -> Message:
4747
"""Add a short-term memory entry."""
4848
pass
4949

packages/memory_module/interfaces/base_memory_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from abc import ABC, abstractmethod
22
from typing import Dict, List, Optional
33

4-
from memory_module.interfaces.types import Memory, Message, ShortTermMemoryRetrievalConfig
4+
from memory_module.interfaces.types import Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig
55

66

77
class BaseMemoryModule(ABC):
88
"""Base class for the memory module interface."""
99

1010
@abstractmethod
11-
async def add_message(self, message: Message) -> None:
11+
async def add_message(self, message: MessageInput) -> Message:
1212
"""Add a message to be processed into memory."""
1313
pass
1414

packages/memory_module/interfaces/base_memory_storage.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
from abc import ABC, abstractmethod
22
from typing import Dict, List, Optional
33

4-
from memory_module.interfaces.types import BaseMemoryInput, EmbedText, Memory, Message, ShortTermMemoryRetrievalConfig
4+
from memory_module.interfaces.types import (
5+
BaseMemoryInput,
6+
EmbedText,
7+
Memory,
8+
Message,
9+
MessageInput,
10+
ShortTermMemoryRetrievalConfig,
11+
)
512

613

714
class BaseMemoryStorage(ABC):
@@ -30,7 +37,7 @@ async def update_memory(self, memory_id: str, updated_memory: str, *, embedding_
3037
pass
3138

3239
@abstractmethod
33-
async def store_short_term_memory(self, message: Message) -> None:
40+
async def store_short_term_memory(self, message: MessageInput) -> Message:
3441
"""Store a short-term memory entry.
3542
3643
Args:

packages/memory_module/interfaces/types.py

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
from abc import ABC
12
from datetime import datetime
23
from decimal import Decimal
34
from enum import Enum
4-
from typing import List, Literal, Optional
5+
from typing import ClassVar, List, Optional
56

67
from pydantic import BaseModel, ConfigDict, Field, model_validator
78

@@ -12,18 +13,78 @@ class User(BaseModel):
1213
id: str
1314

1415

15-
class Message(BaseModel):
16-
"""Represents a message in a conversation."""
16+
class BaseMessageInput(ABC, BaseModel):
17+
content: str
18+
author_id: str
19+
conversation_ref: str
20+
21+
22+
class InternalMessageInput(BaseMessageInput):
23+
"""
24+
Input parameter for an internal message. Used when creating a new message.
25+
"""
26+
27+
model_config = ConfigDict(from_attributes=True)
28+
type: ClassVar = "internal"
29+
created_at: Optional[datetime] = None
30+
31+
32+
class InternalMessage(InternalMessageInput):
33+
"""
34+
Represents a message that is not meant to be shown to the user.
35+
Useful for keeping agentic transcript state.
36+
These are not used as part of memory extraction
37+
"""
1738

1839
model_config = ConfigDict(from_attributes=True)
1940

2041
id: str
21-
content: str
22-
author_id: Optional[str]
23-
conversation_ref: str
42+
created_at: datetime # type: ignore Ignoring because this will exist in the concrete class
43+
44+
45+
class UserMessageInput(BaseMessageInput):
46+
"""
47+
Input parameter for a user message. Used when creating a new message.
48+
"""
49+
50+
model_config = ConfigDict(from_attributes=True)
51+
id: str
52+
type: ClassVar = "user"
53+
deep_link: Optional[str] = None
2454
created_at: datetime
25-
type: Literal["user", "assistant", "internal"] | None = None
55+
56+
57+
class UserMessage(UserMessageInput):
58+
"""
59+
Represents a message that was sent by the user.
60+
"""
61+
62+
model_config = ConfigDict(from_attributes=True)
63+
64+
65+
class AssistantMessageInput(BaseMessageInput):
66+
"""
67+
Input parameter for an assistant message. Used when creating a new message.
68+
"""
69+
70+
model_config = ConfigDict(from_attributes=True)
71+
id: str
72+
type: ClassVar = "assistant"
2673
deep_link: Optional[str] = None
74+
created_at: Optional[datetime] = None
75+
76+
77+
class AssistantMessage(AssistantMessageInput):
78+
"""
79+
Represents a message that was sent by the assistant.
80+
"""
81+
82+
model_config = ConfigDict(from_attributes=True)
83+
created_at: datetime # type: ignore Ignoring because this will exist in the concrete class
84+
85+
86+
type MessageInput = InternalMessageInput | UserMessageInput | AssistantMessageInput
87+
type Message = InternalMessage | UserMessage | AssistantMessage
2788

2889

2990
class MemoryAttribution(BaseModel):

packages/memory_module/storage/in_memory_storage.py

Lines changed: 71 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import datetime
2+
import uuid
23
from collections import defaultdict
3-
from typing import Dict, List, Optional
4+
from typing import Dict, List, Optional, TypedDict
45

56
import numpy as np
67
from memory_module.interfaces.base_memory_storage import BaseMemoryStorage
@@ -9,12 +10,32 @@
910
)
1011
from memory_module.interfaces.base_scheduled_events_service import Event
1112
from memory_module.interfaces.base_scheduled_events_storage import BaseScheduledEventsStorage
12-
from memory_module.interfaces.types import BaseMemoryInput, EmbedText, Memory, Message, ShortTermMemoryRetrievalConfig
13+
from memory_module.interfaces.types import (
14+
AssistantMessage,
15+
AssistantMessageInput,
16+
BaseMemoryInput,
17+
EmbedText,
18+
InternalMessage,
19+
InternalMessageInput,
20+
Memory,
21+
Message,
22+
MessageInput,
23+
ShortTermMemoryRetrievalConfig,
24+
UserMessage,
25+
UserMessageInput,
26+
)
27+
28+
29+
class InMemoryInternalStore(TypedDict):
30+
memories: Dict[str, Memory]
31+
embeddings: Dict[str, List[List[float]]]
32+
buffered_messages: Dict[str, List[Message]]
33+
scheduled_events: Dict[str, Event]
1334

1435

1536
class InMemoryStorage(BaseMemoryStorage, BaseMessageBufferStorage, BaseScheduledEventsStorage):
1637
def __init__(self):
17-
self.storage: Dict = {
38+
self.storage: InMemoryInternalStore = {
1839
"embeddings": {},
1940
"buffered_messages": defaultdict(list),
2041
"scheduled_events": {},
@@ -28,7 +49,8 @@ async def store_memory(
2849
embedding_vectors: List[List[float]],
2950
) -> str | None:
3051
memory_id = str(len(self.storage["memories"]) + 1)
31-
self.storage["memories"][memory_id] = memory
52+
memory_obj = Memory(**memory.model_dump(), id=memory_id)
53+
self.storage["memories"][memory_id] = memory_obj
3254
self.storage["embeddings"][memory_id] = embedding_vectors
3355
return memory_id
3456

@@ -37,8 +59,48 @@ async def update_memory(self, memory_id: str, updated_memory: str, *, embedding_
3759
self.storage["memories"][memory_id].content = updated_memory
3860
self.storage["embeddings"][memory_id] = embedding_vectors
3961

40-
async def store_short_term_memory(self, message: Message) -> None:
41-
await self.store_buffered_message(message)
62+
async def store_short_term_memory(self, message: MessageInput) -> Message:
63+
if isinstance(message, InternalMessageInput):
64+
id = str(uuid.uuid4())
65+
else:
66+
id = message.id
67+
68+
created_at = message.created_at or datetime.datetime.now()
69+
70+
if isinstance(message, InternalMessageInput):
71+
deep_link = None
72+
else:
73+
deep_link = message.deep_link
74+
75+
if isinstance(message, UserMessageInput):
76+
message_obj = UserMessage(
77+
id=id,
78+
content=message.content,
79+
created_at=created_at,
80+
conversation_ref=message.conversation_ref,
81+
deep_link=deep_link,
82+
author_id=message.author_id,
83+
)
84+
elif isinstance(message, AssistantMessageInput):
85+
message_obj = AssistantMessage(
86+
id=id,
87+
content=message.content,
88+
created_at=created_at,
89+
conversation_ref=message.conversation_ref,
90+
deep_link=deep_link,
91+
author_id=message.author_id,
92+
)
93+
else:
94+
message_obj = InternalMessage(
95+
id=id,
96+
content=message.content,
97+
created_at=created_at,
98+
conversation_ref=message.conversation_ref,
99+
author_id=message.author_id,
100+
)
101+
102+
await self.store_buffered_message(message_obj)
103+
return message_obj
42104

43105
async def retrieve_memories(
44106
self, embedText: EmbedText, user_id: Optional[str], limit: Optional[int] = None
@@ -66,7 +128,7 @@ async def retrieve_memories(
66128
)
67129

68130
sorted_memories.sort(key=lambda x: x["distance"], reverse=True)
69-
return [Memory(id=item["id"], **item["memory"].__dict__) for item in sorted_memories[:limit]]
131+
return [Memory(**item["memory"].__dict__) for item in sorted_memories[:limit]]
70132

71133
async def get_memories(self, memory_ids: List[str]) -> List[Memory]:
72134
return [
@@ -81,7 +143,7 @@ async def get_messages(self, memory_ids: List[str]) -> Dict[str, List[Message]]:
81143
str_id = memory_id
82144
if str_id in self.storage["memories"]:
83145
memory = self.storage["memories"][str_id]
84-
if hasattr(memory, "message_attributions"):
146+
if memory.message_attributions:
85147
messages = []
86148
for msg_id in memory.message_attributions:
87149
# Search through buffered messages to find matching message
@@ -104,7 +166,7 @@ async def clear_memories(self, user_id: str) -> None:
104166
self.storage["embeddings"].pop(memory_id, None)
105167
self.storage["memories"].pop(memory_id, None)
106168

107-
async def get_memory(self, memory_id: int) -> Optional[Memory]:
169+
async def get_memory(self, memory_id: str) -> Optional[Memory]:
108170
return self.storage["memories"].get(memory_id)
109171

110172
async def get_all_memories(self, limit: Optional[int] = None) -> List[Memory]:

0 commit comments

Comments
 (0)