Skip to content

Commit

Permalink
add InvokeThreadPool tests
Browse files Browse the repository at this point in the history
  • Loading branch information
User committed Jun 3, 2024
1 parent e8b2932 commit efa7d94
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 38 deletions.
44 changes: 44 additions & 0 deletions tests/test_crew/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest

from motleycrew.crew import MotleyCrew


class AgentMock:
def invoke(self, input_dict) -> str:
clear_dict = self.clear_input_dict(input_dict)
return str(clear_dict)

async def ainvoke(self, input_dict) -> str:
return self.invoke(input_dict)

@staticmethod
def clear_input_dict(input_dict: dict) -> dict:
clear_dict = {}
for param in ("name", "prompt"):
value = input_dict.get(param, None)
if value is not None:
clear_dict[param] = value
return clear_dict


class CrewFixtures:
num_task = 0

@pytest.fixture(scope="class")
def crew(self):
obj = MotleyCrew()
return obj

@pytest.fixture
def agent(self):
return AgentMock()

@pytest.fixture
def tasks(self, request, crew, agent):
num_tasks = request.param or 1
tasks = []
for i in range(num_tasks):
description = "task{} description".format(self.num_task)
tasks.append(crew.create_simple_task(description=description, agent=agent))
CrewFixtures.num_task += 1
return tasks
40 changes: 2 additions & 38 deletions tests/test_agents/test_crew.py → tests/test_crew/test_crew.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,10 @@
from motleycrew.crew import MotleyCrew
from motleycrew.tasks.simple import SimpleTask, SimpleTaskUnit

from tests.test_crew import CrewFixtures

class AgentMock:
def invoke(self, input_dict) -> str:
clear_dict = self.clear_input_dict(input_dict)
return str(clear_dict)

async def ainvoke(self, input_dict) -> str:
return self.invoke(input_dict)

@staticmethod
def clear_input_dict(input_dict: dict) -> dict:
clear_dict = {}
for param in ("name", "prompt"):
value = input_dict.get(param, None)
if value is not None:
clear_dict[param] = value
return clear_dict


class TestCrew:
num_task = 0

@pytest.fixture(scope="class")
def crew(self):
obj = MotleyCrew()
return obj

@pytest.fixture
def agent(self):
return AgentMock()

@pytest.fixture
def tasks(self, request, crew, agent):
num_tasks = request.param or 1
tasks = []
for i in range(num_tasks):
description = "task{} description".format(self.num_task)
tasks.append(crew.create_simple_task(description=description, agent=agent))
TestCrew.num_task += 1
return tasks
class TestCrew(CrewFixtures):

def test_create_simple_task(self, crew, agent):
assert len(crew.tasks) == 0
Expand Down
60 changes: 60 additions & 0 deletions tests/test_crew/test_crew_threads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import pytest
import time

from motleycrew.crew.crew_threads import InvokeThreadPool, InvokeThreadState
from motleycrew.common import Defaults
from tests.test_crew import CrewFixtures


class TestInvokeThreadPool(CrewFixtures):

@pytest.fixture(scope="class")
def thread_pool(self):
obj = InvokeThreadPool()
return obj

def test_init_thread_pool(self, thread_pool):

assert len(thread_pool._threads) == Defaults.DEFAULT_NUM_THREADS
assert all([t.is_alive() for t in thread_pool._threads])
assert thread_pool.input_queue.empty()
assert thread_pool.output_queue.empty()
assert thread_pool.is_completed()

@pytest.mark.parametrize("tasks", [4], indirect=True)
def test_put(self, thread_pool, agent, tasks):
for task in tasks:
unit = task.get_next_unit()
thread_pool.put(agent, task, unit)

assert not thread_pool.is_completed()
assert len(thread_pool._in_process_tasks) == 4

def test_get_completed_tasks(self, thread_pool):
time.sleep(3)
completed_tasks = thread_pool.get_completed_tasks()
assert len(completed_tasks) == 4
assert len(thread_pool._in_process_tasks) == 0
assert thread_pool.is_completed()
assert all([t.state == InvokeThreadState.WAITING for t in thread_pool._threads])

@pytest.mark.parametrize("tasks", [1], indirect=True)
def test_get_completed_task_exception(self, thread_pool, agent, tasks):
for task in tasks:
thread_pool.put(agent, task, None)
time.sleep(1)

with pytest.raises(AttributeError):
thread_pool.get_completed_tasks()

assert not thread_pool.is_completed()

def test_close(self, thread_pool):
thread_pool.close()
time.sleep(3)
assert all([not t.is_alive() for t in thread_pool._threads])
assert all([t.state == InvokeThreadState.STOP for t in thread_pool._threads])

def test_is_completed(self, thread_pool):
assert len(thread_pool._in_process_tasks) == 1
assert not thread_pool.is_completed()

0 comments on commit efa7d94

Please sign in to comment.