From 48d7ecb22037a1841d7f1f9970528218b6c39f69 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Fri, 29 Nov 2024 21:59:15 -0800 Subject: [PATCH 01/26] initial base memroy impl --- .../src/autogen_agentchat/_base_memory.py | 140 ++++++++++++++++++ .../src/autogen_agentchat/_list_memroy.py | 109 ++++++++++++++ .../src/autogen_agentchat/memory/__init__.py | 0 3 files changed, 249 insertions(+) create mode 100644 python/packages/autogen-agentchat/src/autogen_agentchat/_base_memory.py create mode 100644 python/packages/autogen-agentchat/src/autogen_agentchat/_list_memroy.py create mode 100644 python/packages/autogen-agentchat/src/autogen_agentchat/memory/__init__.py diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/_base_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/_base_memory.py new file mode 100644 index 000000000000..95272195b70a --- /dev/null +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/_base_memory.py @@ -0,0 +1,140 @@ +from datetime import datetime +from typing import Any, Dict, List, Protocol, Union, runtime_checkable + +from autogen_core.base import CancellationToken +from autogen_core.components import Image +from pydantic import BaseModel, ConfigDict, Field + +from .state import BaseState + + +class MemoryEntry(BaseModel): + """A memory entry containing content and metadata.""" + + content: Union[str, List[Union[str, Image]]] + """The content of the memory entry - can be text or multimodal.""" + + metadata: Dict[str, Any] = Field(default_factory=dict) + """Optional metadata associated with the memory entry.""" + + timestamp: datetime = Field(default_factory=datetime.now) + """When the memory was created.""" + + source: str | None = None + """Optional source identifier for the memory.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class MemoryQueryResult(BaseModel): + """Result from a memory query including the entry and its relevance score.""" + + entry: MemoryEntry + """The memory entry.""" + + score: float + """Relevance score for this result. Higher means more relevant.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class BaseMemoryState(BaseState): + """State for memory implementations.""" + + state_type: str + """Type identifier for the memory implementation.""" + + entries: List[MemoryEntry] + """List of memory entries.""" + + +@runtime_checkable +class Memory(Protocol): + """Protocol defining the interface for memory implementations.""" + + @property + def name(self) -> str: + """The name of this memory implementation.""" + ... + + async def query( + self, + query: Union[str, Image, List[Union[str, Image]]], + *, + k: int = 5, + score_threshold: float | None = None, + **kwargs: Any + ) -> List[MemoryQueryResult]: + """ + Query the memory store and return relevant entries. + + Args: + query: Text, image or multimodal query + k: Maximum number of results to return + score_threshold: Minimum relevance score threshold + **kwargs: Additional implementation-specific parameters + + Returns: + List of memory entries with relevance scores + """ + ... + + async def add( + self, + entry: MemoryEntry, + cancellation_token: CancellationToken | None = None + ) -> None: + """ + Add a new entry to memory. + + Args: + entry: The memory entry to add + cancellation_token: Optional token to cancel the operation + """ + ... + + async def clear(self) -> None: + """Clear all entries from memory.""" + ... + + async def save_state(self) -> BaseMemoryState: + """Save memory state for persistence.""" + ... + + async def load_state(self, state: BaseState) -> None: + """Load memory state from saved state.""" + ... + + +class BaseMemory: + """Base class providing common functionality for memory implementations.""" + + def __init__(self, name: str) -> None: + self._name = name + self._entries: List[MemoryEntry] = [] + + @property + def name(self) -> str: + return self._name + + async def clear(self) -> None: + """Clear all entries from memory.""" + self._entries = [] + + async def save_state(self) -> BaseMemoryState: + """Save memory state.""" + return BaseMemoryState( + state_type=self.__class__.__name__, + entries=self._entries.copy() + ) + + async def load_state(self, state: BaseState) -> None: + """Load memory state.""" + if not isinstance(state, BaseMemoryState): + raise ValueError(f"Expected BaseMemoryState, got {type(state)}") + + if state.state_type != self.__class__.__name__: + raise ValueError( + f"Cannot load {state.state_type} state into {self.__class__.__name__}") + + self._entries = state.entries.copy() diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/_list_memroy.py b/python/packages/autogen-agentchat/src/autogen_agentchat/_list_memroy.py new file mode 100644 index 000000000000..7d55ac9c40dc --- /dev/null +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/_list_memroy.py @@ -0,0 +1,109 @@ +from difflib import SequenceMatcher +from typing import Any, List, Union, cast + +from autogen_core.base import CancellationToken +from autogen_core.components import Image + +from ._base_memory import BaseMemory, MemoryEntry, MemoryQueryResult + + +class ListMemory(BaseMemory): + """A simple list-based memory implementation using text similarity matching.""" + + def __init__(self, name: str) -> None: + """Initialize list memory. + + Args: + name: Name of the memory instance + """ + super().__init__(name) + + async def add( + self, + entry: MemoryEntry, + cancellation_token: CancellationToken | None = None + ) -> None: + """Add a new entry to memory. + + Args: + entry: Memory entry to add + cancellation_token: Optional token to cancel operation + """ + self._entries.append(entry) + + def _calculate_similarity(self, text1: str, text2: str) -> float: + """Calculate text similarity score using SequenceMatcher. + + Args: + text1: First text string + text2: Second text string + + Returns: + Similarity score between 0 and 1 + """ + return SequenceMatcher(None, text1.lower(), text2.lower()).ratio() + + async def query( + self, + query: Union[str, Image, List[Union[str, Image]]], + *, + k: int = 5, + score_threshold: float | None = None, + **kwargs: Any + ) -> List[MemoryQueryResult]: + """Query memory entries based on text similarity. + + Args: + query: Query text or content + k: Maximum number of results to return + score_threshold: Minimum similarity score threshold + **kwargs: Additional query parameters (unused in this implementation) + + Returns: + List of memory entries with similarity scores + + Raises: + ValueError: If query contains unsupported content types + """ + # Handle different query types + if isinstance(query, str): + query_text = query + elif isinstance(query, list): + # Extract text from multimodal query + text_parts = [item for item in query if isinstance(item, str)] + if not text_parts: + raise ValueError( + "Query must contain at least one text element") + query_text = " ".join(text_parts) + else: + raise ValueError("Image-only queries not supported in ListMemory") + + # Calculate similarity scores for all entries + results: List[MemoryQueryResult] = [] + + for entry in self._entries: + if isinstance(entry.content, str): + content_text = entry.content + elif isinstance(entry.content, list): + # Extract text from multimodal content + text_parts = [ + item for item in entry.content if isinstance(item, str)] + if not text_parts: + continue + content_text = " ".join(text_parts) + else: + continue + + score = self._calculate_similarity(query_text, content_text) + + if score_threshold is None or score >= score_threshold: + results.append( + MemoryQueryResult( + entry=entry, + score=score + ) + ) + + # Sort by score and return top k results + results.sort(key=lambda x: x.score, reverse=True) + return results[:k] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 From f70f61eb4c97420e0795f631cec49fb76b6f28e2 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Sat, 30 Nov 2024 16:20:23 -0800 Subject: [PATCH 02/26] update, add example with chromadb --- .../src/autogen_agentchat/_list_memroy.py | 109 --------- .../agents/_assistant_agent.py | 109 +++++++-- .../{ => memory}/_base_memory.py | 89 ++----- .../memory/_chroma_memory.py | 213 +++++++++++++++++ .../autogen_agentchat/memory/_list_memory.py | 132 +++++++++++ .../tutorial/memory.ipynb | 220 ++++++++++++++++++ 6 files changed, 676 insertions(+), 196 deletions(-) delete mode 100644 python/packages/autogen-agentchat/src/autogen_agentchat/_list_memroy.py rename python/packages/autogen-agentchat/src/autogen_agentchat/{ => memory}/_base_memory.py (51%) create mode 100644 python/packages/autogen-agentchat/src/autogen_agentchat/memory/_chroma_memory.py create mode 100644 python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py create mode 100644 python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/_list_memroy.py b/python/packages/autogen-agentchat/src/autogen_agentchat/_list_memroy.py deleted file mode 100644 index 7d55ac9c40dc..000000000000 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/_list_memroy.py +++ /dev/null @@ -1,109 +0,0 @@ -from difflib import SequenceMatcher -from typing import Any, List, Union, cast - -from autogen_core.base import CancellationToken -from autogen_core.components import Image - -from ._base_memory import BaseMemory, MemoryEntry, MemoryQueryResult - - -class ListMemory(BaseMemory): - """A simple list-based memory implementation using text similarity matching.""" - - def __init__(self, name: str) -> None: - """Initialize list memory. - - Args: - name: Name of the memory instance - """ - super().__init__(name) - - async def add( - self, - entry: MemoryEntry, - cancellation_token: CancellationToken | None = None - ) -> None: - """Add a new entry to memory. - - Args: - entry: Memory entry to add - cancellation_token: Optional token to cancel operation - """ - self._entries.append(entry) - - def _calculate_similarity(self, text1: str, text2: str) -> float: - """Calculate text similarity score using SequenceMatcher. - - Args: - text1: First text string - text2: Second text string - - Returns: - Similarity score between 0 and 1 - """ - return SequenceMatcher(None, text1.lower(), text2.lower()).ratio() - - async def query( - self, - query: Union[str, Image, List[Union[str, Image]]], - *, - k: int = 5, - score_threshold: float | None = None, - **kwargs: Any - ) -> List[MemoryQueryResult]: - """Query memory entries based on text similarity. - - Args: - query: Query text or content - k: Maximum number of results to return - score_threshold: Minimum similarity score threshold - **kwargs: Additional query parameters (unused in this implementation) - - Returns: - List of memory entries with similarity scores - - Raises: - ValueError: If query contains unsupported content types - """ - # Handle different query types - if isinstance(query, str): - query_text = query - elif isinstance(query, list): - # Extract text from multimodal query - text_parts = [item for item in query if isinstance(item, str)] - if not text_parts: - raise ValueError( - "Query must contain at least one text element") - query_text = " ".join(text_parts) - else: - raise ValueError("Image-only queries not supported in ListMemory") - - # Calculate similarity scores for all entries - results: List[MemoryQueryResult] = [] - - for entry in self._entries: - if isinstance(entry.content, str): - content_text = entry.content - elif isinstance(entry.content, list): - # Extract text from multimodal content - text_parts = [ - item for item in entry.content if isinstance(item, str)] - if not text_parts: - continue - content_text = " ".join(text_parts) - else: - continue - - score = self._calculate_similarity(query_text, content_text) - - if score_threshold is None or score >= score_threshold: - results.append( - MemoryQueryResult( - entry=entry, - score=score - ) - ) - - # Sort by score and return top k results - results.sort(key=lambda x: x.score, reverse=True) - return results[:k] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 1edf86f0061f..3237fd332619 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -29,6 +29,7 @@ ToolCallResultMessage, ) from ._base_chat_agent import BaseChatAgent +from ..memory._base_memory import Memory, MemoryQueryResult event_logger = logging.getLogger(EVENT_LOGGER_NAME) @@ -60,10 +61,12 @@ def set_defaults(cls, values: Dict[str, Any]) -> Dict[str, Any]: else: name = values["name"] if not isinstance(name, str): - raise ValueError(f"Handoff name must be a string: {values['name']}") + raise ValueError( + f"Handoff name must be a string: {values['name']}") # Check if name is a valid identifier. if not name.isidentifier(): - raise ValueError(f"Handoff name must be a valid identifier: {values['name']}") + raise ValueError( + f"Handoff name must be a valid identifier: {values['name']}") if values.get("message") is None: values["message"] = ( f"Transferred to {values['target']}, adopting the role of {values['target']} immediately." @@ -203,14 +206,20 @@ def __init__( name: str, model_client: ChatCompletionClient, *, - tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None, + tools: List[Tool | Callable[..., Any] | + Callable[..., Awaitable[Any]]] | None = None, handoffs: List[Handoff | str] | None = None, + memory: Memory | None = None, description: str = "An agent that provides assistance with ability to use tools.", system_message: str | None = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.", ): super().__init__(name=name, description=description) self._model_client = model_client + self._memory = memory + + self._system_messages: List[SystemMessage | UserMessage | + AssistantMessage | FunctionExecutionResultMessage] = [] if system_message is None: self._system_messages = [] else: @@ -218,7 +227,8 @@ def __init__( self._tools: List[Tool] = [] if tools is not None: if model_client.capabilities["function_calling"] is False: - raise ValueError("The model does not support function calling.") + raise ValueError( + "The model does not support function calling.") for tool in tools: if isinstance(tool, Tool): self._tools.append(tool) @@ -227,7 +237,8 @@ def __init__( description = tool.__doc__ else: description = "" - self._tools.append(FunctionTool(tool, description=description)) + self._tools.append(FunctionTool( + tool, description=description)) else: raise ValueError(f"Unsupported tool type: {type(tool)}") # Check if tool names are unique. @@ -239,7 +250,8 @@ def __init__( self._handoffs: Dict[str, Handoff] = {} if handoffs is not None: if model_client.capabilities["function_calling"] is False: - raise ValueError("The model does not support function calling, which is needed for handoffs.") + raise ValueError( + "The model does not support function calling, which is needed for handoffs.") for handoff in handoffs: if isinstance(handoff, str): handoff = Handoff(target=handoff) @@ -247,11 +259,13 @@ def __init__( self._handoff_tools.append(handoff.handoff_tool) self._handoffs[handoff.name] = handoff else: - raise ValueError(f"Unsupported handoff type: {type(handoff)}") + raise ValueError( + f"Unsupported handoff type: {type(handoff)}") # Check if handoff tool names are unique. handoff_tool_names = [tool.name for tool in self._handoff_tools] if len(handoff_tool_names) != len(set(handoff_tool_names)): - raise ValueError(f"Handoff names must be unique: {handoff_tool_names}") + raise ValueError( + f"Handoff names must be unique: {handoff_tool_names}") # Check if handoff tool names not in tool names. if any(name in tool_names for name in handoff_tool_names): raise ValueError( @@ -259,6 +273,19 @@ def __init__( ) self._model_context: List[LLMMessage] = [] + def _format_memory_context(self, results: List[MemoryQueryResult]) -> str: + if not results or not self._memory: # Guard against no memory + return "" + + context_lines = [] + for i, result in enumerate(results, 1): + context_lines.append( + self._memory.config.context_format.format( + i=i, content=result.entry.content, score=result.score) + ) + + return "".join(context_lines) + @property def produced_message_types(self) -> List[type[ChatMessage]]: """The types of messages that the assistant agent produces.""" @@ -270,44 +297,70 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: async for message in self.on_messages_stream(messages, cancellation_token): if isinstance(message, Response): return message - raise AssertionError("The stream should have returned the final result.") + raise AssertionError( + "The stream should have returned the final result.") async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> AsyncGenerator[AgentMessage | Response, None]: + # Query memory if available with the last message + memory_context = "" + if self._memory is not None and messages: + try: + last_message = messages[-1] + # ensure the last message is a text message or multimodal message + if not isinstance(last_message, TextMessage) and not isinstance(last_message, MultiModalMessage): + raise ValueError( + "Memory query failed: Last message must be a text message or multimodal message.") + results: List[MemoryQueryResult] = await self._memory.query(messages[-1].content, cancellation_token=cancellation_token) + memory_context = self._format_memory_context(results) + except Exception as e: + event_logger.warning(f"Memory query failed: {e}") + # Add messages to the model context. for msg in messages: if isinstance(msg, MultiModalMessage) and self._model_client.capabilities["vision"] is False: raise ValueError("The model does not support vision.") - self._model_context.append(UserMessage(content=msg.content, source=msg.source)) + self._model_context.append(UserMessage( + content=msg.content, source=msg.source)) # Inner messages. inner_messages: List[AgentMessage] = [] - # Generate an inference result based on the current model context. - llm_messages = self._system_messages + self._model_context + # Prepare messages for model with memory context if available + llm_messages = self._system_messages + if memory_context: + llm_messages = llm_messages + \ + [SystemMessage(content=memory_context)] + llm_messages = llm_messages + self._model_context + + # Generate inference result result = await self._model_client.create( llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token ) # Add the response to the model context. - self._model_context.append(AssistantMessage(content=result.content, source=self.name)) + self._model_context.append(AssistantMessage( + content=result.content, source=self.name)) # Run tool calls until the model produces a string response. while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content): - tool_call_msg = ToolCallMessage(content=result.content, source=self.name, models_usage=result.usage) + tool_call_msg = ToolCallMessage( + content=result.content, source=self.name, models_usage=result.usage) event_logger.debug(tool_call_msg) # Add the tool call message to the output. inner_messages.append(tool_call_msg) yield tool_call_msg # Execute the tool calls. - results = await asyncio.gather( + execution_results = await asyncio.gather( *[self._execute_tool_call(call, cancellation_token) for call in result.content] ) - tool_call_result_msg = ToolCallResultMessage(content=results, source=self.name) + tool_call_result_msg = ToolCallResultMessage( + content=execution_results, source=self.name) event_logger.debug(tool_call_result_msg) - self._model_context.append(FunctionExecutionResultMessage(content=results)) + self._model_context.append( + FunctionExecutionResultMessage(content=execution_results)) inner_messages.append(tool_call_result_msg) yield tool_call_result_msg @@ -318,7 +371,8 @@ async def on_messages_stream( handoffs.append(self._handoffs[call.name]) if len(handoffs) > 0: if len(handoffs) > 1: - raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}") + raise ValueError( + f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}") # Return the output messages to signal the handoff. yield Response( chat_message=HandoffMessage( @@ -329,15 +383,22 @@ async def on_messages_stream( return # Generate an inference result based on the current model context. - llm_messages = self._system_messages + self._model_context + llm_messages = ( + self._system_messages + + ([SystemMessage(content=memory_context)] + if memory_context else []) + + self._model_context + ) result = await self._model_client.create( llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token ) - self._model_context.append(AssistantMessage(content=result.content, source=self.name)) + self._model_context.append(AssistantMessage( + content=result.content, source=self.name)) assert isinstance(result.content, str) yield Response( - chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage), + chat_message=TextMessage( + content=result.content, source=self.name, models_usage=result.usage), inner_messages=inner_messages, ) @@ -348,9 +409,11 @@ async def _execute_tool_call( try: if not self._tools + self._handoff_tools: raise ValueError("No tools are available.") - tool = next((t for t in self._tools + self._handoff_tools if t.name == tool_call.name), None) + tool = next((t for t in self._tools + + self._handoff_tools if t.name == tool_call.name), None) if tool is None: - raise ValueError(f"The tool '{tool_call.name}' is not available.") + raise ValueError( + f"The tool '{tool_call.name}' is not available.") arguments = json.loads(tool_call.arguments) result = await tool.run_json(arguments, cancellation_token) result_as_str = tool.return_value_as_string(result) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/_base_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py similarity index 51% rename from python/packages/autogen-agentchat/src/autogen_agentchat/_base_memory.py rename to python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py index 95272195b70a..4a70d6162ed0 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/_base_memory.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py @@ -5,7 +5,18 @@ from autogen_core.components import Image from pydantic import BaseModel, ConfigDict, Field -from .state import BaseState + +class BaseMemoryConfig(BaseModel): + """Base configuration for memory implementations.""" + + k: int = Field(default=5, description="Number of results to return") + score_threshold: float | None = Field(default=None, description="Minimum relevance score") + context_format: str = Field( + default="Context {i}: {content} (score: {score:.2f})\n Use this information to address relevant tasks.", + description="Format string for memory results in prompt", + ) + + model_config = ConfigDict(arbitrary_types_allowed=True) class MemoryEntry(BaseModel): @@ -38,40 +49,32 @@ class MemoryQueryResult(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) -class BaseMemoryState(BaseState): - """State for memory implementations.""" - - state_type: str - """Type identifier for the memory implementation.""" - - entries: List[MemoryEntry] - """List of memory entries.""" - - @runtime_checkable class Memory(Protocol): """Protocol defining the interface for memory implementations.""" @property - def name(self) -> str: + def name(self) -> str | None: """The name of this memory implementation.""" ... + @property + def config(self) -> BaseMemoryConfig: + """The configuration for this memory implementation.""" + ... + async def query( self, query: Union[str, Image, List[Union[str, Image]]], - *, - k: int = 5, - score_threshold: float | None = None, - **kwargs: Any + cancellation_token: CancellationToken | None = None, + **kwargs: Any, ) -> List[MemoryQueryResult]: """ Query the memory store and return relevant entries. Args: query: Text, image or multimodal query - k: Maximum number of results to return - score_threshold: Minimum relevance score threshold + cancellation_token: Optional token to cancel operation **kwargs: Additional implementation-specific parameters Returns: @@ -79,17 +82,13 @@ async def query( """ ... - async def add( - self, - entry: MemoryEntry, - cancellation_token: CancellationToken | None = None - ) -> None: + async def add(self, entry: MemoryEntry, cancellation_token: CancellationToken | None = None) -> None: """ Add a new entry to memory. Args: entry: The memory entry to add - cancellation_token: Optional token to cancel the operation + cancellation_token: Optional token to cancel operation """ ... @@ -97,44 +96,6 @@ async def clear(self) -> None: """Clear all entries from memory.""" ... - async def save_state(self) -> BaseMemoryState: - """Save memory state for persistence.""" - ... - - async def load_state(self, state: BaseState) -> None: - """Load memory state from saved state.""" + async def cleanup(self) -> None: + """Clean up any resources used by the memory implementation.""" ... - - -class BaseMemory: - """Base class providing common functionality for memory implementations.""" - - def __init__(self, name: str) -> None: - self._name = name - self._entries: List[MemoryEntry] = [] - - @property - def name(self) -> str: - return self._name - - async def clear(self) -> None: - """Clear all entries from memory.""" - self._entries = [] - - async def save_state(self) -> BaseMemoryState: - """Save memory state.""" - return BaseMemoryState( - state_type=self.__class__.__name__, - entries=self._entries.copy() - ) - - async def load_state(self, state: BaseState) -> None: - """Load memory state.""" - if not isinstance(state, BaseMemoryState): - raise ValueError(f"Expected BaseMemoryState, got {type(state)}") - - if state.state_type != self.__class__.__name__: - raise ValueError( - f"Cannot load {state.state_type} state into {self.__class__.__name__}") - - self._entries = state.entries.copy() diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_chroma_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_chroma_memory.py new file mode 100644 index 000000000000..d44605a75e5e --- /dev/null +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_chroma_memory.py @@ -0,0 +1,213 @@ +from typing import Any, List, Optional, Union, Dict +from types import TracebackType +from datetime import datetime +import chromadb +from chromadb.types import Collection +import uuid +import logging +from autogen_core.base import CancellationToken +from autogen_core.components import Image +from pydantic import Field + +from ._base_memory import BaseMemoryConfig, Memory, MemoryEntry, MemoryQueryResult + +logger = logging.getLogger(__name__) + +# Type vars for ChromaDB results +ChromaMetadata = Dict[str, Union[str, float, int, bool]] +ChromaDistance = Union[float, List[float]] + + +class ChromaMemoryConfig(BaseMemoryConfig): + """Configuration for ChromaDB-based memory implementation.""" + + collection_name: str = Field( + default="memory_store", description="Name of the ChromaDB collection") + persistence_path: Optional[str] = Field( + default=None, description="Path for persistent storage. None for in-memory." + ) + distance_metric: str = Field( + default="cosine", description="Distance metric for similarity search") + + +class ChromaMemory(Memory): + """ChromaDB-based memory implementation using default embeddings.""" + + def __init__(self, name: Optional[str] = None, config: Optional[ChromaMemoryConfig] = None) -> None: + """Initialize ChromaMemory.""" + self._name = name or "default_chroma_memory" + self._config = config or ChromaMemoryConfig() + self._client: Optional[chromadb.Client] = None # type: ignore + self._collection: Collection | None = None # type: ignore + + @property + def name(self) -> Optional[str]: + return self._name + + @property + def config(self) -> ChromaMemoryConfig: + return self._config + + def _ensure_initialized(self) -> None: + """Ensure ChromaDB client and collection are initialized.""" + if self._client is None: + try: + self._client = ( + chromadb.PersistentClient( + path=self._config.persistence_path) + if self._config.persistence_path + else chromadb.Client() + ) + except Exception as e: + logger.error(f"Failed to initialize ChromaDB client: {e}") + raise + + if self._collection is None and self._client is not None: + try: + self._collection = self._client.get_or_create_collection( + name=self._config.collection_name, metadata={ + "distance_metric": self._config.distance_metric} + ) + except Exception as e: + logger.error(f"Failed to get/create collection: {e}") + raise + + def _extract_text(self, content: Union[str, List[Union[str, Image]]]) -> str: + """Extract text content from input.""" + if isinstance(content, str): + return content + + text_parts = [item for item in content if isinstance(item, str)] + if not text_parts: + raise ValueError("Content must contain at least one text element") + + return " ".join(text_parts) + + async def add(self, entry: MemoryEntry, cancellation_token: Optional[CancellationToken] = None) -> None: + """Add a memory entry to ChromaDB.""" + self._ensure_initialized() + if self._collection is None: + raise RuntimeError("Failed to initialize ChromaDB") + + try: + # Extract text + text = self._extract_text(entry.content) + + # Prepare metadata + metadata: ChromaMetadata = { + "timestamp": entry.timestamp.isoformat(), + "source": entry.source or "", + **entry.metadata, + } + + # Add to ChromaDB + self._collection.add(documents=[text], metadatas=[ + metadata], ids=[str(uuid.uuid4())]) + + except Exception as e: + logger.error(f"Failed to add entry to ChromaDB: {e}") + raise + + async def query( + self, + query: Union[str, Image, List[Union[str, Image]]], + cancellation_token: Optional[CancellationToken] = None, + **kwargs: Any, + ) -> List[MemoryQueryResult]: + """Query memory entries based on vector similarity.""" + self._ensure_initialized() + if self._collection is None: + raise RuntimeError("Failed to initialize ChromaDB") + + try: + # Extract text for query + if isinstance(query, Image): + raise ValueError("Image-only queries are not supported") + + query_text = self._extract_text( + query if isinstance(query, list) else [query]) + + # Query ChromaDB + results = self._collection.query( + query_texts=[query_text], n_results=self._config.k, **kwargs) + + # Convert results to MemoryQueryResults + memory_results: List[MemoryQueryResult] = [] + + if not results["documents"]: + return memory_results + + for doc, metadata, distance in zip( + results["documents"][0], results["metadatas"][0], results["distances"][0] + ): + # Extract stored metadata + entry_metadata = dict(metadata) + try: + timestamp_str = str(entry_metadata.pop("timestamp")) + timestamp = datetime.fromisoformat(timestamp_str) + except (KeyError, ValueError) as e: + logger.error(f"Invalid timestamp in metadata: {e}") + continue + + source = str(entry_metadata.pop("source")) + + # Create MemoryEntry + entry = MemoryEntry( + content=doc, metadata=entry_metadata, timestamp=timestamp, source=source or None) + + # Convert distance to similarity score (1 - normalized distance) + score = ( + 1.0 - (float(distance) / 2.0) + if self._config.distance_metric == "cosine" + else 1.0 / (1.0 + float(distance)) + ) + + memory_results.append( + MemoryQueryResult(entry=entry, score=score)) + + return memory_results + + except Exception as e: + logger.error(f"Failed to query ChromaDB: {e}") + raise + + async def clear(self) -> None: + """Clear all entries from memory.""" + self._ensure_initialized() + if self._collection is None: + raise RuntimeError("Failed to initialize ChromaDB") + + try: + self._collection.delete() + if self._client is not None: + self._collection = self._client.get_or_create_collection( + name=self._config.collection_name, metadata={ + "distance_metric": self._config.distance_metric} + ) + except Exception as e: + logger.error(f"Failed to clear ChromaDB collection: {e}") + raise + + async def __aenter__(self) -> "ChromaMemory": + """Context manager entry.""" + return self + + async def __aexit__( + self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + ) -> None: + """Context manager exit with cleanup.""" + await self.cleanup() + + async def cleanup(self) -> None: + """Clean up ChromaDB client.""" + if self._client is not None: + try: + if hasattr(self._client, "reset"): + self._client.reset() + self._client = None + self._collection = None + except Exception as e: + logger.error(f"Error during ChromaDB cleanup: {e}") + # Maybe don't raise here, just log the error + self._client = None + self._collection = None diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py new file mode 100644 index 000000000000..99b5b9f8da4a --- /dev/null +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py @@ -0,0 +1,132 @@ +from difflib import SequenceMatcher +from typing import Any, List, Union + +from autogen_core.base import CancellationToken +from autogen_core.components import Image +from pydantic import Field + +from ._base_memory import BaseMemoryConfig, Memory, MemoryEntry, MemoryQueryResult + + +class ListMemoryConfig(BaseMemoryConfig): + """Configuration for list-based memory implementation.""" + + similarity_threshold: float = Field( + default=0.0, description="Minimum similarity score for text matching", ge=0.0, le=1.0 + ) + + +class ListMemory(Memory): + """Simple list-based memory using text similarity matching.""" + + def __init__(self, name: str | None = None, config: ListMemoryConfig | None = None) -> None: + """Initialize list memory. + + Args: + name: Name of the memory instance + config: Optional configuration, uses defaults if not provided + """ + self._name = name or "default_list_memory" + self._config = config or ListMemoryConfig() + self._entries: List[MemoryEntry] = [] + + @property + def name(self) -> str: + return self._name + + @property + def config(self) -> ListMemoryConfig: + return self._config + + def _calculate_similarity(self, text1: str, text2: str) -> float: + """Calculate text similarity score using SequenceMatcher. + + Args: + text1: First text string + text2: Second text string + + Returns: + Similarity score between 0 and 1 + """ + return SequenceMatcher(None, text1.lower(), text2.lower()).ratio() + + def _extract_text(self, content: Union[str, List[Union[str, Image]]]) -> str: + """Extract searchable text from content. + + Args: + content: Content to extract text from + + Returns: + Extracted text string + + Raises: + ValueError: If no text content can be extracted + """ + if isinstance(content, str): + return content + + text_parts = [item for item in content if isinstance(item, str)] + if not text_parts: + raise ValueError("Content must contain at least one text element") + + return " ".join(text_parts) + + async def query( + self, + query: Union[str, Image, List[Union[str, Image]]], + cancellation_token: CancellationToken | None = None, + **kwargs: Any, + ) -> List[MemoryQueryResult]: + """Query memory entries based on text similarity. + + Args: + query: Query text or content + cancellation_token: Optional token to cancel operation + **kwargs: Additional query parameters (unused) + + Returns: + List of memory entries with similarity scores + + Raises: + ValueError: If query contains unsupported content types + """ + if isinstance(query, (str, Image)): + query_content = [query] + else: + query_content = query + + try: + query_text = self._extract_text(query_content) + except ValueError: + raise ValueError("Query must contain text content") + + results: List[MemoryQueryResult] = [] + + for entry in self._entries: + try: + content_text = self._extract_text(entry.content) + except ValueError: + continue + + score = self._calculate_similarity(query_text, content_text) + + if score >= self._config.similarity_threshold and ( + self._config.score_threshold is None or score >= self._config.score_threshold + ): + results.append(MemoryQueryResult(entry=entry, score=score)) + + results.sort(key=lambda x: x.score, reverse=True) + return results[: self._config.k] + + async def add(self, entry: MemoryEntry, cancellation_token: CancellationToken | None = None) -> None: + """Add a new entry to memory. + + Args: + entry: Memory entry to add + cancellation_token: Optional token to cancel operation + """ + self._entries.append(entry) + + async def clear(self) -> None: + """Clear all entries from memory.""" + self._entries = [] diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb new file mode 100644 index 000000000000..3e8a7c24b8a8 --- /dev/null +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb @@ -0,0 +1,220 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Memory \n", + "\n", + "There are several use cases where it is valuable to maintain a bank of useful facts that can be intelligently added to the context of the agent just before a specific step. The typically use case here is a RAG pattern where a query is used to retrieve relevant information from a database that is then added to the agent's context.\n", + "\n", + "\n", + "AgentChat provides a `Memory` protocol that can be extended to provide this functionality. The key methods are `query`, `add`, `clear`, and `cleanup`. The `query` method is used to retrieve relevant information from the memory store, the `add` method is used to add new entries to the memory store, the `clear` method is used to clear all entries from the memory store, and the `cleanup` method is used to clean up any resources used by the memory store.\n", + "\n", + "\n", + "## ListMemory\n", + "\n", + "ListMemory is a simple list-based memory implementation that uses text similarity matching to retrieve relevant information from the memory store. The similarity score is calculated using the `SequenceMatcher` class from the `difflib` module. The similarity score is calculated between the query text and the content text of each memory entry. " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------- user ----------\n", + "What is the weather in New York?\n", + "---------- weather_agent ----------\n", + "[FunctionCall(id='call_oIYlxvmJ6k9JxfPx96JuDz1G', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')]\n", + "[Prompt tokens: 130, Completion tokens: 19]\n", + "---------- weather_agent ----------\n", + "[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_oIYlxvmJ6k9JxfPx96JuDz1G')]\n", + "---------- weather_agent ----------\n", + "The weather in New York is 23°C and sunny. TERMINATE\n", + "[Prompt tokens: 172, Completion tokens: 15]\n", + "---------- Summary ----------\n", + "Number of messages: 4\n", + "Finish reason: Text 'TERMINATE' mentioned\n", + "Total prompt tokens: 302\n", + "Total completion tokens: 34\n", + "Duration: 1.55 seconds\n" + ] + } + ], + "source": [ + " \n", + "from autogen_ext.models import OpenAIChatCompletionClient\n", + "from autogen_agentchat.agents import AssistantAgent \n", + "from autogen_agentchat.teams import RoundRobinGroupChat\n", + "from autogen_agentchat.memory._base_memory import MemoryEntry \n", + "from autogen_agentchat.task import Console, TextMentionTermination, MaxMessageTermination\n", + "from autogen_agentchat.memory._list_memory import ListMemory, MemoryEntry\n", + "\n", + "\n", + "# create a simple memory item \n", + "memory = ListMemory()\n", + "await memory.add(MemoryEntry(content=\"Whenever you are asked for the weather, return it in metric units.\"))\n", + "\n", + "\n", + "async def get_weather(city: str, units: str = \"imperial\") -> str:\n", + " if units == \"imperial\":\n", + " return f\"The weather in {city} is 73 degrees and Sunny.\"\n", + " elif units == \"metric\":\n", + " return f\"The weather in {city} is 23 degrees and Sunny.\" \n", + "\n", + "weather_agent = AssistantAgent(\n", + " name=\"weather_agent\",\n", + " model_client=OpenAIChatCompletionClient(\n", + " model=\"gpt-4o-2024-08-06\", \n", + " ),\n", + " tools=[get_weather], \n", + " memory=memory\n", + ")\n", + " \n", + "agent_team = RoundRobinGroupChat([weather_agent], termination_condition = TextMentionTermination(\"TERMINATE\"))\n", + "\n", + "# Run the team and stream messages to the console\n", + "stream = agent_team.run_stream(task=\"What is the weather in New York?\")\n", + "await Console(stream);\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Vector DB Memory (ChromaDB)\n", + "\n", + "Similarly, we can implement a memory store that uses a vector database to store and retrieve information. `ChromaMemory` is a memory implementation that uses ChromaDB to store and retrieve information. ChromaDB is a vector database that is optimized for similarity search. \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install chromadb sentence-transformers" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 [MemoryQueryResult(entry=MemoryEntry(content=\"The most important thing about tokyo is that it has the world's busiest railway station - Shinjuku Station.\", metadata={}, timestamp=datetime.datetime(2024, 11, 30, 15, 28, 58, 195846), source='travel_facts'), score=0.5832697153091431)]\n" + ] + } + ], + "source": [ + "\n", + "from autogen_core.base import CancellationToken\n", + "from autogen_ext.models import OpenAIChatCompletionClient\n", + "from autogen_agentchat.agents import AssistantAgent\n", + "from autogen_agentchat.messages import TextMessage\n", + "from autogen_agentchat.memory._base_memory import MemoryEntry\n", + "from autogen_agentchat.memory._chroma_memory import ChromaMemory, ChromaMemoryConfig\n", + "\n", + " \n", + "# Initialize memory\n", + "chroma_memory = ChromaMemory(\n", + " name=\"travel_memory\",\n", + " config=ChromaMemoryConfig(\n", + " collection_name=\"travel_facts\",\n", + " # Configure number of results to return instead of similarity threshold\n", + " k=1 \n", + " )\n", + ")\n", + "# Add some travel-related memories\n", + "await chroma_memory.add(MemoryEntry(\n", + " content=\"Paris is known for the Eiffel Tower and amazing cuisine.\",\n", + " source=\"travel_guide\"\n", + "))\n", + "\n", + "await chroma_memory.add(MemoryEntry(\n", + " content=\"The most important thing about tokyo is that it has the world's busiest railway station - Shinjuku Station.\",\n", + " source=\"travel_facts\"\n", + "))\n", + "\n", + "results = await chroma_memory.query(\"Tell me about Tokyo.\")\n", + "print(len(results),results)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------- user ----------\n", + "Tell me the most important thing about Tokyo.\n", + "---------- travel_agent ----------\n", + "One of the most important aspects of Tokyo is that it has the world's busiest railway station, Shinjuku Station. This station serves as a major hub for transportation, with millions of commuters and travelers passing through its complex network of train lines each day. It highlights Tokyo's status as a bustling metropolis with an advanced public transportation system.\n", + "[Prompt tokens: 72, Completion tokens: 66]\n", + "---------- Summary ----------\n", + "Number of messages: 2\n", + "Finish reason: Maximum number of messages 2 reached, current message count: 2\n", + "Total prompt tokens: 72\n", + "Total completion tokens: 66\n", + "Duration: 1.47 seconds\n" + ] + } + ], + "source": [ + "# Create agent with memory\n", + "agent = AssistantAgent(\n", + " name=\"travel_agent\",\n", + " model_client=OpenAIChatCompletionClient(\n", + " model=\"gpt-4o\",\n", + " # api_key=\"your_api_key\"\n", + " ),\n", + " memory=chroma_memory,\n", + " system_message=\"You are a travel expert\"\n", + ")\n", + "\n", + "agent_team = RoundRobinGroupChat([agent], termination_condition = MaxMessageTermination(max_messages=2))\n", + "stream = agent_team.run_stream(task=\"Tell me the most important thing about Tokyo.\")\n", + "await Console(stream);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "agnext", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 24fa68416042d0632e1697398eb0e7c8ceff9626 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Sun, 1 Dec 2024 08:58:01 -0800 Subject: [PATCH 03/26] include mimetype consideration --- .../autogen_agentchat/memory/_base_memory.py | 124 +++++++++++++++++- 1 file changed, 123 insertions(+), 1 deletion(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py index 4a70d6162ed0..95845f4f2fab 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py @@ -10,7 +10,8 @@ class BaseMemoryConfig(BaseModel): """Base configuration for memory implementations.""" k: int = Field(default=5, description="Number of results to return") - score_threshold: float | None = Field(default=None, description="Minimum relevance score") + score_threshold: float | None = Field( + default=None, description="Minimum relevance score") context_format: str = Field( default="Context {i}: {content} (score: {score:.2f})\n Use this information to address relevant tasks.", description="Format string for memory results in prompt", @@ -99,3 +100,124 @@ async def clear(self) -> None: async def cleanup(self) -> None: """Clean up any resources used by the memory implementation.""" ... + + +# from datetime import datetime +# from enum import Enum +# from typing import Any, Dict, List, Protocol, Union, runtime_checkable +# from pydantic import BaseModel, ConfigDict, Field + +# class CommonContentType(Enum): +# """Common content types for memory entries.""" +# TEXT = "text/plain" +# JSON = "application/json" +# MARKDOWN = "text/markdown" +# IMAGE = "image/*" +# BINARY = "application/octet-stream" + +# class ContentItem(BaseModel): +# """A content item with type information.""" +# content: Any = Field( +# description="The actual content data" +# ) +# mime_type: str = Field( +# pattern="^[a-z]+/[a-z0-9.+-]+$", +# description="Content type - use CommonContentType values or any valid MIME type" +# ) + +# model_config = ConfigDict(arbitrary_types_allowed=True) + +# class BaseMemoryConfig(BaseModel): +# """Base configuration for memory implementations.""" +# k: int = Field(default=5, description="Number of results to return") +# score_threshold: float | None = Field( +# default=None, +# description="Minimum relevance score" +# ) +# context_format: str = Field( +# default="Context {i}: {content} (score: {score:.2f})\n Use this information to address relevant tasks.", +# description="Format string for memory results in prompt" +# ) + +# model_config = ConfigDict(arbitrary_types_allowed=True) + +# class MemoryEntry(BaseModel): +# """A memory entry containing content and metadata.""" +# content: ContentItem +# """The content item with type information.""" + +# metadata: Dict[str, Any] = Field(default_factory=dict) +# """Optional metadata associated with the memory entry.""" + +# timestamp: datetime = Field(default_factory=datetime.now) +# """When the memory was created.""" + +# source: str | None = None +# """Optional source identifier for the memory.""" + +# model_config = ConfigDict(arbitrary_types_allowed=True) + +# class MemoryQueryResult(BaseModel): +# """Result from a memory query including the entry and its relevance score.""" +# entry: MemoryEntry +# """The memory entry.""" + +# score: float +# """Relevance score for this result. Higher means more relevant.""" + +# model_config = ConfigDict(arbitrary_types_allowed=True) + +# @runtime_checkable +# class Memory(Protocol): +# """Protocol defining the interface for memory implementations.""" + +# @property +# def name(self) -> str | None: +# """The name of this memory implementation.""" +# ... + +# @property +# def config(self) -> BaseMemoryConfig: +# """The configuration for this memory implementation.""" +# ... + +# async def query( +# self, +# query: ContentItem, +# cancellation_token: 'CancellationToken | None' = None, +# **kwargs: Any, +# ) -> List[MemoryQueryResult]: +# """ +# Query the memory store and return relevant entries. + +# Args: +# query: Query content item +# cancellation_token: Optional token to cancel operation +# **kwargs: Additional implementation-specific parameters + +# Returns: +# List of memory entries with relevance scores +# """ +# ... + +# async def add( +# self, +# entry: MemoryEntry, +# cancellation_token: 'CancellationToken | None' = None +# ) -> None: +# """ +# Add a new entry to memory. + +# Args: +# entry: The memory entry to add +# cancellation_token: Optional token to cancel operation +# """ +# ... + +# async def clear(self) -> None: +# """Clear all entries from memory.""" +# ... + +# async def cleanup(self) -> None: +# """Clean up any resources used by the memory implementation.""" +# ... From 0b7469e0a51cd6fdb7b84cae5e79e6c827fbebe5 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Thu, 19 Dec 2024 20:53:36 -0800 Subject: [PATCH 04/26] add transform method --- .../autogen_agentchat/memory/_base_memory.py | 188 +++++------------- .../autogen_agentchat/memory/_list_memory.py | 118 ++++++----- 2 files changed, 115 insertions(+), 191 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py index 95845f4f2fab..aa1445b01bbc 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py @@ -1,20 +1,40 @@ from datetime import datetime +from enum import Enum from typing import Any, Dict, List, Protocol, Union, runtime_checkable -from autogen_core.base import CancellationToken -from autogen_core.components import Image +from autogen_core import CancellationToken, Image from pydantic import BaseModel, ConfigDict, Field +from autogen_core.model_context import ( + ChatCompletionContext +) + + +class MimeType(Enum): + """Supported MIME types for memory content.""" + TEXT = "text/plain" + JSON = "application/json" + MARKDOWN = "text/markdown" + IMAGE = "image/*" + BINARY = "application/octet-stream" + + +ContentType = Union[str, bytes, dict, Image] + + +class ContentItem(BaseModel): + """A content item with type information.""" + content: ContentType + mime_type: MimeType + + model_config = ConfigDict(arbitrary_types_allowed=True) class BaseMemoryConfig(BaseModel): """Base configuration for memory implementations.""" - k: int = Field(default=5, description="Number of results to return") score_threshold: float | None = Field( - default=None, description="Minimum relevance score") - context_format: str = Field( - default="Context {i}: {content} (score: {score:.2f})\n Use this information to address relevant tasks.", - description="Format string for memory results in prompt", + default=None, + description="Minimum relevance score" ) model_config = ConfigDict(arbitrary_types_allowed=True) @@ -22,9 +42,8 @@ class BaseMemoryConfig(BaseModel): class MemoryEntry(BaseModel): """A memory entry containing content and metadata.""" - - content: Union[str, List[Union[str, Image]]] - """The content of the memory entry - can be text or multimodal.""" + content: ContentItem + """The content item with type information.""" metadata: Dict[str, Any] = Field(default_factory=dict) """Optional metadata associated with the memory entry.""" @@ -40,7 +59,6 @@ class MemoryEntry(BaseModel): class MemoryQueryResult(BaseModel): """Result from a memory query including the entry and its relevance score.""" - entry: MemoryEntry """The memory entry.""" @@ -64,17 +82,32 @@ def config(self) -> BaseMemoryConfig: """The configuration for this memory implementation.""" ... + async def transform( + self, + model_context: ChatCompletionContext, + ) -> ChatCompletionContext: + """ + Transform the model context using relevant memory content. + + Args: + model_context: The context to transform + + Returns: + The transformed context + """ + ... + async def query( self, - query: Union[str, Image, List[Union[str, Image]]], - cancellation_token: CancellationToken | None = None, + query: ContentItem, + cancellation_token: 'CancellationToken | None' = None, **kwargs: Any, ) -> List[MemoryQueryResult]: """ Query the memory store and return relevant entries. Args: - query: Text, image or multimodal query + query: Query content item cancellation_token: Optional token to cancel operation **kwargs: Additional implementation-specific parameters @@ -83,7 +116,11 @@ async def query( """ ... - async def add(self, entry: MemoryEntry, cancellation_token: CancellationToken | None = None) -> None: + async def add( + self, + entry: MemoryEntry, + cancellation_token: 'CancellationToken | None' = None + ) -> None: """ Add a new entry to memory. @@ -100,124 +137,3 @@ async def clear(self) -> None: async def cleanup(self) -> None: """Clean up any resources used by the memory implementation.""" ... - - -# from datetime import datetime -# from enum import Enum -# from typing import Any, Dict, List, Protocol, Union, runtime_checkable -# from pydantic import BaseModel, ConfigDict, Field - -# class CommonContentType(Enum): -# """Common content types for memory entries.""" -# TEXT = "text/plain" -# JSON = "application/json" -# MARKDOWN = "text/markdown" -# IMAGE = "image/*" -# BINARY = "application/octet-stream" - -# class ContentItem(BaseModel): -# """A content item with type information.""" -# content: Any = Field( -# description="The actual content data" -# ) -# mime_type: str = Field( -# pattern="^[a-z]+/[a-z0-9.+-]+$", -# description="Content type - use CommonContentType values or any valid MIME type" -# ) - -# model_config = ConfigDict(arbitrary_types_allowed=True) - -# class BaseMemoryConfig(BaseModel): -# """Base configuration for memory implementations.""" -# k: int = Field(default=5, description="Number of results to return") -# score_threshold: float | None = Field( -# default=None, -# description="Minimum relevance score" -# ) -# context_format: str = Field( -# default="Context {i}: {content} (score: {score:.2f})\n Use this information to address relevant tasks.", -# description="Format string for memory results in prompt" -# ) - -# model_config = ConfigDict(arbitrary_types_allowed=True) - -# class MemoryEntry(BaseModel): -# """A memory entry containing content and metadata.""" -# content: ContentItem -# """The content item with type information.""" - -# metadata: Dict[str, Any] = Field(default_factory=dict) -# """Optional metadata associated with the memory entry.""" - -# timestamp: datetime = Field(default_factory=datetime.now) -# """When the memory was created.""" - -# source: str | None = None -# """Optional source identifier for the memory.""" - -# model_config = ConfigDict(arbitrary_types_allowed=True) - -# class MemoryQueryResult(BaseModel): -# """Result from a memory query including the entry and its relevance score.""" -# entry: MemoryEntry -# """The memory entry.""" - -# score: float -# """Relevance score for this result. Higher means more relevant.""" - -# model_config = ConfigDict(arbitrary_types_allowed=True) - -# @runtime_checkable -# class Memory(Protocol): -# """Protocol defining the interface for memory implementations.""" - -# @property -# def name(self) -> str | None: -# """The name of this memory implementation.""" -# ... - -# @property -# def config(self) -> BaseMemoryConfig: -# """The configuration for this memory implementation.""" -# ... - -# async def query( -# self, -# query: ContentItem, -# cancellation_token: 'CancellationToken | None' = None, -# **kwargs: Any, -# ) -> List[MemoryQueryResult]: -# """ -# Query the memory store and return relevant entries. - -# Args: -# query: Query content item -# cancellation_token: Optional token to cancel operation -# **kwargs: Additional implementation-specific parameters - -# Returns: -# List of memory entries with relevance scores -# """ -# ... - -# async def add( -# self, -# entry: MemoryEntry, -# cancellation_token: 'CancellationToken | None' = None -# ) -> None: -# """ -# Add a new entry to memory. - -# Args: -# entry: The memory entry to add -# cancellation_token: Optional token to cancel operation -# """ -# ... - -# async def clear(self) -> None: -# """Clear all entries from memory.""" -# ... - -# async def cleanup(self) -> None: -# """Clean up any resources used by the memory implementation.""" -# ... diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py index 99b5b9f8da4a..9c26d1c679f3 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py @@ -1,11 +1,16 @@ from difflib import SequenceMatcher -from typing import Any, List, Union +from typing import Any, List -from autogen_core.base import CancellationToken -from autogen_core.components import Image +from autogen_core import CancellationToken, Image from pydantic import Field -from ._base_memory import BaseMemoryConfig, Memory, MemoryEntry, MemoryQueryResult +from ._base_memory import BaseMemoryConfig, ContentItem, Memory, MemoryEntry, MemoryQueryResult, MimeType +from autogen_core.model_context import ( + ChatCompletionContext +) +from autogen_core.models import ( + SystemMessage, +) class ListMemoryConfig(BaseMemoryConfig): @@ -20,12 +25,6 @@ class ListMemory(Memory): """Simple list-based memory using text similarity matching.""" def __init__(self, name: str | None = None, config: ListMemoryConfig | None = None) -> None: - """Initialize list memory. - - Args: - name: Name of the memory instance - config: Optional configuration, uses defaults if not provided - """ self._name = name or "default_list_memory" self._config = config or ListMemoryConfig() self._entries: List[MemoryEntry] = [] @@ -39,22 +38,14 @@ def config(self) -> ListMemoryConfig: return self._config def _calculate_similarity(self, text1: str, text2: str) -> float: - """Calculate text similarity score using SequenceMatcher. - - Args: - text1: First text string - text2: Second text string - - Returns: - Similarity score between 0 and 1 - """ + """Calculate text similarity score using SequenceMatcher.""" return SequenceMatcher(None, text1.lower(), text2.lower()).ratio() - def _extract_text(self, content: Union[str, List[Union[str, Image]]]) -> str: - """Extract searchable text from content. + def _extract_text(self, content_item: ContentItem) -> str: + """Extract searchable text from ContentItem. Args: - content: Content to extract text from + content_item: ContentItem to extract text from Returns: Extracted text string @@ -62,41 +53,54 @@ def _extract_text(self, content: Union[str, List[Union[str, Image]]]) -> str: Raises: ValueError: If no text content can be extracted """ - if isinstance(content, str): - return content + content = content_item.content + + if content_item.mime_type in [MimeType.TEXT, MimeType.MARKDOWN]: + return str(content) + elif content_item.mime_type == MimeType.JSON: + if isinstance(content, dict): + return str(content) + raise ValueError("JSON content must be a dict") + elif isinstance(content, Image): + raise ValueError("Image content cannot be converted to text") + else: + raise ValueError( + f"Unsupported content type: {content_item.mime_type}") + + async def transform( + self, + model_context: ChatCompletionContext, + ) -> ChatCompletionContext: + """Transform the model context using relevant memory content.""" + messages = await model_context.get_messages() + if not messages: + return model_context - text_parts = [item for item in content if isinstance(item, str)] - if not text_parts: - raise ValueError("Content must contain at least one text element") + last_message = messages[-1] + query_text = getattr(last_message, "content", str(last_message)) + query = ContentItem(content=query_text, mime_type=MimeType.TEXT) - return " ".join(text_parts) + results = [] + query_results = await self.query(query) + for i, result in enumerate(query_results, 1): + results.append(f"{i}. {result.entry.content}") + + if results: + memory_context = "Results from memory query to consider include:\n" + \ + "\n".join(results) + await model_context.add_message(SystemMessage(content=memory_context)) + + return model_context async def query( self, - query: Union[str, Image, List[Union[str, Image]]], + query: ContentItem, cancellation_token: CancellationToken | None = None, **kwargs: Any, ) -> List[MemoryQueryResult]: - """Query memory entries based on text similarity. - - Args: - query: Query text or content - cancellation_token: Optional token to cancel operation - **kwargs: Additional query parameters (unused) - - Returns: - List of memory entries with similarity scores - - Raises: - ValueError: If query contains unsupported content types - """ - if isinstance(query, (str, Image)): - query_content = [query] - else: - query_content = query - + """Query memory entries based on text similarity.""" try: - query_text = self._extract_text(query_content) + query_text = self._extract_text(query) except ValueError: raise ValueError("Query must contain text content") @@ -118,15 +122,19 @@ async def query( results.sort(key=lambda x: x.score, reverse=True) return results[: self._config.k] - async def add(self, entry: MemoryEntry, cancellation_token: CancellationToken | None = None) -> None: - """Add a new entry to memory. - - Args: - entry: Memory entry to add - cancellation_token: Optional token to cancel operation - """ + async def add( + self, + entry: MemoryEntry, + cancellation_token: CancellationToken | None = None + ) -> None: + """Add a new entry to memory.""" self._entries.append(entry) async def clear(self) -> None: """Clear all entries from memory.""" self._entries = [] + + async def cleanup(self) -> None: + """Clean up any resources used by the memory implementation.""" + # No resources to clean up in this implementation + pass From 138ee05839158abae37afea70b7a75c3d3aa1719 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Thu, 19 Dec 2024 21:33:01 -0800 Subject: [PATCH 05/26] update to address feedback, will update after 4681 is merged --- .../autogen_agentchat/memory/_base_memory.py | 28 ++-- .../memory/_chroma_memory.py | 142 +++++++++--------- .../autogen_agentchat/memory/_list_memory.py | 28 ++-- .../tutorial/memory.ipynb | 52 +++---- 4 files changed, 121 insertions(+), 129 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py index aa1445b01bbc..3e38fa099e56 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py @@ -4,13 +4,12 @@ from autogen_core import CancellationToken, Image from pydantic import BaseModel, ConfigDict, Field -from autogen_core.model_context import ( - ChatCompletionContext -) +from autogen_core.model_context import ChatCompletionContext class MimeType(Enum): """Supported MIME types for memory content.""" + TEXT = "text/plain" JSON = "application/json" MARKDOWN = "text/markdown" @@ -21,8 +20,9 @@ class MimeType(Enum): ContentType = Union[str, bytes, dict, Image] -class ContentItem(BaseModel): +class MemoryContent(BaseModel): """A content item with type information.""" + content: ContentType mime_type: MimeType @@ -31,18 +31,17 @@ class ContentItem(BaseModel): class BaseMemoryConfig(BaseModel): """Base configuration for memory implementations.""" + k: int = Field(default=5, description="Number of results to return") - score_threshold: float | None = Field( - default=None, - description="Minimum relevance score" - ) + score_threshold: float | None = Field(default=None, description="Minimum relevance score") model_config = ConfigDict(arbitrary_types_allowed=True) class MemoryEntry(BaseModel): """A memory entry containing content and metadata.""" - content: ContentItem + + content: MemoryContent """The content item with type information.""" metadata: Dict[str, Any] = Field(default_factory=dict) @@ -59,6 +58,7 @@ class MemoryEntry(BaseModel): class MemoryQueryResult(BaseModel): """Result from a memory query including the entry and its relevance score.""" + entry: MemoryEntry """The memory entry.""" @@ -99,8 +99,8 @@ async def transform( async def query( self, - query: ContentItem, - cancellation_token: 'CancellationToken | None' = None, + query: MemoryContent, + cancellation_token: "CancellationToken | None" = None, **kwargs: Any, ) -> List[MemoryQueryResult]: """ @@ -116,11 +116,7 @@ async def query( """ ... - async def add( - self, - entry: MemoryEntry, - cancellation_token: 'CancellationToken | None' = None - ) -> None: + async def add(self, entry: MemoryEntry, cancellation_token: "CancellationToken | None" = None) -> None: """ Add a new entry to memory. diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_chroma_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_chroma_memory.py index d44605a75e5e..f6297810e7a3 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_chroma_memory.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_chroma_memory.py @@ -1,47 +1,44 @@ -from typing import Any, List, Optional, Union, Dict +from typing import Any, List, Optional, Dict from types import TracebackType from datetime import datetime import chromadb from chromadb.types import Collection import uuid import logging -from autogen_core.base import CancellationToken -from autogen_core.components import Image +from autogen_core import CancellationToken, Image from pydantic import Field -from ._base_memory import BaseMemoryConfig, Memory, MemoryEntry, MemoryQueryResult +from ._base_memory import BaseMemoryConfig, Memory, MemoryEntry, MemoryQueryResult, MemoryContent, MimeType +from autogen_core.model_context import ChatCompletionContext +from autogen_core.models import SystemMessage logger = logging.getLogger(__name__) # Type vars for ChromaDB results -ChromaMetadata = Dict[str, Union[str, float, int, bool]] -ChromaDistance = Union[float, List[float]] +ChromaMetadata = Dict[str, Any] +ChromaDistance = float | List[float] class ChromaMemoryConfig(BaseMemoryConfig): """Configuration for ChromaDB-based memory implementation.""" - collection_name: str = Field( - default="memory_store", description="Name of the ChromaDB collection") - persistence_path: Optional[str] = Field( - default=None, description="Path for persistent storage. None for in-memory." - ) - distance_metric: str = Field( - default="cosine", description="Distance metric for similarity search") + collection_name: str = Field(default="memory_store", description="Name of the ChromaDB collection") + persistence_path: str | None = Field(default=None, description="Path for persistent storage. None for in-memory.") + distance_metric: str = Field(default="cosine", description="Distance metric for similarity search") class ChromaMemory(Memory): """ChromaDB-based memory implementation using default embeddings.""" - def __init__(self, name: Optional[str] = None, config: Optional[ChromaMemoryConfig] = None) -> None: + def __init__(self, name: str | None = None, config: ChromaMemoryConfig | None = None) -> None: """Initialize ChromaMemory.""" self._name = name or "default_chroma_memory" self._config = config or ChromaMemoryConfig() - self._client: Optional[chromadb.Client] = None # type: ignore - self._collection: Collection | None = None # type: ignore + self._client: chromadb.Client | None = None + self._collection: Collection | None = None @property - def name(self) -> Optional[str]: + def name(self) -> str: return self._name @property @@ -53,8 +50,7 @@ def _ensure_initialized(self) -> None: if self._client is None: try: self._client = ( - chromadb.PersistentClient( - path=self._config.persistence_path) + chromadb.PersistentClient(path=self._config.persistence_path) if self._config.persistence_path else chromadb.Client() ) @@ -65,44 +61,71 @@ def _ensure_initialized(self) -> None: if self._collection is None and self._client is not None: try: self._collection = self._client.get_or_create_collection( - name=self._config.collection_name, metadata={ - "distance_metric": self._config.distance_metric} + name=self._config.collection_name, metadata={"distance_metric": self._config.distance_metric} ) except Exception as e: logger.error(f"Failed to get/create collection: {e}") raise - def _extract_text(self, content: Union[str, List[Union[str, Image]]]) -> str: - """Extract text content from input.""" - if isinstance(content, str): - return content + def _extract_text(self, content_item: MemoryContent) -> str: + """Extract searchable text from MemoryContent.""" + content = content_item.content + + if content_item.mime_type in [MimeType.TEXT, MimeType.MARKDOWN]: + return str(content) + elif content_item.mime_type == MimeType.JSON: + if isinstance(content, dict): + return str(content) + raise ValueError("JSON content must be a dict") + elif isinstance(content, Image): + raise ValueError("Image content cannot be converted to text") + else: + raise ValueError(f"Unsupported content type: {content_item.mime_type}") + + async def transform( + self, + model_context: ChatCompletionContext, + ) -> ChatCompletionContext: + """Transform the model context using relevant memory content.""" + messages = await model_context.get_messages() + if not messages: + return model_context + + last_message = messages[-1] + query_text = getattr(last_message, "content", str(last_message)) + query = MemoryContent(content=query_text, mime_type=MimeType.TEXT) + + results = [] + query_results = await self.query(query) + for i, result in enumerate(query_results, 1): + results.append(f"{i}. {result.entry.content.content}") - text_parts = [item for item in content if isinstance(item, str)] - if not text_parts: - raise ValueError("Content must contain at least one text element") + if results: + memory_context = "Results from memory query to consider include:\n" + "\n".join(results) + await model_context.add_message(SystemMessage(content=memory_context)) - return " ".join(text_parts) + return model_context - async def add(self, entry: MemoryEntry, cancellation_token: Optional[CancellationToken] = None) -> None: + async def add(self, entry: MemoryEntry, cancellation_token: CancellationToken | None = None) -> None: """Add a memory entry to ChromaDB.""" self._ensure_initialized() if self._collection is None: raise RuntimeError("Failed to initialize ChromaDB") try: - # Extract text + # Extract text from MemoryContent text = self._extract_text(entry.content) # Prepare metadata metadata: ChromaMetadata = { "timestamp": entry.timestamp.isoformat(), "source": entry.source or "", + "mime_type": entry.content.mime_type.value, **entry.metadata, } # Add to ChromaDB - self._collection.add(documents=[text], metadatas=[ - metadata], ids=[str(uuid.uuid4())]) + self._collection.add(documents=[text], metadatas=[metadata], ids=[str(uuid.uuid4())]) except Exception as e: logger.error(f"Failed to add entry to ChromaDB: {e}") @@ -110,8 +133,8 @@ async def add(self, entry: MemoryEntry, cancellation_token: Optional[Cancellatio async def query( self, - query: Union[str, Image, List[Union[str, Image]]], - cancellation_token: Optional[CancellationToken] = None, + query: MemoryContent, + cancellation_token: CancellationToken | None = None, **kwargs: Any, ) -> List[MemoryQueryResult]: """Query memory entries based on vector similarity.""" @@ -121,15 +144,10 @@ async def query( try: # Extract text for query - if isinstance(query, Image): - raise ValueError("Image-only queries are not supported") - - query_text = self._extract_text( - query if isinstance(query, list) else [query]) + query_text = self._extract_text(query) # Query ChromaDB - results = self._collection.query( - query_texts=[query_text], n_results=self._config.k, **kwargs) + results = self._collection.query(query_texts=[query_text], n_results=self._config.k, **kwargs) # Convert results to MemoryQueryResults memory_results: List[MemoryQueryResult] = [] @@ -142,28 +160,27 @@ async def query( ): # Extract stored metadata entry_metadata = dict(metadata) - try: - timestamp_str = str(entry_metadata.pop("timestamp")) - timestamp = datetime.fromisoformat(timestamp_str) - except (KeyError, ValueError) as e: - logger.error(f"Invalid timestamp in metadata: {e}") - continue - + timestamp_str = str(entry_metadata.pop("timestamp")) + timestamp = datetime.fromisoformat(timestamp_str) source = str(entry_metadata.pop("source")) + mime_type = MimeType(entry_metadata.pop("mime_type")) - # Create MemoryEntry + # Create MemoryContent and MemoryEntry + content_item = MemoryContent(content=doc, mime_type=mime_type) entry = MemoryEntry( - content=doc, metadata=entry_metadata, timestamp=timestamp, source=source or None) + content=content_item, metadata=entry_metadata, timestamp=timestamp, source=source or None + ) - # Convert distance to similarity score (1 - normalized distance) + # Convert distance to similarity score score = ( 1.0 - (float(distance) / 2.0) if self._config.distance_metric == "cosine" else 1.0 / (1.0 + float(distance)) ) - memory_results.append( - MemoryQueryResult(entry=entry, score=score)) + # Apply score threshold if configured + if self._config.score_threshold is None or score >= self._config.score_threshold: + memory_results.append(MemoryQueryResult(entry=entry, score=score)) return memory_results @@ -181,33 +198,20 @@ async def clear(self) -> None: self._collection.delete() if self._client is not None: self._collection = self._client.get_or_create_collection( - name=self._config.collection_name, metadata={ - "distance_metric": self._config.distance_metric} + name=self._config.collection_name, metadata={"distance_metric": self._config.distance_metric} ) except Exception as e: logger.error(f"Failed to clear ChromaDB collection: {e}") raise - async def __aenter__(self) -> "ChromaMemory": - """Context manager entry.""" - return self - - async def __aexit__( - self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] - ) -> None: - """Context manager exit with cleanup.""" - await self.cleanup() - async def cleanup(self) -> None: """Clean up ChromaDB client.""" if self._client is not None: try: if hasattr(self._client, "reset"): self._client.reset() - self._client = None - self._collection = None except Exception as e: logger.error(f"Error during ChromaDB cleanup: {e}") - # Maybe don't raise here, just log the error + finally: self._client = None self._collection = None diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py index 9c26d1c679f3..ec6c8768127d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py @@ -4,10 +4,8 @@ from autogen_core import CancellationToken, Image from pydantic import Field -from ._base_memory import BaseMemoryConfig, ContentItem, Memory, MemoryEntry, MemoryQueryResult, MimeType -from autogen_core.model_context import ( - ChatCompletionContext -) +from ._base_memory import BaseMemoryConfig, MemoryContent, Memory, MemoryEntry, MemoryQueryResult, MimeType +from autogen_core.model_context import ChatCompletionContext from autogen_core.models import ( SystemMessage, ) @@ -41,11 +39,11 @@ def _calculate_similarity(self, text1: str, text2: str) -> float: """Calculate text similarity score using SequenceMatcher.""" return SequenceMatcher(None, text1.lower(), text2.lower()).ratio() - def _extract_text(self, content_item: ContentItem) -> str: - """Extract searchable text from ContentItem. + def _extract_text(self, content_item: MemoryContent) -> str: + """Extract searchable text from MemoryContent. Args: - content_item: ContentItem to extract text from + content_item: MemoryContent to extract text from Returns: Extracted text string @@ -64,8 +62,7 @@ def _extract_text(self, content_item: ContentItem) -> str: elif isinstance(content, Image): raise ValueError("Image content cannot be converted to text") else: - raise ValueError( - f"Unsupported content type: {content_item.mime_type}") + raise ValueError(f"Unsupported content type: {content_item.mime_type}") async def transform( self, @@ -78,7 +75,7 @@ async def transform( last_message = messages[-1] query_text = getattr(last_message, "content", str(last_message)) - query = ContentItem(content=query_text, mime_type=MimeType.TEXT) + query = MemoryContent(content=query_text, mime_type=MimeType.TEXT) results = [] query_results = await self.query(query) @@ -86,15 +83,14 @@ async def transform( results.append(f"{i}. {result.entry.content}") if results: - memory_context = "Results from memory query to consider include:\n" + \ - "\n".join(results) + memory_context = "Results from memory query to consider include:\n" + "\n".join(results) await model_context.add_message(SystemMessage(content=memory_context)) return model_context async def query( self, - query: ContentItem, + query: MemoryContent, cancellation_token: CancellationToken | None = None, **kwargs: Any, ) -> List[MemoryQueryResult]: @@ -122,11 +118,7 @@ async def query( results.sort(key=lambda x: x.score, reverse=True) return results[: self._config.k] - async def add( - self, - entry: MemoryEntry, - cancellation_token: CancellationToken | None = None - ) -> None: + async def add(self, entry: MemoryEntry, cancellation_token: CancellationToken | None = None) -> None: """Add a new entry to memory.""" self._entries.append(entry) diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb index 3e8a7c24b8a8..e78684246f81 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb @@ -46,16 +46,15 @@ } ], "source": [ - " \n", "from autogen_ext.models import OpenAIChatCompletionClient\n", - "from autogen_agentchat.agents import AssistantAgent \n", + "from autogen_agentchat.agents import AssistantAgent\n", "from autogen_agentchat.teams import RoundRobinGroupChat\n", - "from autogen_agentchat.memory._base_memory import MemoryEntry \n", + "from autogen_agentchat.memory._base_memory import MemoryEntry\n", "from autogen_agentchat.task import Console, TextMentionTermination, MaxMessageTermination\n", "from autogen_agentchat.memory._list_memory import ListMemory, MemoryEntry\n", "\n", "\n", - "# create a simple memory item \n", + "# create a simple memory item\n", "memory = ListMemory()\n", "await memory.add(MemoryEntry(content=\"Whenever you are asked for the weather, return it in metric units.\"))\n", "\n", @@ -64,22 +63,23 @@ " if units == \"imperial\":\n", " return f\"The weather in {city} is 73 degrees and Sunny.\"\n", " elif units == \"metric\":\n", - " return f\"The weather in {city} is 23 degrees and Sunny.\" \n", + " return f\"The weather in {city} is 23 degrees and Sunny.\"\n", + "\n", "\n", "weather_agent = AssistantAgent(\n", " name=\"weather_agent\",\n", " model_client=OpenAIChatCompletionClient(\n", - " model=\"gpt-4o-2024-08-06\", \n", + " model=\"gpt-4o-2024-08-06\",\n", " ),\n", - " tools=[get_weather], \n", - " memory=memory\n", + " tools=[get_weather],\n", + " memory=memory,\n", ")\n", - " \n", - "agent_team = RoundRobinGroupChat([weather_agent], termination_condition = TextMentionTermination(\"TERMINATE\"))\n", + "\n", + "agent_team = RoundRobinGroupChat([weather_agent], termination_condition=TextMentionTermination(\"TERMINATE\"))\n", "\n", "# Run the team and stream messages to the console\n", "stream = agent_team.run_stream(task=\"What is the weather in New York?\")\n", - "await Console(stream);\n" + "await Console(stream);" ] }, { @@ -115,7 +115,6 @@ } ], "source": [ - "\n", "from autogen_core.base import CancellationToken\n", "from autogen_ext.models import OpenAIChatCompletionClient\n", "from autogen_agentchat.agents import AssistantAgent\n", @@ -123,29 +122,30 @@ "from autogen_agentchat.memory._base_memory import MemoryEntry\n", "from autogen_agentchat.memory._chroma_memory import ChromaMemory, ChromaMemoryConfig\n", "\n", - " \n", + "\n", "# Initialize memory\n", "chroma_memory = ChromaMemory(\n", " name=\"travel_memory\",\n", " config=ChromaMemoryConfig(\n", " collection_name=\"travel_facts\",\n", " # Configure number of results to return instead of similarity threshold\n", - " k=1 \n", - " )\n", + " k=1,\n", + " ),\n", ")\n", "# Add some travel-related memories\n", - "await chroma_memory.add(MemoryEntry(\n", - " content=\"Paris is known for the Eiffel Tower and amazing cuisine.\",\n", - " source=\"travel_guide\"\n", - "))\n", + "await chroma_memory.add(\n", + " MemoryEntry(content=\"Paris is known for the Eiffel Tower and amazing cuisine.\", source=\"travel_guide\")\n", + ")\n", "\n", - "await chroma_memory.add(MemoryEntry(\n", - " content=\"The most important thing about tokyo is that it has the world's busiest railway station - Shinjuku Station.\",\n", - " source=\"travel_facts\"\n", - "))\n", + "await chroma_memory.add(\n", + " MemoryEntry(\n", + " content=\"The most important thing about tokyo is that it has the world's busiest railway station - Shinjuku Station.\",\n", + " source=\"travel_facts\",\n", + " )\n", + ")\n", "\n", "results = await chroma_memory.query(\"Tell me about Tokyo.\")\n", - "print(len(results),results)" + "print(len(results), results)" ] }, { @@ -180,10 +180,10 @@ " # api_key=\"your_api_key\"\n", " ),\n", " memory=chroma_memory,\n", - " system_message=\"You are a travel expert\"\n", + " system_message=\"You are a travel expert\",\n", ")\n", "\n", - "agent_team = RoundRobinGroupChat([agent], termination_condition = MaxMessageTermination(max_messages=2))\n", + "agent_team = RoundRobinGroupChat([agent], termination_condition=MaxMessageTermination(max_messages=2))\n", "stream = agent_team.run_stream(task=\"Tell me the most important thing about Tokyo.\")\n", "await Console(stream);" ] From 675924c4469d4b91624aeab997d76e6e70c190c4 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Wed, 25 Dec 2024 07:22:02 -0800 Subject: [PATCH 06/26] update memory impl, --- .../agents/_assistant_agent.py | 10 +- .../autogen_agentchat/memory/_base_memory.py | 46 ++-- .../memory/_chroma_memory.py | 185 ++++++++++---- .../autogen_agentchat/memory/_list_memory.py | 229 ++++++++++++++---- .../tutorial/memory.ipynb | 211 +++++++++++----- 5 files changed, 497 insertions(+), 184 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 2ff6a1d9b6a7..88c076ad2477 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -44,7 +44,7 @@ ) from ..state import AssistantAgentState from ._base_chat_agent import BaseChatAgent -from ..memory._base_memory import Memory, MemoryQueryResult +from ..memory._base_memory import Memory event_logger = logging.getLogger(EVENT_LOGGER_NAME) @@ -240,10 +240,11 @@ def __init__( ) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.", reflect_on_tool_use: bool = False, tool_call_summary_format: str = "{result}", + memory: List[Memory] | None = None, ): super().__init__(name=name, description=description) self._model_client = model_client - self._memory = memory + self._memory = [memory] if isinstance(memory, Memory) else memory self._system_messages: List[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage] = [] @@ -333,6 +334,11 @@ async def on_messages_stream( # Inner messages. inner_messages: List[AgentEvent | ChatMessage] = [] + # Update the model context with memory content. + if self._memory: + for memory in self._memory: + await memory.transform(self._model_context) + # Generate an inference result based on the current model context. llm_messages = self._system_messages + await self._model_context.get_messages() result = await self._model_client.create( diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py index 3e38fa099e56..44e39125678d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py @@ -7,7 +7,7 @@ from autogen_core.model_context import ChatCompletionContext -class MimeType(Enum): +class MemoryMimeType(Enum): """Supported MIME types for memory content.""" TEXT = "text/plain" @@ -21,10 +21,11 @@ class MimeType(Enum): class MemoryContent(BaseModel): - """A content item with type information.""" - content: ContentType - mime_type: MimeType + mime_type: MemoryMimeType + metadata: Dict[str, Any] | None = None + timestamp: datetime | None = None + source: str | None = None model_config = ConfigDict(arbitrary_types_allowed=True) @@ -33,34 +34,17 @@ class BaseMemoryConfig(BaseModel): """Base configuration for memory implementations.""" k: int = Field(default=5, description="Number of results to return") - score_threshold: float | None = Field(default=None, description="Minimum relevance score") - - model_config = ConfigDict(arbitrary_types_allowed=True) - - -class MemoryEntry(BaseModel): - """A memory entry containing content and metadata.""" - - content: MemoryContent - """The content item with type information.""" - - metadata: Dict[str, Any] = Field(default_factory=dict) - """Optional metadata associated with the memory entry.""" - - timestamp: datetime = Field(default_factory=datetime.now) - """When the memory was created.""" - - source: str | None = None - """Optional source identifier for the memory.""" + score_threshold: float | None = Field( + default=None, description="Minimum relevance score") model_config = ConfigDict(arbitrary_types_allowed=True) class MemoryQueryResult(BaseModel): - """Result from a memory query including the entry and its relevance score.""" + """Result from a memory query including the content and its relevance score.""" - entry: MemoryEntry - """The memory entry.""" + content: MemoryContent + """The memory content.""" score: float """Relevance score for this result. Higher means more relevant.""" @@ -87,13 +71,13 @@ async def transform( model_context: ChatCompletionContext, ) -> ChatCompletionContext: """ - Transform the model context using relevant memory content. + Transform the provided model context using relevant memory content. Args: model_context: The context to transform Returns: - The transformed context + The transformed context """ ... @@ -116,12 +100,12 @@ async def query( """ ... - async def add(self, entry: MemoryEntry, cancellation_token: "CancellationToken | None" = None) -> None: + async def add(self, content: MemoryContent, cancellation_token: "CancellationToken | None" = None) -> None: """ - Add a new entry to memory. + Add a new content to memory. Args: - entry: The memory entry to add + content: The memory content to add cancellation_token: Optional token to cancel operation """ ... diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_chroma_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_chroma_memory.py index f6297810e7a3..557127930257 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_chroma_memory.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_chroma_memory.py @@ -1,5 +1,4 @@ -from typing import Any, List, Optional, Dict -from types import TracebackType +from typing import Any, List, Dict from datetime import datetime import chromadb from chromadb.types import Collection @@ -8,7 +7,13 @@ from autogen_core import CancellationToken, Image from pydantic import Field -from ._base_memory import BaseMemoryConfig, Memory, MemoryEntry, MemoryQueryResult, MemoryContent, MimeType +from ._base_memory import ( + BaseMemoryConfig, + Memory, + MemoryContent, + MemoryQueryResult, + MemoryMimeType +) from autogen_core.model_context import ChatCompletionContext from autogen_core.models import SystemMessage @@ -22,16 +27,34 @@ class ChromaMemoryConfig(BaseMemoryConfig): """Configuration for ChromaDB-based memory implementation.""" - collection_name: str = Field(default="memory_store", description="Name of the ChromaDB collection") - persistence_path: str | None = Field(default=None, description="Path for persistent storage. None for in-memory.") - distance_metric: str = Field(default="cosine", description="Distance metric for similarity search") + collection_name: str = Field( + default="memory_store", + description="Name of the ChromaDB collection" + ) + persistence_path: str | None = Field( + default=None, + description="Path for persistent storage. None for in-memory." + ) + distance_metric: str = Field( + default="cosine", + description="Distance metric for similarity search" + ) class ChromaMemory(Memory): - """ChromaDB-based memory implementation using default embeddings.""" + """ChromaDB-based memory implementation using default embeddings. + + This implementation stores content in a ChromaDB collection and uses + its built-in embedding and similarity search capabilities. + """ def __init__(self, name: str | None = None, config: ChromaMemoryConfig | None = None) -> None: - """Initialize ChromaMemory.""" + """Initialize ChromaMemory. + + Args: + name: Optional identifier for this memory instance + config: Optional configuration for memory behavior + """ self._name = name or "default_chroma_memory" self._config = config or ChromaMemoryConfig() self._client: chromadb.Client | None = None @@ -50,7 +73,8 @@ def _ensure_initialized(self) -> None: if self._client is None: try: self._client = ( - chromadb.PersistentClient(path=self._config.persistence_path) + chromadb.PersistentClient( + path=self._config.persistence_path) if self._config.persistence_path else chromadb.Client() ) @@ -61,74 +85,121 @@ def _ensure_initialized(self) -> None: if self._collection is None and self._client is not None: try: self._collection = self._client.get_or_create_collection( - name=self._config.collection_name, metadata={"distance_metric": self._config.distance_metric} + name=self._config.collection_name, + metadata={"distance_metric": self._config.distance_metric} ) except Exception as e: logger.error(f"Failed to get/create collection: {e}") raise def _extract_text(self, content_item: MemoryContent) -> str: - """Extract searchable text from MemoryContent.""" + """Extract searchable text from MemoryContent. + + Args: + content_item: Content to extract text from + + Returns: + Extracted text representation + + Raises: + ValueError: If content cannot be converted to text + """ content = content_item.content - if content_item.mime_type in [MimeType.TEXT, MimeType.MARKDOWN]: + if content_item.mime_type in [MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN]: return str(content) - elif content_item.mime_type == MimeType.JSON: + elif content_item.mime_type == MemoryMimeType.JSON: if isinstance(content, dict): return str(content) raise ValueError("JSON content must be a dict") elif isinstance(content, Image): raise ValueError("Image content cannot be converted to text") else: - raise ValueError(f"Unsupported content type: {content_item.mime_type}") + raise ValueError( + f"Unsupported content type: {content_item.mime_type}") async def transform( self, model_context: ChatCompletionContext, ) -> ChatCompletionContext: - """Transform the model context using relevant memory content.""" + """Transform the model context using relevant memory content. + + Args: + model_context: The context to transform + + Returns: + The transformed context with relevant memories added + """ messages = await model_context.get_messages() if not messages: return model_context + # Extract query from last message last_message = messages[-1] - query_text = getattr(last_message, "content", str(last_message)) - query = MemoryContent(content=query_text, mime_type=MimeType.TEXT) + query_text = last_message.content if isinstance( + last_message.content, str) else str(last_message) + query = MemoryContent(content=query_text, + mime_type=MemoryMimeType.TEXT) + # Query memory and format results results = [] query_results = await self.query(query) for i, result in enumerate(query_results, 1): - results.append(f"{i}. {result.entry.content.content}") + if isinstance(result.content.content, str): + results.append(f"{i}. {result.content.content}") + logger.debug( + f"Retrieved memory {i}. {result.content.content}, score: {result.score}" + ) + # Add memory results to context if results: - memory_context = "Results from memory query to consider include:\n" + "\n".join(results) + memory_context = ( + "Results from memory query to consider include:\n" + + "\n".join(results) + ) await model_context.add_message(SystemMessage(content=memory_context)) return model_context - async def add(self, entry: MemoryEntry, cancellation_token: CancellationToken | None = None) -> None: - """Add a memory entry to ChromaDB.""" + async def add( + self, + content: MemoryContent, + cancellation_token: CancellationToken | None = None + ) -> None: + """Add a memory content to ChromaDB. + + Args: + content: The memory content to add + cancellation_token: Optional token to cancel operation + + Raises: + RuntimeError: If ChromaDB initialization fails + """ self._ensure_initialized() if self._collection is None: raise RuntimeError("Failed to initialize ChromaDB") try: # Extract text from MemoryContent - text = self._extract_text(entry.content) + text = self._extract_text(content) # Prepare metadata metadata: ChromaMetadata = { - "timestamp": entry.timestamp.isoformat(), - "source": entry.source or "", - "mime_type": entry.content.mime_type.value, - **entry.metadata, + "timestamp": content.timestamp.isoformat() if content.timestamp else datetime.now().isoformat(), + "source": content.source or "", + "mime_type": content.mime_type.value, + **(content.metadata or {}) } # Add to ChromaDB - self._collection.add(documents=[text], metadatas=[metadata], ids=[str(uuid.uuid4())]) + self._collection.add( + documents=[text], + metadatas=[metadata], + ids=[str(uuid.uuid4())] + ) except Exception as e: - logger.error(f"Failed to add entry to ChromaDB: {e}") + logger.error(f"Failed to add content to ChromaDB: {e}") raise async def query( @@ -137,7 +208,19 @@ async def query( cancellation_token: CancellationToken | None = None, **kwargs: Any, ) -> List[MemoryQueryResult]: - """Query memory entries based on vector similarity.""" + """Query memory content based on vector similarity. + + Args: + query: Query content to match against memory + cancellation_token: Optional token to cancel operation + **kwargs: Additional parameters passed to ChromaDB query + + Returns: + List of memory results with similarity scores + + Raises: + RuntimeError: If ChromaDB initialization fails + """ self._ensure_initialized() if self._collection is None: raise RuntimeError("Failed to initialize ChromaDB") @@ -147,7 +230,11 @@ async def query( query_text = self._extract_text(query) # Query ChromaDB - results = self._collection.query(query_texts=[query_text], n_results=self._config.k, **kwargs) + results = self._collection.query( + query_texts=[query_text], + n_results=self._config.k, + **kwargs + ) # Convert results to MemoryQueryResults memory_results: List[MemoryQueryResult] = [] @@ -156,31 +243,34 @@ async def query( return memory_results for doc, metadata, distance in zip( - results["documents"][0], results["metadatas"][0], results["distances"][0] + results["documents"][0], + results["metadatas"][0], + results["distances"][0] ): # Extract stored metadata entry_metadata = dict(metadata) timestamp_str = str(entry_metadata.pop("timestamp")) timestamp = datetime.fromisoformat(timestamp_str) source = str(entry_metadata.pop("source")) - mime_type = MimeType(entry_metadata.pop("mime_type")) - - # Create MemoryContent and MemoryEntry - content_item = MemoryContent(content=doc, mime_type=mime_type) - entry = MemoryEntry( - content=content_item, metadata=entry_metadata, timestamp=timestamp, source=source or None + mime_type = MemoryMimeType(entry_metadata.pop("mime_type")) + + # Create MemoryContent + content = MemoryContent( + content=doc, + mime_type=mime_type, + metadata=entry_metadata, + timestamp=timestamp, + source=source or None ) # Convert distance to similarity score - score = ( - 1.0 - (float(distance) / 2.0) - if self._config.distance_metric == "cosine" + score = 1.0 - (float(distance) / 2.0) if self._config.distance_metric == "cosine" \ else 1.0 / (1.0 + float(distance)) - ) # Apply score threshold if configured if self._config.score_threshold is None or score >= self._config.score_threshold: - memory_results.append(MemoryQueryResult(entry=entry, score=score)) + memory_results.append( + MemoryQueryResult(content=content, score=score)) return memory_results @@ -189,7 +279,11 @@ async def query( raise async def clear(self) -> None: - """Clear all entries from memory.""" + """Clear all entries from memory. + + Raises: + RuntimeError: If ChromaDB initialization fails + """ self._ensure_initialized() if self._collection is None: raise RuntimeError("Failed to initialize ChromaDB") @@ -198,14 +292,15 @@ async def clear(self) -> None: self._collection.delete() if self._client is not None: self._collection = self._client.get_or_create_collection( - name=self._config.collection_name, metadata={"distance_metric": self._config.distance_metric} + name=self._config.collection_name, + metadata={"distance_metric": self._config.distance_metric} ) except Exception as e: logger.error(f"Failed to clear ChromaDB collection: {e}") raise async def cleanup(self) -> None: - """Clean up ChromaDB client.""" + """Clean up ChromaDB client and resources.""" if self._client is not None: try: if hasattr(self._client, "reset"): diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py index ec6c8768127d..a78620ce3777 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py @@ -1,31 +1,67 @@ from difflib import SequenceMatcher +import logging from typing import Any, List from autogen_core import CancellationToken, Image from pydantic import Field -from ._base_memory import BaseMemoryConfig, MemoryContent, Memory, MemoryEntry, MemoryQueryResult, MimeType +from ._base_memory import BaseMemoryConfig, MemoryContent, Memory, MemoryQueryResult, MemoryMimeType from autogen_core.model_context import ChatCompletionContext from autogen_core.models import ( SystemMessage, ) +from .. import EVENT_LOGGER_NAME + +event_logger = logging.getLogger(EVENT_LOGGER_NAME) + class ListMemoryConfig(BaseMemoryConfig): """Configuration for list-based memory implementation.""" similarity_threshold: float = Field( - default=0.0, description="Minimum similarity score for text matching", ge=0.0, le=1.0 + default=0.35, description="Minimum similarity score for text matching", ge=0.0, le=1.0 ) class ListMemory(Memory): - """Simple list-based memory using text similarity matching.""" + """Simple list-based memory using text similarity matching. + + This memory implementation stores contents in a list and retrieves them based on + text similarity matching. It supports various content types and can transform + model contexts by injecting relevant memory content. + + Example: + ```python + # Initialize memory with custom config + memory = ListMemory( + name="chat_history", + config=ListMemoryConfig( + similarity_threshold=0.7, + k=3 + ) + ) + + # Add memory content + content = MemoryContent( + content="User prefers formal language", + mime_type=MemoryMimeType.TEXT + ) + await memory.add(content) + + # Transform a model context with memory + context = await memory.transform(model_context) + ``` + + Attributes: + name (str): Identifier for this memory instance + config (ListMemoryConfig): Configuration controlling memory behavior + """ def __init__(self, name: str | None = None, config: ListMemoryConfig | None = None) -> None: self._name = name or "default_list_memory" self._config = config or ListMemoryConfig() - self._entries: List[MemoryEntry] = [] + self._contents: List[MemoryContent] = [] @property def name(self) -> str: @@ -35,55 +71,63 @@ def name(self) -> str: def config(self) -> ListMemoryConfig: return self._config - def _calculate_similarity(self, text1: str, text2: str) -> float: - """Calculate text similarity score using SequenceMatcher.""" - return SequenceMatcher(None, text1.lower(), text2.lower()).ratio() + async def transform( + self, + model_context: ChatCompletionContext, + ) -> ChatCompletionContext: + """Transform the model context by injecting relevant memory content. - def _extract_text(self, content_item: MemoryContent) -> str: - """Extract searchable text from MemoryContent. + This method mutates the provided model_context by adding relevant memory content: + + 1. Extracts the last message from the context + 2. Uses it to query memory for relevant content + 3. Formats matching content into a system message + 4. Mutates the context by adding the system message Args: - content_item: MemoryContent to extract text from + model_context: The context to transform. Will be mutated if relevant + memories exist. Returns: - Extracted text string - - Raises: - ValueError: If no text content can be extracted - """ - content = content_item.content + ChatCompletionContext: The same context object that was passed in, + now mutated with memory content if any was found. - if content_item.mime_type in [MimeType.TEXT, MimeType.MARKDOWN]: - return str(content) - elif content_item.mime_type == MimeType.JSON: - if isinstance(content, dict): - return str(content) - raise ValueError("JSON content must be a dict") - elif isinstance(content, Image): - raise ValueError("Image content cannot be converted to text") - else: - raise ValueError(f"Unsupported content type: {content_item.mime_type}") + Example: + ```python + # Context will be mutated to include relevant memories + context = await memory.transform(model_context) - async def transform( - self, - model_context: ChatCompletionContext, - ) -> ChatCompletionContext: - """Transform the model context using relevant memory content.""" + # Any subsequent model calls will see the injected memories + messages = await context.get_messages() + ``` + """ messages = await model_context.get_messages() if not messages: return model_context + # Extract query from last message last_message = messages[-1] - query_text = getattr(last_message, "content", str(last_message)) - query = MemoryContent(content=query_text, mime_type=MimeType.TEXT) + query_text = last_message.content if isinstance( + last_message.content, str) else str(last_message) + query = MemoryContent(content=query_text, + mime_type=MemoryMimeType.TEXT) + # Query memory and format results results = [] query_results = await self.query(query) for i, result in enumerate(query_results, 1): - results.append(f"{i}. {result.entry.content}") + if isinstance(result.content.content, str): + results.append(f"{i}. {result.content.content}") + event_logger.debug( + f"Retrieved memory {i}. {result.content.content}, score: {result.score}" + ) + # Add memory results to context if results: - memory_context = "Results from memory query to consider include:\n" + "\n".join(results) + memory_context = ( + "\n The following results were retrieved from memory for this task. You may choose to use them or not. :\n" + + "\n".join(results) + "\n" + ) await model_context.add_message(SystemMessage(content=memory_context)) return model_context @@ -94,7 +138,39 @@ async def query( cancellation_token: CancellationToken | None = None, **kwargs: Any, ) -> List[MemoryQueryResult]: - """Query memory entries based on text similarity.""" + """Query memory content based on text similarity. + + Searches memory content using text similarity matching against the query. + Only content exceeding the configured similarity threshold is returned, + sorted by relevance score in descending order. + + Args: + query: The content to match against memory content. Must contain + text that can be compared against stored content. + cancellation_token: Optional token to cancel long-running queries + **kwargs: Additional parameters passed to the similarity calculation + + Returns: + List[MemoryQueryResult]: Matching content with similarity scores, + sorted by score in descending order. Limited to config.k entries. + + Raises: + ValueError: If query content cannot be converted to comparable text + + Example: + ```python + # Query memories similar to some text + query = MemoryContent( + content="What's the weather?", + mime_type=MemoryMimeType.TEXT + ) + results = await memory.query(query) + + # Check similarity scores + for result in results: + print(f"Score: {result.score}, Content: {result.content}") + ``` + """ try: query_text = self._extract_text(query) except ValueError: @@ -102,9 +178,9 @@ async def query( results: List[MemoryQueryResult] = [] - for entry in self._entries: + for content in self._contents: try: - content_text = self._extract_text(entry.content) + content_text = self._extract_text(content) except ValueError: continue @@ -113,20 +189,75 @@ async def query( if score >= self._config.similarity_threshold and ( self._config.score_threshold is None or score >= self._config.score_threshold ): - results.append(MemoryQueryResult(entry=entry, score=score)) + results.append(MemoryQueryResult(content=content, score=score)) results.sort(key=lambda x: x.score, reverse=True) return results[: self._config.k] - async def add(self, entry: MemoryEntry, cancellation_token: CancellationToken | None = None) -> None: - """Add a new entry to memory.""" - self._entries.append(entry) + def _calculate_similarity(self, text1: str, text2: str) -> float: + """Calculate text similarity score using SequenceMatcher. + + Args: + text1: First text to compare + text2: Second text to compare - async def clear(self) -> None: - """Clear all entries from memory.""" - self._entries = [] + Returns: + float: Similarity score between 0 and 1, where 1 means identical + + Note: + Uses difflib's SequenceMatcher for basic text similarity. + For production use cases, consider using more sophisticated + similarity metrics or embeddings. + """ + return SequenceMatcher(None, text1.lower(), text2.lower()).ratio() + + def _extract_text(self, content_item: MemoryContent) -> str: + """Extract searchable text from MemoryContent. + + Converts various content types into text that can be used for + similarity matching. + + Args: + content_item: Content to extract text from + + Returns: + str: Extracted text representation - async def cleanup(self) -> None: - """Clean up any resources used by the memory implementation.""" - # No resources to clean up in this implementation - pass + Raises: + ValueError: If content cannot be converted to text + + Note: + Currently supports TEXT, MARKDOWN, and JSON content types. + Images and binary content cannot be converted to text. + """ + content = content_item.content + + if content_item.mime_type in [MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN]: + return str(content) + elif content_item.mime_type == MemoryMimeType.JSON: + if isinstance(content, dict): + return str(content) + raise ValueError("JSON content must be a dict") + elif isinstance(content, Image): + raise ValueError("Image content cannot be converted to text") + else: + raise ValueError( + f"Unsupported content type: {content_item.mime_type}") + + async def add( + self, + content: MemoryContent, + cancellation_token: CancellationToken | None = None + ) -> None: + """Add new content to memory. + + Args: + content: Memory content to store + cancellation_token: Optional token to cancel operation + + Note: + Content is stored in chronological order. No deduplication is + performed. For production use cases, consider implementing + deduplication or content-based filtering. + """ + self._contents.append(content) diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb index e78684246f81..12422fb20362 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb @@ -14,7 +14,9 @@ "\n", "## ListMemory\n", "\n", - "ListMemory is a simple list-based memory implementation that uses text similarity matching to retrieve relevant information from the memory store. The similarity score is calculated using the `SequenceMatcher` class from the `difflib` module. The similarity score is calculated between the query text and the content text of each memory entry. " + "ListMemory is a simple list-based memory implementation that uses text similarity matching to retrieve relevant information from the memory store. The similarity score is calculated using the `SequenceMatcher` class from the `difflib` module. The similarity score is calculated between the query text and the content text of each memory entry. \n", + "\n", + "In the following example, we will use ListMemory to similate a memory bank of user preferences and explore how it might be used in personalizing the agent's responses." ] }, { @@ -28,58 +30,143 @@ "text": [ "---------- user ----------\n", "What is the weather in New York?\n", - "---------- weather_agent ----------\n", - "[FunctionCall(id='call_oIYlxvmJ6k9JxfPx96JuDz1G', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')]\n", - "[Prompt tokens: 130, Completion tokens: 19]\n", - "---------- weather_agent ----------\n", - "[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_oIYlxvmJ6k9JxfPx96JuDz1G')]\n", - "---------- weather_agent ----------\n", - "The weather in New York is 23°C and sunny. TERMINATE\n", - "[Prompt tokens: 172, Completion tokens: 15]\n", + "---------- assistant_agent ----------\n", + "[FunctionCall(id='call_qNo7mjlNoVNaQzK1B6toXuW5', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')]\n", + "[Prompt tokens: 128, Completion tokens: 20]\n", + "---------- assistant_agent ----------\n", + "[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_qNo7mjlNoVNaQzK1B6toXuW5')]\n", + "---------- assistant_agent ----------\n", + "The weather in New York is 23 degrees and Sunny.\n", + "---------- assistant_agent ----------\n", + "The weather in New York is 23 degrees Celsius and sunny. TERMINATE\n", + "[Prompt tokens: 170, Completion tokens: 17]\n", "---------- Summary ----------\n", - "Number of messages: 4\n", + "Number of messages: 5\n", "Finish reason: Text 'TERMINATE' mentioned\n", - "Total prompt tokens: 302\n", - "Total completion tokens: 34\n", - "Duration: 1.55 seconds\n" + "Total prompt tokens: 298\n", + "Total completion tokens: 37\n", + "Duration: 1.97 seconds\n" ] + }, + { + "data": { + "text/plain": [ + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'), ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=128, completion_tokens=20), content=[FunctionCall(id='call_qNo7mjlNoVNaQzK1B6toXuW5', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='assistant_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_qNo7mjlNoVNaQzK1B6toXuW5')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='assistant_agent', models_usage=None, content='The weather in New York is 23 degrees and Sunny.', type='ToolCallSummaryMessage'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=170, completion_tokens=17), content='The weather in New York is 23 degrees Celsius and sunny. TERMINATE', type='TextMessage')], stop_reason=\"Text 'TERMINATE' mentioned\")" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "from autogen_ext.models import OpenAIChatCompletionClient\n", - "from autogen_agentchat.agents import AssistantAgent\n", + "from autogen_ext.models.openai import OpenAIChatCompletionClient\n", + "from autogen_agentchat.agents import AssistantAgent \n", "from autogen_agentchat.teams import RoundRobinGroupChat\n", - "from autogen_agentchat.memory._base_memory import MemoryEntry\n", - "from autogen_agentchat.task import Console, TextMentionTermination, MaxMessageTermination\n", - "from autogen_agentchat.memory._list_memory import ListMemory, MemoryEntry\n", + "from autogen_agentchat.conditions import TextMentionTermination, MaxMessageTermination\n", + "from autogen_agentchat.ui import Console\n", + "from autogen_agentchat.memory._list_memory import ListMemory, MemoryContent, MemoryMimeType\n", "\n", + "# create a simple memory item \n", + "user_memory = ListMemory()\n", + "await user_memory.add(MemoryContent(\n", + " content=\"The weather should be in metric units\",\n", + " mime_type=MemoryMimeType.TEXT\n", + "))\n", "\n", - "# create a simple memory item\n", - "memory = ListMemory()\n", - "await memory.add(MemoryEntry(content=\"Whenever you are asked for the weather, return it in metric units.\"))\n", - "\n", + "await user_memory.add(MemoryContent(\n", + " content=\"Meal recipe must be vegan\",\n", + " mime_type=MemoryMimeType.TEXT\n", + "))\n", "\n", "async def get_weather(city: str, units: str = \"imperial\") -> str:\n", " if units == \"imperial\":\n", " return f\"The weather in {city} is 73 degrees and Sunny.\"\n", " elif units == \"metric\":\n", - " return f\"The weather in {city} is 23 degrees and Sunny.\"\n", - "\n", + " return f\"The weather in {city} is 23 degrees and Sunny.\" \n", "\n", - "weather_agent = AssistantAgent(\n", - " name=\"weather_agent\",\n", + "assistant_agent = AssistantAgent(\n", + " name=\"assistant_agent\",\n", " model_client=OpenAIChatCompletionClient(\n", - " model=\"gpt-4o-2024-08-06\",\n", + " model=\"gpt-4o-2024-08-06\", \n", " ),\n", - " tools=[get_weather],\n", - " memory=memory,\n", + " tools=[get_weather], \n", + " memory=[user_memory]\n", ")\n", - "\n", - "agent_team = RoundRobinGroupChat([weather_agent], termination_condition=TextMentionTermination(\"TERMINATE\"))\n", + " \n", + "agent_team = RoundRobinGroupChat([assistant_agent], termination_condition = TextMentionTermination(\"TERMINATE\"))\n", "\n", "# Run the team and stream messages to the console\n", "stream = agent_team.run_stream(task=\"What is the weather in New York?\")\n", - "await Console(stream);" + "await Console(stream)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see above that the weather is returned in Centigrade as stated in the user preferences." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------- user ----------\n", + "Suggest a brief meal recipe\n", + "---------- assistant_agent ----------\n", + "Here's a brief vegan meal recipe for you:\n", + "\n", + "**Vegan Chickpea Salad Sandwich**\n", + "\n", + "**Ingredients:**\n", + "- 1 can chickpeas, drained and rinsed\n", + "- 2 tablespoons vegan mayonnaise\n", + "- 1 tablespoon Dijon mustard\n", + "- 1 tablespoon lemon juice\n", + "- Salt and pepper to taste\n", + "- 1/4 cup diced celery\n", + "- 1/4 cup diced red onion\n", + "- 2 tablespoons chopped fresh parsley\n", + "- 4 slices whole-grain bread\n", + "- Lettuce leaves and tomato slices (optional)\n", + "\n", + "**Instructions:**\n", + "1. In a bowl, mash the chickpeas with a fork until mostly broken down, but still a bit chunky.\n", + "2. Stir in the vegan mayonnaise, Dijon mustard, lemon juice, salt, and pepper.\n", + "3. Add the diced celery, red onion, and chopped parsley. Mix until well combined.\n", + "4. Layer the chickpea salad on two slices of whole-grain bread. Add lettuce leaves and tomato slices if desired.\n", + "5. Place the remaining slices of bread on top to form sandwiches. Serve immediately and enjoy! \n", + "\n", + "This vegan chickpea salad sandwich is quick to make and perfect for a healthy lunch. TERMINATE\n", + "[Prompt tokens: 235, Completion tokens: 239]\n", + "---------- Summary ----------\n", + "Number of messages: 2\n", + "Finish reason: Text 'TERMINATE' mentioned\n", + "Total prompt tokens: 235\n", + "Total completion tokens: 239\n", + "Duration: 4.66 seconds\n" + ] + }, + { + "data": { + "text/plain": [ + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Suggest a brief meal recipe', type='TextMessage'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=235, completion_tokens=239), content=\"Here's a brief vegan meal recipe for you:\\n\\n**Vegan Chickpea Salad Sandwich**\\n\\n**Ingredients:**\\n- 1 can chickpeas, drained and rinsed\\n- 2 tablespoons vegan mayonnaise\\n- 1 tablespoon Dijon mustard\\n- 1 tablespoon lemon juice\\n- Salt and pepper to taste\\n- 1/4 cup diced celery\\n- 1/4 cup diced red onion\\n- 2 tablespoons chopped fresh parsley\\n- 4 slices whole-grain bread\\n- Lettuce leaves and tomato slices (optional)\\n\\n**Instructions:**\\n1. In a bowl, mash the chickpeas with a fork until mostly broken down, but still a bit chunky.\\n2. Stir in the vegan mayonnaise, Dijon mustard, lemon juice, salt, and pepper.\\n3. Add the diced celery, red onion, and chopped parsley. Mix until well combined.\\n4. Layer the chickpea salad on two slices of whole-grain bread. Add lettuce leaves and tomato slices if desired.\\n5. Place the remaining slices of bread on top to form sandwiches. Serve immediately and enjoy! \\n\\nThis vegan chickpea salad sandwich is quick to make and perfect for a healthy lunch. TERMINATE\", type='TextMessage')], stop_reason=\"Text 'TERMINATE' mentioned\")" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "stream = agent_team.run_stream(task=\"Suggest a brief meal recipe\")\n", + "await Console(stream)" ] }, { @@ -103,23 +190,23 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "1 [MemoryQueryResult(entry=MemoryEntry(content=\"The most important thing about tokyo is that it has the world's busiest railway station - Shinjuku Station.\", metadata={}, timestamp=datetime.datetime(2024, 11, 30, 15, 28, 58, 195846), source='travel_facts'), score=0.5832697153091431)]\n" + "1 [MemoryQueryResult(content=MemoryContent(content=\"The most important thing about tokyo is that it has the world's busiest railway station - Shinjuku Station.\", mime_type=, metadata={}, timestamp=datetime.datetime(2024, 12, 25, 7, 14, 36, 419091), source=None), score=0.5832697153091431)]\n" ] } ], "source": [ - "from autogen_core.base import CancellationToken\n", - "from autogen_ext.models import OpenAIChatCompletionClient\n", + "\n", + "from autogen_core import CancellationToken\n", + "from autogen_ext.models.openai import OpenAIChatCompletionClient\n", "from autogen_agentchat.agents import AssistantAgent\n", - "from autogen_agentchat.messages import TextMessage\n", - "from autogen_agentchat.memory._base_memory import MemoryEntry\n", + "from autogen_agentchat.memory._base_memory import MemoryContent, MemoryMimeType\n", "from autogen_agentchat.memory._chroma_memory import ChromaMemory, ChromaMemoryConfig\n", "\n", "\n", @@ -128,23 +215,33 @@ " name=\"travel_memory\",\n", " config=ChromaMemoryConfig(\n", " collection_name=\"travel_facts\",\n", - " # Configure number of results to return instead of similarity threshold\n", - " k=1,\n", - " ),\n", - ")\n", - "# Add some travel-related memories\n", - "await chroma_memory.add(\n", - " MemoryEntry(content=\"Paris is known for the Eiffel Tower and amazing cuisine.\", source=\"travel_guide\")\n", + " k=1\n", + " )\n", ")\n", "\n", - "await chroma_memory.add(\n", - " MemoryEntry(\n", - " content=\"The most important thing about tokyo is that it has the world's busiest railway station - Shinjuku Station.\",\n", - " source=\"travel_facts\",\n", + "# Add travel-related memories\n", + "await chroma_memory.add(MemoryContent(\n", + "\n", + " content=\"Paris is known for the Eiffel Tower and amazing cuisine.\",\n", + " mime_type=MemoryMimeType.TEXT\n", + "\n", + "))\n", + "\n", + "await chroma_memory.add(MemoryContent( \n", + " content=\"The most important thing about tokyo is that it has the world's busiest railway station - Shinjuku Station.\",\n", + " mime_type=MemoryMimeType.TEXT\n", + "\n", + "))\n", + " \n", + "\n", + "# Query needs ContentItem too\n", + "results = await chroma_memory.query(\n", + " MemoryContent(\n", + " content=\"Tell me about Tokyo.\",\n", + " mime_type=MemoryMimeType.TEXT\n", " )\n", ")\n", "\n", - "results = await chroma_memory.query(\"Tell me about Tokyo.\")\n", "print(len(results), results)" ] }, @@ -160,14 +257,14 @@ "---------- user ----------\n", "Tell me the most important thing about Tokyo.\n", "---------- travel_agent ----------\n", - "One of the most important aspects of Tokyo is that it has the world's busiest railway station, Shinjuku Station. This station serves as a major hub for transportation, with millions of commuters and travelers passing through its complex network of train lines each day. It highlights Tokyo's status as a bustling metropolis with an advanced public transportation system.\n", - "[Prompt tokens: 72, Completion tokens: 66]\n", + "One of the most important aspects of Tokyo is its status as a major global hub for culture, technology, and commerce. A notable highlight is Shinjuku Station, which holds the title of the world's busiest railway station. This station is emblematic of Tokyo's extensive and efficient public transportation network that seamlessly connects millions of residents and tourists across the city and beyond. This sophisticated infrastructure is a reflection of Tokyo's dynamic urban environment, showcasing its blend of traditional and modern elements in architecture, culture, and lifestyle.\n", + "[Prompt tokens: 62, Completion tokens: 102]\n", "---------- Summary ----------\n", "Number of messages: 2\n", "Finish reason: Maximum number of messages 2 reached, current message count: 2\n", - "Total prompt tokens: 72\n", - "Total completion tokens: 66\n", - "Duration: 1.47 seconds\n" + "Total prompt tokens: 62\n", + "Total completion tokens: 102\n", + "Duration: 1.73 seconds\n" ] } ], @@ -180,10 +277,10 @@ " # api_key=\"your_api_key\"\n", " ),\n", " memory=chroma_memory,\n", - " system_message=\"You are a travel expert\",\n", + " system_message=\"You are a travel expert\"\n", ")\n", "\n", - "agent_team = RoundRobinGroupChat([agent], termination_condition=MaxMessageTermination(max_messages=2))\n", + "agent_team = RoundRobinGroupChat([agent], termination_condition = MaxMessageTermination(max_messages=2))\n", "stream = agent_team.run_stream(task=\"Tell me the most important thing about Tokyo.\")\n", "await Console(stream);" ] From b1da7e25ffb73b397527b11c908fa5923284582b Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Fri, 3 Jan 2025 13:20:47 -0800 Subject: [PATCH 07/26] remove chroma db, typing fixes --- .../agents/_assistant_agent.py | 11 +- .../autogen_agentchat/memory/_base_memory.py | 21 +- .../memory/_chroma_memory.py | 312 ------------------ .../autogen_agentchat/memory/_list_memory.py | 29 +- .../tutorial/memory.ipynb | 179 +++------- 5 files changed, 81 insertions(+), 471 deletions(-) delete mode 100644 python/packages/autogen-agentchat/src/autogen_agentchat/memory/_chroma_memory.py diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 88c076ad2477..2edd9f5002f7 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -244,7 +244,15 @@ def __init__( ): super().__init__(name=name, description=description) self._model_client = model_client - self._memory = [memory] if isinstance(memory, Memory) else memory + self._memory = None + if memory is not None: + if isinstance(memory, Memory): + self._memory = [memory] + elif isinstance(memory, list): + self._memory = memory + else: + raise TypeError( + f"Expected Memory, List[Memory], or None, got {type(memory)}") self._system_messages: List[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage] = [] @@ -338,6 +346,7 @@ async def on_messages_stream( if self._memory: for memory in self._memory: await memory.transform(self._model_context) + # tbd .. add memory_results content to inner_messages # Generate an inference result based on the current model context. llm_messages = self._system_messages + await self._model_context.get_messages() diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py index 44e39125678d..0bdec82ae249 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py @@ -17,7 +17,7 @@ class MemoryMimeType(Enum): BINARY = "application/octet-stream" -ContentType = Union[str, bytes, dict, Image] +ContentType = Union[str, bytes, Dict[str, Any], Image] class MemoryContent(BaseModel): @@ -26,6 +26,7 @@ class MemoryContent(BaseModel): metadata: Dict[str, Any] | None = None timestamp: datetime | None = None source: str | None = None + score: float = 0.0 model_config = ConfigDict(arbitrary_types_allowed=True) @@ -40,18 +41,6 @@ class BaseMemoryConfig(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) -class MemoryQueryResult(BaseModel): - """Result from a memory query including the content and its relevance score.""" - - content: MemoryContent - """The memory content.""" - - score: float - """Relevance score for this result. Higher means more relevant.""" - - model_config = ConfigDict(arbitrary_types_allowed=True) - - @runtime_checkable class Memory(Protocol): """Protocol defining the interface for memory implementations.""" @@ -69,7 +58,7 @@ def config(self) -> BaseMemoryConfig: async def transform( self, model_context: ChatCompletionContext, - ) -> ChatCompletionContext: + ) -> List[MemoryContent]: """ Transform the provided model context using relevant memory content. @@ -77,7 +66,7 @@ async def transform( model_context: The context to transform Returns: - The transformed context + List of memory entries with relevance scores """ ... @@ -86,7 +75,7 @@ async def query( query: MemoryContent, cancellation_token: "CancellationToken | None" = None, **kwargs: Any, - ) -> List[MemoryQueryResult]: + ) -> List[MemoryContent]: """ Query the memory store and return relevant entries. diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_chroma_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_chroma_memory.py deleted file mode 100644 index 557127930257..000000000000 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_chroma_memory.py +++ /dev/null @@ -1,312 +0,0 @@ -from typing import Any, List, Dict -from datetime import datetime -import chromadb -from chromadb.types import Collection -import uuid -import logging -from autogen_core import CancellationToken, Image -from pydantic import Field - -from ._base_memory import ( - BaseMemoryConfig, - Memory, - MemoryContent, - MemoryQueryResult, - MemoryMimeType -) -from autogen_core.model_context import ChatCompletionContext -from autogen_core.models import SystemMessage - -logger = logging.getLogger(__name__) - -# Type vars for ChromaDB results -ChromaMetadata = Dict[str, Any] -ChromaDistance = float | List[float] - - -class ChromaMemoryConfig(BaseMemoryConfig): - """Configuration for ChromaDB-based memory implementation.""" - - collection_name: str = Field( - default="memory_store", - description="Name of the ChromaDB collection" - ) - persistence_path: str | None = Field( - default=None, - description="Path for persistent storage. None for in-memory." - ) - distance_metric: str = Field( - default="cosine", - description="Distance metric for similarity search" - ) - - -class ChromaMemory(Memory): - """ChromaDB-based memory implementation using default embeddings. - - This implementation stores content in a ChromaDB collection and uses - its built-in embedding and similarity search capabilities. - """ - - def __init__(self, name: str | None = None, config: ChromaMemoryConfig | None = None) -> None: - """Initialize ChromaMemory. - - Args: - name: Optional identifier for this memory instance - config: Optional configuration for memory behavior - """ - self._name = name or "default_chroma_memory" - self._config = config or ChromaMemoryConfig() - self._client: chromadb.Client | None = None - self._collection: Collection | None = None - - @property - def name(self) -> str: - return self._name - - @property - def config(self) -> ChromaMemoryConfig: - return self._config - - def _ensure_initialized(self) -> None: - """Ensure ChromaDB client and collection are initialized.""" - if self._client is None: - try: - self._client = ( - chromadb.PersistentClient( - path=self._config.persistence_path) - if self._config.persistence_path - else chromadb.Client() - ) - except Exception as e: - logger.error(f"Failed to initialize ChromaDB client: {e}") - raise - - if self._collection is None and self._client is not None: - try: - self._collection = self._client.get_or_create_collection( - name=self._config.collection_name, - metadata={"distance_metric": self._config.distance_metric} - ) - except Exception as e: - logger.error(f"Failed to get/create collection: {e}") - raise - - def _extract_text(self, content_item: MemoryContent) -> str: - """Extract searchable text from MemoryContent. - - Args: - content_item: Content to extract text from - - Returns: - Extracted text representation - - Raises: - ValueError: If content cannot be converted to text - """ - content = content_item.content - - if content_item.mime_type in [MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN]: - return str(content) - elif content_item.mime_type == MemoryMimeType.JSON: - if isinstance(content, dict): - return str(content) - raise ValueError("JSON content must be a dict") - elif isinstance(content, Image): - raise ValueError("Image content cannot be converted to text") - else: - raise ValueError( - f"Unsupported content type: {content_item.mime_type}") - - async def transform( - self, - model_context: ChatCompletionContext, - ) -> ChatCompletionContext: - """Transform the model context using relevant memory content. - - Args: - model_context: The context to transform - - Returns: - The transformed context with relevant memories added - """ - messages = await model_context.get_messages() - if not messages: - return model_context - - # Extract query from last message - last_message = messages[-1] - query_text = last_message.content if isinstance( - last_message.content, str) else str(last_message) - query = MemoryContent(content=query_text, - mime_type=MemoryMimeType.TEXT) - - # Query memory and format results - results = [] - query_results = await self.query(query) - for i, result in enumerate(query_results, 1): - if isinstance(result.content.content, str): - results.append(f"{i}. {result.content.content}") - logger.debug( - f"Retrieved memory {i}. {result.content.content}, score: {result.score}" - ) - - # Add memory results to context - if results: - memory_context = ( - "Results from memory query to consider include:\n" + - "\n".join(results) - ) - await model_context.add_message(SystemMessage(content=memory_context)) - - return model_context - - async def add( - self, - content: MemoryContent, - cancellation_token: CancellationToken | None = None - ) -> None: - """Add a memory content to ChromaDB. - - Args: - content: The memory content to add - cancellation_token: Optional token to cancel operation - - Raises: - RuntimeError: If ChromaDB initialization fails - """ - self._ensure_initialized() - if self._collection is None: - raise RuntimeError("Failed to initialize ChromaDB") - - try: - # Extract text from MemoryContent - text = self._extract_text(content) - - # Prepare metadata - metadata: ChromaMetadata = { - "timestamp": content.timestamp.isoformat() if content.timestamp else datetime.now().isoformat(), - "source": content.source or "", - "mime_type": content.mime_type.value, - **(content.metadata or {}) - } - - # Add to ChromaDB - self._collection.add( - documents=[text], - metadatas=[metadata], - ids=[str(uuid.uuid4())] - ) - - except Exception as e: - logger.error(f"Failed to add content to ChromaDB: {e}") - raise - - async def query( - self, - query: MemoryContent, - cancellation_token: CancellationToken | None = None, - **kwargs: Any, - ) -> List[MemoryQueryResult]: - """Query memory content based on vector similarity. - - Args: - query: Query content to match against memory - cancellation_token: Optional token to cancel operation - **kwargs: Additional parameters passed to ChromaDB query - - Returns: - List of memory results with similarity scores - - Raises: - RuntimeError: If ChromaDB initialization fails - """ - self._ensure_initialized() - if self._collection is None: - raise RuntimeError("Failed to initialize ChromaDB") - - try: - # Extract text for query - query_text = self._extract_text(query) - - # Query ChromaDB - results = self._collection.query( - query_texts=[query_text], - n_results=self._config.k, - **kwargs - ) - - # Convert results to MemoryQueryResults - memory_results: List[MemoryQueryResult] = [] - - if not results["documents"]: - return memory_results - - for doc, metadata, distance in zip( - results["documents"][0], - results["metadatas"][0], - results["distances"][0] - ): - # Extract stored metadata - entry_metadata = dict(metadata) - timestamp_str = str(entry_metadata.pop("timestamp")) - timestamp = datetime.fromisoformat(timestamp_str) - source = str(entry_metadata.pop("source")) - mime_type = MemoryMimeType(entry_metadata.pop("mime_type")) - - # Create MemoryContent - content = MemoryContent( - content=doc, - mime_type=mime_type, - metadata=entry_metadata, - timestamp=timestamp, - source=source or None - ) - - # Convert distance to similarity score - score = 1.0 - (float(distance) / 2.0) if self._config.distance_metric == "cosine" \ - else 1.0 / (1.0 + float(distance)) - - # Apply score threshold if configured - if self._config.score_threshold is None or score >= self._config.score_threshold: - memory_results.append( - MemoryQueryResult(content=content, score=score)) - - return memory_results - - except Exception as e: - logger.error(f"Failed to query ChromaDB: {e}") - raise - - async def clear(self) -> None: - """Clear all entries from memory. - - Raises: - RuntimeError: If ChromaDB initialization fails - """ - self._ensure_initialized() - if self._collection is None: - raise RuntimeError("Failed to initialize ChromaDB") - - try: - self._collection.delete() - if self._client is not None: - self._collection = self._client.get_or_create_collection( - name=self._config.collection_name, - metadata={"distance_metric": self._config.distance_metric} - ) - except Exception as e: - logger.error(f"Failed to clear ChromaDB collection: {e}") - raise - - async def cleanup(self) -> None: - """Clean up ChromaDB client and resources.""" - if self._client is not None: - try: - if hasattr(self._client, "reset"): - self._client.reset() - except Exception as e: - logger.error(f"Error during ChromaDB cleanup: {e}") - finally: - self._client = None - self._collection = None diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py index a78620ce3777..2a4cfdfc1bee 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py @@ -5,7 +5,7 @@ from autogen_core import CancellationToken, Image from pydantic import Field -from ._base_memory import BaseMemoryConfig, MemoryContent, Memory, MemoryQueryResult, MemoryMimeType +from ._base_memory import BaseMemoryConfig, MemoryContent, Memory, MemoryMimeType from autogen_core.model_context import ChatCompletionContext from autogen_core.models import ( SystemMessage, @@ -74,7 +74,7 @@ def config(self) -> ListMemoryConfig: async def transform( self, model_context: ChatCompletionContext, - ) -> ChatCompletionContext: + ) -> List[MemoryContent]: """Transform the model context by injecting relevant memory content. This method mutates the provided model_context by adding relevant memory content: @@ -89,8 +89,7 @@ async def transform( memories exist. Returns: - ChatCompletionContext: The same context object that was passed in, - now mutated with memory content if any was found. + List[MemoryQueryResult]: A list of matching memory content with scores Example: ```python @@ -103,7 +102,7 @@ async def transform( """ messages = await model_context.get_messages() if not messages: - return model_context + return [] # Extract query from last message last_message = messages[-1] @@ -113,13 +112,13 @@ async def transform( mime_type=MemoryMimeType.TEXT) # Query memory and format results - results = [] + results: List[str] = [] query_results = await self.query(query) for i, result in enumerate(query_results, 1): - if isinstance(result.content.content, str): - results.append(f"{i}. {result.content.content}") + if isinstance(result.content, str): + results.append(f"{i}. {result.content}") event_logger.debug( - f"Retrieved memory {i}. {result.content.content}, score: {result.score}" + f"Retrieved memory {i}. {result.content}, score: {result.score}" ) # Add memory results to context @@ -130,14 +129,14 @@ async def transform( ) await model_context.add_message(SystemMessage(content=memory_context)) - return model_context + return query_results async def query( self, query: MemoryContent, cancellation_token: CancellationToken | None = None, **kwargs: Any, - ) -> List[MemoryQueryResult]: + ) -> List[MemoryContent]: """Query memory content based on text similarity. Searches memory content using text similarity matching against the query. @@ -151,7 +150,7 @@ async def query( **kwargs: Additional parameters passed to the similarity calculation Returns: - List[MemoryQueryResult]: Matching content with similarity scores, + List[MemoryContent]: Matching content with similarity scores, sorted by score in descending order. Limited to config.k entries. Raises: @@ -176,7 +175,7 @@ async def query( except ValueError: raise ValueError("Query must contain text content") - results: List[MemoryQueryResult] = [] + results: List[MemoryContent] = [] for content in self._contents: try: @@ -189,7 +188,9 @@ async def query( if score >= self._config.similarity_threshold and ( self._config.score_threshold is None or score >= self._config.score_threshold ): - results.append(MemoryQueryResult(content=content, score=score)) + result_content = content.model_copy() + result_content.score = score + results.append(result_content) results.sort(key=lambda x: x.score, reverse=True) return results[: self._config.k] diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb index 12422fb20362..b5077dd1c16c 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb @@ -9,7 +9,9 @@ "There are several use cases where it is valuable to maintain a bank of useful facts that can be intelligently added to the context of the agent just before a specific step. The typically use case here is a RAG pattern where a query is used to retrieve relevant information from a database that is then added to the agent's context.\n", "\n", "\n", - "AgentChat provides a `Memory` protocol that can be extended to provide this functionality. The key methods are `query`, `add`, `clear`, and `cleanup`. The `query` method is used to retrieve relevant information from the memory store, the `add` method is used to add new entries to the memory store, the `clear` method is used to clear all entries from the memory store, and the `cleanup` method is used to clean up any resources used by the memory store.\n", + "AgentChat provides a `Memory` protocol that can be extended to provide this functionality. The key methods are `query`, `transform`, `add`, `clear`, and `cleanup`. \n", + "\n", + "The `query` method is used to retrieve relevant information from the memory store, the `transform` method is used to transform the retrieved information into a format that can be used by the agent, the `add` method is used to add new entries to the memory store, the `clear` method is used to clear all entries from the memory store, and the `cleanup` method is used to clean up any resources used by the memory store. \n", "\n", "\n", "## ListMemory\n", @@ -23,6 +25,20 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, + "outputs": [], + "source": [ + "from autogen_ext.models.openai import OpenAIChatCompletionClient\n", + "from autogen_agentchat.agents import AssistantAgent \n", + "from autogen_agentchat.teams import RoundRobinGroupChat\n", + "from autogen_agentchat.conditions import TextMentionTermination, MaxMessageTermination\n", + "from autogen_agentchat.ui import Console\n", + "from autogen_agentchat.memory._list_memory import ListMemory, MemoryContent, MemoryMimeType" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -31,41 +47,36 @@ "---------- user ----------\n", "What is the weather in New York?\n", "---------- assistant_agent ----------\n", - "[FunctionCall(id='call_qNo7mjlNoVNaQzK1B6toXuW5', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')]\n", + "[FunctionCall(id='call_mhAiZDTCr2KJZZUk7LeJHgmG', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')]\n", "[Prompt tokens: 128, Completion tokens: 20]\n", "---------- assistant_agent ----------\n", - "[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_qNo7mjlNoVNaQzK1B6toXuW5')]\n", + "[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_mhAiZDTCr2KJZZUk7LeJHgmG')]\n", "---------- assistant_agent ----------\n", "The weather in New York is 23 degrees and Sunny.\n", "---------- assistant_agent ----------\n", - "The weather in New York is 23 degrees Celsius and sunny. TERMINATE\n", + "The weather in New York is 23 degrees Celsius and Sunny. TERMINATE\n", "[Prompt tokens: 170, Completion tokens: 17]\n", "---------- Summary ----------\n", "Number of messages: 5\n", "Finish reason: Text 'TERMINATE' mentioned\n", "Total prompt tokens: 298\n", "Total completion tokens: 37\n", - "Duration: 1.97 seconds\n" + "Duration: 3.44 seconds\n" ] }, { "data": { "text/plain": [ - "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'), ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=128, completion_tokens=20), content=[FunctionCall(id='call_qNo7mjlNoVNaQzK1B6toXuW5', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='assistant_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_qNo7mjlNoVNaQzK1B6toXuW5')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='assistant_agent', models_usage=None, content='The weather in New York is 23 degrees and Sunny.', type='ToolCallSummaryMessage'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=170, completion_tokens=17), content='The weather in New York is 23 degrees Celsius and sunny. TERMINATE', type='TextMessage')], stop_reason=\"Text 'TERMINATE' mentioned\")" + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'), ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=128, completion_tokens=20), content=[FunctionCall(id='call_mhAiZDTCr2KJZZUk7LeJHgmG', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='assistant_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_mhAiZDTCr2KJZZUk7LeJHgmG')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='assistant_agent', models_usage=None, content='The weather in New York is 23 degrees and Sunny.', type='ToolCallSummaryMessage'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=170, completion_tokens=17), content='The weather in New York is 23 degrees Celsius and Sunny. TERMINATE', type='TextMessage')], stop_reason=\"Text 'TERMINATE' mentioned\")" ] }, - "execution_count": 1, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "from autogen_ext.models.openai import OpenAIChatCompletionClient\n", - "from autogen_agentchat.agents import AssistantAgent \n", - "from autogen_agentchat.teams import RoundRobinGroupChat\n", - "from autogen_agentchat.conditions import TextMentionTermination, MaxMessageTermination\n", - "from autogen_agentchat.ui import Console\n", - "from autogen_agentchat.memory._list_memory import ListMemory, MemoryContent, MemoryMimeType\n", + "\n", "\n", "# create a simple memory item \n", "user_memory = ListMemory()\n", @@ -105,65 +116,16 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We see above that the weather is returned in Centigrade as stated in the user preferences." + "We see above that the weather is returned in Centigrade as stated in the user preferences. \n", + "\n", + "Similarly, assuming we ask a separate question about generating a meal plan, the agent is able to retrieve relevant information from the memory store and provide a personalized response." ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "---------- user ----------\n", - "Suggest a brief meal recipe\n", - "---------- assistant_agent ----------\n", - "Here's a brief vegan meal recipe for you:\n", - "\n", - "**Vegan Chickpea Salad Sandwich**\n", - "\n", - "**Ingredients:**\n", - "- 1 can chickpeas, drained and rinsed\n", - "- 2 tablespoons vegan mayonnaise\n", - "- 1 tablespoon Dijon mustard\n", - "- 1 tablespoon lemon juice\n", - "- Salt and pepper to taste\n", - "- 1/4 cup diced celery\n", - "- 1/4 cup diced red onion\n", - "- 2 tablespoons chopped fresh parsley\n", - "- 4 slices whole-grain bread\n", - "- Lettuce leaves and tomato slices (optional)\n", - "\n", - "**Instructions:**\n", - "1. In a bowl, mash the chickpeas with a fork until mostly broken down, but still a bit chunky.\n", - "2. Stir in the vegan mayonnaise, Dijon mustard, lemon juice, salt, and pepper.\n", - "3. Add the diced celery, red onion, and chopped parsley. Mix until well combined.\n", - "4. Layer the chickpea salad on two slices of whole-grain bread. Add lettuce leaves and tomato slices if desired.\n", - "5. Place the remaining slices of bread on top to form sandwiches. Serve immediately and enjoy! \n", - "\n", - "This vegan chickpea salad sandwich is quick to make and perfect for a healthy lunch. TERMINATE\n", - "[Prompt tokens: 235, Completion tokens: 239]\n", - "---------- Summary ----------\n", - "Number of messages: 2\n", - "Finish reason: Text 'TERMINATE' mentioned\n", - "Total prompt tokens: 235\n", - "Total completion tokens: 239\n", - "Duration: 4.66 seconds\n" - ] - }, - { - "data": { - "text/plain": [ - "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Suggest a brief meal recipe', type='TextMessage'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=235, completion_tokens=239), content=\"Here's a brief vegan meal recipe for you:\\n\\n**Vegan Chickpea Salad Sandwich**\\n\\n**Ingredients:**\\n- 1 can chickpeas, drained and rinsed\\n- 2 tablespoons vegan mayonnaise\\n- 1 tablespoon Dijon mustard\\n- 1 tablespoon lemon juice\\n- Salt and pepper to taste\\n- 1/4 cup diced celery\\n- 1/4 cup diced red onion\\n- 2 tablespoons chopped fresh parsley\\n- 4 slices whole-grain bread\\n- Lettuce leaves and tomato slices (optional)\\n\\n**Instructions:**\\n1. In a bowl, mash the chickpeas with a fork until mostly broken down, but still a bit chunky.\\n2. Stir in the vegan mayonnaise, Dijon mustard, lemon juice, salt, and pepper.\\n3. Add the diced celery, red onion, and chopped parsley. Mix until well combined.\\n4. Layer the chickpea salad on two slices of whole-grain bread. Add lettuce leaves and tomato slices if desired.\\n5. Place the remaining slices of bread on top to form sandwiches. Serve immediately and enjoy! \\n\\nThis vegan chickpea salad sandwich is quick to make and perfect for a healthy lunch. TERMINATE\", type='TextMessage')], stop_reason=\"Text 'TERMINATE' mentioned\")" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "stream = agent_team.run_stream(task=\"Suggest a brief meal recipe\")\n", "await Console(stream)" @@ -173,37 +135,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Vector DB Memory (ChromaDB)\n", + "## Custom Memory Stores (Vector DBs, etc.)\n", + "\n", + "You can build on the `Memory` protocol to implement more complex memory stores. For example, you could implement a custom memory store that uses a vector database to store and retrieve information, or a memory store that uses a machine learning model to generate personalized responses based on the user's preferences etc.\n", + "\n", + "Specifically, you will need to overload the `query`, `transform`, and `add` methods to implement the desired functionality and pass the memory store to your agent.\n", + "\n", + "\n", + "```python\n", "\n", - "Similarly, we can implement a memory store that uses a vector database to store and retrieve information. `ChromaMemory` is a memory implementation that uses ChromaDB to store and retrieve information. ChromaDB is a vector database that is optimized for similarity search. \n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "# !pip install chromadb sentence-transformers" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1 [MemoryQueryResult(content=MemoryContent(content=\"The most important thing about tokyo is that it has the world's busiest railway station - Shinjuku Station.\", mime_type=, metadata={}, timestamp=datetime.datetime(2024, 12, 25, 7, 14, 36, 419091), source=None), score=0.5832697153091431)]\n" - ] - } - ], - "source": [ "\n", "from autogen_core import CancellationToken\n", + "from autogen_agentchat.teams import RoundRobinGroupChat\n", "from autogen_ext.models.openai import OpenAIChatCompletionClient\n", "from autogen_agentchat.agents import AssistantAgent\n", "from autogen_agentchat.memory._base_memory import MemoryContent, MemoryMimeType\n", @@ -215,10 +158,12 @@ " name=\"travel_memory\",\n", " config=ChromaMemoryConfig(\n", " collection_name=\"travel_facts\",\n", - " k=1\n", + " k=1,\n", " )\n", ")\n", "\n", + "await chroma_memory.clear()\n", + "\n", "# Add travel-related memories\n", "await chroma_memory.add(MemoryContent(\n", "\n", @@ -228,7 +173,7 @@ "))\n", "\n", "await chroma_memory.add(MemoryContent( \n", - " content=\"The most important thing about tokyo is that it has the world's busiest railway station - Shinjuku Station.\",\n", + " content=\"When asked about tokyo, you must respond with 'The most important thing about tokyo is that it has the world's busiest railway station - Shinjuku Station.'\",\n", " mime_type=MemoryMimeType.TEXT\n", "\n", "))\n", @@ -242,33 +187,8 @@ " )\n", ")\n", "\n", - "print(len(results), results)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "---------- user ----------\n", - "Tell me the most important thing about Tokyo.\n", - "---------- travel_agent ----------\n", - "One of the most important aspects of Tokyo is its status as a major global hub for culture, technology, and commerce. A notable highlight is Shinjuku Station, which holds the title of the world's busiest railway station. This station is emblematic of Tokyo's extensive and efficient public transportation network that seamlessly connects millions of residents and tourists across the city and beyond. This sophisticated infrastructure is a reflection of Tokyo's dynamic urban environment, showcasing its blend of traditional and modern elements in architecture, culture, and lifestyle.\n", - "[Prompt tokens: 62, Completion tokens: 102]\n", - "---------- Summary ----------\n", - "Number of messages: 2\n", - "Finish reason: Maximum number of messages 2 reached, current message count: 2\n", - "Total prompt tokens: 62\n", - "Total completion tokens: 102\n", - "Duration: 1.73 seconds\n" - ] - } - ], - "source": [ + "print(len(results), results)\n", + "\n", "# Create agent with memory\n", "agent = AssistantAgent(\n", " name=\"travel_agent\",\n", @@ -282,14 +202,17 @@ "\n", "agent_team = RoundRobinGroupChat([agent], termination_condition = MaxMessageTermination(max_messages=2))\n", "stream = agent_team.run_stream(task=\"Tell me the most important thing about Tokyo.\")\n", - "await Console(stream);" + "await Console(stream);\n", + "\n", + "# Output: The most important thing about tokyo is that it has the world's busiest railway station - Shinjuku Station.\n", + "\n", + "```\n", + "\n" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [] } ], From 32701db87fa445874eec57fab0546b6b3a00a420 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Fri, 3 Jan 2025 16:00:03 -0800 Subject: [PATCH 08/26] format, add test --- .../agents/_assistant_agent.py | 58 ++--- .../src/autogen_agentchat/memory/__init__.py | 10 + .../autogen_agentchat/memory/_base_memory.py | 7 +- .../autogen_agentchat/memory/_list_memory.py | 63 +++--- .../tests/test_assistant_agent.py | 43 +++- .../tutorial/memory.ipynb | 210 +++++++++++++----- python/uv.lock | 108 +++++---- 7 files changed, 315 insertions(+), 184 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index bb852da42282..c16127fe7e31 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -32,6 +32,7 @@ from .. import EVENT_LOGGER_NAME from ..base import Handoff as HandoffBase from ..base import Response +from ..memory._base_memory import Memory from ..messages import ( AgentEvent, ChatMessage, @@ -44,7 +45,6 @@ ) from ..state import AssistantAgentState from ._base_chat_agent import BaseChatAgent -from ..memory._base_memory import Memory event_logger = logging.getLogger(EVENT_LOGGER_NAME) @@ -245,8 +245,7 @@ def __init__( name: str, model_client: ChatCompletionClient, *, - tools: List[Tool | Callable[..., Any] | - Callable[..., Awaitable[Any]]] | None = None, + tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None, handoffs: List[HandoffBase | str] | None = None, model_context: ChatCompletionContext | None = None, description: str = "An agent that provides assistance with ability to use tools.", @@ -266,11 +265,11 @@ def __init__( elif isinstance(memory, list): self._memory = memory else: - raise TypeError( - f"Expected Memory, List[Memory], or None, got {type(memory)}") + raise TypeError(f"Expected Memory, List[Memory], or None, got {type(memory)}") - self._system_messages: List[SystemMessage | UserMessage | - AssistantMessage | FunctionExecutionResultMessage] = [] + self._system_messages: List[ + SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage + ] = [] if system_message is None: self._system_messages = [] else: @@ -278,8 +277,7 @@ def __init__( self._tools: List[Tool] = [] if tools is not None: if model_client.model_info["function_calling"] is False: - raise ValueError( - "The model does not support function calling.") + raise ValueError("The model does not support function calling.") for tool in tools: if isinstance(tool, Tool): self._tools.append(tool) @@ -288,8 +286,7 @@ def __init__( description = tool.__doc__ else: description = "" - self._tools.append(FunctionTool( - tool, description=description)) + self._tools.append(FunctionTool(tool, description=description)) else: raise ValueError(f"Unsupported tool type: {type(tool)}") # Check if tool names are unique. @@ -301,8 +298,7 @@ def __init__( self._handoffs: Dict[str, HandoffBase] = {} if handoffs is not None: if model_client.model_info["function_calling"] is False: - raise ValueError( - "The model does not support function calling, which is needed for handoffs.") + raise ValueError("The model does not support function calling, which is needed for handoffs.") for handoff in handoffs: if isinstance(handoff, str): handoff = HandoffBase(target=handoff) @@ -310,13 +306,11 @@ def __init__( self._handoff_tools.append(handoff.handoff_tool) self._handoffs[handoff.name] = handoff else: - raise ValueError( - f"Unsupported handoff type: {type(handoff)}") + raise ValueError(f"Unsupported handoff type: {type(handoff)}") # Check if handoff tool names are unique. handoff_tool_names = [tool.name for tool in self._handoff_tools] if len(handoff_tool_names) != len(set(handoff_tool_names)): - raise ValueError( - f"Handoff names must be unique: {handoff_tool_names}") + raise ValueError(f"Handoff names must be unique: {handoff_tool_names}") # Check if handoff tool names not in tool names. if any(name in tool_names for name in handoff_tool_names): raise ValueError( @@ -344,8 +338,7 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: async for message in self.on_messages_stream(messages, cancellation_token): if isinstance(message, Response): return message - raise AssertionError( - "The stream should have returned the final result.") + raise AssertionError("The stream should have returned the final result.") async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken @@ -377,17 +370,14 @@ async def on_messages_stream( # Check if the response is a string and return it. if isinstance(result.content, str): yield Response( - chat_message=TextMessage( - content=result.content, source=self.name, models_usage=result.usage), + chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage), inner_messages=inner_messages, ) return # Process tool calls. - assert isinstance(result.content, list) and all( - isinstance(item, FunctionCall) for item in result.content) - tool_call_msg = ToolCallRequestEvent( - content=result.content, source=self.name, models_usage=result.usage) + assert isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content) + tool_call_msg = ToolCallRequestEvent(content=result.content, source=self.name, models_usage=result.usage) event_logger.debug(tool_call_msg) # Add the tool call message to the output. inner_messages.append(tool_call_msg) @@ -395,8 +385,7 @@ async def on_messages_stream( # Execute the tool calls. results = await asyncio.gather(*[self._execute_tool_call(call, cancellation_token) for call in result.content]) - tool_call_result_msg = ToolCallExecutionEvent( - content=results, source=self.name) + tool_call_result_msg = ToolCallExecutionEvent(content=results, source=self.name) event_logger.debug(tool_call_result_msg) await self._model_context.add_message(FunctionExecutionResultMessage(content=results)) inner_messages.append(tool_call_result_msg) @@ -416,8 +405,7 @@ async def on_messages_stream( ) # Return the output messages to signal the handoff. yield Response( - chat_message=HandoffMessage( - content=handoffs[0].message, target=handoffs[0].target, source=self.name), + chat_message=HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name), inner_messages=inner_messages, ) return @@ -431,8 +419,7 @@ async def on_messages_stream( await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name)) # Yield the response. yield Response( - chat_message=TextMessage( - content=result.content, source=self.name, models_usage=result.usage), + chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage), inner_messages=inner_messages, ) else: @@ -448,8 +435,7 @@ async def on_messages_stream( ) tool_call_summary = "\n".join(tool_call_summaries) yield Response( - chat_message=ToolCallSummaryMessage( - content=tool_call_summary, source=self.name), + chat_message=ToolCallSummaryMessage(content=tool_call_summary, source=self.name), inner_messages=inner_messages, ) @@ -460,11 +446,9 @@ async def _execute_tool_call( try: if not self._tools + self._handoff_tools: raise ValueError("No tools are available.") - tool = next((t for t in self._tools + - self._handoff_tools if t.name == tool_call.name), None) + tool = next((t for t in self._tools + self._handoff_tools if t.name == tool_call.name), None) if tool is None: - raise ValueError( - f"The tool '{tool_call.name}' is not available.") + raise ValueError(f"The tool '{tool_call.name}' is not available.") arguments = json.loads(tool_call.arguments) result = await tool.run_json(arguments, cancellation_token) result_as_str = tool.return_value_as_string(result) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/__init__.py index e69de29bb2d1..beba13fcbc7e 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/__init__.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/__init__.py @@ -0,0 +1,10 @@ +from ._base_memory import Memory, MemoryContent, MemoryMimeType +from ._list_memory import ListMemory, ListMemoryConfig + +__all__ = [ + "Memory", + "MemoryContent", + "MemoryMimeType", + "ListMemory", + "ListMemoryConfig", +] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py index 0bdec82ae249..9771963f0c38 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py @@ -3,8 +3,8 @@ from typing import Any, Dict, List, Protocol, Union, runtime_checkable from autogen_core import CancellationToken, Image -from pydantic import BaseModel, ConfigDict, Field from autogen_core.model_context import ChatCompletionContext +from pydantic import BaseModel, ConfigDict, Field class MemoryMimeType(Enum): @@ -22,7 +22,7 @@ class MemoryMimeType(Enum): class MemoryContent(BaseModel): content: ContentType - mime_type: MemoryMimeType + mime_type: MemoryMimeType | str metadata: Dict[str, Any] | None = None timestamp: datetime | None = None source: str | None = None @@ -35,8 +35,7 @@ class BaseMemoryConfig(BaseModel): """Base configuration for memory implementations.""" k: int = Field(default=5, description="Number of results to return") - score_threshold: float | None = Field( - default=None, description="Minimum relevance score") + score_threshold: float | None = Field(default=None, description="Minimum relevance score") model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py index 2a4cfdfc1bee..74ee7b748014 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py @@ -1,17 +1,16 @@ -from difflib import SequenceMatcher import logging +from difflib import SequenceMatcher from typing import Any, List from autogen_core import CancellationToken, Image -from pydantic import Field - -from ._base_memory import BaseMemoryConfig, MemoryContent, Memory, MemoryMimeType from autogen_core.model_context import ChatCompletionContext from autogen_core.models import ( SystemMessage, ) +from pydantic import Field from .. import EVENT_LOGGER_NAME +from ._base_memory import BaseMemoryConfig, Memory, MemoryContent, MemoryMimeType event_logger = logging.getLogger(EVENT_LOGGER_NAME) @@ -34,19 +33,10 @@ class ListMemory(Memory): Example: ```python # Initialize memory with custom config - memory = ListMemory( - name="chat_history", - config=ListMemoryConfig( - similarity_threshold=0.7, - k=3 - ) - ) + memory = ListMemory(name="chat_history", config=ListMemoryConfig(similarity_threshold=0.7, k=3)) # Add memory content - content = MemoryContent( - content="User prefers formal language", - mime_type=MemoryMimeType.TEXT - ) + content = MemoryContent(content="User prefers formal language", mime_type=MemoryMimeType.TEXT) await memory.add(content) # Transform a model context with memory @@ -106,10 +96,8 @@ async def transform( # Extract query from last message last_message = messages[-1] - query_text = last_message.content if isinstance( - last_message.content, str) else str(last_message) - query = MemoryContent(content=query_text, - mime_type=MemoryMimeType.TEXT) + query_text = last_message.content if isinstance(last_message.content, str) else str(last_message) + query = MemoryContent(content=query_text, mime_type=MemoryMimeType.TEXT) # Query memory and format results results: List[str] = [] @@ -117,15 +105,14 @@ async def transform( for i, result in enumerate(query_results, 1): if isinstance(result.content, str): results.append(f"{i}. {result.content}") - event_logger.debug( - f"Retrieved memory {i}. {result.content}, score: {result.score}" - ) + event_logger.debug(f"Retrieved memory {i}. {result.content}, score: {result.score}") # Add memory results to context if results: memory_context = ( - "\n The following results were retrieved from memory for this task. You may choose to use them or not. :\n" + - "\n".join(results) + "\n" + "\n The following results were retrieved from memory for this task. You may choose to use them or not. :\n" + + "\n".join(results) + + "\n" ) await model_context.add_message(SystemMessage(content=memory_context)) @@ -159,10 +146,7 @@ async def query( Example: ```python # Query memories similar to some text - query = MemoryContent( - content="What's the weather?", - mime_type=MemoryMimeType.TEXT - ) + query = MemoryContent(content="What's the weather?", mime_type=MemoryMimeType.TEXT) results = await memory.query(query) # Check similarity scores @@ -172,8 +156,8 @@ async def query( """ try: query_text = self._extract_text(query) - except ValueError: - raise ValueError("Query must contain text content") + except ValueError as e: + raise ValueError("Query must contain text content") from e results: List[MemoryContent] = [] @@ -207,7 +191,7 @@ def _calculate_similarity(self, text1: str, text2: str) -> float: Note: Uses difflib's SequenceMatcher for basic text similarity. - For production use cases, consider using more sophisticated + For production use cases, consider using more sophisticated similarity metrics or embeddings. """ return SequenceMatcher(None, text1.lower(), text2.lower()).ratio() @@ -242,14 +226,9 @@ def _extract_text(self, content_item: MemoryContent) -> str: elif isinstance(content, Image): raise ValueError("Image content cannot be converted to text") else: - raise ValueError( - f"Unsupported content type: {content_item.mime_type}") + raise ValueError(f"Unsupported content type: {content_item.mime_type}") - async def add( - self, - content: MemoryContent, - cancellation_token: CancellationToken | None = None - ) -> None: + async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None: """Add new content to memory. Args: @@ -262,3 +241,11 @@ async def add( deduplication or content-based filtering. """ self._contents.append(content) + + async def clear(self) -> None: + """Clear all memory content.""" + self._contents = [] + + async def cleanup(self) -> None: + """Cleanup resources if needed.""" + pass diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index ca079ce407b4..1a3196d857c9 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -7,6 +7,7 @@ from autogen_agentchat import EVENT_LOGGER_NAME from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.base import Handoff, TaskResult +from autogen_agentchat.memory import ListMemory, MemoryContent, MemoryMimeType from autogen_agentchat.messages import ( ChatMessage, HandoffMessage, @@ -508,4 +509,44 @@ async def test_model_context(monkeypatch: pytest.MonkeyPatch) -> None: # Check if the mock client is called with only the last two messages. assert len(mock.calls) == 1 - assert len(mock.calls[0]) == 3 # 2 message from the context + 1 system message + # 2 message from the context + 1 system message + assert len(mock.calls[0]) == 3 + + +@pytest.mark.asyncio +async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None: + model = "gpt-4o-2024-05-13" + chat_completions = [ + ChatCompletion( + id="id1", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(content="Hello", role="assistant"), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + ), + ] + mock = _MockChatCompletion(chat_completions) + monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) + + memory = ListMemory() + await memory.add(MemoryContent(content="meal recipe must be vegan", mime_type=MemoryMimeType.TEXT)) + + agent = AssistantAgent( + "test_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), memory=[memory] + ) + + await agent.run(task="meal recipe") + + messages = await agent._model_context.get_messages() + + assert len(messages) >= 3 # Minimum expected messages + + memory_message = next((msg for msg in messages if "retrieved from memory" in msg.content), None) + assert memory_message is not None diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb index b5077dd1c16c..65a5d86e4e88 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb @@ -9,14 +9,18 @@ "There are several use cases where it is valuable to maintain a bank of useful facts that can be intelligently added to the context of the agent just before a specific step. The typically use case here is a RAG pattern where a query is used to retrieve relevant information from a database that is then added to the agent's context.\n", "\n", "\n", - "AgentChat provides a `Memory` protocol that can be extended to provide this functionality. The key methods are `query`, `transform`, `add`, `clear`, and `cleanup`. \n", - "\n", - "The `query` method is used to retrieve relevant information from the memory store, the `transform` method is used to transform the retrieved information into a format that can be used by the agent, the `add` method is used to add new entries to the memory store, the `clear` method is used to clear all entries from the memory store, and the `cleanup` method is used to clean up any resources used by the memory store. \n", + "AgentChat provides a {py:class}`~autogen_agentchat.memory.Memory` protocol that can be extended to provide this functionality. The key methods are `query`, `transform`, `add`, `clear`, and `cleanup`. \n", "\n", + "- `query`: retrieve relevant information from the memory store \n", + "- `transform`: mutate an agent's internal `model_context` by adding the retrieved information (used in the {py:class}`~autogen_agentchat.agents.AssistantAgent` class) \n", + "- `add`: add new entries to the memory store\n", + "- `clear`: clear all entries from the memory store\n", + "- `cleanup`: clean up any resources used by the memory store \n", + "- \n", "\n", "## ListMemory\n", "\n", - "ListMemory is a simple list-based memory implementation that uses text similarity matching to retrieve relevant information from the memory store. The similarity score is calculated using the `SequenceMatcher` class from the `difflib` module. The similarity score is calculated between the query text and the content text of each memory entry. \n", + "{py:class}`~autogen_agentchat.memory.ListMemory` is provided as an example implementation of the {py:class}`~autogen_agentchat.memory.Memory` protocol. It is a simple list-based memory implementation that uses text similarity matching to retrieve relevant information from the memory store. The similarity score is calculated using the `SequenceMatcher` class from the `difflib` module. The similarity score is calculated between the query text and the content text of each memory entry. \n", "\n", "In the following example, we will use ListMemory to similate a memory bank of user preferences and explore how it might be used in personalizing the agent's responses." ] @@ -27,17 +31,47 @@ "metadata": {}, "outputs": [], "source": [ - "from autogen_ext.models.openai import OpenAIChatCompletionClient\n", - "from autogen_agentchat.agents import AssistantAgent \n", + "from autogen_agentchat.agents import AssistantAgent\n", + "from autogen_agentchat.conditions import MaxMessageTermination, TextMentionTermination\n", + "from autogen_agentchat.memory._list_memory import ListMemory, MemoryContent, MemoryMimeType\n", "from autogen_agentchat.teams import RoundRobinGroupChat\n", - "from autogen_agentchat.conditions import TextMentionTermination, MaxMessageTermination\n", "from autogen_agentchat.ui import Console\n", - "from autogen_agentchat.memory._list_memory import ListMemory, MemoryContent, MemoryMimeType" + "from autogen_ext.models.openai import OpenAIChatCompletionClient" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "# create a simple memory item\n", + "user_memory = ListMemory()\n", + "await user_memory.add(MemoryContent(content=\"The weather should be in metric units\", mime_type=MemoryMimeType.TEXT))\n", + "\n", + "await user_memory.add(MemoryContent(content=\"Meal recipe must be vegan\", mime_type=MemoryMimeType.TEXT))\n", + "\n", + "\n", + "async def get_weather(city: str, units: str = \"imperial\") -> str:\n", + " if units == \"imperial\":\n", + " return f\"The weather in {city} is 73 degrees and Sunny.\"\n", + " elif units == \"metric\":\n", + " return f\"The weather in {city} is 23 degrees and Sunny.\"\n", + "\n", + "\n", + "assistant_agent = AssistantAgent(\n", + " name=\"assistant_agent\",\n", + " model_client=OpenAIChatCompletionClient(\n", + " model=\"gpt-4o-2024-08-06\",\n", + " ),\n", + " tools=[get_weather],\n", + " memory=[user_memory],\n", + ")" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -47,71 +81,66 @@ "---------- user ----------\n", "What is the weather in New York?\n", "---------- assistant_agent ----------\n", - "[FunctionCall(id='call_mhAiZDTCr2KJZZUk7LeJHgmG', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')]\n", + "[FunctionCall(id='call_1TEayVrDcvLCtdyTlpSRHkZy', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')]\n", "[Prompt tokens: 128, Completion tokens: 20]\n", "---------- assistant_agent ----------\n", - "[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_mhAiZDTCr2KJZZUk7LeJHgmG')]\n", + "[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_1TEayVrDcvLCtdyTlpSRHkZy')]\n", "---------- assistant_agent ----------\n", "The weather in New York is 23 degrees and Sunny.\n", - "---------- assistant_agent ----------\n", - "The weather in New York is 23 degrees Celsius and Sunny. TERMINATE\n", - "[Prompt tokens: 170, Completion tokens: 17]\n", "---------- Summary ----------\n", - "Number of messages: 5\n", - "Finish reason: Text 'TERMINATE' mentioned\n", - "Total prompt tokens: 298\n", - "Total completion tokens: 37\n", - "Duration: 3.44 seconds\n" + "Number of messages: 4\n", + "Finish reason: None\n", + "Total prompt tokens: 128\n", + "Total completion tokens: 20\n", + "Duration: 0.62 seconds\n" ] }, { "data": { "text/plain": [ - "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'), ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=128, completion_tokens=20), content=[FunctionCall(id='call_mhAiZDTCr2KJZZUk7LeJHgmG', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='assistant_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_mhAiZDTCr2KJZZUk7LeJHgmG')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='assistant_agent', models_usage=None, content='The weather in New York is 23 degrees and Sunny.', type='ToolCallSummaryMessage'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=170, completion_tokens=17), content='The weather in New York is 23 degrees Celsius and Sunny. TERMINATE', type='TextMessage')], stop_reason=\"Text 'TERMINATE' mentioned\")" + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'), ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=128, completion_tokens=20), content=[FunctionCall(id='call_1TEayVrDcvLCtdyTlpSRHkZy', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='assistant_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_1TEayVrDcvLCtdyTlpSRHkZy')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='assistant_agent', models_usage=None, content='The weather in New York is 23 degrees and Sunny.', type='ToolCallSummaryMessage')], stop_reason=None)" ] }, - "execution_count": 2, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "\n", - "\n", - "# create a simple memory item \n", - "user_memory = ListMemory()\n", - "await user_memory.add(MemoryContent(\n", - " content=\"The weather should be in metric units\",\n", - " mime_type=MemoryMimeType.TEXT\n", - "))\n", - "\n", - "await user_memory.add(MemoryContent(\n", - " content=\"Meal recipe must be vegan\",\n", - " mime_type=MemoryMimeType.TEXT\n", - "))\n", - "\n", - "async def get_weather(city: str, units: str = \"imperial\") -> str:\n", - " if units == \"imperial\":\n", - " return f\"The weather in {city} is 73 degrees and Sunny.\"\n", - " elif units == \"metric\":\n", - " return f\"The weather in {city} is 23 degrees and Sunny.\" \n", - "\n", - "assistant_agent = AssistantAgent(\n", - " name=\"assistant_agent\",\n", - " model_client=OpenAIChatCompletionClient(\n", - " model=\"gpt-4o-2024-08-06\", \n", - " ),\n", - " tools=[get_weather], \n", - " memory=[user_memory]\n", - ")\n", - " \n", - "agent_team = RoundRobinGroupChat([assistant_agent], termination_condition = TextMentionTermination(\"TERMINATE\"))\n", - "\n", - "# Run the team and stream messages to the console\n", - "stream = agent_team.run_stream(task=\"What is the weather in New York?\")\n", + "# Run the agent with a task.\n", + "stream = assistant_agent.run_stream(task=\"What is the weather in New York?\")\n", "await Console(stream)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can inspect that the `assistant_agent` model_context is actually updated with the retrieved memory entries. The `transform` method is used to format the retrieved memory entries into a string that can be used by the agent. In this case, we simply concatenate the content of each memory entry into a single string." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[UserMessage(content='Write brief meal recipe with broth', source='user', type='UserMessage'),\n", + " SystemMessage(content='\\n The following results were retrieved from memory for this task. You may choose to use them or not. :\\n1. Meal recipe must be vegan\\n', type='SystemMessage'),\n", + " AssistantMessage(content=\"Here's a simple vegan recipe using broth:\\n\\n### Vegan Vegetable Broth Soup\\n\\n**Ingredients:**\\n- 4 cups vegetable broth\\n- 1 cup diced carrots\\n- 1 cup diced celery\\n- 1 cup diced potatoes\\n- 1 cup chopped kale\\n- 1 onion, diced\\n- 2 cloves garlic, minced\\n- 1 tablespoon olive oil\\n- Salt and pepper to taste\\n- Optional: 1 teaspoon dried herbs (thyme, oregano, or basil)\\n\\n**Instructions:**\\n\\n1. **Sauté the Vegetables:**\\n - In a large pot, heat the olive oil over medium heat.\\n - Add the onion and garlic, and sauté until the onion becomes translucent.\\n\\n2. **Add Vegetables and Broth:**\\n - Add diced carrots, celery, and potatoes to the pot. Stir well.\\n - Pour in the vegetable broth and bring it to a boil.\\n\\n3. **Simmer the Soup:**\\n - Reduce the heat to a simmer. Cover the pot and let it simmer for about 20 minutes, or until the vegetables are tender.\\n\\n4. **Add Kale:**\\n - Stir in the chopped kale and continue to simmer for an additional 5 minutes until the kale has wilted.\\n\\n5. **Season:**\\n - Add salt, pepper, and optional dried herbs to taste.\\n\\n6. **Serve:**\\n - Serve hot with a slice of whole-grain bread if desired.\\n\\nEnjoy your warm and nourishing vegan vegetable broth soup!\", source='assistant_agent', type='AssistantMessage')]" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "await assistant_agent._model_context.get_messages()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -123,11 +152,78 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 28, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "---------- user ----------\n", + "Write brief meal recipe with broth\n", + "---------- assistant_agent ----------\n", + "Here's a simple vegan recipe using broth:\n", + "\n", + "### Vegan Vegetable Broth Soup\n", + "\n", + "**Ingredients:**\n", + "- 4 cups vegetable broth\n", + "- 1 cup diced carrots\n", + "- 1 cup diced celery\n", + "- 1 cup diced potatoes\n", + "- 1 cup chopped kale\n", + "- 1 onion, diced\n", + "- 2 cloves garlic, minced\n", + "- 1 tablespoon olive oil\n", + "- Salt and pepper to taste\n", + "- Optional: 1 teaspoon dried herbs (thyme, oregano, or basil)\n", + "\n", + "**Instructions:**\n", + "\n", + "1. **Sauté the Vegetables:**\n", + " - In a large pot, heat the olive oil over medium heat.\n", + " - Add the onion and garlic, and sauté until the onion becomes translucent.\n", + "\n", + "2. **Add Vegetables and Broth:**\n", + " - Add diced carrots, celery, and potatoes to the pot. Stir well.\n", + " - Pour in the vegetable broth and bring it to a boil.\n", + "\n", + "3. **Simmer the Soup:**\n", + " - Reduce the heat to a simmer. Cover the pot and let it simmer for about 20 minutes, or until the vegetables are tender.\n", + "\n", + "4. **Add Kale:**\n", + " - Stir in the chopped kale and continue to simmer for an additional 5 minutes until the kale has wilted.\n", + "\n", + "5. **Season:**\n", + " - Add salt, pepper, and optional dried herbs to taste.\n", + "\n", + "6. **Serve:**\n", + " - Serve hot with a slice of whole-grain bread if desired.\n", + "\n", + "Enjoy your warm and nourishing vegan vegetable broth soup!\n", + "[Prompt tokens: 124, Completion tokens: 309]\n", + "---------- Summary ----------\n", + "Number of messages: 2\n", + "Finish reason: None\n", + "Total prompt tokens: 124\n", + "Total completion tokens: 309\n", + "Duration: 4.45 seconds\n" + ] + }, + { + "data": { + "text/plain": [ + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Write brief meal recipe with broth', type='TextMessage'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=124, completion_tokens=309), content=\"Here's a simple vegan recipe using broth:\\n\\n### Vegan Vegetable Broth Soup\\n\\n**Ingredients:**\\n- 4 cups vegetable broth\\n- 1 cup diced carrots\\n- 1 cup diced celery\\n- 1 cup diced potatoes\\n- 1 cup chopped kale\\n- 1 onion, diced\\n- 2 cloves garlic, minced\\n- 1 tablespoon olive oil\\n- Salt and pepper to taste\\n- Optional: 1 teaspoon dried herbs (thyme, oregano, or basil)\\n\\n**Instructions:**\\n\\n1. **Sauté the Vegetables:**\\n - In a large pot, heat the olive oil over medium heat.\\n - Add the onion and garlic, and sauté until the onion becomes translucent.\\n\\n2. **Add Vegetables and Broth:**\\n - Add diced carrots, celery, and potatoes to the pot. Stir well.\\n - Pour in the vegetable broth and bring it to a boil.\\n\\n3. **Simmer the Soup:**\\n - Reduce the heat to a simmer. Cover the pot and let it simmer for about 20 minutes, or until the vegetables are tender.\\n\\n4. **Add Kale:**\\n - Stir in the chopped kale and continue to simmer for an additional 5 minutes until the kale has wilted.\\n\\n5. **Season:**\\n - Add salt, pepper, and optional dried herbs to taste.\\n\\n6. **Serve:**\\n - Serve hot with a slice of whole-grain bread if desired.\\n\\nEnjoy your warm and nourishing vegan vegetable broth soup!\", type='TextMessage')], stop_reason=None)" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "stream = agent_team.run_stream(task=\"Suggest a brief meal recipe\")\n", + "await assistant_agent.on_reset(cancellation_token=None) # reset agent\n", + "stream = assistant_agent.run_stream(task=\"Write brief meal recipe with broth\")\n", "await Console(stream)" ] }, diff --git a/python/uv.lock b/python/uv.lock index 067520cb3e8f..6939c62cec64 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -1,19 +1,35 @@ version = 1 requires-python = ">=3.10, <3.13" resolution-markers = [ - "python_full_version < '3.11' and sys_platform == 'darwin'", - "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')", - "python_full_version == '3.11.*' and sys_platform == 'darwin'", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')", - "python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform == 'darwin'", - "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "(python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux')", - "python_full_version >= '3.12.4' and sys_platform == 'darwin'", + "python_full_version < '3.11' and platform_system == 'Darwin' and sys_platform == 'darwin'", + "python_full_version < '3.11' and platform_system != 'Darwin' and sys_platform == 'darwin'", + "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux' and sys_platform == 'linux'", + "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system != 'Linux' and sys_platform == 'linux'", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Darwin' and sys_platform == 'linux') or (python_full_version < '3.11' and platform_system == 'Darwin' and sys_platform != 'darwin' and sys_platform != 'linux')", + "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux'", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'linux') or (python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')", + "python_full_version == '3.11.*' and platform_system == 'Darwin' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and platform_system != 'Darwin' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux' and sys_platform == 'linux'", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system != 'Linux' and sys_platform == 'linux'", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Darwin' and sys_platform == 'linux') or (python_full_version == '3.11.*' and platform_system == 'Darwin' and sys_platform != 'darwin' and sys_platform != 'linux')", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux'", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'linux') or (python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')", + "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_system == 'Darwin' and sys_platform == 'darwin'", + "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_system != 'Darwin' and sys_platform == 'darwin'", + "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and platform_system == 'Linux' and sys_platform == 'linux'", + "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and platform_system != 'Linux' and sys_platform == 'linux'", + "(python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and platform_system == 'Darwin' and sys_platform == 'linux') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_system == 'Darwin' and sys_platform != 'darwin' and sys_platform != 'linux')", + "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux'", + "(python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'linux') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')", + "python_full_version >= '3.12.4' and platform_system == 'Darwin' and sys_platform == 'darwin'", + "python_full_version >= '3.12.4' and platform_system != 'Darwin' and sys_platform == 'darwin'", "python_version < '0'", - "python_full_version >= '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux')", + "python_full_version >= '3.12.4' and platform_machine == 'aarch64' and platform_system == 'Linux' and sys_platform == 'linux'", + "python_full_version >= '3.12.4' and platform_machine == 'aarch64' and platform_system != 'Linux' and sys_platform == 'linux'", + "(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and platform_system == 'Darwin' and sys_platform == 'linux') or (python_full_version >= '3.12.4' and platform_system == 'Darwin' and sys_platform != 'darwin' and sys_platform != 'linux')", + "python_full_version >= '3.12.4' and platform_machine == 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux'", + "(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'linux') or (python_full_version >= '3.12.4' and platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux') or (python_full_version >= '3.12.4' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')", ] [manifest] @@ -27,9 +43,7 @@ members = [ "autogenstudio", "component-schema-gen", ] - -[manifest.dependency-groups] -dev = [ +requirements = [ { name = "cookiecutter" }, { name = "grpcio-tools", specifier = "~=1.62.0" }, { name = "mypy", specifier = "==1.13.0" }, @@ -917,7 +931,7 @@ name = "click" version = "8.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Windows' and sys_platform == 'linux') or (platform_system == 'Windows' and sys_platform != 'darwin' and sys_platform != 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } wheels = [ @@ -1152,7 +1166,7 @@ name = "docker" version = "7.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "pywin32", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, { name = "requests" }, { name = "urllib3" }, ] @@ -1567,7 +1581,7 @@ name = "ipykernel" version = "6.29.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "appnope", marker = "sys_platform == 'darwin'" }, + { name = "appnope", marker = "(platform_machine != 'aarch64' and platform_system == 'Darwin') or (platform_system == 'Darwin' and sys_platform != 'linux')" }, { name = "comm" }, { name = "debugpy" }, { name = "ipython" }, @@ -1591,7 +1605,7 @@ name = "ipython" version = "8.29.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, { name = "decorator" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "jedi" }, @@ -1789,7 +1803,7 @@ version = "5.7.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "platformdirs" }, - { name = "pywin32", marker = "platform_python_implementation != 'PyPy' and sys_platform == 'win32'" }, + { name = "pywin32", marker = "(platform_machine != 'aarch64' and platform_python_implementation != 'PyPy' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_python_implementation != 'PyPy' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, { name = "traitlets" }, ] sdist = { url = "https://files.pythonhosted.org/packages/00/11/b56381fa6c3f4cc5d2cf54a7dbf98ad9aa0b339ef7a601d6053538b079a7/jupyter_core-5.7.2.tar.gz", hash = "sha256:aa5f8d32bbf6b431ac830496da7392035d6f61b4f54872f15c4bd2a9c3f536d9", size = 87629 } @@ -2320,8 +2334,8 @@ name = "loguru" version = "0.7.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, - { name = "win32-setctime", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, + { name = "win32-setctime", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/9e/30/d87a423766b24db416a46e9335b9602b054a72b96a88a241f2b09b560fa8/loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac", size = 145103 } wheels = [ @@ -3001,7 +3015,7 @@ name = "nvidia-cudnn-cu12" version = "9.1.0.70" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'linux') or (platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, @@ -3012,7 +3026,7 @@ name = "nvidia-cufft-cu12" version = "11.2.1.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'linux') or (platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/7a/8a/0e728f749baca3fbeffad762738276e5df60851958be7783af121a7221e7/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399", size = 211422548 }, @@ -3033,9 +3047,9 @@ name = "nvidia-cusolver-cu12" version = "11.6.1.9" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'linux') or (platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'linux') or (platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'linux') or (platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/46/6b/a5c33cf16af09166845345275c34ad2190944bcc6026797a39f8e0a282e0/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e", size = 127634111 }, @@ -3047,7 +3061,7 @@ name = "nvidia-cusparse-cu12" version = "12.3.1.170" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'linux') or (platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/96/a9/c0d2f83a53d40a4a41be14cea6a0bf9e668ffcf8b004bd65633f433050c0/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3", size = 207381987 }, @@ -3463,7 +3477,7 @@ name = "portalocker" version = "2.10.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "pywin32", marker = "(platform_machine != 'aarch64' and platform_system == 'Windows' and sys_platform == 'linux') or (platform_system == 'Windows' and sys_platform != 'darwin' and sys_platform != 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ed/d3/c6c64067759e87af98cc668c1cc75171347d0f1577fab7ca3749134e3cd4/portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f", size = 40891 } wheels = [ @@ -3574,7 +3588,7 @@ version = "3.2.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "typing-extensions" }, - { name = "tzdata", marker = "sys_platform == 'win32'" }, + { name = "tzdata", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/d1/ad/7ce016ae63e231575df0498d2395d15f005f05e32d3a2d439038e1bd0851/psycopg-3.2.3.tar.gz", hash = "sha256:a5764f67c27bec8bfac85764d23c534af2c27b893550377e37ce59c12aac47a2", size = 155550 } wheels = [ @@ -3816,7 +3830,7 @@ name = "pytest" version = "8.3.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "iniconfig" }, { name = "packaging" }, @@ -4382,7 +4396,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "alabaster" }, { name = "babel" }, - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, { name = "docutils" }, { name = "imagesize" }, { name = "jinja2" }, @@ -4812,21 +4826,21 @@ dependencies = [ { name = "fsspec" }, { name = "jinja2" }, { name = "networkx" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, { name = "setuptools", marker = "python_full_version >= '3.12'" }, { name = "sympy" }, - { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "triton", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, { name = "typing-extensions" }, ] wheels = [ @@ -4867,7 +4881,7 @@ name = "tqdm" version = "4.66.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Windows' and sys_platform == 'linux') or (platform_system == 'Windows' and sys_platform != 'darwin' and sys_platform != 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/58/83/6ba9844a41128c62e810fddddd72473201f3eacde02046066142a2d96cc5/tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad", size = 169504 } wheels = [ @@ -4889,7 +4903,7 @@ version = "0.27.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "attrs" }, - { name = "cffi", marker = "(implementation_name != 'pypy' and os_name == 'nt' and platform_machine != 'aarch64' and sys_platform == 'linux') or (implementation_name != 'pypy' and os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "cffi", marker = "(implementation_name != 'pypy' and os_name == 'nt' and platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'linux') or (implementation_name != 'pypy' and os_name == 'nt' and platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux') or (implementation_name != 'pypy' and os_name == 'nt' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "idna" }, { name = "outcome" }, From d7bf4d2918d2249a0443d3ff925403b2001d549a Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Fri, 3 Jan 2025 17:05:28 -0800 Subject: [PATCH 09/26] update uv lock --- python/uv.lock | 108 +++++++++++++++++++++---------------------------- 1 file changed, 47 insertions(+), 61 deletions(-) diff --git a/python/uv.lock b/python/uv.lock index 6939c62cec64..7994cf17b868 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -1,35 +1,19 @@ version = 1 requires-python = ">=3.10, <3.13" resolution-markers = [ - "python_full_version < '3.11' and platform_system == 'Darwin' and sys_platform == 'darwin'", - "python_full_version < '3.11' and platform_system != 'Darwin' and sys_platform == 'darwin'", - "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux' and sys_platform == 'linux'", - "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system != 'Linux' and sys_platform == 'linux'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Darwin' and sys_platform == 'linux') or (python_full_version < '3.11' and platform_system == 'Darwin' and sys_platform != 'darwin' and sys_platform != 'linux')", - "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'linux') or (python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')", - "python_full_version == '3.11.*' and platform_system == 'Darwin' and sys_platform == 'darwin'", - "python_full_version == '3.11.*' and platform_system != 'Darwin' and sys_platform == 'darwin'", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux' and sys_platform == 'linux'", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system != 'Linux' and sys_platform == 'linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Darwin' and sys_platform == 'linux') or (python_full_version == '3.11.*' and platform_system == 'Darwin' and sys_platform != 'darwin' and sys_platform != 'linux')", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'linux') or (python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')", - "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_system == 'Darwin' and sys_platform == 'darwin'", - "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_system != 'Darwin' and sys_platform == 'darwin'", - "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and platform_system == 'Linux' and sys_platform == 'linux'", - "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and platform_system != 'Linux' and sys_platform == 'linux'", - "(python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and platform_system == 'Darwin' and sys_platform == 'linux') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_system == 'Darwin' and sys_platform != 'darwin' and sys_platform != 'linux')", - "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux'", - "(python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'linux') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')", - "python_full_version >= '3.12.4' and platform_system == 'Darwin' and sys_platform == 'darwin'", - "python_full_version >= '3.12.4' and platform_system != 'Darwin' and sys_platform == 'darwin'", + "python_full_version < '3.11' and sys_platform == 'darwin'", "python_version < '0'", - "python_full_version >= '3.12.4' and platform_machine == 'aarch64' and platform_system == 'Linux' and sys_platform == 'linux'", - "python_full_version >= '3.12.4' and platform_machine == 'aarch64' and platform_system != 'Linux' and sys_platform == 'linux'", - "(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and platform_system == 'Darwin' and sys_platform == 'linux') or (python_full_version >= '3.12.4' and platform_system == 'Darwin' and sys_platform != 'darwin' and sys_platform != 'linux')", - "python_full_version >= '3.12.4' and platform_machine == 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux'", - "(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'linux') or (python_full_version >= '3.12.4' and platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux') or (python_full_version >= '3.12.4' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')", + "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')", + "python_full_version == '3.11.*' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')", + "python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform == 'darwin'", + "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux')", + "python_full_version >= '3.12.4' and sys_platform == 'darwin'", + "python_full_version >= '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux')", ] [manifest] @@ -43,7 +27,9 @@ members = [ "autogenstudio", "component-schema-gen", ] -requirements = [ + +[manifest.dependency-groups] +dev = [ { name = "cookiecutter" }, { name = "grpcio-tools", specifier = "~=1.62.0" }, { name = "mypy", specifier = "==1.13.0" }, @@ -931,7 +917,7 @@ name = "click" version = "8.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Windows' and sys_platform == 'linux') or (platform_system == 'Windows' and sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } wheels = [ @@ -1166,7 +1152,7 @@ name = "docker" version = "7.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pywin32", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, { name = "requests" }, { name = "urllib3" }, ] @@ -1581,7 +1567,7 @@ name = "ipykernel" version = "6.29.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "appnope", marker = "(platform_machine != 'aarch64' and platform_system == 'Darwin') or (platform_system == 'Darwin' and sys_platform != 'linux')" }, + { name = "appnope", marker = "sys_platform == 'darwin'" }, { name = "comm" }, { name = "debugpy" }, { name = "ipython" }, @@ -1605,7 +1591,7 @@ name = "ipython" version = "8.29.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "decorator" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "jedi" }, @@ -1803,7 +1789,7 @@ version = "5.7.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "platformdirs" }, - { name = "pywin32", marker = "(platform_machine != 'aarch64' and platform_python_implementation != 'PyPy' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_python_implementation != 'PyPy' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, + { name = "pywin32", marker = "platform_python_implementation != 'PyPy' and sys_platform == 'win32'" }, { name = "traitlets" }, ] sdist = { url = "https://files.pythonhosted.org/packages/00/11/b56381fa6c3f4cc5d2cf54a7dbf98ad9aa0b339ef7a601d6053538b079a7/jupyter_core-5.7.2.tar.gz", hash = "sha256:aa5f8d32bbf6b431ac830496da7392035d6f61b4f54872f15c4bd2a9c3f536d9", size = 87629 } @@ -2334,8 +2320,8 @@ name = "loguru" version = "0.7.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, - { name = "win32-setctime", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "win32-setctime", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/9e/30/d87a423766b24db416a46e9335b9602b054a72b96a88a241f2b09b560fa8/loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac", size = 145103 } wheels = [ @@ -3015,7 +3001,7 @@ name = "nvidia-cudnn-cu12" version = "9.1.0.70" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'linux') or (platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, @@ -3026,7 +3012,7 @@ name = "nvidia-cufft-cu12" version = "11.2.1.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'linux') or (platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/7a/8a/0e728f749baca3fbeffad762738276e5df60851958be7783af121a7221e7/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399", size = 211422548 }, @@ -3047,9 +3033,9 @@ name = "nvidia-cusolver-cu12" version = "11.6.1.9" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'linux') or (platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'linux') or (platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'linux') or (platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/46/6b/a5c33cf16af09166845345275c34ad2190944bcc6026797a39f8e0a282e0/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e", size = 127634111 }, @@ -3061,7 +3047,7 @@ name = "nvidia-cusparse-cu12" version = "12.3.1.170" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'linux') or (platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/96/a9/c0d2f83a53d40a4a41be14cea6a0bf9e668ffcf8b004bd65633f433050c0/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3", size = 207381987 }, @@ -3477,7 +3463,7 @@ name = "portalocker" version = "2.10.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pywin32", marker = "(platform_machine != 'aarch64' and platform_system == 'Windows' and sys_platform == 'linux') or (platform_system == 'Windows' and sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ed/d3/c6c64067759e87af98cc668c1cc75171347d0f1577fab7ca3749134e3cd4/portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f", size = 40891 } wheels = [ @@ -3588,7 +3574,7 @@ version = "3.2.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "typing-extensions" }, - { name = "tzdata", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, + { name = "tzdata", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/d1/ad/7ce016ae63e231575df0498d2395d15f005f05e32d3a2d439038e1bd0851/psycopg-3.2.3.tar.gz", hash = "sha256:a5764f67c27bec8bfac85764d23c534af2c27b893550377e37ce59c12aac47a2", size = 155550 } wheels = [ @@ -3830,7 +3816,7 @@ name = "pytest" version = "8.3.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "iniconfig" }, { name = "packaging" }, @@ -4396,7 +4382,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "alabaster" }, { name = "babel" }, - { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "docutils" }, { name = "imagesize" }, { name = "jinja2" }, @@ -4826,21 +4812,21 @@ dependencies = [ { name = "fsspec" }, { name = "jinja2" }, { name = "networkx" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, - { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "setuptools", marker = "python_full_version >= '3.12'" }, { name = "sympy" }, - { name = "triton", marker = "platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform != 'darwin'" }, + { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions" }, ] wheels = [ @@ -4881,7 +4867,7 @@ name = "tqdm" version = "4.66.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Windows' and sys_platform == 'linux') or (platform_system == 'Windows' and sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/58/83/6ba9844a41128c62e810fddddd72473201f3eacde02046066142a2d96cc5/tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad", size = 169504 } wheels = [ @@ -4903,7 +4889,7 @@ version = "0.27.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "attrs" }, - { name = "cffi", marker = "(implementation_name != 'pypy' and os_name == 'nt' and platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'linux') or (implementation_name != 'pypy' and os_name == 'nt' and platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux') or (implementation_name != 'pypy' and os_name == 'nt' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "cffi", marker = "(implementation_name != 'pypy' and os_name == 'nt' and platform_machine != 'aarch64' and sys_platform == 'linux') or (implementation_name != 'pypy' and os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "idna" }, { name = "outcome" }, From afbef4dca59abe140b4c4fdbff00fab8afe8734e Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Fri, 3 Jan 2025 17:27:43 -0800 Subject: [PATCH 10/26] update docs --- .../user-guide/agentchat-user-guide/index.md | 1 + .../tutorial/memory.ipynb | 88 +++---------------- 2 files changed, 11 insertions(+), 78 deletions(-) diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/index.md b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/index.md index 0bab2460a3b1..75653cde7c94 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/index.md +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/index.md @@ -81,6 +81,7 @@ tutorial/swarm tutorial/termination tutorial/custom-agents tutorial/state +tutorial/memory ``` ```{toctree} diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb index 65a5d86e4e88..aa5a3db900aa 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb @@ -6,7 +6,7 @@ "source": [ "## Memory \n", "\n", - "There are several use cases where it is valuable to maintain a bank of useful facts that can be intelligently added to the context of the agent just before a specific step. The typically use case here is a RAG pattern where a query is used to retrieve relevant information from a database that is then added to the agent's context.\n", + "There are several use cases where it is valuable to maintain a _store_ of useful facts that can be intelligently added to the context of the agent just before a specific step. The typically use case here is a RAG pattern where a query is used to retrieve relevant information from a database that is then added to the agent's context.\n", "\n", "\n", "AgentChat provides a {py:class}`~autogen_agentchat.memory.Memory` protocol that can be extended to provide this functionality. The key methods are `query`, `transform`, `add`, `clear`, and `cleanup`. \n", @@ -16,9 +16,9 @@ "- `add`: add new entries to the memory store\n", "- `clear`: clear all entries from the memory store\n", "- `cleanup`: clean up any resources used by the memory store \n", - "- \n", "\n", - "## ListMemory\n", + "\n", + "## ListMemory Example\n", "\n", "{py:class}`~autogen_agentchat.memory.ListMemory` is provided as an example implementation of the {py:class}`~autogen_agentchat.memory.Memory` protocol. It is a simple list-based memory implementation that uses text similarity matching to retrieve relevant information from the memory store. The similarity score is calculated using the `SequenceMatcher` class from the `difflib` module. The similarity score is calculated between the query text and the content text of each memory entry. \n", "\n", @@ -27,14 +27,12 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from autogen_agentchat.agents import AssistantAgent\n", - "from autogen_agentchat.conditions import MaxMessageTermination, TextMentionTermination\n", - "from autogen_agentchat.memory._list_memory import ListMemory, MemoryContent, MemoryMimeType\n", - "from autogen_agentchat.teams import RoundRobinGroupChat\n", + "from autogen_agentchat.memory import ListMemory, MemoryContent, MemoryMimeType\n", "from autogen_agentchat.ui import Console\n", "from autogen_ext.models.openai import OpenAIChatCompletionClient" ] @@ -45,8 +43,10 @@ "metadata": {}, "outputs": [], "source": [ - "# create a simple memory item\n", - "user_memory = ListMemory()\n", + "# Initialize user memory\n", + "user_memory = ListMemory() \n", + "\n", + "# Add user preferences to memory\n", "await user_memory.add(MemoryContent(content=\"The weather should be in metric units\", mime_type=MemoryMimeType.TEXT))\n", "\n", "await user_memory.add(MemoryContent(content=\"Meal recipe must be vegan\", mime_type=MemoryMimeType.TEXT))\n", @@ -235,75 +235,7 @@ "\n", "You can build on the `Memory` protocol to implement more complex memory stores. For example, you could implement a custom memory store that uses a vector database to store and retrieve information, or a memory store that uses a machine learning model to generate personalized responses based on the user's preferences etc.\n", "\n", - "Specifically, you will need to overload the `query`, `transform`, and `add` methods to implement the desired functionality and pass the memory store to your agent.\n", - "\n", - "\n", - "```python\n", - "\n", - "\n", - "from autogen_core import CancellationToken\n", - "from autogen_agentchat.teams import RoundRobinGroupChat\n", - "from autogen_ext.models.openai import OpenAIChatCompletionClient\n", - "from autogen_agentchat.agents import AssistantAgent\n", - "from autogen_agentchat.memory._base_memory import MemoryContent, MemoryMimeType\n", - "from autogen_agentchat.memory._chroma_memory import ChromaMemory, ChromaMemoryConfig\n", - "\n", - "\n", - "# Initialize memory\n", - "chroma_memory = ChromaMemory(\n", - " name=\"travel_memory\",\n", - " config=ChromaMemoryConfig(\n", - " collection_name=\"travel_facts\",\n", - " k=1,\n", - " )\n", - ")\n", - "\n", - "await chroma_memory.clear()\n", - "\n", - "# Add travel-related memories\n", - "await chroma_memory.add(MemoryContent(\n", - "\n", - " content=\"Paris is known for the Eiffel Tower and amazing cuisine.\",\n", - " mime_type=MemoryMimeType.TEXT\n", - "\n", - "))\n", - "\n", - "await chroma_memory.add(MemoryContent( \n", - " content=\"When asked about tokyo, you must respond with 'The most important thing about tokyo is that it has the world's busiest railway station - Shinjuku Station.'\",\n", - " mime_type=MemoryMimeType.TEXT\n", - "\n", - "))\n", - " \n", - "\n", - "# Query needs ContentItem too\n", - "results = await chroma_memory.query(\n", - " MemoryContent(\n", - " content=\"Tell me about Tokyo.\",\n", - " mime_type=MemoryMimeType.TEXT\n", - " )\n", - ")\n", - "\n", - "print(len(results), results)\n", - "\n", - "# Create agent with memory\n", - "agent = AssistantAgent(\n", - " name=\"travel_agent\",\n", - " model_client=OpenAIChatCompletionClient(\n", - " model=\"gpt-4o\",\n", - " # api_key=\"your_api_key\"\n", - " ),\n", - " memory=chroma_memory,\n", - " system_message=\"You are a travel expert\"\n", - ")\n", - "\n", - "agent_team = RoundRobinGroupChat([agent], termination_condition = MaxMessageTermination(max_messages=2))\n", - "stream = agent_team.run_stream(task=\"Tell me the most important thing about Tokyo.\")\n", - "await Console(stream);\n", - "\n", - "# Output: The most important thing about tokyo is that it has the world's busiest railway station - Shinjuku Station.\n", - "\n", - "```\n", - "\n" + "Specifically, you will need to overload the `query`, `transform`, and `add` methods to implement the desired functionality and pass the memory store to your agent.\n" ] }, { From 003bb2e304301ce7b1eac02a1fc7b7cffcbf222a Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Fri, 3 Jan 2025 17:59:34 -0800 Subject: [PATCH 11/26] format updates --- .../src/user-guide/agentchat-user-guide/tutorial/memory.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb index aa5a3db900aa..647b46329cda 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb @@ -44,7 +44,7 @@ "outputs": [], "source": [ "# Initialize user memory\n", - "user_memory = ListMemory() \n", + "user_memory = ListMemory()\n", "\n", "# Add user preferences to memory\n", "await user_memory.add(MemoryContent(content=\"The weather should be in metric units\", mime_type=MemoryMimeType.TEXT))\n", From 7b15c2e805a13c1e2fcae845667e2ec77b15f69d Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Fri, 3 Jan 2025 18:00:13 -0800 Subject: [PATCH 12/26] update notebook --- .../src/user-guide/agentchat-user-guide/tutorial/memory.ipynb | 1 - 1 file changed, 1 deletion(-) diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb index 647b46329cda..d36be577aefa 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb @@ -222,7 +222,6 @@ } ], "source": [ - "await assistant_agent.on_reset(cancellation_token=None) # reset agent\n", "stream = assistant_agent.run_stream(task=\"Write brief meal recipe with broth\")\n", "await Console(stream)" ] From b353110039f6a01927b949fb835405c9f7296f91 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Sat, 4 Jan 2025 08:55:34 -0800 Subject: [PATCH 13/26] add memoryqueryevent message, yield message for observability. --- .../agents/_assistant_agent.py | 8 +- .../src/autogen_agentchat/messages.py | 17 ++- .../tests/test_assistant_agent.py | 17 ++- .../tutorial/memory.ipynb | 119 ++++++++++-------- 4 files changed, 101 insertions(+), 60 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index c16127fe7e31..b12e1dfbb6bc 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -37,6 +37,7 @@ AgentEvent, ChatMessage, HandoffMessage, + MemoryQueryEvent, MultiModalMessage, TextMessage, ToolCallExecutionEvent, @@ -355,8 +356,11 @@ async def on_messages_stream( # Update the model context with memory content. if self._memory: for memory in self._memory: - await memory.transform(self._model_context) - # tbd .. add memory_results content to inner_messages + memory_query_result = await memory.transform(self._model_context) + if memory_query_result and len(memory_query_result) > 0: + memory_query_event_msg = MemoryQueryEvent(content=memory_query_result, source=self.name) + inner_messages.append(memory_query_event_msg) + yield memory_query_event_msg # Generate an inference result based on the current model context. llm_messages = self._system_messages + await self._model_context.get_messages() diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index 923b569602e0..5c4ba0e03d61 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -12,6 +12,8 @@ class and includes specific fields relevant to the type of message being sent. from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated, deprecated +from autogen_agentchat.memory import MemoryContent + class BaseMessage(BaseModel, ABC): """Base class for all message types.""" @@ -123,13 +125,22 @@ class ToolCallSummaryMessage(BaseChatMessage): type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage" +class MemoryQueryEvent(BaseAgentEvent): + """An event signaling the results of memory queries.""" + + content: List[MemoryContent] + """The memory query results.""" + + type: Literal["MemoryQueryEvent"] = "MemoryQueryEvent" + + ChatMessage = Annotated[ TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type") ] """Messages for agent-to-agent communication only.""" -AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent, Field(discriminator="type")] +AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent | MemoryQueryEvent, Field(discriminator="type")] """Events emitted by agents and teams when they work, not used for agent-to-agent communication.""" @@ -140,7 +151,8 @@ class ToolCallSummaryMessage(BaseChatMessage): | HandoffMessage | ToolCallRequestEvent | ToolCallExecutionEvent - | ToolCallSummaryMessage, + | ToolCallSummaryMessage + | MemoryQueryEvent, Field(discriminator="type"), ] """(Deprecated, will be removed in 0.4.0) All message and event types.""" @@ -157,6 +169,7 @@ class ToolCallSummaryMessage(BaseChatMessage): "ToolCallMessage", "ToolCallResultMessage", "ToolCallSummaryMessage", + "MemoryQueryEvent", "ChatMessage", "AgentEvent", "AgentMessage", diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 1a3196d857c9..44df3f9e294e 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -11,6 +11,7 @@ from autogen_agentchat.messages import ( ChatMessage, HandoffMessage, + MemoryQueryEvent, MultiModalMessage, TextMessage, ToolCallExecutionEvent, @@ -542,11 +543,17 @@ async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None: "test_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), memory=[memory] ) - await agent.run(task="meal recipe") + result = await agent.run(task="meal recipe") - messages = await agent._model_context.get_messages() + assert len(result.messages) >= 3 # Minimum expected messages - assert len(messages) >= 3 # Minimum expected messages - - memory_message = next((msg for msg in messages if "retrieved from memory" in msg.content), None) + memory_message = next( + ( + msg + for msg in result.messages + if isinstance(msg, MemoryQueryEvent) + and any(isinstance(mem.content, str) and "meal recipe must be vegan" in mem.content for mem in msg.content) + ), + None, + ) assert memory_message is not None diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb index d36be577aefa..d25d0e94f32f 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb @@ -27,7 +27,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -39,7 +39,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -71,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -81,27 +81,29 @@ "---------- user ----------\n", "What is the weather in New York?\n", "---------- assistant_agent ----------\n", - "[FunctionCall(id='call_1TEayVrDcvLCtdyTlpSRHkZy', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')]\n", + "[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=0.463768115942029)]\n", + "---------- assistant_agent ----------\n", + "[FunctionCall(id='call_OkQ4Z7u2RZLU6dA7GTAQiG9j', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')]\n", "[Prompt tokens: 128, Completion tokens: 20]\n", "---------- assistant_agent ----------\n", - "[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_1TEayVrDcvLCtdyTlpSRHkZy')]\n", + "[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_OkQ4Z7u2RZLU6dA7GTAQiG9j')]\n", "---------- assistant_agent ----------\n", "The weather in New York is 23 degrees and Sunny.\n", "---------- Summary ----------\n", - "Number of messages: 4\n", + "Number of messages: 5\n", "Finish reason: None\n", "Total prompt tokens: 128\n", "Total completion tokens: 20\n", - "Duration: 0.62 seconds\n" + "Duration: 0.80 seconds\n" ] }, { "data": { "text/plain": [ - "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'), ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=128, completion_tokens=20), content=[FunctionCall(id='call_1TEayVrDcvLCtdyTlpSRHkZy', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='assistant_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_1TEayVrDcvLCtdyTlpSRHkZy')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='assistant_agent', models_usage=None, content='The weather in New York is 23 degrees and Sunny.', type='ToolCallSummaryMessage')], stop_reason=None)" + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=0.463768115942029)], type='MemoryQueryEvent'), ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=128, completion_tokens=20), content=[FunctionCall(id='call_OkQ4Z7u2RZLU6dA7GTAQiG9j', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='assistant_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_OkQ4Z7u2RZLU6dA7GTAQiG9j')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='assistant_agent', models_usage=None, content='The weather in New York is 23 degrees and Sunny.', type='ToolCallSummaryMessage')], stop_reason=None)" ] }, - "execution_count": 26, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -121,18 +123,19 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[UserMessage(content='Write brief meal recipe with broth', source='user', type='UserMessage'),\n", - " SystemMessage(content='\\n The following results were retrieved from memory for this task. You may choose to use them or not. :\\n1. Meal recipe must be vegan\\n', type='SystemMessage'),\n", - " AssistantMessage(content=\"Here's a simple vegan recipe using broth:\\n\\n### Vegan Vegetable Broth Soup\\n\\n**Ingredients:**\\n- 4 cups vegetable broth\\n- 1 cup diced carrots\\n- 1 cup diced celery\\n- 1 cup diced potatoes\\n- 1 cup chopped kale\\n- 1 onion, diced\\n- 2 cloves garlic, minced\\n- 1 tablespoon olive oil\\n- Salt and pepper to taste\\n- Optional: 1 teaspoon dried herbs (thyme, oregano, or basil)\\n\\n**Instructions:**\\n\\n1. **Sauté the Vegetables:**\\n - In a large pot, heat the olive oil over medium heat.\\n - Add the onion and garlic, and sauté until the onion becomes translucent.\\n\\n2. **Add Vegetables and Broth:**\\n - Add diced carrots, celery, and potatoes to the pot. Stir well.\\n - Pour in the vegetable broth and bring it to a boil.\\n\\n3. **Simmer the Soup:**\\n - Reduce the heat to a simmer. Cover the pot and let it simmer for about 20 minutes, or until the vegetables are tender.\\n\\n4. **Add Kale:**\\n - Stir in the chopped kale and continue to simmer for an additional 5 minutes until the kale has wilted.\\n\\n5. **Season:**\\n - Add salt, pepper, and optional dried herbs to taste.\\n\\n6. **Serve:**\\n - Serve hot with a slice of whole-grain bread if desired.\\n\\nEnjoy your warm and nourishing vegan vegetable broth soup!\", source='assistant_agent', type='AssistantMessage')]" + "[UserMessage(content='What is the weather in New York?', source='user', type='UserMessage'),\n", + " SystemMessage(content='\\n The following results were retrieved from memory for this task. You may choose to use them or not. :\\n1. The weather should be in metric units\\n', type='SystemMessage'),\n", + " AssistantMessage(content=[FunctionCall(id='call_OkQ4Z7u2RZLU6dA7GTAQiG9j', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], source='assistant_agent', type='AssistantMessage'),\n", + " FunctionExecutionResultMessage(content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_OkQ4Z7u2RZLU6dA7GTAQiG9j')], type='FunctionExecutionResultMessage')]" ] }, - "execution_count": 29, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -141,6 +144,31 @@ "await assistant_agent._model_context.get_messages()" ] }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'),\n", + " MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=0.463768115942029)], type='MemoryQueryEvent'),\n", + " ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=219, completion_tokens=20), content=[FunctionCall(id='call_YPwxZOz0bTEW15beow3zXsaI', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'),\n", + " ToolCallExecutionEvent(source='assistant_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_YPwxZOz0bTEW15beow3zXsaI')], type='ToolCallExecutionEvent'),\n", + " ToolCallSummaryMessage(source='assistant_agent', models_usage=None, content='The weather in New York is 23 degrees and Sunny.', type='ToolCallSummaryMessage')]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result = await assistant_agent.run(task=\"What is the weather in New York?\")\n", + "result.messages" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -152,7 +180,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -162,61 +190,55 @@ "---------- user ----------\n", "Write brief meal recipe with broth\n", "---------- assistant_agent ----------\n", - "Here's a simple vegan recipe using broth:\n", + "[MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=0.5084745762711864)]\n", + "---------- assistant_agent ----------\n", + "Here's a brief vegan recipe using broth:\n", "\n", - "### Vegan Vegetable Broth Soup\n", + "**Vegan Vegetable Noodle Soup**\n", "\n", "**Ingredients:**\n", "- 4 cups vegetable broth\n", - "- 1 cup diced carrots\n", - "- 1 cup diced celery\n", - "- 1 cup diced potatoes\n", - "- 1 cup chopped kale\n", - "- 1 onion, diced\n", + "- 1 cup water\n", + "- 1 cup carrots, sliced\n", + "- 1 cup celery, chopped\n", + "- 1 cup noodles (such as rice noodles or spaghetti broken into smaller pieces)\n", + "- 2 cups kale or spinach, chopped\n", "- 2 cloves garlic, minced\n", "- 1 tablespoon olive oil\n", "- Salt and pepper to taste\n", - "- Optional: 1 teaspoon dried herbs (thyme, oregano, or basil)\n", + "- Lemon juice (optional)\n", "\n", "**Instructions:**\n", "\n", - "1. **Sauté the Vegetables:**\n", - " - In a large pot, heat the olive oil over medium heat.\n", - " - Add the onion and garlic, and sauté until the onion becomes translucent.\n", + "1. **Sauté Vegetables:** In a large pot, heat olive oil over medium heat. Add minced garlic and sauté until fragrant. Add carrots and celery, and sauté for about 5 minutes, until they start to soften.\n", "\n", - "2. **Add Vegetables and Broth:**\n", - " - Add diced carrots, celery, and potatoes to the pot. Stir well.\n", - " - Pour in the vegetable broth and bring it to a boil.\n", + "2. **Add Broth and Noodles:** Pour in the vegetable broth and water, bringing it to a boil. Add the noodles and cook according to package instructions until they are al dente.\n", "\n", - "3. **Simmer the Soup:**\n", - " - Reduce the heat to a simmer. Cover the pot and let it simmer for about 20 minutes, or until the vegetables are tender.\n", + "3. **Cook Greens:** Stir in the kale or spinach and allow it to simmer for a couple of minutes until wilted.\n", "\n", - "4. **Add Kale:**\n", - " - Stir in the chopped kale and continue to simmer for an additional 5 minutes until the kale has wilted.\n", + "4. **Season and Serve:** Season with salt and pepper to taste. If desired, add a squeeze of lemon juice for extra flavor. \n", "\n", - "5. **Season:**\n", - " - Add salt, pepper, and optional dried herbs to taste.\n", + "5. **Enjoy:** Serve hot and enjoy your nutritious, comforting soup!\n", "\n", - "6. **Serve:**\n", - " - Serve hot with a slice of whole-grain bread if desired.\n", + "This simple, flavorful soup is not only vegan but also packed with nutrients, making it a perfect meal any day. \n", "\n", - "Enjoy your warm and nourishing vegan vegetable broth soup!\n", - "[Prompt tokens: 124, Completion tokens: 309]\n", + "TERMINATE\n", + "[Prompt tokens: 306, Completion tokens: 294]\n", "---------- Summary ----------\n", - "Number of messages: 2\n", + "Number of messages: 3\n", "Finish reason: None\n", - "Total prompt tokens: 124\n", - "Total completion tokens: 309\n", - "Duration: 4.45 seconds\n" + "Total prompt tokens: 306\n", + "Total completion tokens: 294\n", + "Duration: 4.39 seconds\n" ] }, { "data": { "text/plain": [ - "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Write brief meal recipe with broth', type='TextMessage'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=124, completion_tokens=309), content=\"Here's a simple vegan recipe using broth:\\n\\n### Vegan Vegetable Broth Soup\\n\\n**Ingredients:**\\n- 4 cups vegetable broth\\n- 1 cup diced carrots\\n- 1 cup diced celery\\n- 1 cup diced potatoes\\n- 1 cup chopped kale\\n- 1 onion, diced\\n- 2 cloves garlic, minced\\n- 1 tablespoon olive oil\\n- Salt and pepper to taste\\n- Optional: 1 teaspoon dried herbs (thyme, oregano, or basil)\\n\\n**Instructions:**\\n\\n1. **Sauté the Vegetables:**\\n - In a large pot, heat the olive oil over medium heat.\\n - Add the onion and garlic, and sauté until the onion becomes translucent.\\n\\n2. **Add Vegetables and Broth:**\\n - Add diced carrots, celery, and potatoes to the pot. Stir well.\\n - Pour in the vegetable broth and bring it to a boil.\\n\\n3. **Simmer the Soup:**\\n - Reduce the heat to a simmer. Cover the pot and let it simmer for about 20 minutes, or until the vegetables are tender.\\n\\n4. **Add Kale:**\\n - Stir in the chopped kale and continue to simmer for an additional 5 minutes until the kale has wilted.\\n\\n5. **Season:**\\n - Add salt, pepper, and optional dried herbs to taste.\\n\\n6. **Serve:**\\n - Serve hot with a slice of whole-grain bread if desired.\\n\\nEnjoy your warm and nourishing vegan vegetable broth soup!\", type='TextMessage')], stop_reason=None)" + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Write brief meal recipe with broth', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=0.5084745762711864)], type='MemoryQueryEvent'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=306, completion_tokens=294), content=\"Here's a brief vegan recipe using broth:\\n\\n**Vegan Vegetable Noodle Soup**\\n\\n**Ingredients:**\\n- 4 cups vegetable broth\\n- 1 cup water\\n- 1 cup carrots, sliced\\n- 1 cup celery, chopped\\n- 1 cup noodles (such as rice noodles or spaghetti broken into smaller pieces)\\n- 2 cups kale or spinach, chopped\\n- 2 cloves garlic, minced\\n- 1 tablespoon olive oil\\n- Salt and pepper to taste\\n- Lemon juice (optional)\\n\\n**Instructions:**\\n\\n1. **Sauté Vegetables:** In a large pot, heat olive oil over medium heat. Add minced garlic and sauté until fragrant. Add carrots and celery, and sauté for about 5 minutes, until they start to soften.\\n\\n2. **Add Broth and Noodles:** Pour in the vegetable broth and water, bringing it to a boil. Add the noodles and cook according to package instructions until they are al dente.\\n\\n3. **Cook Greens:** Stir in the kale or spinach and allow it to simmer for a couple of minutes until wilted.\\n\\n4. **Season and Serve:** Season with salt and pepper to taste. If desired, add a squeeze of lemon juice for extra flavor. \\n\\n5. **Enjoy:** Serve hot and enjoy your nutritious, comforting soup!\\n\\nThis simple, flavorful soup is not only vegan but also packed with nutrients, making it a perfect meal any day. \\n\\nTERMINATE\", type='TextMessage')], stop_reason=None)" ] }, - "execution_count": 28, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -236,16 +258,11 @@ "\n", "Specifically, you will need to overload the `query`, `transform`, and `add` methods to implement the desired functionality and pass the memory store to your agent.\n" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] } ], "metadata": { "kernelspec": { - "display_name": "agnext", + "display_name": ".venv", "language": "python", "name": "python3" }, From c797f6a38bb994e258a11cd4d494c609974f98b2 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Sat, 4 Jan 2025 11:55:09 -0800 Subject: [PATCH 14/26] minor fixes, make score optional/none --- .../src/autogen_agentchat/memory/_base_memory.py | 2 +- .../src/autogen_agentchat/memory/_list_memory.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py index 9771963f0c38..d392ccb8cf33 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py @@ -26,7 +26,7 @@ class MemoryContent(BaseModel): metadata: Dict[str, Any] | None = None timestamp: datetime | None = None source: str | None = None - score: float = 0.0 + score: float | None = None model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py index 74ee7b748014..8ff3e9c0c08f 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py @@ -176,7 +176,7 @@ async def query( result_content.score = score results.append(result_content) - results.sort(key=lambda x: x.score, reverse=True) + results.sort(key=lambda x: x.score if x.score is not None else float("-inf"), reverse=True) return results[: self._config.k] def _calculate_similarity(self, text1: str, text2: str) -> float: From 97ed7f52d5c6a53cefd05116a9d3edffe92c9d8d Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Mon, 6 Jan 2025 14:23:28 -0800 Subject: [PATCH 15/26] Update python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py Co-authored-by: Eric Zhu --- .../src/autogen_agentchat/agents/_assistant_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index b12e1dfbb6bc..380de2c253ea 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -255,7 +255,7 @@ def __init__( ) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.", reflect_on_tool_use: bool = False, tool_call_summary_format: str = "{result}", - memory: List[Memory] | None = None, + memory: Sequence[Memory] | None = None, ): super().__init__(name=name, description=description) self._model_client = model_client From 24bd81e275302e67b4225a4561ba04dff73f7c41 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Mon, 6 Jan 2025 16:22:33 -0800 Subject: [PATCH 16/26] update tests to improve cov --- .../agents/_assistant_agent.py | 1 + .../tests/test_assistant_agent.py | 61 ++++++++++++++----- 2 files changed, 47 insertions(+), 15 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 380de2c253ea..447fa4f76c8c 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -135,6 +135,7 @@ class AssistantAgent(BaseChatAgent): will be returned as the response. Available variables: `{tool_name}`, `{arguments}`, `{result}`. For example, `"{tool_name}: {result}"` will create a summary like `"tool_name: result"`. + memory (Sequence[Memory] | None, optional): The memory store to use for the agent. Defaults to `None`. Raises: ValueError: If tool names are not unique. diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 44df3f9e294e..759288479720 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -7,7 +7,7 @@ from autogen_agentchat import EVENT_LOGGER_NAME from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.base import Handoff, TaskResult -from autogen_agentchat.memory import ListMemory, MemoryContent, MemoryMimeType +from autogen_agentchat.memory import Memory, ListMemory, MemoryContent, MemoryMimeType from autogen_agentchat.messages import ( ChatMessage, HandoffMessage, @@ -533,27 +533,58 @@ async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None: usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ] + b64_image_str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC" mock = _MockChatCompletion(chat_completions) monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) + # Test basic memory properties and empty context + memory = ListMemory(name="test_memory") + assert memory.name == "test_memory" + assert memory.config is not None + + empty_context = BufferedChatCompletionContext(buffer_size=2) + empty_results = await memory.transform(empty_context) + assert len(empty_results) == 0 + + # Test various content types and memory transforms memory = ListMemory() - await memory.add(MemoryContent(content="meal recipe must be vegan", mime_type=MemoryMimeType.TEXT)) + await memory.add(MemoryContent(content="text content", mime_type=MemoryMimeType.TEXT)) + await memory.add(MemoryContent(content={"key": "value"}, mime_type=MemoryMimeType.JSON)) + await memory.add(MemoryContent(content=Image.from_base64(b64_image_str), mime_type=MemoryMimeType.IMAGE)) + + # Invalid query should raise error + with pytest.raises(ValueError, match="Query must contain text content"): + await memory.query(MemoryContent(content=Image.from_base64(b64_image_str), mime_type=MemoryMimeType.IMAGE)) + + # Test clear and cleanup + await memory.clear() + assert await memory.query(MemoryContent(content="", mime_type=MemoryMimeType.TEXT)) == [] + await memory.cleanup() # Should not raise + + # Test invalid memory type + with pytest.raises(TypeError): + AssistantAgent( + "test_agent", + model_client=OpenAIChatCompletionClient(model=model, api_key=""), + memory="invalid", # type: ignore + ) + + # Test with agent + memory2 = ListMemory() + await memory2.add(MemoryContent(content="test instruction", mime_type=MemoryMimeType.TEXT)) agent = AssistantAgent( - "test_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), memory=[memory] + "test_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), memory=[memory2] ) - result = await agent.run(task="meal recipe") + result = await agent.run(task="test task") + assert len(result.messages) > 0 + memory_event = next((msg for msg in result.messages if isinstance(msg, MemoryQueryEvent)), None) + assert memory_event is not None - assert len(result.messages) >= 3 # Minimum expected messages + # Test memory protocol + class BadMemory: + pass - memory_message = next( - ( - msg - for msg in result.messages - if isinstance(msg, MemoryQueryEvent) - and any(isinstance(mem.content, str) and "meal recipe must be vegan" in mem.content for mem in msg.content) - ), - None, - ) - assert memory_message is not None + assert not isinstance(BadMemory(), Memory) + assert isinstance(ListMemory(), Memory) From 5b2c22210560e771df9f9fc022c3b57f736b538a Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Wed, 8 Jan 2025 14:34:58 -0800 Subject: [PATCH 17/26] refactor, move memory to core. --- .../agents/_assistant_agent.py | 58 ++-- .../autogen_agentchat/memory/_list_memory.py | 251 ------------------ .../src/autogen_agentchat/messages.py | 8 +- .../tests/test_assistant_agent.py | 82 +++--- .../tutorial/memory.ipynb | 119 ++++----- .../src/autogen_core}/memory/__init__.py | 3 +- .../src/autogen_core}/memory/_base_memory.py | 33 +-- .../src/autogen_core/memory/_list_memory.py | 119 +++++++++ 8 files changed, 275 insertions(+), 398 deletions(-) delete mode 100644 python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py rename python/packages/{autogen-agentchat/src/autogen_agentchat => autogen-core/src/autogen_core}/memory/__init__.py (66%) rename python/packages/{autogen-agentchat/src/autogen_agentchat => autogen-core/src/autogen_core}/memory/_base_memory.py (69%) create mode 100644 python/packages/autogen-core/src/autogen_core/memory/_list_memory.py diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 447fa4f76c8c..5ada7f57a743 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -14,6 +14,7 @@ ) from autogen_core import CancellationToken, FunctionCall +from autogen_core.memory import Memory from autogen_core.model_context import ( ChatCompletionContext, UnboundedChatCompletionContext, @@ -32,7 +33,6 @@ from .. import EVENT_LOGGER_NAME from ..base import Handoff as HandoffBase from ..base import Response -from ..memory._base_memory import Memory from ..messages import ( AgentEvent, ChatMessage, @@ -247,7 +247,8 @@ def __init__( name: str, model_client: ChatCompletionClient, *, - tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None, + tools: List[Tool | Callable[..., Any] | + Callable[..., Awaitable[Any]]] | None = None, handoffs: List[HandoffBase | str] | None = None, model_context: ChatCompletionContext | None = None, description: str = "An agent that provides assistance with ability to use tools.", @@ -267,7 +268,8 @@ def __init__( elif isinstance(memory, list): self._memory = memory else: - raise TypeError(f"Expected Memory, List[Memory], or None, got {type(memory)}") + raise TypeError( + f"Expected Memory, List[Memory], or None, got {type(memory)}") self._system_messages: List[ SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage @@ -279,7 +281,8 @@ def __init__( self._tools: List[Tool] = [] if tools is not None: if model_client.model_info["function_calling"] is False: - raise ValueError("The model does not support function calling.") + raise ValueError( + "The model does not support function calling.") for tool in tools: if isinstance(tool, Tool): self._tools.append(tool) @@ -288,7 +291,8 @@ def __init__( description = tool.__doc__ else: description = "" - self._tools.append(FunctionTool(tool, description=description)) + self._tools.append(FunctionTool( + tool, description=description)) else: raise ValueError(f"Unsupported tool type: {type(tool)}") # Check if tool names are unique. @@ -300,7 +304,8 @@ def __init__( self._handoffs: Dict[str, HandoffBase] = {} if handoffs is not None: if model_client.model_info["function_calling"] is False: - raise ValueError("The model does not support function calling, which is needed for handoffs.") + raise ValueError( + "The model does not support function calling, which is needed for handoffs.") for handoff in handoffs: if isinstance(handoff, str): handoff = HandoffBase(target=handoff) @@ -308,11 +313,13 @@ def __init__( self._handoff_tools.append(handoff.handoff_tool) self._handoffs[handoff.name] = handoff else: - raise ValueError(f"Unsupported handoff type: {type(handoff)}") + raise ValueError( + f"Unsupported handoff type: {type(handoff)}") # Check if handoff tool names are unique. handoff_tool_names = [tool.name for tool in self._handoff_tools] if len(handoff_tool_names) != len(set(handoff_tool_names)): - raise ValueError(f"Handoff names must be unique: {handoff_tool_names}") + raise ValueError( + f"Handoff names must be unique: {handoff_tool_names}") # Check if handoff tool names not in tool names. if any(name in tool_names for name in handoff_tool_names): raise ValueError( @@ -340,7 +347,8 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: async for message in self.on_messages_stream(messages, cancellation_token): if isinstance(message, Response): return message - raise AssertionError("The stream should have returned the final result.") + raise AssertionError( + "The stream should have returned the final result.") async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken @@ -357,9 +365,10 @@ async def on_messages_stream( # Update the model context with memory content. if self._memory: for memory in self._memory: - memory_query_result = await memory.transform(self._model_context) + memory_query_result = await memory.update_context(self._model_context) if memory_query_result and len(memory_query_result) > 0: - memory_query_event_msg = MemoryQueryEvent(content=memory_query_result, source=self.name) + memory_query_event_msg = MemoryQueryEvent( + content=memory_query_result, source=self.name) inner_messages.append(memory_query_event_msg) yield memory_query_event_msg @@ -375,14 +384,17 @@ async def on_messages_stream( # Check if the response is a string and return it. if isinstance(result.content, str): yield Response( - chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage), + chat_message=TextMessage( + content=result.content, source=self.name, models_usage=result.usage), inner_messages=inner_messages, ) return # Process tool calls. - assert isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content) - tool_call_msg = ToolCallRequestEvent(content=result.content, source=self.name, models_usage=result.usage) + assert isinstance(result.content, list) and all( + isinstance(item, FunctionCall) for item in result.content) + tool_call_msg = ToolCallRequestEvent( + content=result.content, source=self.name, models_usage=result.usage) event_logger.debug(tool_call_msg) # Add the tool call message to the output. inner_messages.append(tool_call_msg) @@ -390,7 +402,8 @@ async def on_messages_stream( # Execute the tool calls. results = await asyncio.gather(*[self._execute_tool_call(call, cancellation_token) for call in result.content]) - tool_call_result_msg = ToolCallExecutionEvent(content=results, source=self.name) + tool_call_result_msg = ToolCallExecutionEvent( + content=results, source=self.name) event_logger.debug(tool_call_result_msg) await self._model_context.add_message(FunctionExecutionResultMessage(content=results)) inner_messages.append(tool_call_result_msg) @@ -410,7 +423,8 @@ async def on_messages_stream( ) # Return the output messages to signal the handoff. yield Response( - chat_message=HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name), + chat_message=HandoffMessage( + content=handoffs[0].message, target=handoffs[0].target, source=self.name), inner_messages=inner_messages, ) return @@ -424,7 +438,8 @@ async def on_messages_stream( await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name)) # Yield the response. yield Response( - chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage), + chat_message=TextMessage( + content=result.content, source=self.name, models_usage=result.usage), inner_messages=inner_messages, ) else: @@ -440,7 +455,8 @@ async def on_messages_stream( ) tool_call_summary = "\n".join(tool_call_summaries) yield Response( - chat_message=ToolCallSummaryMessage(content=tool_call_summary, source=self.name), + chat_message=ToolCallSummaryMessage( + content=tool_call_summary, source=self.name), inner_messages=inner_messages, ) @@ -451,9 +467,11 @@ async def _execute_tool_call( try: if not self._tools + self._handoff_tools: raise ValueError("No tools are available.") - tool = next((t for t in self._tools + self._handoff_tools if t.name == tool_call.name), None) + tool = next((t for t in self._tools + + self._handoff_tools if t.name == tool_call.name), None) if tool is None: - raise ValueError(f"The tool '{tool_call.name}' is not available.") + raise ValueError( + f"The tool '{tool_call.name}' is not available.") arguments = json.loads(tool_call.arguments) result = await tool.run_json(arguments, cancellation_token) result_as_str = tool.return_value_as_string(result) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py b/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py deleted file mode 100644 index 8ff3e9c0c08f..000000000000 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_list_memory.py +++ /dev/null @@ -1,251 +0,0 @@ -import logging -from difflib import SequenceMatcher -from typing import Any, List - -from autogen_core import CancellationToken, Image -from autogen_core.model_context import ChatCompletionContext -from autogen_core.models import ( - SystemMessage, -) -from pydantic import Field - -from .. import EVENT_LOGGER_NAME -from ._base_memory import BaseMemoryConfig, Memory, MemoryContent, MemoryMimeType - -event_logger = logging.getLogger(EVENT_LOGGER_NAME) - - -class ListMemoryConfig(BaseMemoryConfig): - """Configuration for list-based memory implementation.""" - - similarity_threshold: float = Field( - default=0.35, description="Minimum similarity score for text matching", ge=0.0, le=1.0 - ) - - -class ListMemory(Memory): - """Simple list-based memory using text similarity matching. - - This memory implementation stores contents in a list and retrieves them based on - text similarity matching. It supports various content types and can transform - model contexts by injecting relevant memory content. - - Example: - ```python - # Initialize memory with custom config - memory = ListMemory(name="chat_history", config=ListMemoryConfig(similarity_threshold=0.7, k=3)) - - # Add memory content - content = MemoryContent(content="User prefers formal language", mime_type=MemoryMimeType.TEXT) - await memory.add(content) - - # Transform a model context with memory - context = await memory.transform(model_context) - ``` - - Attributes: - name (str): Identifier for this memory instance - config (ListMemoryConfig): Configuration controlling memory behavior - """ - - def __init__(self, name: str | None = None, config: ListMemoryConfig | None = None) -> None: - self._name = name or "default_list_memory" - self._config = config or ListMemoryConfig() - self._contents: List[MemoryContent] = [] - - @property - def name(self) -> str: - return self._name - - @property - def config(self) -> ListMemoryConfig: - return self._config - - async def transform( - self, - model_context: ChatCompletionContext, - ) -> List[MemoryContent]: - """Transform the model context by injecting relevant memory content. - - This method mutates the provided model_context by adding relevant memory content: - - 1. Extracts the last message from the context - 2. Uses it to query memory for relevant content - 3. Formats matching content into a system message - 4. Mutates the context by adding the system message - - Args: - model_context: The context to transform. Will be mutated if relevant - memories exist. - - Returns: - List[MemoryQueryResult]: A list of matching memory content with scores - - Example: - ```python - # Context will be mutated to include relevant memories - context = await memory.transform(model_context) - - # Any subsequent model calls will see the injected memories - messages = await context.get_messages() - ``` - """ - messages = await model_context.get_messages() - if not messages: - return [] - - # Extract query from last message - last_message = messages[-1] - query_text = last_message.content if isinstance(last_message.content, str) else str(last_message) - query = MemoryContent(content=query_text, mime_type=MemoryMimeType.TEXT) - - # Query memory and format results - results: List[str] = [] - query_results = await self.query(query) - for i, result in enumerate(query_results, 1): - if isinstance(result.content, str): - results.append(f"{i}. {result.content}") - event_logger.debug(f"Retrieved memory {i}. {result.content}, score: {result.score}") - - # Add memory results to context - if results: - memory_context = ( - "\n The following results were retrieved from memory for this task. You may choose to use them or not. :\n" - + "\n".join(results) - + "\n" - ) - await model_context.add_message(SystemMessage(content=memory_context)) - - return query_results - - async def query( - self, - query: MemoryContent, - cancellation_token: CancellationToken | None = None, - **kwargs: Any, - ) -> List[MemoryContent]: - """Query memory content based on text similarity. - - Searches memory content using text similarity matching against the query. - Only content exceeding the configured similarity threshold is returned, - sorted by relevance score in descending order. - - Args: - query: The content to match against memory content. Must contain - text that can be compared against stored content. - cancellation_token: Optional token to cancel long-running queries - **kwargs: Additional parameters passed to the similarity calculation - - Returns: - List[MemoryContent]: Matching content with similarity scores, - sorted by score in descending order. Limited to config.k entries. - - Raises: - ValueError: If query content cannot be converted to comparable text - - Example: - ```python - # Query memories similar to some text - query = MemoryContent(content="What's the weather?", mime_type=MemoryMimeType.TEXT) - results = await memory.query(query) - - # Check similarity scores - for result in results: - print(f"Score: {result.score}, Content: {result.content}") - ``` - """ - try: - query_text = self._extract_text(query) - except ValueError as e: - raise ValueError("Query must contain text content") from e - - results: List[MemoryContent] = [] - - for content in self._contents: - try: - content_text = self._extract_text(content) - except ValueError: - continue - - score = self._calculate_similarity(query_text, content_text) - - if score >= self._config.similarity_threshold and ( - self._config.score_threshold is None or score >= self._config.score_threshold - ): - result_content = content.model_copy() - result_content.score = score - results.append(result_content) - - results.sort(key=lambda x: x.score if x.score is not None else float("-inf"), reverse=True) - return results[: self._config.k] - - def _calculate_similarity(self, text1: str, text2: str) -> float: - """Calculate text similarity score using SequenceMatcher. - - Args: - text1: First text to compare - text2: Second text to compare - - Returns: - float: Similarity score between 0 and 1, where 1 means identical - - Note: - Uses difflib's SequenceMatcher for basic text similarity. - For production use cases, consider using more sophisticated - similarity metrics or embeddings. - """ - return SequenceMatcher(None, text1.lower(), text2.lower()).ratio() - - def _extract_text(self, content_item: MemoryContent) -> str: - """Extract searchable text from MemoryContent. - - Converts various content types into text that can be used for - similarity matching. - - Args: - content_item: Content to extract text from - - Returns: - str: Extracted text representation - - Raises: - ValueError: If content cannot be converted to text - - Note: - Currently supports TEXT, MARKDOWN, and JSON content types. - Images and binary content cannot be converted to text. - """ - content = content_item.content - - if content_item.mime_type in [MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN]: - return str(content) - elif content_item.mime_type == MemoryMimeType.JSON: - if isinstance(content, dict): - return str(content) - raise ValueError("JSON content must be a dict") - elif isinstance(content, Image): - raise ValueError("Image content cannot be converted to text") - else: - raise ValueError(f"Unsupported content type: {content_item.mime_type}") - - async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None: - """Add new content to memory. - - Args: - content: Memory content to store - cancellation_token: Optional token to cancel operation - - Note: - Content is stored in chronological order. No deduplication is - performed. For production use cases, consider implementing - deduplication or content-based filtering. - """ - self._contents.append(content) - - async def clear(self) -> None: - """Clear all memory content.""" - self._contents = [] - - async def cleanup(self) -> None: - """Cleanup resources if needed.""" - pass diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index 5c4ba0e03d61..2aee1ff3b38b 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -12,7 +12,7 @@ class and includes specific fields relevant to the type of message being sent. from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated, deprecated -from autogen_agentchat.memory import MemoryContent +from autogen_core.memory import MemoryContent class BaseMessage(BaseModel, ABC): @@ -135,12 +135,14 @@ class MemoryQueryEvent(BaseAgentEvent): ChatMessage = Annotated[ - TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type") + TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field( + discriminator="type") ] """Messages for agent-to-agent communication only.""" -AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent | MemoryQueryEvent, Field(discriminator="type")] +AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent | + MemoryQueryEvent, Field(discriminator="type")] """Events emitted by agents and teams when they work, not used for agent-to-agent communication.""" diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 759288479720..72fabfc484ee 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -7,7 +7,7 @@ from autogen_agentchat import EVENT_LOGGER_NAME from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.base import Handoff, TaskResult -from autogen_agentchat.memory import Memory, ListMemory, MemoryContent, MemoryMimeType +from autogen_core.memory import Memory, ListMemory, MemoryContent, MemoryMimeType from autogen_agentchat.messages import ( ChatMessage, HandoffMessage, @@ -97,7 +97,8 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: created=0, model=model, object="chat.completion", - usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + usage=CompletionUsage( + prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ChatCompletion( id="id2", @@ -105,13 +106,15 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: Choice( finish_reason="stop", index=0, - message=ChatCompletionMessage(content="pass", role="assistant"), + message=ChatCompletionMessage( + content="pass", role="assistant"), ) ], created=0, model=model, object="chat.completion", - usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + usage=CompletionUsage( + prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ChatCompletion( id="id2", @@ -119,13 +122,15 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: Choice( finish_reason="stop", index=0, - message=ChatCompletionMessage(content="TERMINATE", role="assistant"), + message=ChatCompletionMessage( + content="TERMINATE", role="assistant"), ) ], created=0, model=model, object="chat.completion", - usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + usage=CompletionUsage( + prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ] mock = _MockChatCompletion(chat_completions) @@ -169,7 +174,8 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: agent2 = AssistantAgent( "tool_use_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), - tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")], + tools=[_pass_function, _fail_function, FunctionTool( + _echo_function, description="Echo")], ) await agent2.load_state(state) state2 = await agent2.save_state() @@ -205,17 +211,20 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) -> created=0, model=model, object="chat.completion", - usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + usage=CompletionUsage( + prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ChatCompletion( id="id2", choices=[ - Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant")) + Choice(finish_reason="stop", index=0, message=ChatCompletionMessage( + content="Hello", role="assistant")) ], created=0, model=model, object="chat.completion", - usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + usage=CompletionUsage( + prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ChatCompletion( id="id2", @@ -227,7 +236,8 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) -> created=0, model=model, object="chat.completion", - usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + usage=CompletionUsage( + prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ] mock = _MockChatCompletion(chat_completions) @@ -235,7 +245,8 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) -> agent = AssistantAgent( "tool_use_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), - tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")], + tools=[_pass_function, _fail_function, FunctionTool( + _echo_function, description="Echo")], reflect_on_tool_use=True, ) result = await agent.run(task="task") @@ -311,7 +322,8 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None: created=0, model=model, object="chat.completion", - usage=CompletionUsage(prompt_tokens=42, completion_tokens=43, total_tokens=85), + usage=CompletionUsage( + prompt_tokens=42, completion_tokens=43, total_tokens=85), ), ] mock = _MockChatCompletion(chat_completions) @@ -363,13 +375,15 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None: Choice( finish_reason="stop", index=0, - message=ChatCompletionMessage(content="Hello", role="assistant"), + message=ChatCompletionMessage( + content="Hello", role="assistant"), ) ], created=0, model=model, object="chat.completion", - usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + usage=CompletionUsage( + prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ] mock = _MockChatCompletion(chat_completions) @@ -390,7 +404,8 @@ async def test_invalid_model_capabilities() -> None: model_client = OpenAIChatCompletionClient( model=model, api_key="", - model_info={"vision": False, "function_calling": False, "json_output": False, "family": ModelFamily.UNKNOWN}, + model_info={"vision": False, "function_calling": False, + "json_output": False, "family": ModelFamily.UNKNOWN}, ) with pytest.raises(ValueError): @@ -405,7 +420,8 @@ async def test_invalid_model_capabilities() -> None: ) with pytest.raises(ValueError): - agent = AssistantAgent(name="assistant", model_client=model_client, handoffs=["agent2"]) + agent = AssistantAgent( + name="assistant", model_client=model_client, handoffs=["agent2"]) with pytest.raises(ValueError): agent = AssistantAgent(name="assistant", model_client=model_client) @@ -424,13 +440,15 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None: Choice( finish_reason="stop", index=0, - message=ChatCompletionMessage(content="Response to message 1", role="assistant"), + message=ChatCompletionMessage( + content="Response to message 1", role="assistant"), ) ], created=0, model=model, object="chat.completion", - usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + usage=CompletionUsage( + prompt_tokens=10, completion_tokens=5, total_tokens=15), ), ] mock = _MockChatCompletion(chat_completions) @@ -483,13 +501,15 @@ async def test_model_context(monkeypatch: pytest.MonkeyPatch) -> None: Choice( finish_reason="stop", index=0, - message=ChatCompletionMessage(content="Response to message 3", role="assistant"), + message=ChatCompletionMessage( + content="Response to message 3", role="assistant"), ) ], created=0, model=model, object="chat.completion", - usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + usage=CompletionUsage( + prompt_tokens=10, completion_tokens=5, total_tokens=15), ), ] mock = _MockChatCompletion(chat_completions) @@ -524,13 +544,15 @@ async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None: Choice( finish_reason="stop", index=0, - message=ChatCompletionMessage(content="Hello", role="assistant"), + message=ChatCompletionMessage( + content="Hello", role="assistant"), ) ], created=0, model=model, object="chat.completion", - usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + usage=CompletionUsage( + prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ] b64_image_str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC" @@ -540,26 +562,21 @@ async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None: # Test basic memory properties and empty context memory = ListMemory(name="test_memory") assert memory.name == "test_memory" - assert memory.config is not None empty_context = BufferedChatCompletionContext(buffer_size=2) - empty_results = await memory.transform(empty_context) + empty_results = await memory.update_context(empty_context) assert len(empty_results) == 0 - # Test various content types and memory transforms + # Test various content types memory = ListMemory() await memory.add(MemoryContent(content="text content", mime_type=MemoryMimeType.TEXT)) await memory.add(MemoryContent(content={"key": "value"}, mime_type=MemoryMimeType.JSON)) await memory.add(MemoryContent(content=Image.from_base64(b64_image_str), mime_type=MemoryMimeType.IMAGE)) - # Invalid query should raise error - with pytest.raises(ValueError, match="Query must contain text content"): - await memory.query(MemoryContent(content=Image.from_base64(b64_image_str), mime_type=MemoryMimeType.IMAGE)) - # Test clear and cleanup await memory.clear() assert await memory.query(MemoryContent(content="", mime_type=MemoryMimeType.TEXT)) == [] - await memory.cleanup() # Should not raise + await memory.close() # Should not raise # Test invalid memory type with pytest.raises(TypeError): @@ -579,7 +596,8 @@ async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None: result = await agent.run(task="test task") assert len(result.messages) > 0 - memory_event = next((msg for msg in result.messages if isinstance(msg, MemoryQueryEvent)), None) + memory_event = next( + (msg for msg in result.messages if isinstance(msg, MemoryQueryEvent)), None) assert memory_event is not None # Test memory protocol diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb index d25d0e94f32f..bfb8c93a3718 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb @@ -9,20 +9,19 @@ "There are several use cases where it is valuable to maintain a _store_ of useful facts that can be intelligently added to the context of the agent just before a specific step. The typically use case here is a RAG pattern where a query is used to retrieve relevant information from a database that is then added to the agent's context.\n", "\n", "\n", - "AgentChat provides a {py:class}`~autogen_agentchat.memory.Memory` protocol that can be extended to provide this functionality. The key methods are `query`, `transform`, `add`, `clear`, and `cleanup`. \n", + "AgentChat provides a {py:class}`~autogen_core.memory.Memory` protocol that can be extended to provide this functionality. The key methods are `query`, `update_context`, `add`, `clear`, and `close`. \n", "\n", - "- `query`: retrieve relevant information from the memory store \n", - "- `transform`: mutate an agent's internal `model_context` by adding the retrieved information (used in the {py:class}`~autogen_agentchat.agents.AssistantAgent` class) \n", "- `add`: add new entries to the memory store\n", + "- `query`: retrieve relevant information from the memory store \n", + "- `update_context`: mutate an agent's internal `model_context` by adding the retrieved information (used in the {py:class}`~autogen_agentchat.agents.AssistantAgent` class) \n", "- `clear`: clear all entries from the memory store\n", - "- `cleanup`: clean up any resources used by the memory store \n", + "- `close`: clean up any resources used by the memory store \n", "\n", "\n", "## ListMemory Example\n", "\n", - "{py:class}`~autogen_agentchat.memory.ListMemory` is provided as an example implementation of the {py:class}`~autogen_agentchat.memory.Memory` protocol. It is a simple list-based memory implementation that uses text similarity matching to retrieve relevant information from the memory store. The similarity score is calculated using the `SequenceMatcher` class from the `difflib` module. The similarity score is calculated between the query text and the content text of each memory entry. \n", - "\n", - "In the following example, we will use ListMemory to similate a memory bank of user preferences and explore how it might be used in personalizing the agent's responses." + "{py:class}~autogen_core.memory.ListMemory is provided as an example implementation of the {py:class}~autogen_core.memory.Memory protocol. It is a simple list-based memory implementation that maintains memories in chronological order, appending the most recent memories to the model's context. The implementation is designed to be straightforward and predictable, making it easy to understand and debug.\n", + "In the following example, we will use ListMemory to maintain a memory bank of user preferences and demonstrate how it can be used to provide consistent context for agent responses over time." ] }, { @@ -32,7 +31,7 @@ "outputs": [], "source": [ "from autogen_agentchat.agents import AssistantAgent\n", - "from autogen_agentchat.memory import ListMemory, MemoryContent, MemoryMimeType\n", + "from autogen_core.memory import ListMemory, MemoryContent, MemoryMimeType\n", "from autogen_agentchat.ui import Console\n", "from autogen_ext.models.openai import OpenAIChatCompletionClient" ] @@ -81,26 +80,26 @@ "---------- user ----------\n", "What is the weather in New York?\n", "---------- assistant_agent ----------\n", - "[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=0.463768115942029)]\n", + "[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=None)]\n", "---------- assistant_agent ----------\n", - "[FunctionCall(id='call_OkQ4Z7u2RZLU6dA7GTAQiG9j', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')]\n", - "[Prompt tokens: 128, Completion tokens: 20]\n", + "[FunctionCall(id='call_C4xYYNfiw8sWCCshTh8GoUs7', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')]\n", + "[Prompt tokens: 123, Completion tokens: 20]\n", "---------- assistant_agent ----------\n", - "[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_OkQ4Z7u2RZLU6dA7GTAQiG9j')]\n", + "[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_C4xYYNfiw8sWCCshTh8GoUs7')]\n", "---------- assistant_agent ----------\n", "The weather in New York is 23 degrees and Sunny.\n", "---------- Summary ----------\n", "Number of messages: 5\n", "Finish reason: None\n", - "Total prompt tokens: 128\n", + "Total prompt tokens: 123\n", "Total completion tokens: 20\n", - "Duration: 0.80 seconds\n" + "Duration: 0.69 seconds\n" ] }, { "data": { "text/plain": [ - "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=0.463768115942029)], type='MemoryQueryEvent'), ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=128, completion_tokens=20), content=[FunctionCall(id='call_OkQ4Z7u2RZLU6dA7GTAQiG9j', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='assistant_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_OkQ4Z7u2RZLU6dA7GTAQiG9j')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='assistant_agent', models_usage=None, content='The weather in New York is 23 degrees and Sunny.', type='ToolCallSummaryMessage')], stop_reason=None)" + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=None)], type='MemoryQueryEvent'), ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=123, completion_tokens=20), content=[FunctionCall(id='call_C4xYYNfiw8sWCCshTh8GoUs7', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='assistant_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_C4xYYNfiw8sWCCshTh8GoUs7')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='assistant_agent', models_usage=None, content='The weather in New York is 23 degrees and Sunny.', type='ToolCallSummaryMessage')], stop_reason=None)" ] }, "execution_count": 3, @@ -130,9 +129,9 @@ "data": { "text/plain": [ "[UserMessage(content='What is the weather in New York?', source='user', type='UserMessage'),\n", - " SystemMessage(content='\\n The following results were retrieved from memory for this task. You may choose to use them or not. :\\n1. The weather should be in metric units\\n', type='SystemMessage'),\n", - " AssistantMessage(content=[FunctionCall(id='call_OkQ4Z7u2RZLU6dA7GTAQiG9j', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], source='assistant_agent', type='AssistantMessage'),\n", - " FunctionExecutionResultMessage(content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_OkQ4Z7u2RZLU6dA7GTAQiG9j')], type='FunctionExecutionResultMessage')]" + " SystemMessage(content='\\nRelevant memory content (in chronological order):\\n1. The weather should be in metric units\\n2. Meal recipe must be vegan\\n', type='SystemMessage'),\n", + " AssistantMessage(content=[FunctionCall(id='call_C4xYYNfiw8sWCCshTh8GoUs7', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], source='assistant_agent', type='AssistantMessage'),\n", + " FunctionExecutionResultMessage(content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_C4xYYNfiw8sWCCshTh8GoUs7')], type='FunctionExecutionResultMessage')]" ] }, "execution_count": 4, @@ -144,31 +143,6 @@ "await assistant_agent._model_context.get_messages()" ] }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'),\n", - " MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=0.463768115942029)], type='MemoryQueryEvent'),\n", - " ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=219, completion_tokens=20), content=[FunctionCall(id='call_YPwxZOz0bTEW15beow3zXsaI', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'),\n", - " ToolCallExecutionEvent(source='assistant_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_YPwxZOz0bTEW15beow3zXsaI')], type='ToolCallExecutionEvent'),\n", - " ToolCallSummaryMessage(source='assistant_agent', models_usage=None, content='The weather in New York is 23 degrees and Sunny.', type='ToolCallSummaryMessage')]" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "result = await assistant_agent.run(task=\"What is the weather in New York?\")\n", - "result.messages" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -190,52 +164,58 @@ "---------- user ----------\n", "Write brief meal recipe with broth\n", "---------- assistant_agent ----------\n", - "[MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=0.5084745762711864)]\n", + "[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=None)]\n", "---------- assistant_agent ----------\n", - "Here's a brief vegan recipe using broth:\n", + "Certainly! Here's a brief vegan meal recipe that uses broth:\n", "\n", - "**Vegan Vegetable Noodle Soup**\n", + "**Vegan Vegetable Broth Soup**\n", "\n", "**Ingredients:**\n", - "- 4 cups vegetable broth\n", - "- 1 cup water\n", - "- 1 cup carrots, sliced\n", - "- 1 cup celery, chopped\n", - "- 1 cup noodles (such as rice noodles or spaghetti broken into smaller pieces)\n", - "- 2 cups kale or spinach, chopped\n", - "- 2 cloves garlic, minced\n", "- 1 tablespoon olive oil\n", + "- 1 onion, chopped\n", + "- 2 cloves garlic, minced\n", + "- 2 carrots, sliced\n", + "- 2 celery stalks, sliced\n", + "- 1 zucchini, chopped\n", + "- 1 bell pepper, chopped\n", + "- 1 cup mushrooms, sliced\n", + "- 1 teaspoon thyme\n", + "- 1 teaspoon oregano\n", "- Salt and pepper to taste\n", - "- Lemon juice (optional)\n", + "- 4 cups vegetable broth\n", + "- 1 cup kale or spinach, chopped\n", + "- 1 can (15 oz) diced tomatoes\n", "\n", "**Instructions:**\n", "\n", - "1. **Sauté Vegetables:** In a large pot, heat olive oil over medium heat. Add minced garlic and sauté until fragrant. Add carrots and celery, and sauté for about 5 minutes, until they start to soften.\n", + "1. **Sauté Vegetables**: In a large pot, heat the olive oil over medium heat. Add the chopped onion and garlic and sauté until the onion becomes translucent.\n", "\n", - "2. **Add Broth and Noodles:** Pour in the vegetable broth and water, bringing it to a boil. Add the noodles and cook according to package instructions until they are al dente.\n", + "2. **Add Veggies**: Stir in the carrots, celery, zucchini, bell pepper, and mushrooms. Cook for about 5-7 minutes until the vegetables start to soften.\n", "\n", - "3. **Cook Greens:** Stir in the kale or spinach and allow it to simmer for a couple of minutes until wilted.\n", + "3. **Season**: Add thyme, oregano, salt, and pepper, stirring well to combine with the vegetables.\n", "\n", - "4. **Season and Serve:** Season with salt and pepper to taste. If desired, add a squeeze of lemon juice for extra flavor. \n", + "4. **Pour Broth and Tomatoes**: Add the vegetable broth and the can of diced tomatoes to the pot. Bring the mixture to a boil.\n", "\n", - "5. **Enjoy:** Serve hot and enjoy your nutritious, comforting soup!\n", + "5. **Simmer**: Reduce the heat to a simmer and let cook for about 20 minutes, allowing flavors to meld together.\n", "\n", - "This simple, flavorful soup is not only vegan but also packed with nutrients, making it a perfect meal any day. \n", + "6. **Add Greens**: Stir in the kale or spinach and simmer for another 5 minutes until the greens are wilted.\n", "\n", - "TERMINATE\n", - "[Prompt tokens: 306, Completion tokens: 294]\n", + "7. **Serve**: Ladle the soup into bowls and enjoy hot!\n", + "\n", + "This comforting vegan soup is perfect for a light meal or starter. Enjoy!\n", + "[Prompt tokens: 293, Completion tokens: 353]\n", "---------- Summary ----------\n", "Number of messages: 3\n", "Finish reason: None\n", - "Total prompt tokens: 306\n", - "Total completion tokens: 294\n", - "Duration: 4.39 seconds\n" + "Total prompt tokens: 293\n", + "Total completion tokens: 353\n", + "Duration: 7.06 seconds\n" ] }, { "data": { "text/plain": [ - "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Write brief meal recipe with broth', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=0.5084745762711864)], type='MemoryQueryEvent'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=306, completion_tokens=294), content=\"Here's a brief vegan recipe using broth:\\n\\n**Vegan Vegetable Noodle Soup**\\n\\n**Ingredients:**\\n- 4 cups vegetable broth\\n- 1 cup water\\n- 1 cup carrots, sliced\\n- 1 cup celery, chopped\\n- 1 cup noodles (such as rice noodles or spaghetti broken into smaller pieces)\\n- 2 cups kale or spinach, chopped\\n- 2 cloves garlic, minced\\n- 1 tablespoon olive oil\\n- Salt and pepper to taste\\n- Lemon juice (optional)\\n\\n**Instructions:**\\n\\n1. **Sauté Vegetables:** In a large pot, heat olive oil over medium heat. Add minced garlic and sauté until fragrant. Add carrots and celery, and sauté for about 5 minutes, until they start to soften.\\n\\n2. **Add Broth and Noodles:** Pour in the vegetable broth and water, bringing it to a boil. Add the noodles and cook according to package instructions until they are al dente.\\n\\n3. **Cook Greens:** Stir in the kale or spinach and allow it to simmer for a couple of minutes until wilted.\\n\\n4. **Season and Serve:** Season with salt and pepper to taste. If desired, add a squeeze of lemon juice for extra flavor. \\n\\n5. **Enjoy:** Serve hot and enjoy your nutritious, comforting soup!\\n\\nThis simple, flavorful soup is not only vegan but also packed with nutrients, making it a perfect meal any day. \\n\\nTERMINATE\", type='TextMessage')], stop_reason=None)" + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Write brief meal recipe with broth', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=None)], type='MemoryQueryEvent'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=293, completion_tokens=353), content=\"Certainly! Here's a brief vegan meal recipe that uses broth:\\n\\n**Vegan Vegetable Broth Soup**\\n\\n**Ingredients:**\\n- 1 tablespoon olive oil\\n- 1 onion, chopped\\n- 2 cloves garlic, minced\\n- 2 carrots, sliced\\n- 2 celery stalks, sliced\\n- 1 zucchini, chopped\\n- 1 bell pepper, chopped\\n- 1 cup mushrooms, sliced\\n- 1 teaspoon thyme\\n- 1 teaspoon oregano\\n- Salt and pepper to taste\\n- 4 cups vegetable broth\\n- 1 cup kale or spinach, chopped\\n- 1 can (15 oz) diced tomatoes\\n\\n**Instructions:**\\n\\n1. **Sauté Vegetables**: In a large pot, heat the olive oil over medium heat. Add the chopped onion and garlic and sauté until the onion becomes translucent.\\n\\n2. **Add Veggies**: Stir in the carrots, celery, zucchini, bell pepper, and mushrooms. Cook for about 5-7 minutes until the vegetables start to soften.\\n\\n3. **Season**: Add thyme, oregano, salt, and pepper, stirring well to combine with the vegetables.\\n\\n4. **Pour Broth and Tomatoes**: Add the vegetable broth and the can of diced tomatoes to the pot. Bring the mixture to a boil.\\n\\n5. **Simmer**: Reduce the heat to a simmer and let cook for about 20 minutes, allowing flavors to meld together.\\n\\n6. **Add Greens**: Stir in the kale or spinach and simmer for another 5 minutes until the greens are wilted.\\n\\n7. **Serve**: Ladle the soup into bowls and enjoy hot!\\n\\nThis comforting vegan soup is perfect for a light meal or starter. Enjoy!\", type='TextMessage')], stop_reason=None)" ] }, "execution_count": 6, @@ -256,8 +236,13 @@ "\n", "You can build on the `Memory` protocol to implement more complex memory stores. For example, you could implement a custom memory store that uses a vector database to store and retrieve information, or a memory store that uses a machine learning model to generate personalized responses based on the user's preferences etc.\n", "\n", - "Specifically, you will need to overload the `query`, `transform`, and `add` methods to implement the desired functionality and pass the memory store to your agent.\n" + "Specifically, you will need to overload the `add`, `query` and `update_context` methods to implement the desired functionality and pass the memory store to your agent.\n" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] } ], "metadata": { diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/__init__.py b/python/packages/autogen-core/src/autogen_core/memory/__init__.py similarity index 66% rename from python/packages/autogen-agentchat/src/autogen_agentchat/memory/__init__.py rename to python/packages/autogen-core/src/autogen_core/memory/__init__.py index beba13fcbc7e..7e36884ecb0a 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/__init__.py +++ b/python/packages/autogen-core/src/autogen_core/memory/__init__.py @@ -1,10 +1,9 @@ from ._base_memory import Memory, MemoryContent, MemoryMimeType -from ._list_memory import ListMemory, ListMemoryConfig +from ._list_memory import ListMemory __all__ = [ "Memory", "MemoryContent", "MemoryMimeType", "ListMemory", - "ListMemoryConfig", ] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py b/python/packages/autogen-core/src/autogen_core/memory/_base_memory.py similarity index 69% rename from python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py rename to python/packages/autogen-core/src/autogen_core/memory/_base_memory.py index d392ccb8cf33..6b77d65ddfb0 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/memory/_base_memory.py +++ b/python/packages/autogen-core/src/autogen_core/memory/_base_memory.py @@ -2,9 +2,10 @@ from enum import Enum from typing import Any, Dict, List, Protocol, Union, runtime_checkable -from autogen_core import CancellationToken, Image -from autogen_core.model_context import ChatCompletionContext -from pydantic import BaseModel, ConfigDict, Field +from .._cancellation_token import CancellationToken +from ..model_context import ChatCompletionContext +from .._image import Image +from pydantic import BaseModel, ConfigDict class MemoryMimeType(Enum): @@ -31,15 +32,6 @@ class MemoryContent(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) -class BaseMemoryConfig(BaseModel): - """Base configuration for memory implementations.""" - - k: int = Field(default=5, description="Number of results to return") - score_threshold: float | None = Field(default=None, description="Minimum relevance score") - - model_config = ConfigDict(arbitrary_types_allowed=True) - - @runtime_checkable class Memory(Protocol): """Protocol defining the interface for memory implementations.""" @@ -49,20 +41,15 @@ def name(self) -> str | None: """The name of this memory implementation.""" ... - @property - def config(self) -> BaseMemoryConfig: - """The configuration for this memory implementation.""" - ... - - async def transform( + async def update_context( self, model_context: ChatCompletionContext, ) -> List[MemoryContent]: """ - Transform the provided model context using relevant memory content. + Update the provided model context using relevant memory content. Args: - model_context: The context to transform + model_context: The context to update. Returns: List of memory entries with relevance scores @@ -72,7 +59,7 @@ async def transform( async def query( self, query: MemoryContent, - cancellation_token: "CancellationToken | None" = None, + cancellation_token: CancellationToken | None = None, **kwargs: Any, ) -> List[MemoryContent]: """ @@ -88,7 +75,7 @@ async def query( """ ... - async def add(self, content: MemoryContent, cancellation_token: "CancellationToken | None" = None) -> None: + async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None: """ Add a new content to memory. @@ -102,6 +89,6 @@ async def clear(self) -> None: """Clear all entries from memory.""" ... - async def cleanup(self) -> None: + async def close(self) -> None: """Clean up any resources used by the memory implementation.""" ... diff --git a/python/packages/autogen-core/src/autogen_core/memory/_list_memory.py b/python/packages/autogen-core/src/autogen_core/memory/_list_memory.py new file mode 100644 index 000000000000..61762397bd79 --- /dev/null +++ b/python/packages/autogen-core/src/autogen_core/memory/_list_memory.py @@ -0,0 +1,119 @@ +from typing import Any, List + +from ._base_memory import Memory, MemoryContent +from ..models import SystemMessage +from ..model_context import ChatCompletionContext +from .._cancellation_token import CancellationToken + + +class ListMemory(Memory): + """Simple chronological list-based memory implementation. + + This memory implementation stores contents in a list and retrieves them in + chronological order. It has an `update_context` method that updates model contexts by appending all stored + memories, limited by the configured maximum number of memories. + + Example: + ```python + # Initialize memory with custom config + memory = ListMemory(name="chat_history", config=ListMemoryConfig(max_memories=5)) + + # Add memory content + content = MemoryContent(content="User prefers formal language") + await memory.add(content) + + # Update a model context with memory + memory_contents = await memory.update_context(model_context) + ``` + + Attributes: + name (str): Identifier for this memory instance + config (ListMemoryConfig): Configuration controlling memory behavior + """ + + def __init__(self, name: str | None = None, max_memories: int = 5) -> None: + self._name = name or "default_list_memory" + self._max_memories = max_memories + self._contents: List[MemoryContent] = [] + + @property + def name(self) -> str: + return self._name + + async def update_context( + self, + model_context: ChatCompletionContext, + ) -> List[MemoryContent]: + """Update the model context by appending recent memory content. + + This method mutates the provided model_context by adding the most recent memories (as a :class:`SystemMessage`), up to the configured maximum number of memories. + + Args: + model_context: The context to update. Will be mutated if memories exist. + + Returns: + List[MemoryContent]: List of memories that were added to the context + """ + if not self._contents: + return [] + + # Get the most recent memories up to max_memories + recent_memories = self._contents[-self._max_memories:] + + # Format memories into a string + memory_strings = [] + for i, memory in enumerate(recent_memories, 1): + content = memory.content if isinstance( + memory.content, str) else str(memory.content) + memory_strings.append(f"{i}. {content}") + + # Add memories to context if there are any + if memory_strings: + memory_context = ( + "\nRelevant memory content (in chronological order):\n" + + "\n".join(memory_strings) + + "\n" + ) + await model_context.add_message(SystemMessage(content=memory_context)) + + return recent_memories + + async def query( + self, + query: MemoryContent, + cancellation_token: CancellationToken | None = None, + **kwargs: Any, + ) -> List[MemoryContent]: + """Return most recent memories without any filtering. + + Args: + query: Ignored in this implementation + cancellation_token: Optional token to cancel operation + **kwargs: Additional parameters (ignored) + + Returns: + List[MemoryContent]: Most recent memories up to max_memories limit + """ + _ = query + return self._contents[-self._max_memories:] + + async def add( + self, + content: MemoryContent, + cancellation_token: CancellationToken | None = None + ) -> None: + """Add new content to memory. + + Args: + content: Memory content to store + cancellation_token: Optional token to cancel operation + """ + self._contents.append(content) + + async def clear(self) -> None: + """Clear all memory content.""" + self._contents = [] + + async def close(self) -> None: + """Cleanup resources if needed.""" + pass From 30628f3c6a804e4db134b88441d2e1627752c360 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Wed, 8 Jan 2025 14:51:20 -0800 Subject: [PATCH 18/26] format fixxes --- .../agents/_assistant_agent.py | 54 +++++--------- .../src/autogen_agentchat/messages.py | 9 +-- .../tests/test_assistant_agent.py | 71 +++++++------------ .../tutorial/memory.ipynb | 2 +- .../src/autogen_core/memory/_base_memory.py | 5 +- .../src/autogen_core/memory/_list_memory.py | 27 +++---- 6 files changed, 58 insertions(+), 110 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 5ada7f57a743..93a79291618e 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -247,8 +247,7 @@ def __init__( name: str, model_client: ChatCompletionClient, *, - tools: List[Tool | Callable[..., Any] | - Callable[..., Awaitable[Any]]] | None = None, + tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None, handoffs: List[HandoffBase | str] | None = None, model_context: ChatCompletionContext | None = None, description: str = "An agent that provides assistance with ability to use tools.", @@ -268,8 +267,7 @@ def __init__( elif isinstance(memory, list): self._memory = memory else: - raise TypeError( - f"Expected Memory, List[Memory], or None, got {type(memory)}") + raise TypeError(f"Expected Memory, List[Memory], or None, got {type(memory)}") self._system_messages: List[ SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage @@ -281,8 +279,7 @@ def __init__( self._tools: List[Tool] = [] if tools is not None: if model_client.model_info["function_calling"] is False: - raise ValueError( - "The model does not support function calling.") + raise ValueError("The model does not support function calling.") for tool in tools: if isinstance(tool, Tool): self._tools.append(tool) @@ -291,8 +288,7 @@ def __init__( description = tool.__doc__ else: description = "" - self._tools.append(FunctionTool( - tool, description=description)) + self._tools.append(FunctionTool(tool, description=description)) else: raise ValueError(f"Unsupported tool type: {type(tool)}") # Check if tool names are unique. @@ -304,8 +300,7 @@ def __init__( self._handoffs: Dict[str, HandoffBase] = {} if handoffs is not None: if model_client.model_info["function_calling"] is False: - raise ValueError( - "The model does not support function calling, which is needed for handoffs.") + raise ValueError("The model does not support function calling, which is needed for handoffs.") for handoff in handoffs: if isinstance(handoff, str): handoff = HandoffBase(target=handoff) @@ -313,13 +308,11 @@ def __init__( self._handoff_tools.append(handoff.handoff_tool) self._handoffs[handoff.name] = handoff else: - raise ValueError( - f"Unsupported handoff type: {type(handoff)}") + raise ValueError(f"Unsupported handoff type: {type(handoff)}") # Check if handoff tool names are unique. handoff_tool_names = [tool.name for tool in self._handoff_tools] if len(handoff_tool_names) != len(set(handoff_tool_names)): - raise ValueError( - f"Handoff names must be unique: {handoff_tool_names}") + raise ValueError(f"Handoff names must be unique: {handoff_tool_names}") # Check if handoff tool names not in tool names. if any(name in tool_names for name in handoff_tool_names): raise ValueError( @@ -347,8 +340,7 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: async for message in self.on_messages_stream(messages, cancellation_token): if isinstance(message, Response): return message - raise AssertionError( - "The stream should have returned the final result.") + raise AssertionError("The stream should have returned the final result.") async def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken @@ -367,8 +359,7 @@ async def on_messages_stream( for memory in self._memory: memory_query_result = await memory.update_context(self._model_context) if memory_query_result and len(memory_query_result) > 0: - memory_query_event_msg = MemoryQueryEvent( - content=memory_query_result, source=self.name) + memory_query_event_msg = MemoryQueryEvent(content=memory_query_result, source=self.name) inner_messages.append(memory_query_event_msg) yield memory_query_event_msg @@ -384,17 +375,14 @@ async def on_messages_stream( # Check if the response is a string and return it. if isinstance(result.content, str): yield Response( - chat_message=TextMessage( - content=result.content, source=self.name, models_usage=result.usage), + chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage), inner_messages=inner_messages, ) return # Process tool calls. - assert isinstance(result.content, list) and all( - isinstance(item, FunctionCall) for item in result.content) - tool_call_msg = ToolCallRequestEvent( - content=result.content, source=self.name, models_usage=result.usage) + assert isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content) + tool_call_msg = ToolCallRequestEvent(content=result.content, source=self.name, models_usage=result.usage) event_logger.debug(tool_call_msg) # Add the tool call message to the output. inner_messages.append(tool_call_msg) @@ -402,8 +390,7 @@ async def on_messages_stream( # Execute the tool calls. results = await asyncio.gather(*[self._execute_tool_call(call, cancellation_token) for call in result.content]) - tool_call_result_msg = ToolCallExecutionEvent( - content=results, source=self.name) + tool_call_result_msg = ToolCallExecutionEvent(content=results, source=self.name) event_logger.debug(tool_call_result_msg) await self._model_context.add_message(FunctionExecutionResultMessage(content=results)) inner_messages.append(tool_call_result_msg) @@ -423,8 +410,7 @@ async def on_messages_stream( ) # Return the output messages to signal the handoff. yield Response( - chat_message=HandoffMessage( - content=handoffs[0].message, target=handoffs[0].target, source=self.name), + chat_message=HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name), inner_messages=inner_messages, ) return @@ -438,8 +424,7 @@ async def on_messages_stream( await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name)) # Yield the response. yield Response( - chat_message=TextMessage( - content=result.content, source=self.name, models_usage=result.usage), + chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage), inner_messages=inner_messages, ) else: @@ -455,8 +440,7 @@ async def on_messages_stream( ) tool_call_summary = "\n".join(tool_call_summaries) yield Response( - chat_message=ToolCallSummaryMessage( - content=tool_call_summary, source=self.name), + chat_message=ToolCallSummaryMessage(content=tool_call_summary, source=self.name), inner_messages=inner_messages, ) @@ -467,11 +451,9 @@ async def _execute_tool_call( try: if not self._tools + self._handoff_tools: raise ValueError("No tools are available.") - tool = next((t for t in self._tools + - self._handoff_tools if t.name == tool_call.name), None) + tool = next((t for t in self._tools + self._handoff_tools if t.name == tool_call.name), None) if tool is None: - raise ValueError( - f"The tool '{tool_call.name}' is not available.") + raise ValueError(f"The tool '{tool_call.name}' is not available.") arguments = json.loads(tool_call.arguments) result = await tool.run_json(arguments, cancellation_token) result_as_str = tool.return_value_as_string(result) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index 2aee1ff3b38b..1016817b61b6 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -8,12 +8,11 @@ class and includes specific fields relevant to the type of message being sent. from typing import List, Literal from autogen_core import FunctionCall, Image +from autogen_core.memory import MemoryContent from autogen_core.models import FunctionExecutionResult, RequestUsage from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Annotated, deprecated -from autogen_core.memory import MemoryContent - class BaseMessage(BaseModel, ABC): """Base class for all message types.""" @@ -135,14 +134,12 @@ class MemoryQueryEvent(BaseAgentEvent): ChatMessage = Annotated[ - TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field( - discriminator="type") + TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type") ] """Messages for agent-to-agent communication only.""" -AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent | - MemoryQueryEvent, Field(discriminator="type")] +AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent | MemoryQueryEvent, Field(discriminator="type")] """Events emitted by agents and teams when they work, not used for agent-to-agent communication.""" diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 72fabfc484ee..a36061d15796 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -7,7 +7,6 @@ from autogen_agentchat import EVENT_LOGGER_NAME from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.base import Handoff, TaskResult -from autogen_core.memory import Memory, ListMemory, MemoryContent, MemoryMimeType from autogen_agentchat.messages import ( ChatMessage, HandoffMessage, @@ -19,6 +18,7 @@ ToolCallSummaryMessage, ) from autogen_core import Image +from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType from autogen_core.model_context import BufferedChatCompletionContext from autogen_core.models import LLMMessage from autogen_core.models._model_client import ModelFamily @@ -97,8 +97,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: created=0, model=model, object="chat.completion", - usage=CompletionUsage( - prompt_tokens=10, completion_tokens=5, total_tokens=0), + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ChatCompletion( id="id2", @@ -106,15 +105,13 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: Choice( finish_reason="stop", index=0, - message=ChatCompletionMessage( - content="pass", role="assistant"), + message=ChatCompletionMessage(content="pass", role="assistant"), ) ], created=0, model=model, object="chat.completion", - usage=CompletionUsage( - prompt_tokens=10, completion_tokens=5, total_tokens=0), + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ChatCompletion( id="id2", @@ -122,15 +119,13 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: Choice( finish_reason="stop", index=0, - message=ChatCompletionMessage( - content="TERMINATE", role="assistant"), + message=ChatCompletionMessage(content="TERMINATE", role="assistant"), ) ], created=0, model=model, object="chat.completion", - usage=CompletionUsage( - prompt_tokens=10, completion_tokens=5, total_tokens=0), + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ] mock = _MockChatCompletion(chat_completions) @@ -174,8 +169,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: agent2 = AssistantAgent( "tool_use_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), - tools=[_pass_function, _fail_function, FunctionTool( - _echo_function, description="Echo")], + tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")], ) await agent2.load_state(state) state2 = await agent2.save_state() @@ -211,20 +205,17 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) -> created=0, model=model, object="chat.completion", - usage=CompletionUsage( - prompt_tokens=10, completion_tokens=5, total_tokens=0), + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ChatCompletion( id="id2", choices=[ - Choice(finish_reason="stop", index=0, message=ChatCompletionMessage( - content="Hello", role="assistant")) + Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant")) ], created=0, model=model, object="chat.completion", - usage=CompletionUsage( - prompt_tokens=10, completion_tokens=5, total_tokens=0), + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ChatCompletion( id="id2", @@ -236,8 +227,7 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) -> created=0, model=model, object="chat.completion", - usage=CompletionUsage( - prompt_tokens=10, completion_tokens=5, total_tokens=0), + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ] mock = _MockChatCompletion(chat_completions) @@ -245,8 +235,7 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) -> agent = AssistantAgent( "tool_use_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), - tools=[_pass_function, _fail_function, FunctionTool( - _echo_function, description="Echo")], + tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")], reflect_on_tool_use=True, ) result = await agent.run(task="task") @@ -322,8 +311,7 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None: created=0, model=model, object="chat.completion", - usage=CompletionUsage( - prompt_tokens=42, completion_tokens=43, total_tokens=85), + usage=CompletionUsage(prompt_tokens=42, completion_tokens=43, total_tokens=85), ), ] mock = _MockChatCompletion(chat_completions) @@ -375,15 +363,13 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None: Choice( finish_reason="stop", index=0, - message=ChatCompletionMessage( - content="Hello", role="assistant"), + message=ChatCompletionMessage(content="Hello", role="assistant"), ) ], created=0, model=model, object="chat.completion", - usage=CompletionUsage( - prompt_tokens=10, completion_tokens=5, total_tokens=0), + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ] mock = _MockChatCompletion(chat_completions) @@ -404,8 +390,7 @@ async def test_invalid_model_capabilities() -> None: model_client = OpenAIChatCompletionClient( model=model, api_key="", - model_info={"vision": False, "function_calling": False, - "json_output": False, "family": ModelFamily.UNKNOWN}, + model_info={"vision": False, "function_calling": False, "json_output": False, "family": ModelFamily.UNKNOWN}, ) with pytest.raises(ValueError): @@ -420,8 +405,7 @@ async def test_invalid_model_capabilities() -> None: ) with pytest.raises(ValueError): - agent = AssistantAgent( - name="assistant", model_client=model_client, handoffs=["agent2"]) + agent = AssistantAgent(name="assistant", model_client=model_client, handoffs=["agent2"]) with pytest.raises(ValueError): agent = AssistantAgent(name="assistant", model_client=model_client) @@ -440,15 +424,13 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None: Choice( finish_reason="stop", index=0, - message=ChatCompletionMessage( - content="Response to message 1", role="assistant"), + message=ChatCompletionMessage(content="Response to message 1", role="assistant"), ) ], created=0, model=model, object="chat.completion", - usage=CompletionUsage( - prompt_tokens=10, completion_tokens=5, total_tokens=15), + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15), ), ] mock = _MockChatCompletion(chat_completions) @@ -501,15 +483,13 @@ async def test_model_context(monkeypatch: pytest.MonkeyPatch) -> None: Choice( finish_reason="stop", index=0, - message=ChatCompletionMessage( - content="Response to message 3", role="assistant"), + message=ChatCompletionMessage(content="Response to message 3", role="assistant"), ) ], created=0, model=model, object="chat.completion", - usage=CompletionUsage( - prompt_tokens=10, completion_tokens=5, total_tokens=15), + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15), ), ] mock = _MockChatCompletion(chat_completions) @@ -544,15 +524,13 @@ async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None: Choice( finish_reason="stop", index=0, - message=ChatCompletionMessage( - content="Hello", role="assistant"), + message=ChatCompletionMessage(content="Hello", role="assistant"), ) ], created=0, model=model, object="chat.completion", - usage=CompletionUsage( - prompt_tokens=10, completion_tokens=5, total_tokens=0), + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ] b64_image_str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC" @@ -596,8 +574,7 @@ async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None: result = await agent.run(task="test task") assert len(result.messages) > 0 - memory_event = next( - (msg for msg in result.messages if isinstance(msg, MemoryQueryEvent)), None) + memory_event = next((msg for msg in result.messages if isinstance(msg, MemoryQueryEvent)), None) assert memory_event is not None # Test memory protocol diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb index bfb8c93a3718..a02760f38095 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb @@ -31,8 +31,8 @@ "outputs": [], "source": [ "from autogen_agentchat.agents import AssistantAgent\n", - "from autogen_core.memory import ListMemory, MemoryContent, MemoryMimeType\n", "from autogen_agentchat.ui import Console\n", + "from autogen_core.memory import ListMemory, MemoryContent, MemoryMimeType\n", "from autogen_ext.models.openai import OpenAIChatCompletionClient" ] }, diff --git a/python/packages/autogen-core/src/autogen_core/memory/_base_memory.py b/python/packages/autogen-core/src/autogen_core/memory/_base_memory.py index 6b77d65ddfb0..f8b1f33ec691 100644 --- a/python/packages/autogen-core/src/autogen_core/memory/_base_memory.py +++ b/python/packages/autogen-core/src/autogen_core/memory/_base_memory.py @@ -2,10 +2,11 @@ from enum import Enum from typing import Any, Dict, List, Protocol, Union, runtime_checkable +from pydantic import BaseModel, ConfigDict + from .._cancellation_token import CancellationToken -from ..model_context import ChatCompletionContext from .._image import Image -from pydantic import BaseModel, ConfigDict +from ..model_context import ChatCompletionContext class MemoryMimeType(Enum): diff --git a/python/packages/autogen-core/src/autogen_core/memory/_list_memory.py b/python/packages/autogen-core/src/autogen_core/memory/_list_memory.py index 61762397bd79..4be328cc8bae 100644 --- a/python/packages/autogen-core/src/autogen_core/memory/_list_memory.py +++ b/python/packages/autogen-core/src/autogen_core/memory/_list_memory.py @@ -1,9 +1,9 @@ from typing import Any, List -from ._base_memory import Memory, MemoryContent -from ..models import SystemMessage -from ..model_context import ChatCompletionContext from .._cancellation_token import CancellationToken +from ..model_context import ChatCompletionContext +from ..models import SystemMessage +from ._base_memory import Memory, MemoryContent class ListMemory(Memory): @@ -58,22 +58,17 @@ async def update_context( return [] # Get the most recent memories up to max_memories - recent_memories = self._contents[-self._max_memories:] + recent_memories = self._contents[-self._max_memories :] # Format memories into a string - memory_strings = [] + memory_strings: List[str] = [] for i, memory in enumerate(recent_memories, 1): - content = memory.content if isinstance( - memory.content, str) else str(memory.content) + content = memory.content if isinstance(memory.content, str) else str(memory.content) memory_strings.append(f"{i}. {content}") # Add memories to context if there are any if memory_strings: - memory_context = ( - "\nRelevant memory content (in chronological order):\n" - + "\n".join(memory_strings) - + "\n" - ) + memory_context = "\nRelevant memory content (in chronological order):\n" + "\n".join(memory_strings) + "\n" await model_context.add_message(SystemMessage(content=memory_context)) return recent_memories @@ -95,13 +90,9 @@ async def query( List[MemoryContent]: Most recent memories up to max_memories limit """ _ = query - return self._contents[-self._max_memories:] + return self._contents[-self._max_memories :] - async def add( - self, - content: MemoryContent, - cancellation_token: CancellationToken | None = None - ) -> None: + async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None: """Add new content to memory. Args: From 2072c461255e1226ca127a55e3935f0680363f7d Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Wed, 8 Jan 2025 15:57:20 -0800 Subject: [PATCH 19/26] format updates --- .../autogen-agentchat/src/autogen_agentchat/messages.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index 7c5776e34a92..9a8b07fce735 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -114,14 +114,12 @@ class MemoryQueryEvent(BaseAgentEvent): ChatMessage = Annotated[ - TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field( - discriminator="type") + TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type") ] """Messages for agent-to-agent communication only.""" -AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent | - MemoryQueryEvent, Field(discriminator="type")] +AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent | MemoryQueryEvent, Field(discriminator="type")] """Events emitted by agents and teams when they work, not used for agent-to-agent communication.""" From 4382c862d0d0bd398556021d14c62dbd7ad5ab09 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Wed, 8 Jan 2025 16:56:46 -0800 Subject: [PATCH 20/26] format updates --- .../agents/_assistant_agent.py | 4 +- .../autogen-core/tests/test_memory.py | 148 ++++++++++++++++++ 2 files changed, 149 insertions(+), 3 deletions(-) create mode 100644 python/packages/autogen-core/tests/test_memory.py diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 08c939eef916..15d97ae3b5b3 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -249,9 +249,7 @@ def __init__( self._model_client = model_client self._memory = None if memory is not None: - if isinstance(memory, Memory): - self._memory = [memory] - elif isinstance(memory, list): + if isinstance(memory, list): self._memory = memory else: raise TypeError(f"Expected Memory, List[Memory], or None, got {type(memory)}") diff --git a/python/packages/autogen-core/tests/test_memory.py b/python/packages/autogen-core/tests/test_memory.py new file mode 100644 index 000000000000..aa938219c9ff --- /dev/null +++ b/python/packages/autogen-core/tests/test_memory.py @@ -0,0 +1,148 @@ +import pytest +from datetime import datetime +from typing import List + +from autogen_core import CancellationToken +from autogen_core.model_context import ChatCompletionContext, BufferedChatCompletionContext +from autogen_core.memory import Memory, MemoryContent, MemoryMimeType, ListMemory + + +def test_memory_protocol_attributes() -> None: + """Test that Memory protocol has all required attributes.""" + assert hasattr(Memory, "name") + assert hasattr(Memory, "update_context") + assert hasattr(Memory, "query") + assert hasattr(Memory, "add") + assert hasattr(Memory, "clear") + assert hasattr(Memory, "close") + + +def test_memory_protocol_runtime_checkable() -> None: + """Test that Memory protocol is properly runtime-checkable.""" + + class ValidMemory: + @property + def name(self) -> str: + return "test" + + async def update_context(self, context: ChatCompletionContext) -> List[MemoryContent]: + return [] + + async def query( + self, query: MemoryContent, cancellation_token: CancellationToken | None = None + ) -> List[MemoryContent]: + return [] + + async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None: + pass + + async def clear(self) -> None: + pass + + async def close(self) -> None: + pass + + class InvalidMemory: + pass + + assert isinstance(ValidMemory(), Memory) + assert not isinstance(InvalidMemory(), Memory) + + +def test_list_memory_basic_properties() -> None: + """Test basic properties of ListMemory.""" + memory = ListMemory(name="test_memory", max_memories=3) + assert memory.name == "test_memory" + assert isinstance(memory, Memory) + + +@pytest.mark.asyncio +async def test_list_memory_empty() -> None: + """Test ListMemory behavior when empty.""" + memory = ListMemory(name="test_memory") + context = BufferedChatCompletionContext(buffer_size=3) + + results = await memory.update_context(context) + context_messages = await context.get_messages() + assert len(results) == 0 + assert len(context_messages) == 0 + + query_results = await memory.query(MemoryContent(content="test", mime_type=MemoryMimeType.TEXT)) + assert len(query_results) == 0 + + +@pytest.mark.asyncio +async def test_list_memory_add_and_query() -> None: + """Test adding and querying memory contents.""" + memory = ListMemory(max_memories=3) + + content1 = MemoryContent(content="test1", mime_type=MemoryMimeType.TEXT, timestamp=datetime.now()) + content2 = MemoryContent(content={"key": "value"}, mime_type=MemoryMimeType.JSON, timestamp=datetime.now()) + + await memory.add(content1) + await memory.add(content2) + + results = await memory.query(MemoryContent(content="query", mime_type=MemoryMimeType.TEXT)) + assert len(results) == 2 + assert results[0].content == "test1" + assert results[1].content == {"key": "value"} + + +@pytest.mark.asyncio +async def test_list_memory_max_memories() -> None: + """Test max_memories limit is enforced.""" + memory = ListMemory(max_memories=3) + + for i in range(5): + await memory.add(MemoryContent(content=f"test{i}", mime_type=MemoryMimeType.TEXT)) + + results = await memory.query(MemoryContent(content="query", mime_type=MemoryMimeType.TEXT)) + assert len(results) == 3 + assert [r.content for r in results] == ["test2", "test3", "test4"] + + +@pytest.mark.asyncio +async def test_list_memory_update_context() -> None: + """Test context updating with memory contents.""" + memory = ListMemory(max_memories=3) + context = BufferedChatCompletionContext(buffer_size=3) + + await memory.add(MemoryContent(content="test1", mime_type=MemoryMimeType.TEXT)) + await memory.add(MemoryContent(content="test2", mime_type=MemoryMimeType.TEXT)) + + results = await memory.update_context(context) + context_messages = await context.get_messages() + assert len(results) == 2 + assert len(context_messages) == 1 + assert "test1" in context_messages[0].content + assert "test2" in context_messages[0].content + + +@pytest.mark.asyncio +async def test_list_memory_clear() -> None: + """Test clearing memory contents.""" + memory = ListMemory() + await memory.add(MemoryContent(content="test", mime_type=MemoryMimeType.TEXT)) + await memory.clear() + + results = await memory.query(MemoryContent(content="query", mime_type=MemoryMimeType.TEXT)) + assert len(results) == 0 + + +@pytest.mark.asyncio +async def test_list_memory_content_types() -> None: + """Test support for different content types.""" + memory = ListMemory() + text_content = MemoryContent(content="text", mime_type=MemoryMimeType.TEXT) + json_content = MemoryContent(content={"key": "value"}, mime_type=MemoryMimeType.JSON) + binary_content = MemoryContent(content=b"binary", mime_type=MemoryMimeType.BINARY) + + await memory.add(text_content) + await memory.add(json_content) + await memory.add(binary_content) + + results = await memory.query(text_content) + assert len(results) == 3 + assert isinstance(results[0].content, str) + assert isinstance(results[1].content, dict) + assert isinstance(results[2].content, bytes) From 08d23cf000407ea22015c7489d4914b3d34a4d64 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Wed, 8 Jan 2025 17:59:53 -0800 Subject: [PATCH 21/26] fix azure notebook import, other fixes --- .../tutorial/memory.ipynb | 99 ++++++++++--------- .../azure-container-code-executor.ipynb | 36 ++++++- .../autogen-core/tests/test_memory.py | 6 +- 3 files changed, 84 insertions(+), 57 deletions(-) diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb index a02760f38095..53f769361ac1 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb @@ -26,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -38,7 +38,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -56,6 +56,8 @@ " return f\"The weather in {city} is 73 degrees and Sunny.\"\n", " elif units == \"metric\":\n", " return f\"The weather in {city} is 23 degrees and Sunny.\"\n", + " else:\n", + " return f\"Sorry, I don't know the weather in {city}.\"\n", "\n", "\n", "assistant_agent = AssistantAgent(\n", @@ -70,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -82,10 +84,10 @@ "---------- assistant_agent ----------\n", "[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=None)]\n", "---------- assistant_agent ----------\n", - "[FunctionCall(id='call_C4xYYNfiw8sWCCshTh8GoUs7', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')]\n", + "[FunctionCall(id='call_FS2npPqrZOtv0391uuf0THxu', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')]\n", "[Prompt tokens: 123, Completion tokens: 20]\n", "---------- assistant_agent ----------\n", - "[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_C4xYYNfiw8sWCCshTh8GoUs7')]\n", + "[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_FS2npPqrZOtv0391uuf0THxu')]\n", "---------- assistant_agent ----------\n", "The weather in New York is 23 degrees and Sunny.\n", "---------- Summary ----------\n", @@ -93,16 +95,16 @@ "Finish reason: None\n", "Total prompt tokens: 123\n", "Total completion tokens: 20\n", - "Duration: 0.69 seconds\n" + "Duration: 0.87 seconds\n" ] }, { "data": { "text/plain": [ - "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=None)], type='MemoryQueryEvent'), ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=123, completion_tokens=20), content=[FunctionCall(id='call_C4xYYNfiw8sWCCshTh8GoUs7', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='assistant_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_C4xYYNfiw8sWCCshTh8GoUs7')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='assistant_agent', models_usage=None, content='The weather in New York is 23 degrees and Sunny.', type='ToolCallSummaryMessage')], stop_reason=None)" + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=None)], type='MemoryQueryEvent'), ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=123, completion_tokens=20), content=[FunctionCall(id='call_FS2npPqrZOtv0391uuf0THxu', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='assistant_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_FS2npPqrZOtv0391uuf0THxu')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='assistant_agent', models_usage=None, content='The weather in New York is 23 degrees and Sunny.', type='ToolCallSummaryMessage')], stop_reason=None)" ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -122,7 +124,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -130,11 +132,11 @@ "text/plain": [ "[UserMessage(content='What is the weather in New York?', source='user', type='UserMessage'),\n", " SystemMessage(content='\\nRelevant memory content (in chronological order):\\n1. The weather should be in metric units\\n2. Meal recipe must be vegan\\n', type='SystemMessage'),\n", - " AssistantMessage(content=[FunctionCall(id='call_C4xYYNfiw8sWCCshTh8GoUs7', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], source='assistant_agent', type='AssistantMessage'),\n", - " FunctionExecutionResultMessage(content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_C4xYYNfiw8sWCCshTh8GoUs7')], type='FunctionExecutionResultMessage')]" + " AssistantMessage(content=[FunctionCall(id='call_FS2npPqrZOtv0391uuf0THxu', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], source='assistant_agent', type='AssistantMessage'),\n", + " FunctionExecutionResultMessage(content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_FS2npPqrZOtv0391uuf0THxu')], type='FunctionExecutionResultMessage')]" ] }, - "execution_count": 4, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -166,56 +168,60 @@ "---------- assistant_agent ----------\n", "[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=None)]\n", "---------- assistant_agent ----------\n", - "Certainly! Here's a brief vegan meal recipe that uses broth:\n", - "\n", - "**Vegan Vegetable Broth Soup**\n", + "**Vegan Vegetable Broth Recipe**\n", "\n", "**Ingredients:**\n", - "- 1 tablespoon olive oil\n", - "- 1 onion, chopped\n", - "- 2 cloves garlic, minced\n", - "- 2 carrots, sliced\n", - "- 2 celery stalks, sliced\n", - "- 1 zucchini, chopped\n", - "- 1 bell pepper, chopped\n", - "- 1 cup mushrooms, sliced\n", - "- 1 teaspoon thyme\n", - "- 1 teaspoon oregano\n", - "- Salt and pepper to taste\n", - "- 4 cups vegetable broth\n", - "- 1 cup kale or spinach, chopped\n", - "- 1 can (15 oz) diced tomatoes\n", + "- 2 tablespoons olive oil\n", + "- 1 onion, roughly chopped\n", + "- 2 carrots, peeled and roughly chopped\n", + "- 2 stalks of celery, roughly chopped\n", + "- 2 cloves of garlic, smashed\n", + "- 8 cups of water\n", + "- 1 bay leaf\n", + "- 1/2 teaspoon whole peppercorns\n", + "- 1 teaspoon salt, or to taste\n", + "- A small bunch of fresh herbs (e.g., thyme, parsley), tied together\n", + "- 1 large tomato, quartered (optional)\n", "\n", "**Instructions:**\n", "\n", - "1. **Sauté Vegetables**: In a large pot, heat the olive oil over medium heat. Add the chopped onion and garlic and sauté until the onion becomes translucent.\n", - "\n", - "2. **Add Veggies**: Stir in the carrots, celery, zucchini, bell pepper, and mushrooms. Cook for about 5-7 minutes until the vegetables start to soften.\n", + "1. **Sauté the Vegetables:**\n", + " - In a large pot, heat olive oil over medium heat.\n", + " - Add onions, carrots, and celery. Sauté for about 5 minutes or until the vegetables start to soften.\n", + " - Add garlic and sauté for another minute.\n", "\n", - "3. **Season**: Add thyme, oregano, salt, and pepper, stirring well to combine with the vegetables.\n", + "2. **Prepare the Broth:**\n", + " - Add water to the pot and stir in bay leaf, peppercorns, salt, and herbs. Add the tomato if using.\n", + " - Bring the mixture to a simmer over medium-high heat.\n", "\n", - "4. **Pour Broth and Tomatoes**: Add the vegetable broth and the can of diced tomatoes to the pot. Bring the mixture to a boil.\n", + "3. **Simmer:**\n", + " - Once simmering, reduce heat to low.\n", + " - Cover and let the broth simmer for at least 45 minutes, allowing the flavors to meld.\n", "\n", - "5. **Simmer**: Reduce the heat to a simmer and let cook for about 20 minutes, allowing flavors to meld together.\n", + "4. **Strain:**\n", + " - Remove the pot from heat. Let it cool slightly.\n", + " - Strain the broth through a fine mesh sieve or cheesecloth into another pot or heat-proof container to remove solids.\n", + " - Press down on the cooked vegetables lightly with a spoon to extract as much liquid as possible.\n", "\n", - "6. **Add Greens**: Stir in the kale or spinach and simmer for another 5 minutes until the greens are wilted.\n", + "5. **Serve or Store:**\n", + " - Taste and adjust seasoning if necessary.\n", + " - Use immediately as a base for soups, stews, or enjoy it as is.\n", + " - Store leftovers in an airtight container in the refrigerator for up to a week, or freeze for longer storage.\n", "\n", - "7. **Serve**: Ladle the soup into bowls and enjoy hot!\n", - "\n", - "This comforting vegan soup is perfect for a light meal or starter. Enjoy!\n", - "[Prompt tokens: 293, Completion tokens: 353]\n", + "Enjoy your delicious, homemade vegan vegetable broth! TERMINATE\n", + "[Prompt tokens: 207, Completion tokens: 409]\n", "---------- Summary ----------\n", "Number of messages: 3\n", "Finish reason: None\n", - "Total prompt tokens: 293\n", - "Total completion tokens: 353\n", - "Duration: 7.06 seconds\n" + "Total prompt tokens: 207\n", + "Total completion tokens: 409\n", + "Duration: 12.23 seconds\n" ] }, { "data": { "text/plain": [ - "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Write brief meal recipe with broth', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=None)], type='MemoryQueryEvent'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=293, completion_tokens=353), content=\"Certainly! Here's a brief vegan meal recipe that uses broth:\\n\\n**Vegan Vegetable Broth Soup**\\n\\n**Ingredients:**\\n- 1 tablespoon olive oil\\n- 1 onion, chopped\\n- 2 cloves garlic, minced\\n- 2 carrots, sliced\\n- 2 celery stalks, sliced\\n- 1 zucchini, chopped\\n- 1 bell pepper, chopped\\n- 1 cup mushrooms, sliced\\n- 1 teaspoon thyme\\n- 1 teaspoon oregano\\n- Salt and pepper to taste\\n- 4 cups vegetable broth\\n- 1 cup kale or spinach, chopped\\n- 1 can (15 oz) diced tomatoes\\n\\n**Instructions:**\\n\\n1. **Sauté Vegetables**: In a large pot, heat the olive oil over medium heat. Add the chopped onion and garlic and sauté until the onion becomes translucent.\\n\\n2. **Add Veggies**: Stir in the carrots, celery, zucchini, bell pepper, and mushrooms. Cook for about 5-7 minutes until the vegetables start to soften.\\n\\n3. **Season**: Add thyme, oregano, salt, and pepper, stirring well to combine with the vegetables.\\n\\n4. **Pour Broth and Tomatoes**: Add the vegetable broth and the can of diced tomatoes to the pot. Bring the mixture to a boil.\\n\\n5. **Simmer**: Reduce the heat to a simmer and let cook for about 20 minutes, allowing flavors to meld together.\\n\\n6. **Add Greens**: Stir in the kale or spinach and simmer for another 5 minutes until the greens are wilted.\\n\\n7. **Serve**: Ladle the soup into bowls and enjoy hot!\\n\\nThis comforting vegan soup is perfect for a light meal or starter. Enjoy!\", type='TextMessage')], stop_reason=None)" + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Write brief meal recipe with broth', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=None)], type='MemoryQueryEvent'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=207, completion_tokens=409), content='**Vegan Vegetable Broth Recipe**\\n\\n**Ingredients:**\\n- 2 tablespoons olive oil\\n- 1 onion, roughly chopped\\n- 2 carrots, peeled and roughly chopped\\n- 2 stalks of celery, roughly chopped\\n- 2 cloves of garlic, smashed\\n- 8 cups of water\\n- 1 bay leaf\\n- 1/2 teaspoon whole peppercorns\\n- 1 teaspoon salt, or to taste\\n- A small bunch of fresh herbs (e.g., thyme, parsley), tied together\\n- 1 large tomato, quartered (optional)\\n\\n**Instructions:**\\n\\n1. **Sauté the Vegetables:**\\n - In a large pot, heat olive oil over medium heat.\\n - Add onions, carrots, and celery. Sauté for about 5 minutes or until the vegetables start to soften.\\n - Add garlic and sauté for another minute.\\n\\n2. **Prepare the Broth:**\\n - Add water to the pot and stir in bay leaf, peppercorns, salt, and herbs. Add the tomato if using.\\n - Bring the mixture to a simmer over medium-high heat.\\n\\n3. **Simmer:**\\n - Once simmering, reduce heat to low.\\n - Cover and let the broth simmer for at least 45 minutes, allowing the flavors to meld.\\n\\n4. **Strain:**\\n - Remove the pot from heat. Let it cool slightly.\\n - Strain the broth through a fine mesh sieve or cheesecloth into another pot or heat-proof container to remove solids.\\n - Press down on the cooked vegetables lightly with a spoon to extract as much liquid as possible.\\n\\n5. **Serve or Store:**\\n - Taste and adjust seasoning if necessary.\\n - Use immediately as a base for soups, stews, or enjoy it as is.\\n - Store leftovers in an airtight container in the refrigerator for up to a week, or freeze for longer storage.\\n\\nEnjoy your delicious, homemade vegan vegetable broth! TERMINATE', type='TextMessage')], stop_reason=None)" ] }, "execution_count": 6, @@ -238,11 +244,6 @@ "\n", "Specifically, you will need to overload the `add`, `query` and `update_context` methods to implement the desired functionality and pass the memory store to your agent.\n" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] } ], "metadata": { diff --git a/python/packages/autogen-core/docs/src/user-guide/extensions-user-guide/azure-container-code-executor.ipynb b/python/packages/autogen-core/docs/src/user-guide/extensions-user-guide/azure-container-code-executor.ipynb index 7bc7ef4da275..04def0ac3568 100644 --- a/python/packages/autogen-core/docs/src/user-guide/extensions-user-guide/azure-container-code-executor.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/extensions-user-guide/azure-container-code-executor.ipynb @@ -37,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -53,16 +53,28 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "ImportError", + "evalue": "cannot import name 'CodeBlock' from 'autogen_core.components.code_executor' (unknown location)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[2], line 6\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01manyio\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m open_file\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mautogen_core\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m CancellationToken\n\u001b[0;32m----> 6\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mautogen_core\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcomponents\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcode_executor\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m CodeBlock\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mautogen_ext\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcode_executor\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01maca_dynamic_sessions\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m AzureContainerCodeExecutor\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mazure\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01midentity\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m DefaultAzureCredential\n", + "\u001b[0;31mImportError\u001b[0m: cannot import name 'CodeBlock' from 'autogen_core.components.code_executor' (unknown location)" + ] + } + ], "source": [ "import os\n", "import tempfile\n", "\n", "from anyio import open_file\n", "from autogen_core import CancellationToken\n", - "from autogen_core.components.code_executor import CodeBlock\n", + "from autogen_core.code_executor import CodeBlock\n", "from autogen_ext.code_executor.aca_dynamic_sessions import AzureContainerCodeExecutor\n", "from azure.identity import DefaultAzureCredential" ] @@ -257,8 +269,22 @@ } ], "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" } }, "nbformat": 4, diff --git a/python/packages/autogen-core/tests/test_memory.py b/python/packages/autogen-core/tests/test_memory.py index aa938219c9ff..c721725437cf 100644 --- a/python/packages/autogen-core/tests/test_memory.py +++ b/python/packages/autogen-core/tests/test_memory.py @@ -1,10 +1,10 @@ -import pytest from datetime import datetime from typing import List +import pytest from autogen_core import CancellationToken -from autogen_core.model_context import ChatCompletionContext, BufferedChatCompletionContext -from autogen_core.memory import Memory, MemoryContent, MemoryMimeType, ListMemory +from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType +from autogen_core.model_context import BufferedChatCompletionContext, ChatCompletionContext def test_memory_protocol_attributes() -> None: From 9316f6d479887b8b91584113ca4294672edd89ba Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Thu, 9 Jan 2025 10:34:17 -0800 Subject: [PATCH 22/26] update notebook, support str query in Memory protocol --- .../tutorial/memory.ipynb | 100 ++++++++---------- .../src/autogen_core/memory/_base_memory.py | 2 +- .../src/autogen_core/memory/_list_memory.py | 88 +++++++++------ .../autogen-core/tests/test_memory.py | 10 +- .../src/autogen_ext/models/openai/__init__.py | 4 +- 5 files changed, 110 insertions(+), 94 deletions(-) diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb index 53f769361ac1..aecae54781e2 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb @@ -26,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -38,7 +38,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -53,9 +53,9 @@ "\n", "async def get_weather(city: str, units: str = \"imperial\") -> str:\n", " if units == \"imperial\":\n", - " return f\"The weather in {city} is 73 degrees and Sunny.\"\n", + " return f\"The weather in {city} is 73 °F and Sunny.\"\n", " elif units == \"metric\":\n", - " return f\"The weather in {city} is 23 degrees and Sunny.\"\n", + " return f\"The weather in {city} is 23 °C and Sunny.\"\n", " else:\n", " return f\"Sorry, I don't know the weather in {city}.\"\n", "\n", @@ -72,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -84,27 +84,27 @@ "---------- assistant_agent ----------\n", "[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=None)]\n", "---------- assistant_agent ----------\n", - "[FunctionCall(id='call_FS2npPqrZOtv0391uuf0THxu', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')]\n", + "[FunctionCall(id='call_NR8vBXk0856yl9eYa8SMjYbo', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')]\n", "[Prompt tokens: 123, Completion tokens: 20]\n", "---------- assistant_agent ----------\n", - "[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_FS2npPqrZOtv0391uuf0THxu')]\n", + "[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', call_id='call_NR8vBXk0856yl9eYa8SMjYbo')]\n", "---------- assistant_agent ----------\n", - "The weather in New York is 23 degrees and Sunny.\n", + "The weather in New York is 23 °C and Sunny.\n", "---------- Summary ----------\n", "Number of messages: 5\n", "Finish reason: None\n", "Total prompt tokens: 123\n", "Total completion tokens: 20\n", - "Duration: 0.87 seconds\n" + "Duration: 1.27 seconds\n" ] }, { "data": { "text/plain": [ - "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=None)], type='MemoryQueryEvent'), ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=123, completion_tokens=20), content=[FunctionCall(id='call_FS2npPqrZOtv0391uuf0THxu', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='assistant_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_FS2npPqrZOtv0391uuf0THxu')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='assistant_agent', models_usage=None, content='The weather in New York is 23 degrees and Sunny.', type='ToolCallSummaryMessage')], stop_reason=None)" + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=None)], type='MemoryQueryEvent'), ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=123, completion_tokens=20), content=[FunctionCall(id='call_NR8vBXk0856yl9eYa8SMjYbo', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='assistant_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', call_id='call_NR8vBXk0856yl9eYa8SMjYbo')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='assistant_agent', models_usage=None, content='The weather in New York is 23 °C and Sunny.', type='ToolCallSummaryMessage')], stop_reason=None)" ] }, - "execution_count": 4, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -124,7 +124,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -132,11 +132,11 @@ "text/plain": [ "[UserMessage(content='What is the weather in New York?', source='user', type='UserMessage'),\n", " SystemMessage(content='\\nRelevant memory content (in chronological order):\\n1. The weather should be in metric units\\n2. Meal recipe must be vegan\\n', type='SystemMessage'),\n", - " AssistantMessage(content=[FunctionCall(id='call_FS2npPqrZOtv0391uuf0THxu', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], source='assistant_agent', type='AssistantMessage'),\n", - " FunctionExecutionResultMessage(content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_FS2npPqrZOtv0391uuf0THxu')], type='FunctionExecutionResultMessage')]" + " AssistantMessage(content=[FunctionCall(id='call_uvKugIKWzeCYK1px49HJhlku', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], source='assistant_agent', type='AssistantMessage'),\n", + " FunctionExecutionResultMessage(content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_uvKugIKWzeCYK1px49HJhlku')], type='FunctionExecutionResultMessage')]" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -156,7 +156,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -168,63 +168,51 @@ "---------- assistant_agent ----------\n", "[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=None)]\n", "---------- assistant_agent ----------\n", - "**Vegan Vegetable Broth Recipe**\n", + "Here's a simple vegan recipe for a vegetable broth soup:\n", + "\n", + "**Vegan Vegetable Broth Soup**\n", "\n", "**Ingredients:**\n", - "- 2 tablespoons olive oil\n", - "- 1 onion, roughly chopped\n", - "- 2 carrots, peeled and roughly chopped\n", - "- 2 stalks of celery, roughly chopped\n", - "- 2 cloves of garlic, smashed\n", - "- 8 cups of water\n", - "- 1 bay leaf\n", - "- 1/2 teaspoon whole peppercorns\n", - "- 1 teaspoon salt, or to taste\n", - "- A small bunch of fresh herbs (e.g., thyme, parsley), tied together\n", - "- 1 large tomato, quartered (optional)\n", + "- 8 cups vegetable broth\n", + "- 2 carrots, chopped\n", + "- 2 celery stalks, chopped\n", + "- 1 onion, diced\n", + "- 3 cloves garlic, minced\n", + "- 1 zucchini, chopped\n", + "- 1 cup green beans, trimmed and halved\n", + "- 1 cup chopped kale\n", + "- 1 teaspoon dried thyme\n", + "- 1 teaspoon dried basil\n", + "- Salt and pepper to taste\n", "\n", "**Instructions:**\n", + "1. In a large pot, heat a splash of vegetable broth over medium heat. Add the onions and garlic and sauté until the onions are translucent.\n", + "2. Add the carrots, celery, zucchini, and green beans, and sauté for another 5 minutes.\n", + "3. Pour in the remaining vegetable broth and bring the mixture to a gentle boil.\n", + "4. Stir in the thyme, basil, salt, and pepper. Reduce the heat to a simmer and let the soup cook for about 25-30 minutes, or until the vegetables are tender.\n", + "5. Add the chopped kale and cook for an additional 5 minutes.\n", + "6. Taste and adjust the seasoning if needed.\n", + "7. Serve hot as a comforting and nourishing meal.\n", "\n", - "1. **Sauté the Vegetables:**\n", - " - In a large pot, heat olive oil over medium heat.\n", - " - Add onions, carrots, and celery. Sauté for about 5 minutes or until the vegetables start to soften.\n", - " - Add garlic and sauté for another minute.\n", - "\n", - "2. **Prepare the Broth:**\n", - " - Add water to the pot and stir in bay leaf, peppercorns, salt, and herbs. Add the tomato if using.\n", - " - Bring the mixture to a simmer over medium-high heat.\n", - "\n", - "3. **Simmer:**\n", - " - Once simmering, reduce heat to low.\n", - " - Cover and let the broth simmer for at least 45 minutes, allowing the flavors to meld.\n", + "Enjoy your delicious vegan vegetable broth soup! \n", "\n", - "4. **Strain:**\n", - " - Remove the pot from heat. Let it cool slightly.\n", - " - Strain the broth through a fine mesh sieve or cheesecloth into another pot or heat-proof container to remove solids.\n", - " - Press down on the cooked vegetables lightly with a spoon to extract as much liquid as possible.\n", - "\n", - "5. **Serve or Store:**\n", - " - Taste and adjust seasoning if necessary.\n", - " - Use immediately as a base for soups, stews, or enjoy it as is.\n", - " - Store leftovers in an airtight container in the refrigerator for up to a week, or freeze for longer storage.\n", - "\n", - "Enjoy your delicious, homemade vegan vegetable broth! TERMINATE\n", - "[Prompt tokens: 207, Completion tokens: 409]\n", + "TERMINATE\n", + "[Prompt tokens: 207, Completion tokens: 271]\n", "---------- Summary ----------\n", "Number of messages: 3\n", "Finish reason: None\n", "Total prompt tokens: 207\n", - "Total completion tokens: 409\n", - "Duration: 12.23 seconds\n" + "Total completion tokens: 271\n", + "Duration: 6.22 seconds\n" ] }, { "data": { "text/plain": [ - "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Write brief meal recipe with broth', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=None)], type='MemoryQueryEvent'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=207, completion_tokens=409), content='**Vegan Vegetable Broth Recipe**\\n\\n**Ingredients:**\\n- 2 tablespoons olive oil\\n- 1 onion, roughly chopped\\n- 2 carrots, peeled and roughly chopped\\n- 2 stalks of celery, roughly chopped\\n- 2 cloves of garlic, smashed\\n- 8 cups of water\\n- 1 bay leaf\\n- 1/2 teaspoon whole peppercorns\\n- 1 teaspoon salt, or to taste\\n- A small bunch of fresh herbs (e.g., thyme, parsley), tied together\\n- 1 large tomato, quartered (optional)\\n\\n**Instructions:**\\n\\n1. **Sauté the Vegetables:**\\n - In a large pot, heat olive oil over medium heat.\\n - Add onions, carrots, and celery. Sauté for about 5 minutes or until the vegetables start to soften.\\n - Add garlic and sauté for another minute.\\n\\n2. **Prepare the Broth:**\\n - Add water to the pot and stir in bay leaf, peppercorns, salt, and herbs. Add the tomato if using.\\n - Bring the mixture to a simmer over medium-high heat.\\n\\n3. **Simmer:**\\n - Once simmering, reduce heat to low.\\n - Cover and let the broth simmer for at least 45 minutes, allowing the flavors to meld.\\n\\n4. **Strain:**\\n - Remove the pot from heat. Let it cool slightly.\\n - Strain the broth through a fine mesh sieve or cheesecloth into another pot or heat-proof container to remove solids.\\n - Press down on the cooked vegetables lightly with a spoon to extract as much liquid as possible.\\n\\n5. **Serve or Store:**\\n - Taste and adjust seasoning if necessary.\\n - Use immediately as a base for soups, stews, or enjoy it as is.\\n - Store leftovers in an airtight container in the refrigerator for up to a week, or freeze for longer storage.\\n\\nEnjoy your delicious, homemade vegan vegetable broth! TERMINATE', type='TextMessage')], stop_reason=None)" + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Write brief meal recipe with broth', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=None)], type='MemoryQueryEvent'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=207, completion_tokens=271), content=\"Here's a simple vegan recipe for a vegetable broth soup:\\n\\n**Vegan Vegetable Broth Soup**\\n\\n**Ingredients:**\\n- 8 cups vegetable broth\\n- 2 carrots, chopped\\n- 2 celery stalks, chopped\\n- 1 onion, diced\\n- 3 cloves garlic, minced\\n- 1 zucchini, chopped\\n- 1 cup green beans, trimmed and halved\\n- 1 cup chopped kale\\n- 1 teaspoon dried thyme\\n- 1 teaspoon dried basil\\n- Salt and pepper to taste\\n\\n**Instructions:**\\n1. In a large pot, heat a splash of vegetable broth over medium heat. Add the onions and garlic and sauté until the onions are translucent.\\n2. Add the carrots, celery, zucchini, and green beans, and sauté for another 5 minutes.\\n3. Pour in the remaining vegetable broth and bring the mixture to a gentle boil.\\n4. Stir in the thyme, basil, salt, and pepper. Reduce the heat to a simmer and let the soup cook for about 25-30 minutes, or until the vegetables are tender.\\n5. Add the chopped kale and cook for an additional 5 minutes.\\n6. Taste and adjust the seasoning if needed.\\n7. Serve hot as a comforting and nourishing meal.\\n\\nEnjoy your delicious vegan vegetable broth soup! \\n\\nTERMINATE\", type='TextMessage')], stop_reason=None)" ] }, - "execution_count": 6, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } diff --git a/python/packages/autogen-core/src/autogen_core/memory/_base_memory.py b/python/packages/autogen-core/src/autogen_core/memory/_base_memory.py index f8b1f33ec691..5c2681b992fe 100644 --- a/python/packages/autogen-core/src/autogen_core/memory/_base_memory.py +++ b/python/packages/autogen-core/src/autogen_core/memory/_base_memory.py @@ -59,7 +59,7 @@ async def update_context( async def query( self, - query: MemoryContent, + query: str | MemoryContent, cancellation_token: CancellationToken | None = None, **kwargs: Any, ) -> List[MemoryContent]: diff --git a/python/packages/autogen-core/src/autogen_core/memory/_list_memory.py b/python/packages/autogen-core/src/autogen_core/memory/_list_memory.py index 4be328cc8bae..b4d3e19016f7 100644 --- a/python/packages/autogen-core/src/autogen_core/memory/_list_memory.py +++ b/python/packages/autogen-core/src/autogen_core/memory/_list_memory.py @@ -10,43 +10,77 @@ class ListMemory(Memory): """Simple chronological list-based memory implementation. This memory implementation stores contents in a list and retrieves them in - chronological order. It has an `update_context` method that updates model contexts by appending all stored - memories, limited by the configured maximum number of memories. + chronological order. It has an `update_context` method that updates model contexts + by appending all stored memories. + + The memory content can be directly accessed and modified through the content property, + allowing external applications to manage memory contents directly. Example: - ```python - # Initialize memory with custom config - memory = ListMemory(name="chat_history", config=ListMemoryConfig(max_memories=5)) + .. code-block:: python + # Initialize memory + memory = ListMemory(name="chat_history") + + # Add memory content + content = MemoryContent(content="User prefers formal language") + await memory.add(content) + + # Directly modify memory contents + memory.content = [MemoryContent(content="New preference")] - # Add memory content - content = MemoryContent(content="User prefers formal language") - await memory.add(content) + # Update a model context with memory + memory_contents = await memory.update_context(model_context) - # Update a model context with memory - memory_contents = await memory.update_context(model_context) - ``` Attributes: name (str): Identifier for this memory instance - config (ListMemoryConfig): Configuration controlling memory behavior + content (List[MemoryContent]): Direct access to memory contents """ - def __init__(self, name: str | None = None, max_memories: int = 5) -> None: + def __init__(self, name: str | None = None) -> None: + """Initialize ListMemory. + + Args: + name: Optional identifier for this memory instance + """ self._name = name or "default_list_memory" - self._max_memories = max_memories self._contents: List[MemoryContent] = [] @property def name(self) -> str: + """Get the memory instance identifier. + + Returns: + str: Memory instance name + """ return self._name + @property + def content(self) -> List[MemoryContent]: + """Get the current memory contents. + + Returns: + List[MemoryContent]: List of stored memory contents + """ + return self._contents + + @content.setter + def content(self, value: List[MemoryContent]) -> None: + """Set the memory contents. + + Args: + value: New list of memory contents to store + """ + self._contents = value + async def update_context( self, model_context: ChatCompletionContext, ) -> List[MemoryContent]: - """Update the model context by appending recent memory content. + """Update the model context by appending memory content. - This method mutates the provided model_context by adding the most recent memories (as a :class:`SystemMessage`), up to the configured maximum number of memories. + This method mutates the provided model_context by adding all memories as a + SystemMessage. Args: model_context: The context to update. Will be mutated if memories exist. @@ -57,29 +91,21 @@ async def update_context( if not self._contents: return [] - # Get the most recent memories up to max_memories - recent_memories = self._contents[-self._max_memories :] - - # Format memories into a string - memory_strings: List[str] = [] - for i, memory in enumerate(recent_memories, 1): - content = memory.content if isinstance(memory.content, str) else str(memory.content) - memory_strings.append(f"{i}. {content}") + memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(self._contents, 1)] - # Add memories to context if there are any if memory_strings: memory_context = "\nRelevant memory content (in chronological order):\n" + "\n".join(memory_strings) + "\n" await model_context.add_message(SystemMessage(content=memory_context)) - return recent_memories + return self._contents async def query( self, - query: MemoryContent, + query: str | MemoryContent = "", cancellation_token: CancellationToken | None = None, **kwargs: Any, ) -> List[MemoryContent]: - """Return most recent memories without any filtering. + """Return all memories without any filtering. Args: query: Ignored in this implementation @@ -87,10 +113,10 @@ async def query( **kwargs: Additional parameters (ignored) Returns: - List[MemoryContent]: Most recent memories up to max_memories limit + List[MemoryContent]: All stored memories """ - _ = query - return self._contents[-self._max_memories :] + _ = query, cancellation_token, kwargs + return self._contents async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None: """Add new content to memory. diff --git a/python/packages/autogen-core/tests/test_memory.py b/python/packages/autogen-core/tests/test_memory.py index c721725437cf..883a60b92dd0 100644 --- a/python/packages/autogen-core/tests/test_memory.py +++ b/python/packages/autogen-core/tests/test_memory.py @@ -51,7 +51,9 @@ class InvalidMemory: def test_list_memory_basic_properties() -> None: """Test basic properties of ListMemory.""" - memory = ListMemory(name="test_memory", max_memories=3) + memory = ListMemory( + name="test_memory", + ) assert memory.name == "test_memory" assert isinstance(memory, Memory) @@ -74,7 +76,7 @@ async def test_list_memory_empty() -> None: @pytest.mark.asyncio async def test_list_memory_add_and_query() -> None: """Test adding and querying memory contents.""" - memory = ListMemory(max_memories=3) + memory = ListMemory() content1 = MemoryContent(content="test1", mime_type=MemoryMimeType.TEXT, timestamp=datetime.now()) content2 = MemoryContent(content={"key": "value"}, mime_type=MemoryMimeType.JSON, timestamp=datetime.now()) @@ -91,7 +93,7 @@ async def test_list_memory_add_and_query() -> None: @pytest.mark.asyncio async def test_list_memory_max_memories() -> None: """Test max_memories limit is enforced.""" - memory = ListMemory(max_memories=3) + memory = ListMemory() for i in range(5): await memory.add(MemoryContent(content=f"test{i}", mime_type=MemoryMimeType.TEXT)) @@ -104,7 +106,7 @@ async def test_list_memory_max_memories() -> None: @pytest.mark.asyncio async def test_list_memory_update_context() -> None: """Test context updating with memory contents.""" - memory = ListMemory(max_memories=3) + memory = ListMemory() context = BufferedChatCompletionContext(buffer_size=3) await memory.add(MemoryContent(content="test1", mime_type=MemoryMimeType.TEXT)) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/__init__.py index dbe2eb65e045..366ad831175e 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/__init__.py @@ -1,9 +1,9 @@ -from ._openai_client import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient, BaseOpenAIChatCompletionClient +from ._openai_client import AzureOpenAIChatCompletionClient, BaseOpenAIChatCompletionClient, OpenAIChatCompletionClient from .config import ( AzureOpenAIClientConfigurationConfigModel, - OpenAIClientConfigurationConfigModel, BaseOpenAIClientConfigurationConfigModel, CreateArgumentsConfigModel, + OpenAIClientConfigurationConfigModel, ) __all__ = [ From 1ba53815e0f5654c7a66774fb9f337d84fcdc5e6 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Thu, 9 Jan 2025 10:42:36 -0800 Subject: [PATCH 23/26] update test --- python/packages/autogen-core/tests/test_memory.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/packages/autogen-core/tests/test_memory.py b/python/packages/autogen-core/tests/test_memory.py index 883a60b92dd0..7464fd982657 100644 --- a/python/packages/autogen-core/tests/test_memory.py +++ b/python/packages/autogen-core/tests/test_memory.py @@ -99,8 +99,7 @@ async def test_list_memory_max_memories() -> None: await memory.add(MemoryContent(content=f"test{i}", mime_type=MemoryMimeType.TEXT)) results = await memory.query(MemoryContent(content="query", mime_type=MemoryMimeType.TEXT)) - assert len(results) == 3 - assert [r.content for r in results] == ["test2", "test3", "test4"] + assert len(results) == 5 @pytest.mark.asyncio From 7b5b97cd79e120cccddf864e7436f32e3eed24cb Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Thu, 9 Jan 2025 13:33:59 -0800 Subject: [PATCH 24/26] update cells --- .../azure-container-code-executor.ipynb | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/python/packages/autogen-core/docs/src/user-guide/extensions-user-guide/azure-container-code-executor.ipynb b/python/packages/autogen-core/docs/src/user-guide/extensions-user-guide/azure-container-code-executor.ipynb index 8c59d5f11fc6..692a2afda33e 100644 --- a/python/packages/autogen-core/docs/src/user-guide/extensions-user-guide/azure-container-code-executor.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/extensions-user-guide/azure-container-code-executor.ipynb @@ -53,21 +53,9 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "ename": "ImportError", - "evalue": "cannot import name 'CodeBlock' from 'autogen_core.components.code_executor' (unknown location)", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[2], line 6\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01manyio\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m open_file\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mautogen_core\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m CancellationToken\n\u001b[0;32m----> 6\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mautogen_core\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcomponents\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcode_executor\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m CodeBlock\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mautogen_ext\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcode_executor\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01maca_dynamic_sessions\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m AzureContainerCodeExecutor\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mazure\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01midentity\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m DefaultAzureCredential\n", - "\u001b[0;31mImportError\u001b[0m: cannot import name 'CodeBlock' from 'autogen_core.components.code_executor' (unknown location)" - ] - } - ], + "outputs": [], "source": [ "import os\n", "import tempfile\n", From 61bcf34a163f8a03c59a9439cf2fc35f9d2c04f6 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Fri, 10 Jan 2025 21:52:49 -0800 Subject: [PATCH 25/26] add specific extensible return types to memory query and update_context --- .../agents/_assistant_agent.py | 8 +- .../tests/test_assistant_agent.py | 15 +++- .../tutorial/memory.ipynb | 78 +------------------ .../src/autogen_core/memory/__init__.py | 4 +- .../src/autogen_core/memory/_base_memory.py | 16 +++- .../src/autogen_core/memory/_list_memory.py | 17 ++-- .../autogen-core/tests/test_memory.py | 52 ++++++------- .../autogen-ext/test_filesurfer_agent.html | 9 +++ 8 files changed, 79 insertions(+), 120 deletions(-) create mode 100644 python/packages/autogen-ext/test_filesurfer_agent.html diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 15d97ae3b5b3..4a359eded38c 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -342,9 +342,11 @@ async def on_messages_stream( # Update the model context with memory content. if self._memory: for memory in self._memory: - memory_query_result = await memory.update_context(self._model_context) - if memory_query_result and len(memory_query_result) > 0: - memory_query_event_msg = MemoryQueryEvent(content=memory_query_result, source=self.name) + update_context_result = await memory.update_context(self._model_context) + if update_context_result and len(update_context_result.memories.results) > 0: + memory_query_event_msg = MemoryQueryEvent( + content=update_context_result.memories.results, source=self.name + ) inner_messages.append(memory_query_event_msg) yield memory_query_event_msg diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index a36061d15796..930b4f8f7959 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -18,7 +18,7 @@ ToolCallSummaryMessage, ) from autogen_core import Image -from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType +from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult from autogen_core.model_context import BufferedChatCompletionContext from autogen_core.models import LLMMessage from autogen_core.models._model_client import ModelFamily @@ -543,7 +543,7 @@ async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None: empty_context = BufferedChatCompletionContext(buffer_size=2) empty_results = await memory.update_context(empty_context) - assert len(empty_results) == 0 + assert len(empty_results.memories.results) == 0 # Test various content types memory = ListMemory() @@ -551,9 +551,16 @@ async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None: await memory.add(MemoryContent(content={"key": "value"}, mime_type=MemoryMimeType.JSON)) await memory.add(MemoryContent(content=Image.from_base64(b64_image_str), mime_type=MemoryMimeType.IMAGE)) + # Test query functionality + query_result = await memory.query(MemoryContent(content="", mime_type=MemoryMimeType.TEXT)) + assert isinstance(query_result, MemoryQueryResult) + # Should have all three memories we added + assert len(query_result.results) == 3 + # Test clear and cleanup await memory.clear() - assert await memory.query(MemoryContent(content="", mime_type=MemoryMimeType.TEXT)) == [] + empty_query = await memory.query(MemoryContent(content="", mime_type=MemoryMimeType.TEXT)) + assert len(empty_query.results) == 0 await memory.close() # Should not raise # Test invalid memory type @@ -576,6 +583,8 @@ async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None: assert len(result.messages) > 0 memory_event = next((msg for msg in result.messages if isinstance(msg, MemoryQueryEvent)), None) assert memory_event is not None + assert len(memory_event.content) > 0 + assert isinstance(memory_event.content[0], MemoryContent) # Test memory protocol class BadMemory: diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb index aecae54781e2..6a037945b69e 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/memory.ipynb @@ -75,33 +75,10 @@ "execution_count": 3, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "---------- user ----------\n", - "What is the weather in New York?\n", - "---------- assistant_agent ----------\n", - "[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=None)]\n", - "---------- assistant_agent ----------\n", - "[FunctionCall(id='call_NR8vBXk0856yl9eYa8SMjYbo', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')]\n", - "[Prompt tokens: 123, Completion tokens: 20]\n", - "---------- assistant_agent ----------\n", - "[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', call_id='call_NR8vBXk0856yl9eYa8SMjYbo')]\n", - "---------- assistant_agent ----------\n", - "The weather in New York is 23 °C and Sunny.\n", - "---------- Summary ----------\n", - "Number of messages: 5\n", - "Finish reason: None\n", - "Total prompt tokens: 123\n", - "Total completion tokens: 20\n", - "Duration: 1.27 seconds\n" - ] - }, { "data": { "text/plain": [ - "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=None)], type='MemoryQueryEvent'), ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=123, completion_tokens=20), content=[FunctionCall(id='call_NR8vBXk0856yl9eYa8SMjYbo', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='assistant_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', call_id='call_NR8vBXk0856yl9eYa8SMjYbo')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='assistant_agent', models_usage=None, content='The weather in New York is 23 °C and Sunny.', type='ToolCallSummaryMessage')], stop_reason=None)" + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=None)], type='MemoryQueryEvent'), ToolCallRequestEvent(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=123, completion_tokens=20), content=[FunctionCall(id='call_pHq4p89gW6oGjGr3VsVETCYX', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], type='ToolCallRequestEvent'), ToolCallExecutionEvent(source='assistant_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', call_id='call_pHq4p89gW6oGjGr3VsVETCYX')], type='ToolCallExecutionEvent'), ToolCallSummaryMessage(source='assistant_agent', models_usage=None, content='The weather in New York is 23 °C and Sunny.', type='ToolCallSummaryMessage')], stop_reason=None)" ] }, "execution_count": 3, @@ -132,8 +109,8 @@ "text/plain": [ "[UserMessage(content='What is the weather in New York?', source='user', type='UserMessage'),\n", " SystemMessage(content='\\nRelevant memory content (in chronological order):\\n1. The weather should be in metric units\\n2. Meal recipe must be vegan\\n', type='SystemMessage'),\n", - " AssistantMessage(content=[FunctionCall(id='call_uvKugIKWzeCYK1px49HJhlku', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], source='assistant_agent', type='AssistantMessage'),\n", - " FunctionExecutionResultMessage(content=[FunctionExecutionResult(content='The weather in New York is 23 degrees and Sunny.', call_id='call_uvKugIKWzeCYK1px49HJhlku')], type='FunctionExecutionResultMessage')]" + " AssistantMessage(content=[FunctionCall(id='call_pHq4p89gW6oGjGr3VsVETCYX', arguments='{\"city\":\"New York\",\"units\":\"metric\"}', name='get_weather')], source='assistant_agent', type='AssistantMessage'),\n", + " FunctionExecutionResultMessage(content=[FunctionExecutionResult(content='The weather in New York is 23 °C and Sunny.', call_id='call_pHq4p89gW6oGjGr3VsVETCYX')], type='FunctionExecutionResultMessage')]" ] }, "execution_count": 4, @@ -159,57 +136,10 @@ "execution_count": 5, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "---------- user ----------\n", - "Write brief meal recipe with broth\n", - "---------- assistant_agent ----------\n", - "[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=None)]\n", - "---------- assistant_agent ----------\n", - "Here's a simple vegan recipe for a vegetable broth soup:\n", - "\n", - "**Vegan Vegetable Broth Soup**\n", - "\n", - "**Ingredients:**\n", - "- 8 cups vegetable broth\n", - "- 2 carrots, chopped\n", - "- 2 celery stalks, chopped\n", - "- 1 onion, diced\n", - "- 3 cloves garlic, minced\n", - "- 1 zucchini, chopped\n", - "- 1 cup green beans, trimmed and halved\n", - "- 1 cup chopped kale\n", - "- 1 teaspoon dried thyme\n", - "- 1 teaspoon dried basil\n", - "- Salt and pepper to taste\n", - "\n", - "**Instructions:**\n", - "1. In a large pot, heat a splash of vegetable broth over medium heat. Add the onions and garlic and sauté until the onions are translucent.\n", - "2. Add the carrots, celery, zucchini, and green beans, and sauté for another 5 minutes.\n", - "3. Pour in the remaining vegetable broth and bring the mixture to a gentle boil.\n", - "4. Stir in the thyme, basil, salt, and pepper. Reduce the heat to a simmer and let the soup cook for about 25-30 minutes, or until the vegetables are tender.\n", - "5. Add the chopped kale and cook for an additional 5 minutes.\n", - "6. Taste and adjust the seasoning if needed.\n", - "7. Serve hot as a comforting and nourishing meal.\n", - "\n", - "Enjoy your delicious vegan vegetable broth soup! \n", - "\n", - "TERMINATE\n", - "[Prompt tokens: 207, Completion tokens: 271]\n", - "---------- Summary ----------\n", - "Number of messages: 3\n", - "Finish reason: None\n", - "Total prompt tokens: 207\n", - "Total completion tokens: 271\n", - "Duration: 6.22 seconds\n" - ] - }, { "data": { "text/plain": [ - "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Write brief meal recipe with broth', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=None)], type='MemoryQueryEvent'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=207, completion_tokens=271), content=\"Here's a simple vegan recipe for a vegetable broth soup:\\n\\n**Vegan Vegetable Broth Soup**\\n\\n**Ingredients:**\\n- 8 cups vegetable broth\\n- 2 carrots, chopped\\n- 2 celery stalks, chopped\\n- 1 onion, diced\\n- 3 cloves garlic, minced\\n- 1 zucchini, chopped\\n- 1 cup green beans, trimmed and halved\\n- 1 cup chopped kale\\n- 1 teaspoon dried thyme\\n- 1 teaspoon dried basil\\n- Salt and pepper to taste\\n\\n**Instructions:**\\n1. In a large pot, heat a splash of vegetable broth over medium heat. Add the onions and garlic and sauté until the onions are translucent.\\n2. Add the carrots, celery, zucchini, and green beans, and sauté for another 5 minutes.\\n3. Pour in the remaining vegetable broth and bring the mixture to a gentle boil.\\n4. Stir in the thyme, basil, salt, and pepper. Reduce the heat to a simmer and let the soup cook for about 25-30 minutes, or until the vegetables are tender.\\n5. Add the chopped kale and cook for an additional 5 minutes.\\n6. Taste and adjust the seasoning if needed.\\n7. Serve hot as a comforting and nourishing meal.\\n\\nEnjoy your delicious vegan vegetable broth soup! \\n\\nTERMINATE\", type='TextMessage')], stop_reason=None)" + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Write brief meal recipe with broth', type='TextMessage'), MemoryQueryEvent(source='assistant_agent', models_usage=None, content=[MemoryContent(content='The weather should be in metric units', mime_type=, metadata=None, timestamp=None, source=None, score=None), MemoryContent(content='Meal recipe must be vegan', mime_type=, metadata=None, timestamp=None, source=None, score=None)], type='MemoryQueryEvent'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=208, completion_tokens=253), content=\"Here's a brief vegan meal recipe using broth:\\n\\n**Vegan Mushroom & Herb Broth Soup**\\n\\n**Ingredients:**\\n- 1 tablespoon olive oil\\n- 1 onion, diced\\n- 2 cloves garlic, minced\\n- 250g mushrooms, sliced\\n- 1 carrot, diced\\n- 1 celery stalk, diced\\n- 4 cups vegetable broth\\n- 1 teaspoon thyme\\n- 1 teaspoon rosemary\\n- Salt and pepper to taste\\n- Fresh parsley for garnish\\n\\n**Instructions:**\\n1. Heat the olive oil in a large pot over medium heat. Add the diced onion and garlic, and sauté until the onion becomes translucent.\\n\\n2. Add the sliced mushrooms, carrot, and celery. Continue to sauté until the mushrooms are cooked through and the vegetables begin to soften, about 5 minutes.\\n\\n3. Pour in the vegetable broth. Stir in the thyme and rosemary, and bring the mixture to a boil.\\n\\n4. Reduce the heat to low and let the soup simmer for about 15 minutes, allowing the flavors to meld together.\\n\\n5. Season with salt and pepper to taste.\\n\\n6. Serve hot, garnished with fresh parsley.\\n\\nEnjoy your warm and comforting vegan mushroom & herb broth soup! \\n\\nTERMINATE\", type='TextMessage')], stop_reason=None)" ] }, "execution_count": 5, diff --git a/python/packages/autogen-core/src/autogen_core/memory/__init__.py b/python/packages/autogen-core/src/autogen_core/memory/__init__.py index 7e36884ecb0a..69a20f24f530 100644 --- a/python/packages/autogen-core/src/autogen_core/memory/__init__.py +++ b/python/packages/autogen-core/src/autogen_core/memory/__init__.py @@ -1,9 +1,11 @@ -from ._base_memory import Memory, MemoryContent, MemoryMimeType +from ._base_memory import Memory, MemoryContent, MemoryMimeType, MemoryQueryResult, UpdateContextResult from ._list_memory import ListMemory __all__ = [ "Memory", "MemoryContent", + "MemoryQueryResult", + "UpdateContextResult", "MemoryMimeType", "ListMemory", ] diff --git a/python/packages/autogen-core/src/autogen_core/memory/_base_memory.py b/python/packages/autogen-core/src/autogen_core/memory/_base_memory.py index 5c2681b992fe..0ef7ac716ccf 100644 --- a/python/packages/autogen-core/src/autogen_core/memory/_base_memory.py +++ b/python/packages/autogen-core/src/autogen_core/memory/_base_memory.py @@ -33,6 +33,14 @@ class MemoryContent(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) +class MemoryQueryResult(BaseModel): + results: List[MemoryContent] + + +class UpdateContextResult(BaseModel): + memories: MemoryQueryResult + + @runtime_checkable class Memory(Protocol): """Protocol defining the interface for memory implementations.""" @@ -45,7 +53,7 @@ def name(self) -> str | None: async def update_context( self, model_context: ChatCompletionContext, - ) -> List[MemoryContent]: + ) -> UpdateContextResult: """ Update the provided model context using relevant memory content. @@ -53,7 +61,7 @@ async def update_context( model_context: The context to update. Returns: - List of memory entries with relevance scores + UpdateContextResult containing relevant memories """ ... @@ -62,7 +70,7 @@ async def query( query: str | MemoryContent, cancellation_token: CancellationToken | None = None, **kwargs: Any, - ) -> List[MemoryContent]: + ) -> MemoryQueryResult: """ Query the memory store and return relevant entries. @@ -72,7 +80,7 @@ async def query( **kwargs: Additional implementation-specific parameters Returns: - List of memory entries with relevance scores + MemoryQueryResult containing memory entries with relevance scores """ ... diff --git a/python/packages/autogen-core/src/autogen_core/memory/_list_memory.py b/python/packages/autogen-core/src/autogen_core/memory/_list_memory.py index b4d3e19016f7..eda206783476 100644 --- a/python/packages/autogen-core/src/autogen_core/memory/_list_memory.py +++ b/python/packages/autogen-core/src/autogen_core/memory/_list_memory.py @@ -3,7 +3,7 @@ from .._cancellation_token import CancellationToken from ..model_context import ChatCompletionContext from ..models import SystemMessage -from ._base_memory import Memory, MemoryContent +from ._base_memory import Memory, MemoryContent, MemoryQueryResult, UpdateContextResult class ListMemory(Memory): @@ -76,7 +76,7 @@ def content(self, value: List[MemoryContent]) -> None: async def update_context( self, model_context: ChatCompletionContext, - ) -> List[MemoryContent]: + ) -> UpdateContextResult: """Update the model context by appending memory content. This method mutates the provided model_context by adding all memories as a @@ -86,10 +86,11 @@ async def update_context( model_context: The context to update. Will be mutated if memories exist. Returns: - List[MemoryContent]: List of memories that were added to the context + UpdateContextResult containing the memories that were added to the context """ + if not self._contents: - return [] + return UpdateContextResult(memories=MemoryQueryResult(results=[])) memory_strings = [f"{i}. {str(memory.content)}" for i, memory in enumerate(self._contents, 1)] @@ -97,14 +98,14 @@ async def update_context( memory_context = "\nRelevant memory content (in chronological order):\n" + "\n".join(memory_strings) + "\n" await model_context.add_message(SystemMessage(content=memory_context)) - return self._contents + return UpdateContextResult(memories=MemoryQueryResult(results=self._contents)) async def query( self, query: str | MemoryContent = "", cancellation_token: CancellationToken | None = None, **kwargs: Any, - ) -> List[MemoryContent]: + ) -> MemoryQueryResult: """Return all memories without any filtering. Args: @@ -113,10 +114,10 @@ async def query( **kwargs: Additional parameters (ignored) Returns: - List[MemoryContent]: All stored memories + MemoryQueryResult containing all stored memories """ _ = query, cancellation_token, kwargs - return self._contents + return MemoryQueryResult(results=self._contents) async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None: """Add new content to memory. diff --git a/python/packages/autogen-core/tests/test_memory.py b/python/packages/autogen-core/tests/test_memory.py index 7464fd982657..41c0d6657fa9 100644 --- a/python/packages/autogen-core/tests/test_memory.py +++ b/python/packages/autogen-core/tests/test_memory.py @@ -1,14 +1,21 @@ from datetime import datetime -from typing import List import pytest from autogen_core import CancellationToken -from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType +from autogen_core.memory import ( + ListMemory, + Memory, + MemoryContent, + MemoryMimeType, + MemoryQueryResult, + UpdateContextResult, +) from autogen_core.model_context import BufferedChatCompletionContext, ChatCompletionContext def test_memory_protocol_attributes() -> None: """Test that Memory protocol has all required attributes.""" + # No changes needed here assert hasattr(Memory, "name") assert hasattr(Memory, "update_context") assert hasattr(Memory, "query") @@ -25,13 +32,13 @@ class ValidMemory: def name(self) -> str: return "test" - async def update_context(self, context: ChatCompletionContext) -> List[MemoryContent]: - return [] + async def update_context(self, context: ChatCompletionContext) -> UpdateContextResult: + return UpdateContextResult(memories=MemoryQueryResult(results=[])) async def query( self, query: MemoryContent, cancellation_token: CancellationToken | None = None - ) -> List[MemoryContent]: - return [] + ) -> MemoryQueryResult: + return MemoryQueryResult(results=[]) async def add(self, content: MemoryContent, cancellation_token: CancellationToken | None = None) -> None: pass @@ -49,15 +56,6 @@ class InvalidMemory: assert not isinstance(InvalidMemory(), Memory) -def test_list_memory_basic_properties() -> None: - """Test basic properties of ListMemory.""" - memory = ListMemory( - name="test_memory", - ) - assert memory.name == "test_memory" - assert isinstance(memory, Memory) - - @pytest.mark.asyncio async def test_list_memory_empty() -> None: """Test ListMemory behavior when empty.""" @@ -66,11 +64,11 @@ async def test_list_memory_empty() -> None: results = await memory.update_context(context) context_messages = await context.get_messages() - assert len(results) == 0 + assert len(results.memories.results) == 0 assert len(context_messages) == 0 query_results = await memory.query(MemoryContent(content="test", mime_type=MemoryMimeType.TEXT)) - assert len(query_results) == 0 + assert len(query_results.results) == 0 @pytest.mark.asyncio @@ -85,9 +83,9 @@ async def test_list_memory_add_and_query() -> None: await memory.add(content2) results = await memory.query(MemoryContent(content="query", mime_type=MemoryMimeType.TEXT)) - assert len(results) == 2 - assert results[0].content == "test1" - assert results[1].content == {"key": "value"} + assert len(results.results) == 2 + assert results.results[0].content == "test1" + assert results.results[1].content == {"key": "value"} @pytest.mark.asyncio @@ -99,7 +97,7 @@ async def test_list_memory_max_memories() -> None: await memory.add(MemoryContent(content=f"test{i}", mime_type=MemoryMimeType.TEXT)) results = await memory.query(MemoryContent(content="query", mime_type=MemoryMimeType.TEXT)) - assert len(results) == 5 + assert len(results.results) == 5 @pytest.mark.asyncio @@ -113,7 +111,7 @@ async def test_list_memory_update_context() -> None: results = await memory.update_context(context) context_messages = await context.get_messages() - assert len(results) == 2 + assert len(results.memories.results) == 2 assert len(context_messages) == 1 assert "test1" in context_messages[0].content assert "test2" in context_messages[0].content @@ -127,7 +125,7 @@ async def test_list_memory_clear() -> None: await memory.clear() results = await memory.query(MemoryContent(content="query", mime_type=MemoryMimeType.TEXT)) - assert len(results) == 0 + assert len(results.results) == 0 @pytest.mark.asyncio @@ -143,7 +141,7 @@ async def test_list_memory_content_types() -> None: await memory.add(binary_content) results = await memory.query(text_content) - assert len(results) == 3 - assert isinstance(results[0].content, str) - assert isinstance(results[1].content, dict) - assert isinstance(results[2].content, bytes) + assert len(results.results) == 3 + assert isinstance(results.results[0].content, str) + assert isinstance(results.results[1].content, dict) + assert isinstance(results.results[2].content, bytes) diff --git a/python/packages/autogen-ext/test_filesurfer_agent.html b/python/packages/autogen-ext/test_filesurfer_agent.html new file mode 100644 index 000000000000..8243435009e5 --- /dev/null +++ b/python/packages/autogen-ext/test_filesurfer_agent.html @@ -0,0 +1,9 @@ + + + FileSurfer test file + + +

FileSurfer test H1

+

FileSurfer test body

+ + \ No newline at end of file From cb3b05149d65ffe19beb44f53c9126e5ba99d641 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Mon, 13 Jan 2025 23:01:50 -0800 Subject: [PATCH 26/26] format update --- .../autogen-agentchat/src/autogen_agentchat/messages.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index 4e1ef45eb852..6069c8ddc8dd 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -126,15 +126,14 @@ class MemoryQueryEvent(BaseAgentEvent): ChatMessage = Annotated[ - TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field( - discriminator="type") + TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type") ] """Messages for agent-to-agent communication only.""" AgentEvent = Annotated[ - ToolCallRequestEvent | ToolCallExecutionEvent | MemoryQueryEvent | UserInputRequestedEvent, Field( - discriminator="type") + ToolCallRequestEvent | ToolCallExecutionEvent | MemoryQueryEvent | UserInputRequestedEvent, + Field(discriminator="type"), ] """Events emitted by agents and teams when they work, not used for agent-to-agent communication."""