Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add autouse memory #1027

Merged
merged 3 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ env = [
'D:MARVIN_AGENT_MODEL=openai:gpt-4o-mini',
'D:MARVIN_AGENT_TEMPERATURE=0.0',
'D:MARVIN_AGENT_RETRIES=3',
'D:MARVIN_MEMORY_PROVIDER=chroma-ephemeral',
'D:MARVIN_LOG_LEVEL=DEBUG',
'D:MARVIN_ENABLE_DEFAULT_PRINT_HANDLER=0',
]
Expand Down
5 changes: 5 additions & 0 deletions src/marvin/agents/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import marvin
import marvin.utilities.asyncio
from marvin.memory.memory import Memory
from marvin.prompts import Template
from marvin.thread import Thread

Expand Down Expand Up @@ -67,6 +68,10 @@ def get_tools(self) -> list[Callable[..., Any]]:
"""A list of tools that this actor can use during its turn."""
return []

def get_memories(self) -> list[Memory]:
"""A list of memories that this actor can use during its turn."""
return []

def get_end_turn_tools(self) -> list[type["marvin.engine.end_turn.EndTurn"]]:
"""A list of `EndTurn` tools that this actor can use to end its turn."""
return []
Expand Down
3 changes: 3 additions & 0 deletions src/marvin/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def get_delegates(self) -> list[Actor] | None:
def get_model(self) -> Model | KnownModelName:
return self.model or marvin.defaults.model

def get_memories(self) -> list[Memory]:
return self.memories

def get_tools(self) -> list[Callable[..., Any]]:
return (
self.tools
Expand Down
10 changes: 6 additions & 4 deletions src/marvin/agents/team.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from marvin.agents.actor import Actor
from marvin.agents.names import TEAM_NAMES
from marvin.engine.end_turn import DelegateToAgent
from marvin.memory.memory import Memory
from marvin.prompts import Template

if TYPE_CHECKING:
Expand All @@ -19,7 +20,6 @@
@dataclass(kw_only=True)
class Team(Actor):
agents: list[Actor]
tools: list[Callable[..., Any]] = field(default_factory=list)
name: str = field(
default_factory=lambda: random.choice(TEAM_NAMES),
metadata={"description": "Name of the team"},
Expand Down Expand Up @@ -63,7 +63,7 @@ def get_agentlet(
**kwargs,
) -> pydantic_ai.Agent[Any, Any]:
return self.active_agent.get_agentlet(
tools=self.tools + self.get_end_turn_tools() + (tools or []),
tools=self.get_end_turn_tools() + (tools or []),
result_types=result_types,
**kwargs,
)
Expand All @@ -72,8 +72,10 @@ def get_prompt(self) -> str:
return Template(source=self.prompt).render(team=self)

def get_tools(self) -> list[Callable[..., Any]]:
tools = self.tools + self.active_agent.get_tools()
return tools
return self.active_agent.get_tools()

def get_memories(self) -> list[Memory]:
return self.active_agent.get_memories()

def get_end_turn_tools(self) -> list[type["EndTurn"]]:
return []
Expand Down
27 changes: 23 additions & 4 deletions src/marvin/engine/orchestrator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import json
import math
from asyncio import CancelledError
from collections.abc import Callable
Expand Down Expand Up @@ -34,9 +35,10 @@
from marvin.engine.handlers import AsyncHandler, Handler
from marvin.engine.print_handler import PrintHandler
from marvin.instructions import get_instructions
from marvin.memory.memory import Memory
from marvin.prompts import Template
from marvin.tasks.task import Task
from marvin.thread import Thread, get_thread
from marvin.thread import Thread, get_thread, message_adapter
from marvin.utilities.logging import get_logger

T = TypeVar("T")
Expand All @@ -58,6 +60,7 @@ class OrchestratorPrompt(Template):
tasks: list[Task[Any]]
instructions: list[str]
end_turn_tools: list[EndTurn]
memories: list[str]


@dataclass(kw_only=True)
Expand Down Expand Up @@ -126,24 +129,40 @@ async def _run_turn(self):

# --- get end turn tools
end_turn_tools = set()

for t in tasks:
end_turn_tools.update(t.get_end_turn_tools())

if self.get_delegates():
end_turn_tools.add(DelegateToAgent)
end_turn_tools.update(self.team.get_end_turn_tools())
end_turn_tools = list(end_turn_tools)

# --- get memories
memories: set[Memory] = set()
for t in tasks:
memories.update(t.memories)
memories.update(self.team.get_memories())
memories = [m for m in memories if m.auto_use]

# --- prepare messages
messages = await self.thread.get_messages_async()

# load auto-use memories
if memories and messages:
query = "\n\n".join(
message_adapter.dump_json(m).decode() for m in messages[-3:]
)
memories = [
json.dumps({m.key: await m.search(query=query, n=3)}) for m in memories
]

orchestrator_prompt = OrchestratorPrompt(
orchestrator=self,
tasks=self.get_all_tasks(),
instructions=get_instructions(),
end_turn_tools=end_turn_tools,
memories=memories,
).render()

messages = await self.thread.get_messages_async()
all_messages = [
marvin.engine.llm.SystemMessage(content=orchestrator_prompt),
] + messages
Expand Down
8 changes: 7 additions & 1 deletion src/marvin/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ class Memory:
default_factory=lambda: marvin.defaults.memory_provider,
repr=False,
)
auto_use: bool = field(
default=False,
metadata={
"description": "If true, the memory will automatically be queried before the agent is run, using the most recent messages.",
},
)

def __hash__(self) -> int:
return id(self)
Expand Down Expand Up @@ -119,7 +125,7 @@ def get_tools(self) -> list[Callable[..., Any]]:
update_fn(
self.search,
name=f"search_memories__{self.key}",
description=f"Search {self.friendly_name()}. {self.instructions or ''}".rstrip(),
description=f"Provide a query string to search {self.friendly_name()}. {self.instructions or ''}".rstrip(),
),
]

