-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
User
committed
Jun 3, 2024
1 parent
e8b2932
commit efa7d94
Showing
3 changed files
with
106 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import pytest | ||
|
||
from motleycrew.crew import MotleyCrew | ||
|
||
|
||
class AgentMock: | ||
def invoke(self, input_dict) -> str: | ||
clear_dict = self.clear_input_dict(input_dict) | ||
return str(clear_dict) | ||
|
||
async def ainvoke(self, input_dict) -> str: | ||
return self.invoke(input_dict) | ||
|
||
@staticmethod | ||
def clear_input_dict(input_dict: dict) -> dict: | ||
clear_dict = {} | ||
for param in ("name", "prompt"): | ||
value = input_dict.get(param, None) | ||
if value is not None: | ||
clear_dict[param] = value | ||
return clear_dict | ||
|
||
|
||
class CrewFixtures: | ||
num_task = 0 | ||
|
||
@pytest.fixture(scope="class") | ||
def crew(self): | ||
obj = MotleyCrew() | ||
return obj | ||
|
||
@pytest.fixture | ||
def agent(self): | ||
return AgentMock() | ||
|
||
@pytest.fixture | ||
def tasks(self, request, crew, agent): | ||
num_tasks = request.param or 1 | ||
tasks = [] | ||
for i in range(num_tasks): | ||
description = "task{} description".format(self.num_task) | ||
tasks.append(crew.create_simple_task(description=description, agent=agent)) | ||
CrewFixtures.num_task += 1 | ||
return tasks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import pytest | ||
import time | ||
|
||
from motleycrew.crew.crew_threads import InvokeThreadPool, InvokeThreadState | ||
from motleycrew.common import Defaults | ||
from tests.test_crew import CrewFixtures | ||
|
||
|
||
class TestInvokeThreadPool(CrewFixtures): | ||
|
||
@pytest.fixture(scope="class") | ||
def thread_pool(self): | ||
obj = InvokeThreadPool() | ||
return obj | ||
|
||
def test_init_thread_pool(self, thread_pool): | ||
|
||
assert len(thread_pool._threads) == Defaults.DEFAULT_NUM_THREADS | ||
assert all([t.is_alive() for t in thread_pool._threads]) | ||
assert thread_pool.input_queue.empty() | ||
assert thread_pool.output_queue.empty() | ||
assert thread_pool.is_completed() | ||
|
||
@pytest.mark.parametrize("tasks", [4], indirect=True) | ||
def test_put(self, thread_pool, agent, tasks): | ||
for task in tasks: | ||
unit = task.get_next_unit() | ||
thread_pool.put(agent, task, unit) | ||
|
||
assert not thread_pool.is_completed() | ||
assert len(thread_pool._in_process_tasks) == 4 | ||
|
||
def test_get_completed_tasks(self, thread_pool): | ||
time.sleep(3) | ||
completed_tasks = thread_pool.get_completed_tasks() | ||
assert len(completed_tasks) == 4 | ||
assert len(thread_pool._in_process_tasks) == 0 | ||
assert thread_pool.is_completed() | ||
assert all([t.state == InvokeThreadState.WAITING for t in thread_pool._threads]) | ||
|
||
@pytest.mark.parametrize("tasks", [1], indirect=True) | ||
def test_get_completed_task_exception(self, thread_pool, agent, tasks): | ||
for task in tasks: | ||
thread_pool.put(agent, task, None) | ||
time.sleep(1) | ||
|
||
with pytest.raises(AttributeError): | ||
thread_pool.get_completed_tasks() | ||
|
||
assert not thread_pool.is_completed() | ||
|
||
def test_close(self, thread_pool): | ||
thread_pool.close() | ||
time.sleep(3) | ||
assert all([not t.is_alive() for t in thread_pool._threads]) | ||
assert all([t.state == InvokeThreadState.STOP for t in thread_pool._threads]) | ||
|
||
def test_is_completed(self, thread_pool): | ||
assert len(thread_pool._in_process_tasks) == 1 | ||
assert not thread_pool.is_completed() |