-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add crew tests * add simple task tests * add tests for Task class * add task unit tests * add autogen tool conversation test * add test agent chain * add test tool chain * Use different react langchain agent in tests for speed --------- Co-authored-by: User <[email protected]> Co-authored-by: whimo <[email protected]>
- Loading branch information
1 parent
ccd2bc6
commit c796987
Showing
11 changed files
with
275 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
import os | ||
|
||
os.environ["OPENAI_API_KEY"] = "YOUR OPENAI API KEY" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
import os | ||
|
||
os.environ["OPENAI_API_KEY"] = "YOUR OPENAI API KEY" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import pytest | ||
|
||
from motleycrew.agents.parent import MotleyAgentParent | ||
|
||
|
||
class AgentMock: | ||
def invoke(self, input_dict: dict, *args, **kwargs): | ||
return input_dict | ||
|
||
|
||
class MotleyAgentMock(MotleyAgentParent): | ||
|
||
def invoke(self, *args, **kwargs): | ||
self.materialize() | ||
return self.agent.invoke(*args, **kwargs) | ||
|
||
|
||
def agent_factory(*args, **kwargs): | ||
return AgentMock() | ||
|
||
|
||
@pytest.fixture | ||
def motley_agents(): | ||
agent1 = MotleyAgentMock("agent1 description", agent_factory=agent_factory) | ||
agent2 = MotleyAgentMock("agent2 description", agent_factory=agent_factory) | ||
return [agent1, agent2] | ||
|
||
|
||
def test_agent_chain(motley_agents): | ||
agent1, agent2 = motley_agents | ||
agent_chain = agent1 | agent2 | ||
assert hasattr(agent_chain, "invoke") | ||
prompt = {"prompt": "test_prompt"} | ||
assert agent_chain.invoke(prompt) == prompt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import pytest | ||
|
||
from motleycrew.crew import MotleyCrew | ||
from motleycrew.tasks.simple import SimpleTask, SimpleTaskUnit | ||
|
||
|
||
class AgentMock: | ||
def invoke(self, input_dict) -> str: | ||
clear_dict = self.clear_input_dict(input_dict) | ||
return str(clear_dict) | ||
|
||
@staticmethod | ||
def clear_input_dict(input_dict: dict) -> dict: | ||
clear_dict = {} | ||
for param in ("name", "prompt"): | ||
value = input_dict.get(param, None) | ||
if value is not None: | ||
clear_dict[param] = value | ||
return clear_dict | ||
|
||
|
||
class TestCrew: | ||
num_task = 0 | ||
|
||
@pytest.fixture(scope="class") | ||
def crew(self): | ||
obj = MotleyCrew() | ||
return obj | ||
|
||
@pytest.fixture | ||
def agent(self): | ||
return AgentMock() | ||
|
||
@pytest.fixture | ||
def tasks(self, request, crew, agent): | ||
num_tasks = request.param or 1 | ||
tasks = [] | ||
for i in range(num_tasks): | ||
description = "task{} description".format(self.num_task) | ||
tasks.append(crew.create_simple_task(description=description, agent=agent)) | ||
TestCrew.num_task += 1 | ||
return tasks | ||
|
||
def test_create_simple_task(self, crew, agent): | ||
assert len(crew.tasks) == 0 | ||
simple_task = crew.create_simple_task(description="task description", agent=agent) | ||
assert isinstance(simple_task, SimpleTask) | ||
assert len(crew.tasks) == 1 | ||
node = simple_task.node | ||
assert crew.graph_store.get_node_by_class_and_id(type(node), node.id) == node | ||
|
||
@pytest.mark.parametrize("tasks", [2], indirect=True) | ||
def test_add_dependency(self, crew, tasks): | ||
task1, task2 = tasks | ||
crew.add_dependency(task1, task2) | ||
assert crew.graph_store.check_relation_exists(task1.node, task2.node) | ||
|
||
@pytest.mark.parametrize("tasks", [1], indirect=True) | ||
def test_register_added_task(self, tasks, crew): | ||
task = tasks[0] | ||
len_tasks = len(crew.tasks) | ||
crew.register_tasks([task]) | ||
assert len(crew.tasks) == len_tasks | ||
|
||
def test_get_available_task(self, crew): | ||
tasks = crew.get_available_tasks() | ||
assert len(tasks) == 3 | ||
|
||
def test_get_extra_tools(self, crew): | ||
tasks = crew.get_available_tasks() | ||
assert not crew.get_extra_tools(tasks[-1]) | ||
|
||
def test_run(self, crew, agent): | ||
available_tasks = crew.get_available_tasks() | ||
crew.run() | ||
for task in crew.tasks: | ||
assert task.done | ||
assert task.node.done | ||
unit = SimpleTaskUnit( | ||
name=task.name, | ||
prompt=task.description, | ||
) | ||
if task in available_tasks: | ||
assert agent.invoke(unit.as_dict()) == task.output | ||
else: | ||
assert agent.invoke(unit.as_dict()) != task.output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import pytest | ||
|
||
from langchain_community.tools import DuckDuckGoSearchRun | ||
from motleycrew.crew import MotleyCrew | ||
from motleycrew.agents.langchain.openai_tools_react import ReactOpenAIToolsAgent | ||
from motleycrew.tasks.simple import ( | ||
SimpleTask, | ||
SimpleTaskUnit, | ||
compose_simple_task_prompt_with_dependencies, | ||
) | ||
|
||
|
||
class TestSimpleTask: | ||
@pytest.fixture(scope="class") | ||
def crew(self): | ||
obj = MotleyCrew() | ||
return obj | ||
|
||
@pytest.fixture(scope="class") | ||
def agent(self): | ||
agent = ReactOpenAIToolsAgent( | ||
name="AI writer agent", | ||
tools=[DuckDuckGoSearchRun()], | ||
verbose=True, | ||
) | ||
return agent | ||
|
||
@pytest.fixture(scope="class") | ||
def tasks(self, crew, agent): | ||
task1 = SimpleTask(crew=crew, description="task1 description", agent=agent) | ||
task2 = SimpleTask(crew=crew, description="task2 description") | ||
crew.register_tasks([task1, task2]) | ||
return [task1, task2] | ||
|
||
def test_register_completed_unit(self, tasks, crew): | ||
task1, task2 = tasks | ||
assert not task1.done | ||
assert task1.output is None | ||
unit = task1.get_next_unit() | ||
unit.output = task1.description | ||
|
||
with pytest.raises(AssertionError): | ||
task1.register_completed_unit(unit) | ||
unit.set_done() | ||
task1.register_completed_unit(unit) | ||
assert task1.done | ||
assert task1.output == unit.output | ||
assert task1.node.done | ||
|
||
def test_get_next_unit(self, tasks, crew): | ||
task1, task2 = tasks | ||
crew.add_dependency(task1, task2) | ||
assert task1.get_next_unit() is None | ||
prompt = compose_simple_task_prompt_with_dependencies(task2.description, task2.get_units()) | ||
expected_unit = SimpleTaskUnit( | ||
name=task2.name, | ||
prompt=prompt, | ||
) | ||
next_unit = task2.get_next_unit() | ||
assert next_unit.prompt == expected_unit.prompt | ||
|
||
def test_get_worker(self, tasks, agent): | ||
task1, task2 = tasks | ||
assert task1.get_worker([]) == agent | ||
with pytest.raises(ValueError): | ||
task2.get_worker([]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import pytest | ||
|
||
from motleycrew.tasks import TaskUnit | ||
|
||
|
||
class TestTaskUnit: | ||
|
||
@pytest.fixture(scope="class") | ||
def unit(self): | ||
return TaskUnit() | ||
|
||
def test_set_pending(self, unit): | ||
unit.set_pending() | ||
assert unit.pending | ||
|
||
def test_set_running(self, unit): | ||
unit.set_running() | ||
assert unit.running | ||
|
||
def test_set_done(self, unit): | ||
unit.set_done() | ||
assert unit.done | ||
|
||
def test_as_dict(self, unit): | ||
assert dict(unit) == unit.as_dict() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import pytest | ||
|
||
from motleycrew.tools import MotleyTool | ||
|
||
|
||
class ToolMock: | ||
def invoke(self, input_dict: dict, *args, **kwargs): | ||
return input_dict | ||
|
||
|
||
@pytest.fixture | ||
def tools(): | ||
tool1 = MotleyTool(ToolMock()) | ||
tool2 = MotleyTool(ToolMock()) | ||
return [tool1, tool2] | ||
|
||
|
||
def test_tool_chain(tools): | ||
tool1, tool2 = tools | ||
tool_chain = tool1 | tool2 | ||
assert hasattr(tool_chain, "invoke") | ||
prompt = {"prompt": "test prompt"} | ||
assert tool_chain.invoke(prompt) == prompt |