Expand Down
2 changes: 1 addition & 1 deletion src/marvin/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def setup_logging(self) -> Self:
# ------------ Memory settings ------------

memory_provider: str = Field(
default="chroma-db",
default="chroma-ephemeral",
description="The default memory provider for agents.",
)

Expand Down
9 changes: 9 additions & 0 deletions src/marvin/templates/orchestrator.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,13 @@
All teams and agents.
{{ pretty_print(orchestrator.get_agent_tree()) | indent(8) }}
</agent-tree>

{% if memories %}
<memories>
The following memories were automatically recalled:
{% for memory in memories %}
<memory> {{ memory | indent(8) }} </memory>
{% endfor %}
</memories>
{% endif %}
</orchestrator>
60 changes: 56 additions & 4 deletions tests/ai/memory/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

def test_agent_memory():
a = marvin.Agent(memories=[marvin.Memory(key="numbers")])
a.say("remember the number 123")
result = a.say("what number did you remember?")
marvin.run("remember the number 123", agents=[a])
result = marvin.run("what number did you remember?", agents=[a])
assert "123" in result


Expand All @@ -26,6 +26,58 @@ def test_instructions():
instructions="when remembering a color, always store it and the word 'house' e.g. 'red house'",
)
a = marvin.Agent(memories=[m])
a.say("remember the color green")
result = a.say("what exactly did you remember?")
marvin.run("remember the color green", agents=[a])
result = marvin.run("what exactly did you remember?", agents=[a])
assert "green" in result and "house" in result


def test_use_memory_as_tool():
m = marvin.Memory(key="colors")
a = marvin.Agent(memories=[m])
marvin.run("remember the color green", agents=[a])
with marvin.Thread() as t:
result = marvin.run("what color did you remember?", agents=[a])

assert "green" in result

# --- check tool call ---
messages = t.get_messages()
found_tool_call = False
for message in messages:
for part in message.parts:
if (
part.part_kind == "tool-call"
and part.tool_name == "search_memories__colors"
):
found_tool_call = True
break
if found_tool_call:
break
assert found_tool_call, "Expected to find a tool call to search_memories__colors"


def test_autouse_memory():
m = marvin.Memory(key="colors", auto_use=True)
a = marvin.Agent(memories=[m])
marvin.run("remember the color green", agents=[a])
with marvin.Thread() as t:
result = marvin.run("what color did you remember?", agents=[a])

assert "green" in result

# --- check tool call did NOT happen---
messages = t.get_messages()
found_tool_call = False
for message in messages:
for part in message.parts:
if (
part.part_kind == "tool-call"
and part.tool_name == "search_memories__colors"
):
found_tool_call = True
break
if found_tool_call:
break
assert not found_tool_call, (
"Expected to not find a tool call to search_memories__colors since it is auto-used"
)
22 changes: 11 additions & 11 deletions tests/basic/memory/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,25 @@


class TestMemory:
def test_store_and_retrieve(self):
async def test_store_and_retrieve(self):
m = marvin.Memory(key="test", instructions="test")
m.add("The number is 42")
result = m.search("numbers")
await m.add("The number is 42")
result = await m.search("numbers")
assert len(result) == 1
assert "The number is 42" in result.values()

def test_delete(self):
async def test_delete(self):
m = marvin.Memory(key="test", instructions="test")
m_id = m.add("The number is 42")
m.delete(m_id)
result = m.search("numbers")
m_id = await m.add("The number is 42")
await m.delete(m_id)
result = await m.search("numbers")
assert len(result) == 0

def test_search(self):
async def test_search(self):
m = marvin.Memory(key="test", instructions="test")
m.add("The number is 42")
m.add("The number is 43")
result = m.search("numbers")
await m.add("The number is 42")
await m.add("The number is 43")
result = await m.search("numbers")
assert len(result) == 2
assert "The number is 42" in result.values()
assert "The number is 43" in result.values()
Expand Down
Loading