Skip to content

Commit

Permalink
add docstrings for tasks package
Browse files Browse the repository at this point in the history
  • Loading branch information
User committed May 29, 2024
1 parent 0413cb7 commit 59b0d73
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 0 deletions.
57 changes: 57 additions & 0 deletions motleycrew/tasks/simple.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
""" Module description
Attributes:
PROMPT_TEMPLATE_WITH_DEPS (str):
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Sequence, List, Optional

Expand All @@ -24,6 +30,16 @@
def compose_simple_task_prompt_with_dependencies(
description: str, upstream_task_units: List[TaskUnit], default_task_name: str = "Unnamed task"
) -> str:
""" Description
Args:
description (str):
upstream_task_units (:obj:`list` of :obj:`TaskUnit`):
default_task_name (:obj:`str`, optional):
Returns:
str:
"""
upstream_results = []
for unit in upstream_task_units:
if not unit.output:
Expand All @@ -43,6 +59,14 @@ def compose_simple_task_prompt_with_dependencies(


class SimpleTaskUnit(TaskUnit):
""" Description
Attributes:
name (str):
prompt (str):
message_history (:obj:`list` of :obj:`str`):
"""
name: str
prompt: str
message_history: List[str] = []
Expand All @@ -60,6 +84,18 @@ def __init__(
creator_name: str | None = None,
return_to_creator: bool = False,
):
""" Description
Args:
crew (MotleyCrew):
description (str):
name (:obj:`str`, optional):
agent (:obj:`MotleyAgentAbstractParent`, optional):
tools (:obj:`Sequence[MotleyTool]`, optional):
documents (:obj:`Sequence[Any]`, optional):
creator_name (:obj:`str`, optional):
return_to_creator (:obj:`bool`, optional):
"""
super().__init__(name=name or description, task_unit_class=SimpleTaskUnit, crew=crew)
self.description = description
self.agent = agent # to be auto-assigned at crew creation if missing?
Expand All @@ -73,13 +109,26 @@ def __init__(
self.output = None # to be filled in by the agent(s) once the task is complete

def register_completed_unit(self, unit: SimpleTaskUnit) -> None:
""" Description
Args:
unit (SimpleTaskUnit):
Returns:
"""
assert isinstance(unit, SimpleTaskUnit)
assert unit.done

self.output = unit.output
self.set_done()

def get_next_unit(self) -> SimpleTaskUnit | None:
""" Description
Returns:
:obj:`SimpleTaskUnit`, None:
"""
if self.done:
logger.info("Task %s is already done", self)
return None
Expand All @@ -96,6 +145,14 @@ def get_next_unit(self) -> SimpleTaskUnit | None:
)

def get_worker(self, tools: Optional[List[MotleyTool]]) -> MotleyAgentAbstractParent:
""" Description
Args:
tools (:obj:`List[MotleyTool]`, :obj:`None`):
Returns:
MotleyAgentAbstractParent
"""
if self.crew is None:
raise ValueError("Task is not associated with a crew")
if self.agent is None:
Expand Down
82 changes: 82 additions & 0 deletions motleycrew/tasks/task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
""" Module description
Attributes:
TaskNodeType (TypeVar):
"""
from __future__ import annotations

from abc import ABC, abstractmethod
Expand All @@ -14,6 +20,13 @@


class TaskNode(MotleyGraphNode):
""" Description
Attributes:
name (str):
done (bool):
"""
__label__ = "TaskNode"
name: str
done: bool = False
Expand All @@ -26,12 +39,24 @@ def __eq__(self, other):


class Task(ABC, Generic[TaskUnitType]):
"""
Attributes:
NODE_CLASS (TaskNodeType):
TASK_IS_UPSTREAM_LABEL (str):
"""
NODE_CLASS: Type[TaskNodeType] = TaskNode
TASK_IS_UPSTREAM_LABEL = "task_is_upstream"

def __init__(
self, name: str, task_unit_class: Type[TaskUnitType], crew: Optional[MotleyCrew] = None
):
""" Description
Args:
name (str):
task_unit_class (Type[TaskUnitType]):
crew (:obj:`MotleyCrew`, optional):
"""
self.name = name
self.done = False
self.node = self.NODE_CLASS(name=name, done=self.done)
Expand All @@ -45,6 +70,11 @@ def __init__(
self.prepare_graph_store()

def prepare_graph_store(self):
""" Description
Returns:
"""
if isinstance(self.graph_store, MotleyKuzuGraphStore):
self.graph_store.ensure_node_table(self.NODE_CLASS)
self.graph_store.ensure_node_table(self.task_unit_class)
Expand All @@ -67,6 +97,14 @@ def __str__(self) -> str:
return self.__repr__()

def set_upstream(self, task: Task) -> Task:
""" Description
Args:
task (Task):
Returns:
Task:
"""
if self.crew is None or task.crew is None:
raise ValueError("Both tasks must be registered with a crew")

Expand Down Expand Up @@ -94,6 +132,11 @@ def __rrshift__(self, other: Sequence[Task]) -> Sequence[Task]:
return other

def get_units(self) -> List[TaskUnitType]:
"""
Returns:
:obj:`list` of :obj:`TaskUnitType`:
"""
assert self.crew is not None, "Task must be registered with a crew for accessing task units"

query = "MATCH (unit:{})-[:{}]->(task:{}) WHERE task.id = $self_id RETURN unit".format(
Expand All @@ -107,6 +150,11 @@ def get_units(self) -> List[TaskUnitType]:
return task_units

def get_upstream_tasks(self) -> List[Task]:
""" Description
Returns:
:obj:`list` of :obj:`Task`
"""
assert (
self.crew is not None and self.node.is_inserted
), "Task must be registered with a crew for accessing upstream tasks"
Expand All @@ -125,6 +173,11 @@ def get_upstream_tasks(self) -> List[Task]:
return [task for task in self.crew.tasks if task.node in upstream_task_nodes]

def get_downstream_tasks(self) -> List[Task]:
""" Description
Returns:
:obj:`list` of :obj:`Task`
"""
assert (
self.crew is not None and self.node.is_inserted
), "Task must be registered with a crew for accessing downstream tasks"
Expand All @@ -143,10 +196,26 @@ def get_downstream_tasks(self) -> List[Task]:
return [task for task in self.crew.tasks if task.node in downstream_task_nodes]

def set_done(self, value: bool = True):
""" Description
Args:
value (bool):
Returns:
"""
self.done = value
self.node.done = value

def register_started_unit(self, unit: TaskUnitType) -> None:
""" Description
Args:
unit (TaskUnitType):
Returns:
"""
assert isinstance(unit, self.task_unit_class)
assert not unit.done

Expand All @@ -161,8 +230,21 @@ def register_completed_unit(self, unit: TaskUnitType) -> None:

@abstractmethod
def get_next_unit(self) -> TaskUnitType | None:
""" Description
Returns:
:obj:`TaskUnitType` | None:
"""
pass

@abstractmethod
def get_worker(self, tools: Optional[List[MotleyTool]]) -> Runnable:
""" Description
Args:
tools (:obj:`List[MotleyTool]`, None):
Returns:
Runnable:
"""
pass
13 changes: 13 additions & 0 deletions motleycrew/tasks/task_unit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
""" Module description
Attributes:
TaskUnitType (TypeVar):
"""
from __future__ import annotations

from abc import ABC
Expand All @@ -8,6 +14,13 @@


class TaskUnit(MotleyGraphNode, ABC):
""" Description
Attributes:
status (:obj:`str`, optional):
output (:obj:`Any`, optional):
"""
status: str = TaskUnitStatus.PENDING
output: Optional[Any] = None

Expand Down

0 comments on commit 59b0d73

Please sign in to comment.