diff --git a/docs/source/_templates/autosummary/base.rst b/docs/source/_templates/autosummary/base.rst index b7556ebf..7d05ce46 100644 --- a/docs/source/_templates/autosummary/base.rst +++ b/docs/source/_templates/autosummary/base.rst @@ -1,5 +1,6 @@ -{{ fullname | escape | underline}} +{{ fullname | escape | underline }} .. currentmodule:: {{ module }} + .. auto{{ objtype }}:: {{ objname }} diff --git a/docs/source/_templates/autosummary/class.rst b/docs/source/_templates/autosummary/class.rst deleted file mode 100644 index 87b1f721..00000000 --- a/docs/source/_templates/autosummary/class.rst +++ /dev/null @@ -1,27 +0,0 @@ -{{ fullname | escape | underline}} - -.. currentmodule:: {{ module }} - -.. autoclass:: {{ objname }} - - {% block methods %} - {% if methods %} - .. rubric:: {{ _('Methods') }} - - .. autosummary:: - {% for item in methods %} - .. automethod:: {{ name }}.{{ item }} - {%- endfor %} - {% endif %} - {% endblock %} - - {% block attributes %} - {% if attributes %} - .. rubric:: {{ _('Attributes') }} - - .. autosummary:: - {% for item in attributes %} - .. autoattribute:: {{ name }}.{{ item }} - {%- endfor %} - {% endif %} - {% endblock %} diff --git a/docs/source/_templates/autosummary/module.rst b/docs/source/_templates/autosummary/module.rst index d7774f1d..9f79716e 100644 --- a/docs/source/_templates/autosummary/module.rst +++ b/docs/source/_templates/autosummary/module.rst @@ -1,65 +1,61 @@ {{ fullname | escape | underline}} -.. automodule:: {{ fullname }} - - {% block attributes %} - {% if attributes %} - .. rubric:: {{ _('Module Attributes') }} - - .. autosummary:: - {% for item in attributes %} - {{ item }} - {%- endfor %} - {% endif %} - {% endblock %} - - {% block functions %} - {% if functions %} - .. rubric:: {{ _('Functions') }} - - .. autosummary:: - {% for item in functions %} - .. autofunction:: {{ item }} - {%- endfor %} - {% endif %} - {% endblock %} - {% block classes %} - {% if classes %} - .. rubric:: {{ _('Classes') }} - - .. autosummary:: - :toctree: - {% for item in classes %} - {{item}} - {%- endfor %} - {% for item in classes %} - .. autoclass:: {{item}} - :members: - {%- endfor %} - {% endif %} - {% endblock %} - - {% block exceptions %} - {% if exceptions %} - .. rubric:: {{ _('Exceptions') }} - - .. autosummary:: - {% for item in exceptions %} - {{ item }} - {%- endfor %} - {% endif %} - {% endblock %} - -{% block modules %} -{% if modules %} -.. rubric:: Modules +.. automodule:: {{ fullname }} -.. autosummary:: - :toctree: - :recursive: -{% for item in modules %} - {{ item }} -{%- endfor %} -{% endif %} -{% endblock %} + {% block attributes %} + {% if attributes %} + .. rubric:: {{ _('Module Attributes') }} + + .. autosummary:: + {% for item in attributes %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block functions %} + {% if functions %} + .. rubric:: {{ _('Functions') }} + + .. autosummary:: + {% for item in functions %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block classes %} + {% if classes %} + .. rubric:: {{ _('Classes') }} + + .. autosummary:: + {% for item in classes %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block exceptions %} + {% if exceptions %} + .. rubric:: {{ _('Exceptions') }} + + .. autosummary:: + {% for item in exceptions %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block modules %} + {% if modules %} + .. rubric:: Modules + + .. autosummary:: + :toctree: + :recursive: + {% for item in modules %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} diff --git a/docs/source/api.rst b/docs/source/api.rst index 86d5e40e..3cec553d 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -5,4 +5,4 @@ API :toctree: _autosummary :recursive: - motleycrew \ No newline at end of file + motleycrew diff --git a/docs/source/conf.py b/docs/source/conf.py index 4aa0062c..8d129ff2 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -11,7 +11,6 @@ sys.path.append(os.path.abspath("../..")) - project = "motleycrew" copyright = "2024, motleycrew" author = "motleycrew" @@ -29,12 +28,22 @@ "nbsphinx_link", ] + templates_path = ["_templates", "_templates/autosummary"] exclude_patterns = [] -autosummary_generate = True + +# autodoc_default_options = { +# "member-order": "bysource", +# "special-members": "__init__", +# } + autodoc_default_options = { + "members": True, "member-order": "bysource", "special-members": "__init__", + "show-inheritance": True, + "inherited-members": False, + "undoc-members": True, } napoleon_google_docstring = True @@ -50,7 +59,11 @@ nbsphinx_allow_errors = True nbsphinx_execute = "never" -html_theme_options = { - "display_github": True, - "github_url": "https://github.com/ShoggothAI/motleycrew", -} +# Additional configuration for better auto-generated documentation +autosummary_generate = True # Turn on autosummary + +# Create separate .rst files for each module +autosummary_generate_overwrite = False + +# Make sure that the generated files are included in the toctree +autosummary_generate_include_files = True diff --git a/docs/source/index.rst b/docs/source/index.rst index c60a5769..134ac7d8 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -16,7 +16,6 @@ Welcome to motleycrew's documentation! :caption: Contents: - Indices and tables ================== diff --git a/docs/source/key_concepts.rst b/docs/source/key_concepts.rst index 05bbea78..7c29dbe6 100644 --- a/docs/source/key_concepts.rst +++ b/docs/source/key_concepts.rst @@ -11,7 +11,7 @@ For a basic introduction, you can check out the `quickstart `_. Crew and knowledge graph ------------------------ -The crew is a central concept in motleycrew. It is the orchestrator that knows what tasks sould be done in which order, +The crew (:class:`motleycrew.crew.crew.MotleyCrew`) is a central concept in motleycrew. It is the orchestrator that knows what tasks sould be done in which order, and manages the execution of those tasks. The crew has an underlying knowledge graph, in which it stores all information relevant to the execution of the tasks. @@ -79,8 +79,8 @@ that you can override to customize the task's behavior. #. ``get_next_unit()`` should return the next task unit to be processed. If there are no units to do at the moment, it should return `None`. #. ``get_worker()`` should return the worker (typically an agent) that will process the task's units. -#. `optional` ``register_started_unit(unit)`` is called by the crew when a task unit is dispatched. -#. `optional` ``register_completed_unit(unit)`` is called by the crew when a task unit is completed. +#. `optional` ``on_unit_dispatch(unit)`` is called by the crew when a task unit is dispatched. +#. `optional` ``on_unit_completion(unit)`` is called by the crew when a task unit is completed. Task hierarchy @@ -106,15 +106,15 @@ The crew queries the tasks for task units and dispatches them in a loop. The cre tasks are completed or available tasks stop providing task units. A task is considered completed when it has ``done`` attribute set to ``True``. For example, in the case of `SimpleTask`, -this happens when its only task unit is completed and the crew calls the task's ``register_completed_unit`` method. +this happens when its only task unit is completed and the crew calls the task's ``on_unit_completion`` method. In case of a custom task, this behavior is up to the task's implementation. Available tasks are defined as tasks that have not been completed and have no incomplete upstream tasks. On each iteration, available tasks are queried for task units one by one, and the crew will dispatch the task unit to the worker that the task provides. -When a task unit is dispatched, the crew adds it to the knowledge graph and calls the task's ``register_started_unit`` -method. When the worker finishes processing the task unit, the crew calls the task's ``register_completed_unit`` method. +When a task unit is dispatched, the crew adds it to the knowledge graph and calls the task's ``on_unit_dispatch`` +method. When the worker finishes processing the task unit, the crew calls the task's ``on_unit_completion`` method. .. image:: images/crew_diagram.png :alt: Crew main loop diff --git a/motleycrew/agents/abstract_parent.py b/motleycrew/agents/abstract_parent.py index df38c5f3..50174811 100644 --- a/motleycrew/agents/abstract_parent.py +++ b/motleycrew/agents/abstract_parent.py @@ -1,14 +1,18 @@ -""" Module description""" from abc import ABC, abstractmethod from typing import Optional, Any, TYPE_CHECKING -from langchain_core.runnables import RunnableConfig +from langchain_core.runnables import Runnable, RunnableConfig if TYPE_CHECKING: from motleycrew.tools import MotleyTool -class MotleyAgentAbstractParent(ABC): +class MotleyAgentAbstractParent(Runnable, ABC): + """Abstract class for describing agents. + + Agents in motleycrew implement the Langchain Runnable interface. + """ + @abstractmethod def invoke( self, @@ -16,24 +20,13 @@ def invoke( config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: - """ Description - - Args: - input (dict): - config (:obj:`RunnableConfig`, optional): - **kwargs: - - Returns: - Any: - """ pass @abstractmethod def as_tool(self) -> "MotleyTool": - """ Description + """Convert the agent to a MotleyTool to be used by other agents via delegation. Returns: - MotleyTool - + The tool representation of the agent. """ pass diff --git a/motleycrew/agents/crewai/agent_with_config.py b/motleycrew/agents/crewai/agent_with_config.py index f139db39..313e4ff0 100644 --- a/motleycrew/agents/crewai/agent_with_config.py +++ b/motleycrew/agents/crewai/agent_with_config.py @@ -1,9 +1,7 @@ -""" Module description """ - from typing import Any, Optional, List -from langchain_core.runnables import RunnableConfig from langchain.tools.render import render_text_description +from langchain_core.runnables import RunnableConfig from motleycrew.common.utils import ensure_module_is_installed diff --git a/motleycrew/agents/crewai/crewai.py b/motleycrew/agents/crewai/crewai.py index a7a8d759..7b57f984 100644 --- a/motleycrew/agents/crewai/crewai.py +++ b/motleycrew/agents/crewai/crewai.py @@ -1,10 +1,10 @@ -""" Module description """ +from __future__ import annotations from typing import Any, Optional, Sequence +from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.runnables import RunnableConfig from langchain_core.tools import StructuredTool -from langchain_core.pydantic_v1 import BaseModel, Field from motleycrew.agents.crewai import CrewAIAgentWithConfig from motleycrew.agents.parent import MotleyAgentParent @@ -21,6 +21,8 @@ class CrewAIMotleyAgentParent(MotleyAgentParent): + """Base MotleyCrew wrapper for CrewAI agents.""" + def __init__( self, goal: str, @@ -32,16 +34,33 @@ def __init__( output_handler: MotleySupportedTool | None = None, verbose: bool = False, ): - """Description - + """ Args: - goal (str): - prompt_prefix (:obj:`str`, optional): - description (:obj:`str`, optional): - name (:obj:`str`, optional): - agent_factory (:obj:`MotleyAgentFactory`, optional): - tools (:obj:`Sequence[MotleySupportedTool]`, optional: - verbose (bool): + goal: Goal of the agent. + + prompt_prefix: Prefix to the agent's prompt. + Can be used for providing additional context, such as the agent's role or backstory. + + description: Description of the agent. + + Unlike the prompt prefix, it is not included in the prompt. + The description is only used for describing the agent's purpose + when giving it as a tool to other agents. + + name: Name of the agent. + The name is used for identifying the agent when it is given as a tool + to other agents, as well as for logging purposes. + It is not included in the agent's prompt. + + agent_factory: Factory function to create the agent. + The factory function should accept a dictionary of tools and return + a CrewAIAgentWithConfig instance. + + tools: Tools to add to the agent. + + output_handler: Output handler for the agent. + + verbose: Whether to log verbose output. """ if output_handler: @@ -67,16 +86,6 @@ def invoke( config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: - """Description - - Args: - input (dict): - config (:obj:`RunnableConfig`, optional): - **kwargs: - - Returns: - Any: - """ prompt = self.prepare_for_invocation(input=input) langchain_tools = [tool.to_langchain_tool() for tool in self.tools.values()] @@ -105,25 +114,9 @@ def materialize(self): # TODO: what do these do? def set_cache_handler(self, cache_handler: Any) -> None: - """Description - - Args: - cache_handler (Any): - - Returns: - None: - """ return self.agent.set_cache_handler(cache_handler) def set_rpm_controller(self, rpm_controller: Any) -> None: - """Description - - Args: - rpm_controller (Any): - - Returns: - None: - """ return self.agent.set_rpm_controller(rpm_controller) @staticmethod @@ -132,15 +125,17 @@ def from_agent( tools: Sequence[MotleySupportedTool] | None = None, verbose: bool = False, ) -> "CrewAIMotleyAgentParent": - """Description + """Create a CrewAIMotleyAgentParent from a CrewAIAgentWithConfig instance. + + Using this method, you can wrap an existing CrewAIAgentWithConfig + without providing a factory function. Args: - agent (CrewAIAgentWithConfig): - tools (:obj:`Sequence[MotleySupportedTool]`, optional): - verbose (bool): + agent: CrewAIAgentWithConfig instance to wrap. + + tools: Tools to add to the agent. - Returns: - CrewAIMotleyAgentParent: + verbose: Whether to log verbose output. """ if tools or agent.tools: tools = list(tools or []) + list(agent.tools or []) @@ -162,9 +157,7 @@ def as_tool(self) -> MotleyTool: class CrewAIAgentInputSchema(BaseModel): prompt: str = Field(..., description="Prompt to be passed to the agent") - expected_output: str = Field( - ..., description="Expected output of the agent" - ) + expected_output: str = Field(..., description="Expected output of the agent") def call_agent(prompt: str, expected_output: str): return self.invoke( diff --git a/motleycrew/agents/crewai/crewai_agent.py b/motleycrew/agents/crewai/crewai_agent.py index 0927c079..2c89b8a7 100644 --- a/motleycrew/agents/crewai/crewai_agent.py +++ b/motleycrew/agents/crewai/crewai_agent.py @@ -1,41 +1,63 @@ -""" Module description """ +from __future__ import annotations from typing import Optional, Any, Sequence -from motleycrew.tools import MotleyTool -from motleycrew.common import MotleySupportedTool +from motleycrew.agents.crewai import CrewAIAgentWithConfig +from motleycrew.agents.crewai import CrewAIMotleyAgentParent from motleycrew.common import LLMFramework +from motleycrew.common import MotleySupportedTool from motleycrew.common.llms import init_llm -from motleycrew.agents.crewai import CrewAIMotleyAgentParent -from motleycrew.agents.crewai import CrewAIAgentWithConfig +from motleycrew.tools import MotleyTool class CrewAIMotleyAgent(CrewAIMotleyAgentParent): + """MotleyCrew wrapper for CrewAI Agent. + + This wrapper is made to mimic the CrewAI agent's interface. + That is why it has mostly the same arguments. + """ + def __init__( - self, - role: str, - goal: str, - backstory: str, - prompt_prefix: str | None = None, - description: str | None = None, - delegation: bool = False, - tools: Sequence[MotleySupportedTool] | None = None, - llm: Optional[Any] = None, - output_handler: MotleySupportedTool | None = None, - verbose: bool = False, + self, + role: str, + goal: str, + backstory: str, + prompt_prefix: str | None = None, + description: str | None = None, + delegation: bool = False, + tools: Sequence[MotleySupportedTool] | None = None, + llm: Optional[Any] = None, + output_handler: MotleySupportedTool | None = None, + verbose: bool = False, ): - """Description - + """ Args: - role (str): - goal (str): - backstory (str): - prompt_prefix (str): - description (str, optional): - delegation (bool): - tools (:obj:`Sequence[MotleySupportedTool]`, optional): - llm (:obj:'Any', optional): - verbose (bool): + role: ``role`` param of the CrewAI Agent. + + goal: ``goal`` param of the CrewAI Agent. + + backstory: ``backstory`` param of the CrewAI Agent. + + prompt_prefix: Prefix to the agent's prompt. + Can be used for providing additional context, such as the agent's role or backstory. + + description: Description of the agent. + + Unlike the prompt prefix, it is not included in the prompt. + The description is only used for describing the agent's purpose + when giving it as a tool to other agents. + + delegation: Whether to allow delegation or not. + **Delegation is not supported in this wrapper.** + Instead, pass the agents you want to delegate to as tools. + + tools: Tools to add to the agent. + + llm: LLM instance to use. + + output_handler: Output handler for the agent. + + verbose: Whether to log verbose output. """ if tools is None: tools = [] diff --git a/motleycrew/agents/langchain/langchain.py b/motleycrew/agents/langchain/langchain.py index 8e49da90..acda0e48 100644 --- a/motleycrew/agents/langchain/langchain.py +++ b/motleycrew/agents/langchain/langchain.py @@ -1,4 +1,4 @@ -""" Module description """ +from __future__ import annotations from typing import Any, Optional, Sequence @@ -15,31 +15,54 @@ class LangchainMotleyAgent(MotleyAgentParent, LangchainOutputHandlingAgentMixin): + """MotleyCrew wrapper for Langchain agents.""" + def __init__( self, - prompt_prefix: str | None = None, description: str | None = None, name: str | None = None, + prompt_prefix: str | None = None, agent_factory: MotleyAgentFactory[AgentExecutor] | None = None, tools: Sequence[MotleySupportedTool] | None = None, output_handler: MotleySupportedTool | None = None, - verbose: bool = False, chat_history: bool | GetSessionHistoryCallable = True, + verbose: bool = False, ): - """Description - + """ Args: - prompt_prefix (:obj:`str`, optional): - description (:obj:`str`, optional): - name (:obj:`str`, optional): - agent_factory (:obj:`MotleyAgentFactory`, optional): - tools (:obj:`Sequence[MotleySupportedTool]`, optional): - output_handler (:obj:`MotleySupportedTool`, optional): - verbose (bool): - chat_history (:obj:`bool`, :obj:`GetSessionHistoryCallable`): - Whether to use chat history or not. If `True`, uses `InMemoryChatMessageHistory`. - If a callable is passed, it is used to get the chat history by session_id. - See Langchain `RunnableWithMessageHistory` get_session_history param for more details. + description: Description of the agent. + + Unlike the prompt prefix, it is not included in the prompt. + The description is only used for describing the agent's purpose + when giving it as a tool to other agents. + + name: Name of the agent. + The name is used for identifying the agent when it is given as a tool + to other agents, as well as for logging purposes. + + It is not included in the agent's prompt. + + prompt_prefix: Prefix to the agent's prompt. + Can be used for providing additional context, such as the agent's role or backstory. + + agent_factory: Factory function to create the agent. + The factory function should accept a dictionary of tools and return + an AgentExecutor instance. + + See :class:`motleycrew.common.types.MotleyAgentFactory` for more details. + + tools: Tools to add to the agent. + + output_handler: Output handler for the agent. + + chat_history: Whether to use chat history or not. + If `True`, uses `InMemoryChatMessageHistory`. + If a callable is passed, it is used to get the chat history by session_id. + + See :class:`langchain_core.runnables.history.RunnableWithMessageHistory` + for more details. + + verbose: Whether to log verbose output. """ super().__init__( prompt_prefix=prompt_prefix, @@ -115,16 +138,6 @@ def invoke( config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: - """Description - - Args: - input (dict): - config (:obj:`RunnableConfig`, optional): - **kwargs: - - Returns: - - """ prompt = self.prepare_for_invocation(input=input) config = add_default_callbacks_to_langchain_config(config) @@ -141,22 +154,31 @@ def invoke( @staticmethod def from_agent( agent: AgentExecutor, - goal: str, description: str | None = None, + prompt_prefix: str | None = None, tools: Sequence[MotleySupportedTool] | None = None, verbose: bool = False, ) -> "LangchainMotleyAgent": - """Description + """Create a LangchainMotleyAgent from a :class:`langchain.agents.AgentExecutor` instance. + + Using this method, you can wrap an existing AgentExecutor + without providing a factory function. Args: - agent (AgentExecutor): - goal (str): - description (:obj:`str`, optional) - tools(:obj:`Sequence[MotleySupportedTool]`, optional): - verbose (bool): - - Returns: - LangchainMotleyAgent + agent: AgentExecutor instance to wrap. + + prompt_prefix: Prefix to the agent's prompt. + Can be used for providing additional context, such as the agent's role or backstory. + + description: Description of the agent. + + Unlike the prompt prefix, it is not included in the prompt. + The description is only used for describing the agent's purpose + when giving it as a tool to other agents. + + tools: Tools to add to the agent. + + verbose: Whether to log verbose output. """ # TODO: do we really need to unite the tools implicitly like this? # TODO: confused users might pass tools both ways at the same time @@ -166,7 +188,7 @@ def from_agent( tools = list(tools or []) + list(agent.tools or []) wrapped_agent = LangchainMotleyAgent( - prompt_prefix=goal, description=description, tools=tools, verbose=verbose + prompt_prefix=prompt_prefix, description=description, tools=tools, verbose=verbose ) wrapped_agent._agent = agent return wrapped_agent diff --git a/motleycrew/agents/langchain/react.py b/motleycrew/agents/langchain/react.py index 1c9a096b..932d21c7 100644 --- a/motleycrew/agents/langchain/react.py +++ b/motleycrew/agents/langchain/react.py @@ -1,19 +1,18 @@ -""" Module description""" +from __future__ import annotations from typing import Sequence, Optional from langchain import hub -from langchain_core.language_models import BaseLanguageModel -from langchain_core.runnables.history import GetSessionHistoryCallable from langchain.agents import AgentExecutor from langchain.agents import create_react_agent +from langchain_core.language_models import BaseLanguageModel +from langchain_core.runnables.history import GetSessionHistoryCallable from motleycrew.agents.langchain import LangchainMotleyAgent -from motleycrew.tools import MotleyTool -from motleycrew.common import MotleySupportedTool from motleycrew.common import LLMFramework +from motleycrew.common import MotleySupportedTool from motleycrew.common.llms import init_llm - +from motleycrew.tools import MotleyTool OUTPUT_HANDLER_WITH_DEFAULT_PROMPT_MESSAGE = ( "Langchain's default ReAct prompt tells the agent to include a final answer keyword, " @@ -23,39 +22,39 @@ class ReActMotleyAgent(LangchainMotleyAgent): + """Basic ReAct agent compatible with older models without dedicated tool calling support. + + It's probably better to use the more advanced + :class:`motleycrew.agents.langchain.tool_calling_react.ReActToolCallingAgent` with newer models. + """ + def __init__( self, tools: Sequence[MotleySupportedTool], - prompt_prefix: str | None = None, description: str | None = None, name: str | None = None, - prompt: str | None = None, + prompt_prefix: str | None = None, output_handler: MotleySupportedTool | None = None, chat_history: bool | GetSessionHistoryCallable = True, + prompt: str | None = None, handle_parsing_errors: bool = True, handle_tool_errors: bool = True, llm: BaseLanguageModel | None = None, verbose: bool = False, ): - """Basic ReAct agent compatible with older models without dedicated tool calling support. - It's probably better to use the more advanced `ReActToolCallingAgent` with newer models. - + """ Args: - tools (Sequence[MotleySupportedTool]): - prompt_prefix (:obj:`str`, optional): - description (:obj:`str`, optional): - name (:obj:`str`, optional): - prompt (:obj:`str`, optional): Prompt to use. If not provided, uses hwchase17/react. - chat_history (:obj:`bool`, :obj:`GetSessionHistoryCallable`): - Whether to use chat history or not. If `True`, uses `InMemoryChatMessageHistory`. - If a callable is passed, it is used to get the chat history by session_id. - See Langchain `RunnableWithMessageHistory` get_session_history param for more details. - output_handler (BaseTool, optional): Tool to use for returning agent's output. - handle_parsing_errors (:obj:`bool`, optional): Whether to handle parsing errors or not. - handle_tool_errors (:obj:`bool`, optional): Whether to handle tool errors or not. - If True, `handle_tool_error` and `handle_validation_error` in all tools are set to True. - llm (:obj:`BaseLanguageModel`, optional): - verbose (:obj:`bool`, optional): + tools: Tools to add to the agent. + description: Description of the agent. + name: Name of the agent. + prompt_prefix: Prefix to the agent's prompt. + output_handler: Output handler for the agent. + chat_history: Whether to use chat history or not. + prompt: Custom prompt to use with the agent. + handle_parsing_errors: Whether to handle parsing errors. + handle_tool_errors: Whether to handle tool errors. + llm: Language model to use. + verbose: Whether to log verbose output. """ if prompt is None: if output_handler is not None: diff --git a/motleycrew/agents/langchain/tool_calling_react.py b/motleycrew/agents/langchain/tool_calling_react.py index 6e7bc636..bf20db2b 100644 --- a/motleycrew/agents/langchain/tool_calling_react.py +++ b/motleycrew/agents/langchain/tool_calling_react.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Sequence, List, Optional from langchain.agents import AgentExecutor @@ -91,27 +93,12 @@ def check_variables(prompt: ChatPromptTemplate): - """ - Args: - prompt (ChatPromptTemplate): - - Returns: - - """ missing_vars = {"agent_scratchpad"}.difference(prompt.input_variables) if missing_vars: raise ValueError(f"Prompt missing required variables: {missing_vars}") def add_thought_to_background(x: dict): - """Description - - Args: - x (dict): - - Returns: - - """ out = x["background"] out["agent_scratchpad"] += [x["thought"]] return out @@ -131,14 +118,6 @@ def cast_thought_to_human_message(thought: BaseMessage): def add_messages_to_action( actions: List[AgentActionMessageLog] | AgentFinish, messages: List[BaseMessage] ) -> List[AgentActionMessageLog] | AgentFinish: - """ - Args: - actions (:obj:`List[AgentActionMessageLog]`, :obj:`AgentFinish`): - messages (List[BaseMessage]): - - Returns: - List[AgentActionMessageLog] | AgentFinish: - """ if not isinstance(actions, AgentFinish): for action in actions: action.message_log = messages + list(action.message_log) @@ -151,10 +130,10 @@ def merge_consecutive_messages(messages: Sequence[BaseMessage]) -> List[BaseMess multiple AIMessages in a row. Args: - messages (Sequence[BaseMessage]): The list of messages to process. + messages: The list of messages to process. Returns: - List[BaseMessage]: The list of messages with consecutive messages of the same type merged. + The list of messages with consecutive messages of the same type merged. """ merged_messages = [] for message in messages: @@ -181,13 +160,13 @@ def create_tool_calling_react_agent( """Create a ReAct-style agent that supports tool calling. Args: - llm (BaseChatModel): LLM to use as the agent. - tools (Sequence[BaseTool]): Tools this agent has access to. - think_prompt (ChatPromptTemplate, optional): The thinking step prompt to use. + llm: LLM to use as the agent. + tools: Tools this agent has access to. + think_prompt: The thinking step prompt to use. See Prompt section below for more on the expected input variables. - act_prompt (ChatPromptTemplate, optional): The acting step prompt to use. + act_prompt: The acting step prompt to use. See Prompt section below for more on the expected input variables. - output_handler (BaseTool, optional): Tool to use for returning agent's output. + output_handler: Tool to use for returning agent's output. Returns: A Runnable sequence representing an agent. It takes as input all the same input @@ -247,12 +226,19 @@ def create_tool_calling_react_agent( class ReActToolCallingAgent(LangchainMotleyAgent): + """Universal ReAct-style agent that supports tool calling. + + This agent only works with newer models that support tool calling. + If you are using an older model, you should use + :class:`motleycrew.agents.langchain.react.ReActMotleyAgent` instead. + """ + def __init__( self, tools: Sequence[MotleySupportedTool], - prompt_prefix: str | None = None, description: str | None = None, name: str | None = None, + prompt_prefix: str | None = None, think_prompt: ChatPromptTemplate | None = None, act_prompt: ChatPromptTemplate | None = None, chat_history: bool | GetSessionHistoryCallable = True, @@ -262,26 +248,33 @@ def __init__( llm: BaseChatModel | None = None, verbose: bool = False, ): - """Universal ReAct-style agent that supports tool calling. - + """ Args: - tools (Sequence[MotleySupportedTool]): - description (:obj:`str`, optional): - name (:obj:`str`, optional): - think_prompt (ChatPromptTemplate, optional): The thinking step prompt to use. + tools: Tools to add to the agent. + description: Description of the agent. + name: Name of the agent. + prompt_prefix: Prefix to the agent's prompt. + think_prompt: The thinking step prompt to use. See Prompt section below for more on the expected input variables. - act_prompt (ChatPromptTemplate, optional): The acting step prompt to use. + act_prompt: The acting step prompt to use. See Prompt section below for more on the expected input variables. - chat_history (:obj:`bool`, :obj:`GetSessionHistoryCallable`): - Whether to use chat history or not. If `True`, uses `InMemoryChatMessageHistory`. + chat_history: Whether to use chat history or not. + If `True`, uses `InMemoryChatMessageHistory`. If a callable is passed, it is used to get the chat history by session_id. - See Langchain `RunnableWithMessageHistory` get_session_history param for more details. - output_handler (BaseTool, optional): Tool to use for returning agent's output. - handle_parsing_errors (:obj:`bool`, optional): Whether to handle parsing errors or not. - handle_tool_errors (:obj:`bool`, optional): Whether to handle tool errors or not. - If True, `handle_tool_error` and `handle_validation_error` in all tools are set to True. - llm (:obj:`BaseLanguageModel`, optional): - verbose (:obj:`bool`, optional): + See :class:`langchain_core.runnables.history.RunnableWithMessageHistory` + for more details. + output_handler: Output handler for the agent. + handle_parsing_errors: Whether to handle parsing errors. + handle_tool_errors: Whether to handle tool errors. + If True, `handle_tool_error` and `handle_validation_error` in all tools + are set to True. + llm: Language model to use. + verbose: Whether to log verbose output. + + Prompt: + This agent uses two prompts, one for thinking and one for acting. The prompts + must have `agent_scratchpad` and `chat_history` ``MessagesPlaceholder``s. + If a prompt is not passed in, the default one is used. """ if llm is None: llm = init_llm(llm_framework=LLMFramework.LANGCHAIN) diff --git a/motleycrew/agents/llama_index/llama_index.py b/motleycrew/agents/llama_index/llama_index.py index 5ed18c79..447e432e 100644 --- a/motleycrew/agents/llama_index/llama_index.py +++ b/motleycrew/agents/llama_index/llama_index.py @@ -1,4 +1,4 @@ -""" Module description """ +from __future__ import annotations import uuid from typing import Any, Optional, Sequence @@ -24,6 +24,7 @@ class LlamaIndexMotleyAgent(MotleyAgentParent): + """MotleyCrew wrapper for LlamaIndex agents.""" def __init__( self, @@ -35,15 +36,34 @@ def __init__( output_handler: MotleySupportedTool | None = None, verbose: bool = False, ): - """Description - + """ Args: - prompt_prefix (:obj:`str`, optional): - description (:obj:`str`, optional): - name (:obj:`str`, optional): - agent_factory (:obj:`MotleyAgentFactory`, optional): - tools (:obj:`Sequence[MotleySupportedTool]`, optional): - verbose (:obj:`bool`, optional): + prompt_prefix: Prefix to the agent's prompt. + Can be used for providing additional context, such as the agent's role or backstory. + + description: Description of the agent. + + Unlike the prompt prefix, it is not included in the prompt. + The description is only used for describing the agent's purpose + when giving it as a tool to other agents. + + name: Name of the agent. + The name is used for identifying the agent when it is given as a tool + to other agents, as well as for logging purposes. + + It is not included in the agent's prompt. + + agent_factory: Factory function to create the agent. + The factory function should accept a dictionary of tools and return + an AgentRunner instance. + + See :class:`motleycrew.common.types.MotleyAgentFactory` for more details. + + tools: Tools to add to the agent. + + output_handler: Output handler for the agent. + + verbose: Whether to log verbose output. """ super().__init__( description=description, @@ -55,8 +75,11 @@ def __init__( verbose=verbose, ) - def run_step_decorator(self): - """Decorator for inclusion in the call chain of the agent, the output handler tool""" + def _run_step_decorator(self): + """Decorator for the ``AgentRunner._run_step`` method that catches DirectOutput exceptions. + + It also blocks plain output and forces the use of the output handler tool if it is present. + """ ensure_module_is_installed("llama_index") def decorator(func): @@ -106,7 +129,7 @@ def wrapper( def materialize(self): super(LlamaIndexMotleyAgent, self).materialize() - self._agent._run_step = self.run_step_decorator()(self._agent._run_step) + self._agent._run_step = self._run_step_decorator()(self._agent._run_step) def invoke( self, @@ -114,16 +137,6 @@ def invoke( config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: - """Description - - Args: - input (dict): - config (:obj:`RunnableConfig`, optional): - **kwargs: - - Returns: - Any: - """ prompt = self.prepare_for_invocation(input=input) output = self.agent.chat(prompt) @@ -141,17 +154,27 @@ def from_agent( tools: Sequence[MotleySupportedTool] | None = None, verbose: bool = False, ) -> "LlamaIndexMotleyAgent": - """Description + """Create a LlamaIndexMotleyAgent from a :class:`llama_index.core.agent.AgentRunner` + instance. + + Using this method, you can wrap an existing AgentRunner + without providing a factory function. Args: - agent (AgentRunner): - description (:obj:`str`, optional): - prompt_prefix (:obj:`str`, optional): - tools (:obj:`Sequence[MotleySupportedTool]`, optional): - verbose (:obj:`bool`, optional): - - Returns: - LlamaIndexMotleyAgent: + agent: AgentRunner instance to wrap. + + prompt_prefix: Prefix to the agent's prompt. + Can be used for providing additional context, such as the agent's role or backstory. + + description: Description of the agent. + + Unlike the prompt prefix, it is not included in the prompt. + The description is only used for describing the agent's purpose + when giving it as a tool to other agents. + + tools: Tools to add to the agent. + + verbose: Whether to log verbose output. """ ensure_module_is_installed("llama_index") wrapped_agent = LlamaIndexMotleyAgent( diff --git a/motleycrew/agents/llama_index/llama_index_react.py b/motleycrew/agents/llama_index/llama_index_react.py index 915aabe9..a4294bcc 100644 --- a/motleycrew/agents/llama_index/llama_index_react.py +++ b/motleycrew/agents/llama_index/llama_index_react.py @@ -1,4 +1,4 @@ -""" Module description """ +from __future__ import annotations from typing import Sequence @@ -19,6 +19,8 @@ class ReActLlamaIndexMotleyAgent(LlamaIndexMotleyAgent): + """Wrapper for LlamaIndex implementation of ReAct agent.""" + def __init__( self, prompt_prefix: str | None = None, @@ -30,15 +32,33 @@ def __init__( verbose: bool = False, max_iterations: int = 10, ): - """Description - + """ Args: - prompt_prefix (:obj:`str`, optional): - description (:obj:`str`, optional): - name (:obj:`str`, optional): - tools (:obj:`Sequence[MotleySupportedTool]`, optional): - llm (:obj:`LLM`, optional): - verbose (:obj:`bool`, optional): + prompt_prefix: Prefix to the agent's prompt. + Can be used for providing additional context, such as the agent's role or backstory. + + description: Description of the agent. + + Unlike the prompt prefix, it is not included in the prompt. + The description is only used for describing the agent's purpose + when giving it as a tool to other agents. + + name: Name of the agent. + The name is used for identifying the agent when it is given as a tool + to other agents, as well as for logging purposes. + + It is not included in the agent's prompt. + + tools: Tools to add to the agent. + + llm: LLM instance to use. + + output_handler: Output handler for the agent. + + verbose: Whether to log verbose output. + + max_iterations: Maximum number of iterations for the agent. + Passed on to the ``max_iterations`` parameter of the ReActAgent. """ ensure_module_is_installed("llama_index") if llm is None: diff --git a/motleycrew/agents/mixins.py b/motleycrew/agents/mixins.py index 5b24f92f..6104083c 100644 --- a/motleycrew/agents/mixins.py +++ b/motleycrew/agents/mixins.py @@ -80,7 +80,7 @@ def wrapper( def take_next_step_decorator(self, func: Callable): """ - Decorator for AgentExecutor._take_next_step() method that catches DirectOutput exceptions. + Decorator for ``AgentExecutor._take_next_step()`` method that catches DirectOutput exceptions. """ def wrapper( diff --git a/motleycrew/agents/output_handler.py b/motleycrew/agents/output_handler.py index 337f0e64..26f936b1 100644 --- a/motleycrew/agents/output_handler.py +++ b/motleycrew/agents/output_handler.py @@ -1,36 +1,44 @@ -from typing import Optional from abc import ABC, abstractmethod -from langchain_core.tools import StructuredTool +from typing import Optional + from langchain_core.pydantic_v1 import BaseModel +from langchain_core.tools import StructuredTool from motleycrew.agents.abstract_parent import MotleyAgentAbstractParent -from motleycrew.common.exceptions import InvalidOutput from motleycrew.common import Defaults +from motleycrew.common.exceptions import InvalidOutput from motleycrew.tools import MotleyTool class MotleyOutputHandler(MotleyTool, ABC): - _name: str = "output_handler" - """Name of the output handler tool.""" + """Base class for output handler tools. - _description: str = "Output handler. ONLY RETURN THE FINAL RESULT USING THIS TOOL!" - """Description of the output handler tool.""" + Output handler tools are used to process the final output of an agent. - _args_schema: Optional[BaseModel] = None - """Pydantic schema for the arguments of the output handler tool. - Inferred from the `handle_output` method if not provided.""" + For creating an output handler tool, inherit from this class and implement + the `handle_output` method. + + Attributes: + _name: Name of the output handler tool. + _description: Description of the output handler tool. + _args_schema: Pydantic schema for the arguments of the output handler tool. + Inferred from the ``handle_output`` method signature if not provided. + _exceptions_to_handle: Exceptions that should be returned to the agent when raised. + """ + _name: str = "output_handler" + _description: str = "Output handler. ONLY RETURN THE FINAL RESULT USING THIS TOOL!" + _args_schema: Optional[BaseModel] = None _exceptions_to_handle: tuple[Exception] = (InvalidOutput,) - """Exceptions that should be returned to the agent when raised in the `handle_output` method.""" def __init__(self, max_iterations: int = Defaults.DEFAULT_OUTPUT_HANDLER_MAX_ITERATIONS): - """Initialize the output handler tool. - + """ Args: - max_iterations (int): Maximum number of iterations to run the output handler. - If an exception is raised in the `handle_output` method, the output handler will return - the exception to the agent unless the number of iterations exceeds `max_iterations`, - in which case the output handler will raise OutputHandlerMaxIterationsExceeded. + max_iterations: Maximum number of iterations to run the output handler. + If an exception is raised in the ``handle_output`` method, the output handler + will return the exception to the agent unless the number of iterations exceeds + ``max_iterations``, in which case the output handler will raise + :class:`motleycrew.common.exceptions.OutputHandlerMaxIterationsExceeded`. """ self.max_iterations = max_iterations langchain_tool = self._create_langchain_tool() @@ -53,4 +61,8 @@ def _create_langchain_tool(self): @abstractmethod def handle_output(self, *args, **kwargs): + """Method for processing the final output of an agent. + + Implement this method in your output handler tool. + """ pass diff --git a/motleycrew/agents/parent.py b/motleycrew/agents/parent.py index 3f500ab6..47457608 100644 --- a/motleycrew/agents/parent.py +++ b/motleycrew/agents/parent.py @@ -1,26 +1,22 @@ -""" Module description """ +from __future__ import annotations import inspect +from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, Optional, Sequence, Any, - Callable, - Dict, - Type, Union, - Tuple, ) from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessage, SystemMessage -from langchain_core.runnables import Runnable +from langchain_core.runnables import RunnableConfig from langchain_core.tools import StructuredTool from langchain_core.tools import Tool -from motleycrew.agents.output_handler import MotleyOutputHandler -from pydantic import BaseModel from motleycrew.agents.abstract_parent import MotleyAgentAbstractParent +from motleycrew.agents.output_handler import MotleyOutputHandler from motleycrew.common import MotleyAgentFactory, MotleySupportedTool from motleycrew.common import logger, Defaults from motleycrew.common.exceptions import ( @@ -36,11 +32,27 @@ class DirectOutput(BaseException): + """Auxiliary exception to return the agent output directly through the output handler. + + When the output handler returns an output, this exception is raised with the output. + It is then handled by the agent, who should gracefully return the output to the user. + """ + def __init__(self, output: Any): self.output = output -class MotleyAgentParent(MotleyAgentAbstractParent, Runnable): +class MotleyAgentParent(MotleyAgentAbstractParent, ABC): + """Parent class for all motleycrew agents. + + This class is abstract and should be subclassed by all agents in motleycrew. + + In most cases, it's better to use one of the specialized agent classes, + such as LangchainMotleyAgent or LlamaIndexMotleyAgent, which provide various + useful features, such as observability and output handling, out of the box. + + If you need to create a custom agent, subclass this class and implement the `invoke` method. + """ def __init__( self, @@ -52,16 +64,28 @@ def __init__( output_handler: MotleySupportedTool | None = None, verbose: bool = False, ): - """Description - + """ Args: - prompt_prefix (:obj:`str`, optional): - description (:obj:`str`, optional): - name (:obj:`str`, optional): - agent_factory (:obj:`MotleyAgentFactory`, optional): - tools (:obj:`Sequence[MotleySupportedTool]`, optional): - output_handler (:obj:`MotleySupportedTool`, optional): - verbose (:obj:`bool`, optional): + prompt_prefix: Prefix to the agent's prompt. + Can be used for providing additional context, such as the agent's role or backstory. + description: Description of the agent. + + Unlike the prompt prefix, it is not included in the prompt. + The description is only used for describing the agent's purpose + when giving it as a tool to other agents. + name: Name of the agent. + The name is used for identifying the agent when it is given as a tool + to other agents, as well as for logging purposes. + + It is not included in the agent's prompt. + agent_factory: Factory function to create the agent. + The factory function should accept a dictionary of tools and return the agent. + It is usually called right before the agent is invoked for the first time. + + See :class:`motleycrew.common.types.MotleyAgentFactory` for more details. + tools: Tools to add to the agent. + output_handler: Output handler for the agent. + verbose: Whether to log verbose output. """ self.name = name or description self.description = description # becomes tool description @@ -86,6 +110,15 @@ def __str__(self): def compose_prompt( self, input_dict: dict, prompt: ChatPromptTemplate | str ) -> Union[str, ChatPromptTemplate]: + """Compose the agent's prompt from the prompt prefix and the provided prompt. + + Args: + input_dict: The input dictionary to the agent. + prompt: The prompt to be added to the agent's prompt. + + Returns: + The composed prompt. + """ # TODO: always cast description and prompt to ChatPromptTemplate first? prompt_messages = [] @@ -100,9 +133,7 @@ def compose_prompt( prompt_messages.append(SystemMessage(content=self.prompt_prefix)) else: - raise ValueError( - "Agent description must be a string or a ChatPromptTemplate" - ) + raise ValueError("Agent description must be a string or a ChatPromptTemplate") if prompt: if isinstance(prompt, ChatPromptTemplate): @@ -130,6 +161,7 @@ def agent(self): @property def is_materialized(self): + """Whether the agent is materialized.""" return self._agent is not None def _prepare_output_handler(self) -> Optional[MotleyTool]: @@ -183,6 +215,10 @@ def handle_agent_output(*args, **kwargs): return MotleyTool.from_langchain_tool(prepared_output_handler) def materialize(self): + """Materialize the agent by creating the agent instance using the agent factory. + This method should be called before invoking the agent for the first time. + """ + if self.is_materialized: logger.info("Agent is already materialized, skipping materialization") return @@ -192,13 +228,9 @@ def materialize(self): if inspect.signature(self.agent_factory).parameters.get("output_handler"): logger.info("Agent factory accepts output handler, passing it") - self._agent = self.agent_factory( - tools=self.tools, output_handler=output_handler - ) + self._agent = self.agent_factory(tools=self.tools, output_handler=output_handler) elif output_handler: - logger.info( - "Agent factory does not accept output handler, passing it as a tool" - ) + logger.info("Agent factory does not accept output handler, passing it as a tool") tools_with_output_handler = self.tools.copy() tools_with_output_handler[output_handler.name] = output_handler self._agent = self.agent_factory(tools=tools_with_output_handler) @@ -206,12 +238,12 @@ def materialize(self): self._agent = self.agent_factory(tools=self.tools) def prepare_for_invocation(self, input: dict) -> str: - """Prepares the agent for invocation by materializing it and composing the prompt. + """Prepare the agent for invocation by materializing it and composing the prompt. Should be called in the beginning of the agent's invoke method. Args: - input (dict): the input to the agent + input: the input to the agent Returns: str: the composed prompt @@ -226,13 +258,10 @@ def prepare_for_invocation(self, input: dict) -> str: return prompt def add_tools(self, tools: Sequence[MotleySupportedTool]): - """Description + """Add tools to the agent. Args: - tools (Sequence[MotleySupportedTool]): - - Returns: - + tools: The tools to add to the agent. """ if self.is_materialized and tools: raise CannotModifyMaterializedAgent(agent_name=self.name) @@ -242,14 +271,11 @@ def add_tools(self, tools: Sequence[MotleySupportedTool]): if motley_tool.name not in self.tools: self.tools[motley_tool.name] = motley_tool - def as_tool(self, input_schema: Optional[Type[BaseModel]] = None) -> MotleyTool: - """Description - - Args: - input_schema (:obj:`Type[BaseModel]`, optional): + def as_tool(self) -> MotleyTool: + """Convert the agent to a tool to be used by other agents via delegation. Returns: - MotleyTool: + The tool representation of the agent. """ if not self.description: @@ -271,42 +297,14 @@ def call_agent(*args, **kwargs): ).lower(), # OpenAI doesn't accept spaces in function names description=self.description, func=call_agent, - args_schema=input_schema, ) ) - # def call_as_tool(self, *args, **kwargs) -> Any: - # logger.info("Entering delegation for %s", self.name) - # assert self.crew, "can't accept delegated task outside of a crew" - # - # if len(args) > 0: - # input_ = args[0] - # elif "tool_input" in kwargs: - # # Is this a crewai notation? - # input_ = kwargs["tool_input"] - # else: - # input_ = json.dumps(kwargs) - # - # logger.info("Made the args: %s", input_) - # - # # TODO: pass context of parent task to agent nicely? - # # TODO: mark the current task as depending on the new task - # task = SimpleTaskRecipe( - # description=input_, - # name=input_, - # agent=self, - # # TODO inject the new subtask as a dep and reschedule the parent - # # TODO probably can't do this from here since we won't know if - # # there are other tasks to schedule - # crew=self.crew, - # ) - # - # # TODO: make sure tools return task objects, which are properly used by callers - # logger.info("Executing subtask '%s'", task.name) - # self.crew.task_graph.set_task_running(task=task) - # result = self.crew.execute(task, return_result=True) - # - # logger.info("Finished subtask '%s' - %s", task.name, result) - # self.crew.task_graph.set_task_done(task=task) - # - # return result + @abstractmethod + def invoke( + self, + input: dict, + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Any: + pass diff --git a/motleycrew/applications/research_agent/answer_task.py b/motleycrew/applications/research_agent/answer_task.py index 3cfcc800..aaf76413 100644 --- a/motleycrew/applications/research_agent/answer_task.py +++ b/motleycrew/applications/research_agent/answer_task.py @@ -1,31 +1,27 @@ -""" Module description""" - from typing import List, Optional + from langchain_core.runnables import Runnable -from motleycrew.crew import MotleyCrew -from motleycrew.tools import MotleyTool -from motleycrew.tasks import Task -from motleycrew.tasks.task_unit import TaskUnitType -from motleycrew.tasks import TaskUnit -from motleycrew.applications.research_agent.question import Question, QuestionAnsweringTaskUnit +from motleycrew.applications.research_agent.question import Question from motleycrew.applications.research_agent.question_answerer import AnswerSubQuestionTool -from motleycrew.storage import MotleyGraphStore from motleycrew.common import logger +from motleycrew.crew import MotleyCrew +from motleycrew.tasks import Task, TaskUnit +from motleycrew.tools import MotleyTool + + +class QuestionAnsweringTaskUnit(TaskUnit): + question: Question class AnswerTask(Task): + """Task to answer a question based on the notes and sub-questions.""" + def __init__( self, crew: MotleyCrew, answer_length: int = 1000, ): - """Description - - Args: - crew (MotleyCrew): - answer_length (:obj:`int`, optional): - """ super().__init__( name="AnswerTask", task_unit_class=QuestionAnsweringTaskUnit, @@ -38,11 +34,10 @@ def __init__( ) def get_next_unit(self) -> QuestionAnsweringTaskUnit | None: - """Description + """Choose an unanswered question to answer. + + The question should have a context and no unanswered subquestions.""" - Returns: - QuestionAnsweringTaskUnit | None: - """ query = ( "MATCH (n1:{}) " "WHERE n1.answer IS NULL AND n1.context IS NOT NULL " @@ -62,12 +57,4 @@ def get_next_unit(self) -> QuestionAnsweringTaskUnit | None: return None def get_worker(self, tools: Optional[List[MotleyTool]]) -> Runnable: - """Description - - Args: - tools (List[MotleyTool]): - - Returns: - Runnable: - """ return self.answerer diff --git a/motleycrew/applications/research_agent/question.py b/motleycrew/applications/research_agent/question.py index 9aecee81..3d95c87b 100644 --- a/motleycrew/applications/research_agent/question.py +++ b/motleycrew/applications/research_agent/question.py @@ -1,23 +1,13 @@ -""" Module description""" from typing import Optional -from dataclasses import dataclass -import json from motleycrew.storage.graph_node import MotleyGraphNode -from motleycrew.tasks import TaskUnit REPR_CONTEXT_LENGTH_LIMIT = 30 class Question(MotleyGraphNode): - """ Description + """Represents a question node in the graph.""" - Attributes: - question (str): - answer (:obj:`str`, optional): - context (:obj:`List[str]`, optional) - - """ question: str answer: Optional[str] = None context: Optional[list[str]] = None @@ -35,11 +25,3 @@ def __repr__(self): return "Question(id={}, question={}, answer={}, context={})".format( self.id, self.question, self.answer, context_repr ) - - -class QuestionGenerationTaskUnit(TaskUnit): - question: Question - - -class QuestionAnsweringTaskUnit(TaskUnit): - question: Question diff --git a/motleycrew/applications/research_agent/question_answerer.py b/motleycrew/applications/research_agent/question_answerer.py index 0e435df2..9301887b 100644 --- a/motleycrew/applications/research_agent/question_answerer.py +++ b/motleycrew/applications/research_agent/question_answerer.py @@ -1,24 +1,17 @@ -""" Module description - -Attributes: - _default_prompt: -""" -from langchain_core.pydantic_v1 import BaseModel, Field from langchain.prompts import PromptTemplate from langchain_core.prompts.base import BasePromptTemplate -from langchain_core.tools import Tool +from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.runnables import ( RunnablePassthrough, RunnableLambda, chain, ) +from langchain_core.tools import Tool -from motleycrew.tools import MotleyTool, LLMTool -from motleycrew.storage import MotleyGraphStore +from motleycrew.applications.research_agent.question import Question from motleycrew.common.utils import print_passthrough - -from motleycrew.applications.research_agent.question import Question, QuestionAnsweringTaskUnit - +from motleycrew.storage import MotleyGraphStore +from motleycrew.tools import MotleyTool, LLMTool _default_prompt = PromptTemplate.from_template( """ @@ -37,19 +30,14 @@ class AnswerSubQuestionTool(MotleyTool): + """Tool to answer a question based on the notes and sub-questions.""" + def __init__( self, graph: MotleyGraphStore, answer_length: int, prompt: str | BasePromptTemplate = None, ): - """ Description - - Args: - graph (MotleyGraphStore): - answer_length (int): - prompt (:obj:`str`, :obj:`BasePromptTemplate`, optional): - """ langchain_tool = create_answer_question_langchain_tool( graph=graph, answer_length=answer_length, @@ -60,27 +48,12 @@ def __init__( class QuestionAnswererInput(BaseModel, arbitrary_types_allowed=True): - """Data on the question to answer. - - Attributes: - question (Question): - """ - question: Question = Field( description="Question node to process.", ) def get_subquestions(graph: MotleyGraphStore, question: Question) -> list[Question]: - """ Description - - Args: - graph (MotleyGraphStore): - question (Question): - - Returns: - list[Question]: - """ query = ( "MATCH (n1:{})-[]->(n2:{}) " "WHERE n1.id = $question_id and n2.context IS NOT NULL " @@ -98,16 +71,6 @@ def create_answer_question_langchain_tool( answer_length: int, prompt: str | BasePromptTemplate = None, ) -> Tool: - """ Creates a LangChainTool for the AnswerSubQuestionTool. - - Args: - graph (MotleyGraphStore): - answer_length (int): - prompt (:obj:`str`, :obj:`BasePromptTemplate`, optional): - - Returns: - - """ if prompt is None: prompt = _default_prompt diff --git a/motleycrew/applications/research_agent/question_generator.py b/motleycrew/applications/research_agent/question_generator.py index 3ce49f1d..95ee6615 100644 --- a/motleycrew/applications/research_agent/question_generator.py +++ b/motleycrew/applications/research_agent/question_generator.py @@ -1,35 +1,22 @@ -""" Module description - -Attributes: - IS_SUBQUESTION_PREDICATE (str): - default_prompt (PromptTemplate): - -""" - from typing import Optional -from pathlib import Path -import time from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts import PromptTemplate +from langchain_core.prompts.base import BasePromptTemplate +from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.runnables import ( RunnablePassthrough, RunnableLambda, ) from langchain_core.tools import Tool -from langchain_core.prompts.base import BasePromptTemplate -from langchain_core.prompts import PromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field - -from motleycrew.tools import MotleyTool +from motleycrew.applications.research_agent.question import Question from motleycrew.common import LLMFramework +from motleycrew.common import logger from motleycrew.common.llms import init_llm from motleycrew.common.utils import print_passthrough from motleycrew.storage import MotleyGraphStore -from motleycrew.common import logger - - -from motleycrew.applications.research_agent.question import Question, QuestionGenerationTaskUnit +from motleycrew.tools import MotleyTool IS_SUBQUESTION_PREDICATE = "is_subquestion" @@ -49,9 +36,6 @@ """ ) -# " The new questions should have no semantic overlap with questions in the following list:\n" -# " {previous_questions}\n" - class QuestionGeneratorTool(MotleyTool): """ @@ -72,15 +56,6 @@ def __init__( llm: Optional[BaseLanguageModel] = None, prompt: str | BasePromptTemplate = None, ): - """Description - - Args: - query_tool (MotleyTool): - graph (MotleyGraphStore): - max_questions (:obj:`int`, optional): - llm (:obj:`BaseLanguageModel`, optional: - prompt (:obj:`str`, :obj:`BasePromptTemplate`, optional): - """ langchain_tool = create_question_generator_langchain_tool( query_tool=query_tool, graph=graph, @@ -93,11 +68,7 @@ def __init__( class QuestionGeneratorToolInput(BaseModel, arbitrary_types_allowed=True): - """Input for the Question Generator Tool. - - Attributes: - question (Question): - """ + """Input for the Question Generator Tool.""" question: Question = Field(description="The input question for which to generate subquestions.") @@ -109,18 +80,6 @@ def create_question_generator_langchain_tool( llm: Optional[BaseLanguageModel] = None, prompt: str | BasePromptTemplate = None, ): - """Description - - Args: - query_tool (MotleyTool): - graph (MotleyGraphStore): - max_questions (:obj:`int`, optional): - llm (:obj:`BaseLanguageModel`, optional: - prompt (:obj:`str`, :obj:`BasePromptTemplate`, optional): - - Returns: - - """ if llm is None: llm = init_llm(llm_framework=LLMFramework.LANGCHAIN) @@ -167,36 +126,3 @@ def set_context(input_dict: dict): and insert them into the knowledge graph.""", args_schema=QuestionGeneratorToolInput, ) - - -if __name__ == "__main__": - import kuzu - from llama_index.graph_stores.kuzu import KuzuGraphStore - - here = Path(__file__).parent - db_path = str(here / "test2") - - db = kuzu.Database(db_path) - graph_store = KuzuGraphStore(db) - - query_tool = MotleyTool.from_langchain_tool( - Tool.from_function( - func=lambda question: [ - "Germany has consisted of many different states over the years", - "The capital of France has moved in 1815, from Lyons to Paris", - "France actually has two capitals, one in the north and one in the south", - ], - name="Query Tool", - description="Query the library for relevant information.", - args_schema=QuestionGeneratorToolInput, - ) - ) - - tool = QuestionGeneratorTool( - query_tool=query_tool, - graph=graph_store, - max_questions=3, - ) - - tool.invoke({"question": "What is the capital of France?"}) - print("Done!") diff --git a/motleycrew/applications/research_agent/question_prioritizer.py b/motleycrew/applications/research_agent/question_prioritizer.py index acd9f9b9..f24cbbb6 100644 --- a/motleycrew/applications/research_agent/question_prioritizer.py +++ b/motleycrew/applications/research_agent/question_prioritizer.py @@ -1,35 +1,26 @@ -""" Module description - -Attributes: - _default_prompt (PromptTemplate): -""" -from langchain_core.pydantic_v1 import BaseModel, Field from langchain.prompts import PromptTemplate from langchain_core.prompts.base import BasePromptTemplate -from langchain_core.tools import StructuredTool +from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.runnables import ( RunnablePassthrough, RunnableLambda, chain, ) - -from motleycrew.tools import MotleyTool -from motleycrew.tools import LLMTool -from motleycrew.common.utils import print_passthrough +from langchain_core.tools import StructuredTool from motleycrew.applications.research_agent.question import Question +from motleycrew.common.utils import print_passthrough +from motleycrew.tools import LLMTool +from motleycrew.tools import MotleyTool class QuestionPrioritizerTool(MotleyTool): + """Tool to prioritize subquestions based on the original question.""" + def __init__( self, prompt: str | BasePromptTemplate = None, ): - """ Description - - Args: - prompt (:obj:`str`, :obj:`BasePromptTemplate`, optional): - """ langchain_tool = create_question_prioritizer_langchain_tool(prompt=prompt) super().__init__(langchain_tool) @@ -50,13 +41,8 @@ def __init__( class QuestionPrioritizerInput(BaseModel, arbitrary_types_allowed=True): - """ Description + """Input for the QuestionPrioritizerTool.""" - Attributes: - original_question (Question): - unanswered_questions (list[Question]): - - """ original_question: Question = Field(description="The original question.") unanswered_questions: list[Question] = Field( description="Questions to pick the most pertinent to the original question from.", @@ -66,14 +52,6 @@ class QuestionPrioritizerInput(BaseModel, arbitrary_types_allowed=True): def create_question_prioritizer_langchain_tool( prompt: str | BasePromptTemplate = None, ) -> StructuredTool: - """ Creates a LangChainTool for the AnswerSubQuestionTool. - - Args: - prompt (:obj:`str`, :obj:`BasePromptTemplate`, optional): - - Returns: - StructuredTool - """ if prompt is None: prompt = _default_prompt @@ -128,17 +106,3 @@ def get_most_pertinent_question(input_dict: dict): ) return langchain_tool - - -if __name__ == "__main__": - q = Question(question="What color is the sky?") - unanswered = [ - Question(question="What time of day is it?"), - Question(question="Who was H.P.Lovecraft?"), - ] - - out = QuestionPrioritizerTool().invoke( - {"unanswered_questions": unanswered, "original_question": q} - ) - print(out) - print("yay!") diff --git a/motleycrew/applications/research_agent/question_task.py b/motleycrew/applications/research_agent/question_task.py index 345d9caf..1ef85603 100644 --- a/motleycrew/applications/research_agent/question_task.py +++ b/motleycrew/applications/research_agent/question_task.py @@ -1,21 +1,24 @@ -""" Module description """ - from typing import List, Optional from langchain_core.runnables import Runnable -from motleycrew.tasks import Task +from motleycrew.common import logger +from motleycrew.crew import MotleyCrew +from motleycrew.tasks import Task, TaskUnit from motleycrew.tasks.task_unit import TaskUnitType from motleycrew.tools import MotleyTool -from motleycrew.crew import MotleyCrew -from motleycrew.common import TaskUnitStatus -from .question import Question, QuestionGenerationTaskUnit +from .question import Question from .question_generator import QuestionGeneratorTool from .question_prioritizer import QuestionPrioritizerTool -from motleycrew.common import logger + + +class QuestionGenerationTaskUnit(TaskUnit): + question: Question class QuestionTask(Task): + """Task to generate subquestions based on a given question.""" + def __init__( self, question: str, @@ -25,17 +28,6 @@ def __init__( allow_async_units: bool = False, name: str = "QuestionTask", ): - """Description - - Args: - question (str): - query_tool (MotleyTool): - crew (MotleyCrew): - max_iter (:obj:`int`, optional): - name (:obj:`str`, optional): - """ - # Need to supply the crew already at this stage - # because need to use the graph store in constructor super().__init__( name=name, task_unit_class=QuestionGenerationTaskUnit, @@ -53,11 +45,8 @@ def __init__( ) def get_next_unit(self) -> QuestionGenerationTaskUnit | None: - """Description + """Choose the most pertinent question to generate subquestions for.""" - Returns: - QuestionGenerationTaskUnit - """ if self.done or self.n_iter >= self.max_iter: return None @@ -82,42 +71,26 @@ def get_next_unit(self) -> QuestionGenerationTaskUnit | None: logger.info("Most pertinent question according to the tool: %s", most_pertinent_question) return QuestionGenerationTaskUnit(question=most_pertinent_question) - def register_started_unit(self, unit: TaskUnitType) -> None: - """Description + def on_unit_dispatch(self, unit: TaskUnitType) -> None: + """Increment the iteration count when a unit is dispatched.""" - Args: - unit (TaskUnitType): - - Returns: - - """ logger.info("==== Started iteration %s of %s ====", self.n_iter + 1, self.max_iter) self.n_iter += 1 - def register_completed_unit(self, unit: TaskUnitType) -> None: + def on_unit_completion(self, unit: TaskUnitType) -> None: + """Check if the task is done after each unit completion. + + The task is done if the maximum number of iterations is reached.""" + if self.n_iter >= self.max_iter: self.set_done(True) def get_worker(self, tools: Optional[List[MotleyTool]]) -> Runnable: - """Description + """Return the worker that will process the task units.""" - Args: - tools (List[MotleyTool]): - - Returns: - Runnable - """ return self.question_generation_tool def get_unanswered_questions(self, only_without_children: bool = False) -> list[Question]: - """Description - - Args: - only_without_children (:obj:`bool`, optional): - - Returns: - list[Question] - """ if only_without_children: query = ( "MATCH (n1:{}) WHERE n1.answer IS NULL AND NOT (n1)-[]->(:{}) RETURN n1;".format( diff --git a/motleycrew/caching/http_cache.py b/motleycrew/caching/http_cache.py deleted file mode 100644 index e69de29b..00000000 diff --git a/motleycrew/common/defaults.py b/motleycrew/common/defaults.py index d2e3fdb0..970acf0c 100644 --- a/motleycrew/common/defaults.py +++ b/motleycrew/common/defaults.py @@ -1,24 +1,9 @@ -""" Module description """ - from motleycrew.common import LLMFamily from motleycrew.common import GraphStoreType class Defaults: - """Description - - Attributes: - DEFAULT_LLM_FAMILY (str): - DEFAULT_LLM_NAME (str): - DEFAULT_LLM_TEMPERATURE (float): - LLM_MAP (dict): - DEFAULT_GRAPH_STORE_TYPE (str): - MODULE_INSTALL_COMMANDS (dict): - DEFAULT_NUM_THREADS (int): - DEFAULT_EVENT_LOOP_SLEEP (int): - DEFAULT_OUTPUT_HANDLER_MAX_ITERATIONS (int): - - """ + """Default values for various settings.""" DEFAULT_LLM_FAMILY = LLMFamily.OPENAI DEFAULT_LLM_NAME = "gpt-4o" diff --git a/motleycrew/common/enums.py b/motleycrew/common/enums.py index 2afa1827..a095eca4 100644 --- a/motleycrew/common/enums.py +++ b/motleycrew/common/enums.py @@ -1,60 +1,27 @@ -class LLMFamily: - """ Description +"""Various enums used in the project.""" - Attributes: - OPENAI (str): - """ +class LLMFamily: OPENAI = "openai" ANTHROPIC = "anthropic" class LLMFramework: - """ Description - - Attributes: - LANGCHAIN (str): - LLAMA_INDEX (str): - - """ LANGCHAIN = "langchain" LLAMA_INDEX = "llama_index" class GraphStoreType: - """ Description - - Attributes: - KUZU (str): - - """ KUZU = "kuzu" class TaskUnitStatus: - """Description - - Attributes: - PENDING (str): - RUNNING (str): - DONE (str): - """ PENDING = "pending" RUNNING = "running" DONE = "done" class LunaryRunType: - """ Description - - Attributes: - LLM (str): - AGENT (str): - TOOL (str): - CHAIN (str): - EMBED (str): - - """ LLM = "llm" AGENT = "agent" TOOL = "tool" @@ -63,15 +30,6 @@ class LunaryRunType: class LunaryEventName: - """ Description - - Attributes: - START (str): - END (str): - UPDATE (str): - ERROR (str): - - """ START = "start" END = "end" UPDATE = "update" @@ -79,14 +37,14 @@ class LunaryEventName: class AsyncBackend: - """ Backends for parallel launch + """Backends for parallel crew execution. Attributes: - ASYNCIO (str): Asynchronous startup using asyncio - THREADING (str): Running using threads - NONE (str): Synchronous startup - + ASYNCIO: Asynchronous execution using asyncio. + THREADING: Parallel execution using threads. + NONE: Synchronous execution. """ + ASYNCIO = "asyncio" THREADING = "threading" NONE = "none" diff --git a/motleycrew/common/exceptions.py b/motleycrew/common/exceptions.py index eb101b42..a02bc127 100644 --- a/motleycrew/common/exceptions.py +++ b/motleycrew/common/exceptions.py @@ -1,4 +1,4 @@ -""" Module description""" +"""Exceptions for motleycrew""" from typing import Any, Dict, Optional @@ -6,12 +6,7 @@ class LLMFamilyNotSupported(Exception): - """Description - - Args: - llm_framework (str): - llm_family (str): - """ + """Raised when an LLM family is not supported in motleycrew via a framework.""" def __init__(self, llm_framework: str, llm_family: str): self.llm_framework = llm_framework @@ -22,12 +17,9 @@ def __str__(self) -> str: class LLMFrameworkNotSupported(Exception): - def __init__(self, llm_framework: str): - """Description + """Raised when an LLM framework is not supported in motleycrew.""" - Args: - llm_framework (str): - """ + def __init__(self, llm_framework: str): self.llm_framework = llm_framework def __str__(self) -> str: @@ -35,11 +27,7 @@ def __str__(self) -> str: class AgentNotMaterialized(Exception): - """Description - - Args: - agent_name (str): - """ + """Raised when an attempt is made to use an agent that is not yet materialized.""" def __init__(self, agent_name: str): self.agent_name = agent_name @@ -49,18 +37,14 @@ def __str__(self) -> str: class CannotModifyMaterializedAgent(Exception): - """Description - - Args: - agent_name (str): - """ + """Raised when an attempt is made to modify a materialized agent, e.g. to add tools.""" def __init__(self, agent_name: str | None): self.agent_name = agent_name def __str__(self) -> str: return "Cannot modify agent{} as it is already materialized".format( - f" `{self.agent_name}`" if self.agent_name is not None else "" + f" '{self.agent_name}'" if self.agent_name is not None else "" ) @@ -69,13 +53,13 @@ class TaskDependencyCycleError(Exception): class IntegrationTestException(Exception): - """Integration tests exception - - Args: - test_names (list[str]): list of names of failed integration tests - """ + """One or more integration tests failed.""" def __init__(self, test_names: list[str]): + """ + Args: + test_names: List of names of failed integration tests. + """ self.test_names = test_names def __str__(self): @@ -83,32 +67,28 @@ def __str__(self): class IpynbIntegrationTestResultNotFound(Exception): - """Ipynb integration test not found result file exception - - Args: - ipynb_path (str): path to running ipynb - result_path (str): path to execution result file - """ + """Raised when the result file of an ipynb integration test run is not found.""" def __init__(self, ipynb_path: str, result_path: str): self.ipynb_path = ipynb_path self.result_path = result_path def __str__(self): - return "File result {} of the ipynb {} execution, not found.".format( + return "File {} with result of the ipynb {} execution not found.".format( self.result_path, self.ipynb_path ) class ModuleNotInstalled(Exception): - """Module not installed - - Args: - module_name (str): the name of the non-installed module - install_command (:obj:`str`, optional): the command to install + """Raised when trying to use some functionality that requires a module that is not installed. """ def __init__(self, module_name: str, install_command: str = None): + """ + Args: + module_name: Name of the module. + install_command: Command to install the module. + """ self.module_name = module_name self.install_command = install_command or Defaults.MODULE_INSTALL_COMMANDS.get( module_name, None @@ -139,13 +119,13 @@ def __str__(self): class InvalidOutput(Exception): - """Raised in output handlers when an agent's output is not accepted""" + """Raised in output handlers when an agent's output is not accepted.""" pass class OutputHandlerMaxIterationsExceeded(BaseException): - """Raised when the output handlers iteration limit is exceeded""" + """Raised when the output handler iterations limit is exceeded.""" def __init__( self, @@ -153,9 +133,17 @@ def __init__( last_call_kwargs: Dict[str, Any], last_exception: Exception, ): + """ + Args: + last_call_args: Positional arguments with which the output handler was last called. + last_call_kwargs: Keyword arguments with which the output handler was last called. + last_exception: Exception that occurred during the last output handler iteration. + """ self.last_call_args = last_call_args self.last_call_kwargs = last_call_kwargs self.last_exception = last_exception def __str__(self): - return "Maximum number of output handler iterations exceeded" + return "Maximum number of output handler iterations exceeded. Last exception: {}".format( + self.last_exception + ) diff --git a/motleycrew/common/llms.py b/motleycrew/common/llms.py index 1d551336..cef8a242 100644 --- a/motleycrew/common/llms.py +++ b/motleycrew/common/llms.py @@ -1,4 +1,5 @@ -""" Module description""" +"""Helper functions to initialize Language Models (LLMs) from different frameworks.""" + from motleycrew.common import Defaults from motleycrew.common import LLMFamily, LLMFramework from motleycrew.common.exceptions import LLMFamilyNotSupported, LLMFrameworkNotSupported @@ -10,15 +11,11 @@ def langchain_openai_llm( llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE, **kwargs, ): - """ Description + """Initialize an OpenAI LLM client for use with Langchain. Args: - llm_name (:obj:`str`, optional): - llm_temperature (:obj:`float`, optional): - **kwargs: - - Returns: - + llm_name: Name of the LLM in OpenAI API. + llm_temperature: Temperature for the LLM. """ from langchain_openai import ChatOpenAI @@ -30,16 +27,13 @@ def llama_index_openai_llm( llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE, **kwargs, ): - """ Description + """Initialize an OpenAI LLM client for use with LlamaIndex. Args: - llm_name (:obj:`str`, optional): - llm_temperature (:obj:`float`, optional): - **kwargs: - - Returns: - + llm_name: Name of the LLM in OpenAI API. + llm_temperature: Temperature for the LLM. """ + ensure_module_is_installed("llama_index") from llama_index.llms.openai import OpenAI @@ -51,6 +45,13 @@ def langchain_anthropic_llm( llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE, **kwargs, ): + """Initialize an Anthropic LLM client for use with Langchain. + + Args: + llm_name: Name of the LLM in Anthropic API. + llm_temperature: Temperature for the LLM. + """ + from langchain_anthropic import ChatAnthropic return ChatAnthropic(model=llm_name, temperature=llm_temperature, **kwargs) @@ -61,6 +62,12 @@ def llama_index_anthropic_llm( llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE, **kwargs, ): + """Initialize an Anthropic LLM client for use with LlamaIndex. + + Args: + llm_name: Name of the LLM in Anthropic API. + llm_temperature: Temperature for the LLM. + """ ensure_module_is_installed("llama_index") from llama_index.llms.anthropic import Anthropic @@ -82,20 +89,13 @@ def init_llm( llm_temperature: float = Defaults.DEFAULT_LLM_TEMPERATURE, **kwargs, ): - """ Description + """Initialize an LLM client for use with the specified framework and family. Args: - llm_framework (str): - llm_family (:obj:`str`, optional): - llm_name (:obj:`str`, optional): - llm_temperature (:obj:`float`, optional): - **kwargs: - - Raises: - LLMFamilyNotSupported - - Returns: - + llm_framework: Framework of the LLM client. + llm_family: Family of the LLM. + llm_name: Name of the LLM. + llm_temperature: Temperature for the LLM. """ func = Defaults.LLM_MAP.get((llm_framework, llm_family), None) diff --git a/motleycrew/common/logging.py b/motleycrew/common/logging.py index 3b665f64..e65af957 100644 --- a/motleycrew/common/logging.py +++ b/motleycrew/common/logging.py @@ -1,4 +1,4 @@ -""" Project logger configuration module +"""Project logger configuration module Attributes: logger (logging.Logger): project logger @@ -16,7 +16,7 @@ def configure_logging(verbose: bool = False, debug: bool = False): - """ Logging configuration + """Logging configuration Args: verbose (:obj:`bool`, optional): if true logging level = INFO diff --git a/motleycrew/common/types.py b/motleycrew/common/types.py index da3c0654..48bd1a62 100644 --- a/motleycrew/common/types.py +++ b/motleycrew/common/types.py @@ -1,4 +1,12 @@ -""" Module description""" +"""Various types and type protocols used in motleycrew. + +Attributes: + MotleySupportedTool: Type that represents a tool that is supported by motleycrew. + It includes tools from motleycrew, langchain, llama_index, and motleycrew agents. +""" + +from __future__ import annotations + from typing import TYPE_CHECKING, Union, Optional, Protocol, TypeVar if TYPE_CHECKING: @@ -26,8 +34,8 @@ class MotleyAgentFactory(Protocol[AgentType]): - """ - Type protocol for an agent factory. + """Type protocol for an agent factory. + It is a function that accepts tools as an argument and returns an agent instance of an appropriate class. """ diff --git a/motleycrew/common/utils.py b/motleycrew/common/utils.py index 507ccf84..bb96fca6 100644 --- a/motleycrew/common/utils.py +++ b/motleycrew/common/utils.py @@ -1,4 +1,4 @@ -""" Module description""" +"""Various helpers and utility functions used throughout the project.""" import sys from typing import Optional, Sequence import hashlib @@ -9,14 +9,8 @@ def to_str(value: str | BaseMessage | Sequence[str] | Sequence[BaseMessage]) -> str: - """ Description + """Converts a message to a string.""" - Args: - value (:obj:`str`, :obj:`BaseMessage`, :obj:`Sequence[str]`, :obj:`Sequence[BaseMessage]`): - - Returns: - str: - """ if isinstance(value, str): return value elif isinstance(value, BaseMessage): @@ -29,14 +23,8 @@ def to_str(value: str | BaseMessage | Sequence[str] | Sequence[BaseMessage]) -> def is_http_url(url): - """ Description - - Args: - url (str): + """Check if the URL is an HTTP URL.""" - Returns: - bool: - """ try: parsed_url = urlparse(url) return parsed_url.scheme in ["http", "https"] @@ -45,15 +33,8 @@ def is_http_url(url): def generate_hex_hash(data: str, length: Optional[int] = None): - """ Description - - Args: - data (str): - length (:obj:`int`, optional): - - Returns: + """Generate a SHA256 hex digest from the given data.""" - """ hash_obj = hashlib.sha256() hash_obj.update(data.encode("utf-8")) hex_hash = hash_obj.hexdigest() @@ -64,19 +45,17 @@ def generate_hex_hash(data: str, length: Optional[int] = None): def print_passthrough(x): + """A helper function useful for debugging LCEL chains. It just returns the input value. + + You can put a breakpoint in this function to debug the chain. + """ + return x def ensure_module_is_installed(module_name: str, install_command: str = None) -> None: - """ Checking the installation of the module - - Args: - module_name (str): - install_command (:obj:`str`, optional): + """Ensure that the given module is installed.""" - Raises: - ModuleNotInstalled: - """ module_path = sys.modules.get(module_name, None) if module_path is None: raise ModuleNotInstalled(module_name, install_command) diff --git a/motleycrew/crew/crew.py b/motleycrew/crew/crew.py index a5073259..ba13b726 100644 --- a/motleycrew/crew/crew.py +++ b/motleycrew/crew/crew.py @@ -1,7 +1,7 @@ import asyncio -from typing import Collection, Generator, Sequence, Optional, Any, List, Tuple import threading import time +from typing import Collection, Generator, Optional, Any from motleycrew.agents.parent import MotleyAgentParent from motleycrew.common import logger, AsyncBackend, Defaults @@ -26,14 +26,14 @@ def __init__( """Initialize the crew. Args: - graph_store (:obj:`MotleyGraphStore`, optional): The graph store to use. + graph_store: The graph store to use. If not provided, a new one will be created with default settings. - async_backend (:obj:`AsyncBackend`, optional): The type of async backend to use. - Defaults to `AsyncBackend.NONE`, which means the crew will run synchronously. - The other options are `AsyncBackend.ASYNCIO` and `AsyncBackend.THREADING`. + async_backend: The type of async backend to use. + Defaults to :obj:`AsyncBackend.NONE`, which means the crew will run synchronously. + The other options are :obj:`AsyncBackend.ASYNCIO` and :obj:`AsyncBackend.THREADING`. - num_threads (:obj:`int`, optional): + num_threads: The number of threads to use when running in threaded mode. """ if graph_store is None: @@ -72,8 +72,8 @@ def add_dependency(self, upstream: Task, downstream: Task): """Add a dependency between two tasks in the graph store. Args: - upstream (Task): The upstream task. - downstream (Task): The downstream task. + upstream: The upstream task. + downstream: The downstream task. """ self.graph_store.create_relation( upstream.node, downstream.node, label=Task.TASK_IS_UPSTREAM_LABEL @@ -85,7 +85,7 @@ def register_tasks(self, tasks: Collection[Task]): """Insert tasks into the crew's graph store. Args: - tasks (Collection[Task]): The tasks to register. + tasks: The tasks to register. """ for task in tasks: if task not in self.tasks: @@ -106,10 +106,10 @@ def _prepare_next_unit_for_dispatch( ) -> Generator[MotleyAgentParent, Task, TaskUnit]: """Retrieve and prepare the next unit for dispatch. Args: - running_sync_tasks (set): Collection of currently running forced synchronous tasks + running_sync_tasks: Collection of currently running forced synchronous tasks. Yields: - tuple: agent, task, unit to be dispatched + Agent, task, unit to be dispatched. """ available_tasks = self.get_available_tasks() logger.info("Available tasks: %s", available_tasks) @@ -138,7 +138,7 @@ def _prepare_next_unit_for_dispatch( logger.info("Assigned unit %s to agent %s, dispatching", next_unit, agent) next_unit.set_running() self.add_task_unit_to_graph(task=task, unit=next_unit) - task.register_started_unit(next_unit) + task.on_unit_dispatch(next_unit) yield agent, task, next_unit @@ -153,11 +153,11 @@ def _handle_task_unit_completion( """Handle task unit completion. Args: - task (Task): Task object - unit (TaskUnit): Task unit object - result (Any): Result of the task unit - running_sync_tasks (set): Collection of currently running forced synchronous tasks - done_units (list): List of completed task units + task: Task object. + unit: Task unit object. + result: Result of the task unit. + running_sync_tasks: Collection of currently running forced synchronous tasks. + done_units: List of completed task units. """ if task in running_sync_tasks: @@ -166,14 +166,14 @@ def _handle_task_unit_completion( unit.output = result logger.info("Task unit %s completed, marking as done", unit) unit.set_done() - task.register_completed_unit(unit) + task.on_unit_completion(unit) done_units.append(unit) def _run_sync(self) -> list[TaskUnit]: """Run the crew synchronously. Returns: - :obj:`list` of :obj:`TaskUnit`: List of completed task units + List of completed task units. """ done_units = [] while True: @@ -203,7 +203,7 @@ def _run_threaded(self) -> list[TaskUnit]: """Run the crew in threads. Returns: - :obj:`list` of :obj:`TaskUnit`: List of completed task units + List of completed task units. """ done_units = [] @@ -225,7 +225,7 @@ def _run_threaded(self) -> list[TaskUnit]: ): thread_pool.add_task_unit(agent, next_task, next_unit) - if thread_pool.is_completed(): + if thread_pool.is_completed: logger.info("Nothing left to do, exiting") return done_units @@ -241,7 +241,7 @@ async def _run_async(self) -> list[TaskUnit]: """Run the crew asynchronously. Returns: - :obj:`list` of :obj:`TaskUnit`: List of completed task units + List of completed task units. """ done_units = [] @@ -261,12 +261,8 @@ async def _run_async(self) -> list[TaskUnit]: done_units=done_units, ) - for agent, next_task, next_unit in self._prepare_next_unit_for_dispatch( - running_tasks - ): - async_task = asyncio.create_task( - MotleyCrew._async_invoke_agent(agent, next_unit) - ) + for agent, next_task, next_unit in self._prepare_next_unit_for_dispatch(running_tasks): + async_task = asyncio.create_task(MotleyCrew._async_invoke_agent(agent, next_unit)) async_units[async_task] = (next_task, next_unit) if not async_units: @@ -276,10 +272,11 @@ async def _run_async(self) -> list[TaskUnit]: await asyncio.sleep(Defaults.DEFAULT_EVENT_LOOP_SLEEP) def get_available_tasks(self) -> list[Task]: - """Get tasks that are available for dispatching units at the moment. + """Get tasks that are able to dispatch units at the moment. + These are tasks that have no upstream dependencies that are not done. Returns: - :obj:`list` of :obj:`Task`: List of tasks that are available + List of tasks that are available. """ query = ( "MATCH (downstream:{}) " @@ -292,17 +289,15 @@ def get_available_tasks(self) -> list[Task]: Task.NODE_CLASS.get_label(), Task.TASK_IS_UPSTREAM_LABEL, ) - available_task_nodes = self.graph_store.run_cypher_query( - query, container=Task.NODE_CLASS - ) + available_task_nodes = self.graph_store.run_cypher_query(query, container=Task.NODE_CLASS) return [task for task in self.tasks if task.node in available_task_nodes] def add_task_unit_to_graph(self, task: Task, unit: TaskUnitType): """Add a task unit to the graph and connect it to its task. Args: - task (Task): The task to which the unit belongs. - unit (TaskUnitType): The unit to add. + task: The task to which the unit belongs. + unit: The unit to add. """ assert isinstance(unit, task.task_unit_class) assert not unit.done @@ -314,6 +309,11 @@ def add_task_unit_to_graph(self, task: Task, unit: TaskUnitType): ) def get_extra_tools(self, task: Task) -> list[MotleyTool]: + """Get tools that should be added to the agent for a given task. + + Args: + task: The task for which to get extra tools. + """ # TODO: Smart tool selection goes here tools = [] tools += self.tools or [] diff --git a/motleycrew/crew/crew_threads.py b/motleycrew/crew/crew_threads.py index 2916b239..c19a0660 100644 --- a/motleycrew/crew/crew_threads.py +++ b/motleycrew/crew/crew_threads.py @@ -1,4 +1,4 @@ -"""Thread pool module for running agents""" +"""Thread pool module for running agents.""" import threading from enum import Enum @@ -24,14 +24,16 @@ class TaskUnitThreadState(Enum): class TaskUnitThread(threading.Thread): + """The thread class for running agents on task units.""" + def __init__(self, input_queue: Queue, output_queue: Queue, *args, **kwargs): - """The thread class for running task units. + """Initialize the thread. Args: - input_queue (Queue): queue of task units to complete - output_queue (Queue): queue of completed task units - *args: - **kwargs: + input_queue: Queue of task units to complete. + output_queue: Queue of completed task units. + *args: threading.Thread arguments. + **kwargs: threading.Thread keyword arguments. """ self.input_queue = input_queue self.output_queue = output_queue @@ -41,6 +43,7 @@ def __init__(self, input_queue: Queue, output_queue: Queue, *args, **kwargs): @property def state(self): + """State of the thread.""" return self._state def run(self) -> None: @@ -71,11 +74,13 @@ def run(self) -> None: class TaskUnitThreadPool: + """The thread pool class for running agents on task units.""" + def __init__(self, num_threads: int = Defaults.DEFAULT_NUM_THREADS): - """The thread pool class for performing task units. + """Initialize the thread pool. Args: - num_threads (int): number of threads to create + num_threads: Number of threads to create. """ self.num_threads = num_threads @@ -93,9 +98,9 @@ def add_task_unit(self, agent: Runnable, task: "Task", unit: "TaskUnit") -> None """Adds a task unit to the queue for execution. Args: - agent (Runnable): agent to run the task unit - task (Task): task to which the unit belongs - unit (TaskUnit): task unit to run + agent: Agent to run the task unit. + task: Task to which the unit belongs. + unit: Task unit to run. """ self._task_units_in_progress.append((task, unit)) self.input_queue.put((agent, task, unit)) @@ -104,7 +109,7 @@ def get_completed_task_units(self) -> List[Tuple["Task", "TaskUnit", Any]]: """Returns a list of completed task units with their results. Returns: - List[Tuple[Task, TaskUnit, Any]]: list of triplets of (task, task unit, result) + List of triplets of (task, task unit, result). """ completed_tasks = [] while not self.output_queue.empty(): @@ -127,10 +132,7 @@ def wait_and_close(self): for t in self._threads: t.join() + @property def is_completed(self) -> bool: - """Returns whether all task units have been completed. - - Returns: - bool: - """ + """Whether all task units have been completed.""" return not bool(self._task_units_in_progress) diff --git a/motleycrew/storage/graph_node.py b/motleycrew/storage/graph_node.py index 19c03fb8..b210a88a 100644 --- a/motleycrew/storage/graph_node.py +++ b/motleycrew/storage/graph_node.py @@ -1,19 +1,18 @@ -""" Module description - -Attributes: - MotleyGraphNodeType (TypeVar): - -""" - from typing import Optional, Any, TypeVar, TYPE_CHECKING +from abc import ABC from pydantic import BaseModel if TYPE_CHECKING: from motleycrew.storage import MotleyGraphStore -class MotleyGraphNode(BaseModel): - """Description""" +class MotleyGraphNode(BaseModel, ABC): + """Base class for describing nodes in the graph. + + Attributes: + __label__: Label of the node in the graph. If not set, the class name is used. + __graph_store__: Graph store in which the node is stored. + """ # Q: KuzuGraphNode a better name? Because def id is specific? # A: No, I think _id attribute is pretty universal @@ -22,19 +21,30 @@ class MotleyGraphNode(BaseModel): @property def id(self) -> Optional[Any]: + """Identifier of the node in the graph. + + The identifier is unique **among nodes of the same label**. + If the node is not inserted in the graph, the identifier is None. + """ return getattr(self, "_id", None) @property def is_inserted(self) -> bool: + """Whether the node is inserted in the graph.""" return self.id is not None @classmethod def get_label(cls) -> str: - """Description + """Get the label of the node. + + Labels can be viewed as node types in the graph. + Generally, the label is the class name, + but it can be overridden by setting the __label__ attribute. Returns: - str: + Label of the node. """ + # Q: why not @property def label(cls) -> str: return cls.__label__ or cls.__name__ ? # A: Because we want to be able to call this method without an instance # and properties can't be class methods since Python 3.12 @@ -43,6 +53,13 @@ def get_label(cls) -> str: return cls.__name__ def __setattr__(self, name, value): + """Set the attribute value + and update the property in the graph store if the node is inserted. + + Args: + name: Name of the attribute. + value: Value of the attribute. + """ super().__setattr__(name, value) if name not in self.model_fields: @@ -54,6 +71,16 @@ def __setattr__(self, name, value): self.__graph_store__.update_property(self, name) def __eq__(self, other): + """Comparison operator for nodes. + + Two nodes are considered equal if they have the same label and identifier. + + Args: + other: Node to compare with. + + Returns: + Whether the nodes are equal. + """ return self.is_inserted and self.get_label() == other.get_label() and self.id == other.id diff --git a/motleycrew/storage/graph_store.py b/motleycrew/storage/graph_store.py index dae6a3eb..34e37f27 100644 --- a/motleycrew/storage/graph_store.py +++ b/motleycrew/storage/graph_store.py @@ -1,50 +1,47 @@ -""" Module description """ from abc import ABC, abstractmethod from typing import Optional, Type from motleycrew.storage import MotleyGraphNode, MotleyGraphNodeType class MotleyGraphStore(ABC): + """Abstract class for a graph database store.""" + @abstractmethod def check_node_exists_by_class_and_id( self, node_class: Type[MotleyGraphNode], node_id: int ) -> bool: - """ Check if a node of given class with given id is present in the database. + """Check if a node of given class with given id is present in the database. Args: - node_class (Type[MotleyGraphNode]): - node_id (int): - - Returns: - bool: + node_class: Python class of the node + node_id: id of the node """ pass @abstractmethod def check_node_exists(self, node: MotleyGraphNode) -> bool: - """ Check if the given node is present in the database. + """Check if the given node is present in the database. Args: - node (MotleyGraphNode): + node: node to check Returns: - bool: + whether the node is present in the database """ + pass @abstractmethod def check_relation_exists( self, from_node: MotleyGraphNode, to_node: MotleyGraphNode, label: Optional[str] ) -> bool: - """ Check if a relation exists between two nodes with given label. + """Check if a relation exists between two nodes with given label. Args: - from_node (MotleyGraphNode): - to_node (MotleyGraphNode): - label (:obj:`str`, None): + from_node: starting node + to_node: ending node + label: relation label. If None, check if any relation exists between the nodes. - Returns: - bool: """ pass diff --git a/motleycrew/storage/graph_store_utils.py b/motleycrew/storage/graph_store_utils.py index 7152155e..56ab9174 100644 --- a/motleycrew/storage/graph_store_utils.py +++ b/motleycrew/storage/graph_store_utils.py @@ -1,4 +1,3 @@ -""" Module description """ from typing import Optional import tempfile import os @@ -6,21 +5,21 @@ from motleycrew.common import Defaults from motleycrew.common import GraphStoreType from motleycrew.common import logger -from motleycrew.storage import MotleyKuzuGraphStore +from motleycrew.storage import MotleyKuzuGraphStore, MotleyGraphStore def init_graph_store( graph_store_type: str = Defaults.DEFAULT_GRAPH_STORE_TYPE, db_path: Optional[str] = None, -): - """ Description +) -> MotleyGraphStore: + """Create and initialize a graph store with the given parameters. Args: - graph_store_type (:obj:`str`, optional): - db_path (:obj:`str`, optional): + graph_store_type: Type of the graph store to use. + db_path: Path to the database for the graph store. Returns: - + Initialized graph store. """ if graph_store_type == GraphStoreType.KUZU: import kuzu diff --git a/motleycrew/storage/kuzu_graph_store.py b/motleycrew/storage/kuzu_graph_store.py index abc7a098..2d8541ba 100644 --- a/motleycrew/storage/kuzu_graph_store.py +++ b/motleycrew/storage/kuzu_graph_store.py @@ -1,11 +1,12 @@ """ Code derived from: https://github.com/run-llama/llama_index/blob/802064aee72b03ab38ead0cda780cfa3e37ce728/llama-index-integrations/graph_stores/llama-index-graph-stores-kuzu/llama_index/graph_stores/kuzu/base.py + Kùzu graph store index. """ import json import os -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, Optional, Type, Collection from kuzu import Connection, PreparedStatement, QueryResult @@ -16,13 +17,7 @@ class MotleyKuzuGraphStore(MotleyGraphStore): - """ - Attributes: - ID_ATTR (str): _id - JSON_CONTENT_PREFIX (str): "JSON__" - PYTHON_TO_CYPHER_TYPES_MAPPING (dict) - - """ + """Kuzu graph store implementation for motleycrew.""" ID_ATTR = "_id" @@ -43,7 +38,7 @@ def __init__(self, database: Any) -> None: """Initialize Kuzu graph store. Args: - database (Any): Kuzu database client + database: Kuzu database client """ self.database = database self.connection = Connection(database) @@ -64,11 +59,11 @@ def _execute_query( """Execute a query, logging it for debugging purposes. Args: - query (:obj:`str`, :obj:`PreparedStatement`): Cypher query or prepared statement - parameters (:obj:`dict` of :obj:`str` , :obj:`Any`): Query parameters + query: Cypher query or prepared statement. + parameters: Query parameters. Returns: - QueryResult: Query result + Query result. """ logger.debug("Executing query: %s", query) if parameters: @@ -77,14 +72,14 @@ def _execute_query( # TODO: retries? return self.connection.execute(query=query, parameters=parameters) - def _check_node_table_exists(self, label: str): - """Description + def _check_node_table_exists(self, label: str) -> bool: + """Check if a table for storing nodes with given label exists in the database. Args: - label (str): + label: Node label. Returns: - + Whether the table exists. """ return label in self.connection._get_node_table_names() @@ -93,16 +88,17 @@ def _check_rel_table_exists( from_label: Optional[str] = None, to_label: Optional[str] = None, rel_label: Optional[str] = None, - ): - """Description + ) -> bool: + """Check if a table for storing relations between nodes with given labels + exists in the database. Args: - from_label (:obj:`str`, optional): - to_label (:obj:`str`, optional): - rel_label (:obj:`str`, optional): + from_label: Label of the source node. + to_label: Label of the destination node. + rel_label: Label of the relation. Returns: - + Whether the table exists. """ for row in self.connection._get_rel_table_names(): if ( @@ -113,14 +109,14 @@ def _check_rel_table_exists( return True return False - def _get_node_property_names(self, label: str): - """Description + def _get_node_property_names(self, label: str) -> Collection[str]: + """Get the names of properties for nodes with given label. Args: - label (str): + label: Node label. Returns: - + Collection of property names. """ return self.connection._get_node_property_names(table_name=label) @@ -128,11 +124,11 @@ def ensure_node_table(self, node_class: Type[MotleyGraphNode]) -> str: """Create a table for storing nodes of that class if such does not already exist. If it does exist, create all missing columns. - Args: - node_class (Type[MotleyGraphNode]): - Returns: - str: Table name + Args: + node_class: Node Python class. + Returns: + Node table name. """ table_name = node_class.get_label() if not self._check_node_table_exists(table_name): @@ -163,17 +159,14 @@ def ensure_node_table(self, node_class: Type[MotleyGraphNode]) -> str: def ensure_relation_table( self, from_class: Type[MotleyGraphNode], to_class: Type[MotleyGraphNode], label: str - ): + ) -> None: """Create a table for storing relations from from_node-like nodes to to_node-like nodes, if such does not already exist. Args: - from_class (Type[MotleyGraphNode]): - to_class (Type[MotleyGraphNode]): - label (str): - - Returns: - + from_class: Source node Python class. + to_class: Destination node Python class. + label: Relation label. """ if not self._check_rel_table_exists( from_label=from_class.get_label(), to_label=to_class.get_label(), rel_label=label @@ -197,11 +190,11 @@ def check_node_exists_by_class_and_id( """Check if a node of given class with given id is present in the database. Args: - node_class (Type[MotleyGraphNode]): - node_id (int): + node_class: Node Python class. + node_id: Node id. Returns: - bool: + Whether the node exists in the database. """ if not self._check_node_table_exists(node_class.get_label()): return False @@ -216,10 +209,10 @@ def check_node_exists(self, node: MotleyGraphNode) -> bool: """Check if the given node is present in the database. Args: - node (MotleyGraphNode): + node: Node to check. Returns: - bool: + Whether the node exists in the database. """ if node.id is None: return False # for cases when id attribute is not set => node does not exist @@ -232,12 +225,12 @@ def check_relation_exists( """Check if a relation exists between two nodes with given label. Args: - from_node (MotleyGraphNode): - to_node (MotleyGraphNode): - label (:obj:`str`, None): + from_node: Source node. + to_node: Destination node. + label: Relation label. If None, any relation is taken into account. Returns: - bool: + Whether the relation exists in the database. """ if from_node.id is None or to_node.id is None: return False @@ -272,14 +265,13 @@ def get_node_by_class_and_id( self, node_class: Type[MotleyGraphNodeType], node_id: int ) -> Optional[MotleyGraphNodeType]: """Retrieve the node of given class with given id if it is present in the database. - Otherwise, return None. Args: - node_class (Type[MotleyGraphNodeType]): - node_id (int): + node_class: Node Python class. + node_id: Node id. Returns: - :obj:`MotleyGraphNodeType`, None: + Node object or None if it does not exist. """ if not self._check_node_table_exists(node_class.get_label()): return None @@ -299,13 +291,14 @@ def get_node_by_class_and_id( def insert_node(self, node: MotleyGraphNodeType) -> MotleyGraphNodeType: """Insert a new node and populate its id. - If node table or some columns do not exist, this method also creates them. + + If the node table or some columns do not exist, this method also creates them. Args: - node (MotleyGraphNodeType): + node: Node to insert. Returns: - MotleyGraphNodeType + Inserted node. """ assert node.id is None, "Entity has its id set, looks like it is already in the DB" @@ -336,15 +329,13 @@ def create_relation( self, from_node: MotleyGraphNode, to_node: MotleyGraphNode, label: str ) -> None: """Create a relation between existing nodes. - If relation table does not exist, this method also creates them - - Args: - from_node (MotleyGraphNode): - to_node (MotleyGraphNode): - label (str): - Returns: + If the relation table does not exist, this method also creates it. + Args: + from_node: Source node. + to_node: Destination node. + label: Relation label. """ assert self.check_node_exists(from_node), ( "From-node is not present in the database, " @@ -380,18 +371,18 @@ def create_relation( assert create_result.has_next() logger.info("Relation created OK") - def upsert_triplet(self, from_node: MotleyGraphNode, to_node: MotleyGraphNode, label: str): + def upsert_triplet( + self, from_node: MotleyGraphNode, to_node: MotleyGraphNode, label: str + ) -> None: """Create a relation with a given label between nodes, if such does not already exist. + If the nodes do not already exist, create them too. - This method also creates and/or updates all necessary tables + This method also creates and/or updates all necessary tables. Args: - from_node (MotleyGraphNode): - to_node (MotleyGraphNode): - label (str): - - Returns: - + from_node: Source node. + to_node: Destination node. + label: Relation label. """ if not self.check_node_exists(from_node): logger.info("Node %s does not exist, creating", from_node) @@ -409,10 +400,7 @@ def delete_node(self, node: MotleyGraphNode) -> None: """Delete a given node and its relations. Args: - node (MotleyGraphNode): - - Returns: - + node: Node to delete. """ def inner_delete_relations(node_label: str, node_id: int) -> None: @@ -449,11 +437,11 @@ def update_property(self, node: MotleyGraphNode, property_name: str) -> MotleyGr """Update a graph node's property with the corresponding value from the node object. Args: - node (MotleyGraphNode): - property_name (str): + node: Node to update. + property_name: Property name to update. Returns: - + Updated node. """ property_value = getattr(node, property_name) @@ -506,15 +494,16 @@ def run_cypher_query( container: Optional[Type[MotleyGraphNodeType]] = None, ) -> list[list | MotleyGraphNodeType]: """Run a Cypher query and return the results. + If container class is provided, deserialize the results into objects of that class. Args: - query (:obj:`dict`, None): - parameters (:obj:`dict`, optional): - container (:obj:`Type[MotleyGraphNodeType]`, optional): + query: Cypher query. + parameters: Query parameters. + container: Node class to deserialize the results into. If None, return raw results. Returns: - + List of query results. """ query_result = self._execute_query(query=query, parameters=parameters) retval = [] @@ -530,14 +519,14 @@ def run_cypher_query( def _deserialize_node( self, node_dict: dict, node_class: Type[MotleyGraphNode] ) -> MotleyGraphNode: - """Description + """Deserialize a node from a dictionary. Args: - node_dict (dict): - node_class (Type[MotleyGraphNode]): + node_dict: Dictionary representation of the node. + node_class: Node class. Returns: - MotleyGraphNode + Deserialized node. """ for field_name, value in node_dict.copy().items(): if isinstance(value, str) and value.startswith( @@ -561,26 +550,23 @@ def _deserialize_node( @staticmethod def _set_node_id(node: MotleyGraphNode, node_id: Optional[int]) -> None: - """Description + """Set the id of the node. Args: - node (MotleyGraphNode): - node_id (:obj:`int`, optional): - - Returns: - + node: Node. + node_id: Node id. """ setattr(node, MotleyKuzuGraphStore.ID_ATTR, node_id) @staticmethod def _node_to_cypher_mapping_with_parameters(node: MotleyGraphNode) -> tuple[str, dict]: - """Description + """Convert a node to a Cypher mapping and parameters. Args: - node (MotleyGraphNode): + node: Node to convert. Returns: - :obj:`tuple` of :obj:`str`, :obj:`dict`: + A tuple of Cypher mapping and parameters. """ node_dict = node.model_dump() @@ -613,10 +599,10 @@ def _get_cypher_type_and_is_json_by_python_type_annotation( and whether the data should be stored in JSON-serialized strings. Args: - annotation (Type): + annotation: Python type annotation. Returns: - :obj:`tuple` of :obj:`str`, :obj:`bool`: + A tuple of Cypher type and whether the data should be stored in JSON-serialized strings. """ cypher_type = MotleyKuzuGraphStore.PYTHON_TO_CYPHER_TYPES_MAPPING.get(annotation) if not cypher_type: @@ -635,11 +621,12 @@ def from_persist_dir( """Load from persist dir. Args: - persist_dir (str): + persist_dir (str): Persist directory. Returns: - MotleyKuzuGraphStore: + Graph store. """ + try: import kuzu except ImportError: @@ -652,9 +639,9 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> "MotleyKuzuGraphStore": """Initialize graph store from configuration dictionary. Args: - config_dict (dict): Configuration dictionary. + config_dict: Configuration dictionary. Returns: - MotleyKuzuGraphStore:Graph store. + Graph store. """ return cls(**config_dict) diff --git a/motleycrew/tasks/__init__.py b/motleycrew/tasks/__init__.py index 6916196a..e48fdc92 100644 --- a/motleycrew/tasks/__init__.py +++ b/motleycrew/tasks/__init__.py @@ -1,6 +1,4 @@ +from motleycrew.tasks.simple import SimpleTask +from motleycrew.tasks.task import Task from motleycrew.tasks.task_unit import TaskUnit from motleycrew.tasks.task_unit import TaskUnitType -from motleycrew.tasks.task import Task -from motleycrew.tasks.simple import SimpleTask - -__all__ = ["TaskUnit", "TaskUnitType", "Task", "SimpleTask"] diff --git a/motleycrew/tasks/simple.py b/motleycrew/tasks/simple.py index dd045386..c15d1848 100644 --- a/motleycrew/tasks/simple.py +++ b/motleycrew/tasks/simple.py @@ -1,48 +1,52 @@ -""" Module description - -Attributes: - PROMPT_TEMPLATE_WITH_DEPS (str): - -""" - from __future__ import annotations from typing import TYPE_CHECKING, Any, Sequence, List, Optional +from langchain_core.prompts import PromptTemplate + from motleycrew.agents.abstract_parent import MotleyAgentAbstractParent from motleycrew.common import logger -from motleycrew.tasks import TaskUnit from motleycrew.tasks.task import Task +from motleycrew.tasks.task_unit import TaskUnit from motleycrew.tools import MotleyTool if TYPE_CHECKING: from motleycrew.crew import MotleyCrew -PROMPT_TEMPLATE_WITH_DEPS = """ -{description} +PROMPT_TEMPLATE_WITH_UPSTREAM_TASKS = PromptTemplate.from_template( + """{description} You must use the results of these upstream tasks: -{upstream_results_section} +{upstream_results} """ +) def compose_simple_task_prompt_with_dependencies( description: str, upstream_task_units: List[TaskUnit], + prompt_template_with_upstreams: PromptTemplate, default_task_name: str = "Unnamed task", ) -> str: - """Description + """Compose a prompt for a simple task with upstream dependencies. Args: - description (str): - upstream_task_units (:obj:`list` of :obj:`TaskUnit`): - default_task_name (:obj:`str`, optional): - - Returns: - str: + description: Description of the task, to be included in the prompt. + upstream_task_units: List of upstream task units whose results should be used. + prompt_template_with_upstreams: Prompt template to use for generating the prompt + if the task has upstream dependencies. Otherwise, just the description is used. + The template must have input variables 'description' and 'upstream_results'. + default_task_name: Name to use for task units that don't have a ``name`` attribute. """ + if set(prompt_template_with_upstreams.input_variables) != { + "description", + "upstream_results", + }: + raise ValueError( + "Prompt template must have input variables 'description' and 'upstream_results'" + ) upstream_results = [] for unit in upstream_task_units: if not unit.output: @@ -55,19 +59,19 @@ def compose_simple_task_prompt_with_dependencies( return description upstream_results_section = "\n\n".join(upstream_results) - return PROMPT_TEMPLATE_WITH_DEPS.format( + return prompt_template_with_upstreams.format( description=description, - upstream_results_section=upstream_results_section, + upstream_results=upstream_results_section, ) class SimpleTaskUnit(TaskUnit): - """Description + """Task unit for a simple task. Attributes: - name (str): - prompt (str): - + name: Name of the task unit. + prompt: Prompt for the task unit. + additional_params: Additional parameters for the task unit (can be used by the agent). """ name: str @@ -76,6 +80,15 @@ class SimpleTaskUnit(TaskUnit): class SimpleTask(Task): + """Simple task class. + + A simple task consists of a description and an agent that can execute the task. + It produces a single task unit with a prompt based on the description + and the results of upstream tasks. + + The task is considered done when the task unit is completed. + """ + def __init__( self, crew: MotleyCrew, @@ -83,37 +96,39 @@ def __init__( name: str | None = None, agent: MotleyAgentAbstractParent | None = None, tools: Sequence[MotleyTool] | None = None, - documents: Sequence[Any] | None = None, additional_params: dict[str, Any] | None = None, + prompt_template_with_upstreams: PromptTemplate = PROMPT_TEMPLATE_WITH_UPSTREAM_TASKS, ): - """Description + """Initialize the simple task. Args: - crew (MotleyCrew): - description (str): - name (:obj:`str`, optional): - agent (:obj:`MotleyAgentAbstractParent`, optional): - tools (:obj:`Sequence[MotleyTool]`, optional): - documents (:obj:`Sequence[Any]`, optional): - additional_kwargs (:obj:`dict`, optional): + crew: Crew to which the task belongs. + description: Description of the task. + name: Name of the task (will be used as the name of the task unit). + agent: Agent to execute the task. + tools: Tools to use for the task. + additional_params: Additional parameters for the task. + prompt_template_with_upstreams: Prompt template to use for generating the prompt + if the task has upstream dependencies. Otherwise, just the description is used. + The template must have input variables 'description' and 'upstream_results'. """ + super().__init__(name=name or description, task_unit_class=SimpleTaskUnit, crew=crew) self.description = description self.agent = agent # to be auto-assigned at crew creation if missing? self.tools = tools or [] - # should tasks own agents or should agents own tasks? - self.documents = documents # to be passed to an auto-init'd retrieval, later on self.additional_params = additional_params or {} - self.output = None # to be filled in by the agent(s) once the task is complete + self.prompt_template_with_upstreams = prompt_template_with_upstreams - def register_completed_unit(self, unit: SimpleTaskUnit) -> None: - """Description + self.output = None # to be filled in by the agent(s) once the task is complete - Args: - unit (SimpleTaskUnit): + def on_unit_completion(self, unit: SimpleTaskUnit) -> None: + """Handle completion of the task unit. - Returns: + Sets the task as done and stores the output of the task unit. + Args: + unit: Task unit that has completed. """ assert isinstance(unit, SimpleTaskUnit) assert unit.done @@ -122,10 +137,14 @@ def register_completed_unit(self, unit: SimpleTaskUnit) -> None: self.set_done() def get_next_unit(self) -> SimpleTaskUnit | None: - """Description + """Get the next task unit to run. + + If all upstream tasks are done, returns a task unit with the prompt + based on the description and the results of the upstream tasks. + Otherwise, returns None (the task is not ready to run yet). Returns: - :obj:`SimpleTaskUnit`, None: + Task unit to run if the task is ready, None otherwise. """ if self.done: logger.info("Task %s is already done", self) @@ -136,7 +155,11 @@ def get_next_unit(self) -> SimpleTaskUnit | None: return None upstream_task_units = [unit for task in upstream_tasks for unit in task.get_units()] - prompt = compose_simple_task_prompt_with_dependencies(self.description, upstream_task_units) + prompt = compose_simple_task_prompt_with_dependencies( + description=self.description, + upstream_task_units=upstream_task_units, + prompt_template_with_upstreams=self.prompt_template_with_upstreams, + ) return SimpleTaskUnit( name=self.name, prompt=prompt, @@ -144,13 +167,16 @@ def get_next_unit(self) -> SimpleTaskUnit | None: ) def get_worker(self, tools: Optional[List[MotleyTool]]) -> MotleyAgentAbstractParent: - """Description + """Get the worker for the task. + + If the task is associated with a crew and an agent, returns the agent. + Otherwise, raises an exception. Args: - tools (:obj:`List[MotleyTool]`, :obj:`None`): + tools: Additional tools to add to the agent. Returns: - MotleyAgentAbstractParent + Agent to run the task unit. """ if self.crew is None: raise ValueError("Task is not associated with a crew") diff --git a/motleycrew/tasks/task.py b/motleycrew/tasks/task.py index 3b567923..798a06bc 100644 --- a/motleycrew/tasks/task.py +++ b/motleycrew/tasks/task.py @@ -1,19 +1,13 @@ -""" Module description - -Attributes: - TaskNodeType (TypeVar): - -""" - from __future__ import annotations from abc import ABC, abstractmethod from typing import Optional, Sequence, List, Type, TypeVar, Generic, TYPE_CHECKING from langchain_core.runnables import Runnable + from motleycrew.common.exceptions import TaskDependencyCycleError from motleycrew.storage import MotleyGraphStore, MotleyGraphNode, MotleyKuzuGraphStore -from motleycrew.tasks import TaskUnitType +from motleycrew.tasks.task_unit import TaskUnitType from motleycrew.tools import MotleyTool if TYPE_CHECKING: @@ -21,11 +15,11 @@ class TaskNode(MotleyGraphNode): - """Description + """Node representing a task in the graph. Attributes: - name (str): - done (bool): + name: Name of the task. + done: Whether the task is done. """ @@ -41,10 +35,13 @@ def __eq__(self, other): class Task(ABC, Generic[TaskUnitType]): - """ + """Base class for describing tasks. + + This class is abstract and must be subclassed to implement the task logic. + Attributes: - NODE_CLASS (TaskNodeType): - TASK_IS_UPSTREAM_LABEL (str): + NODE_CLASS: Class for representing task nodes, can be overridden. + TASK_IS_UPSTREAM_LABEL: Label for indicating upstream tasks, can be overridden. """ NODE_CLASS: Type[TaskNodeType] = TaskNode @@ -57,13 +54,16 @@ def __init__( crew: Optional[MotleyCrew] = None, allow_async_units: bool = False, ): - """Description + """Initialize the task. Args: - name (str): - task_unit_class (Type[TaskUnitType]): - crew (:obj:`MotleyCrew`, optional): - allow_async_units (:obj:'bool', optional) + name: Name of the task. + task_unit_class: Class for representing task units. + crew: Crew to which the task belongs. + If not provided, the task should be registered with a crew later. + allow_async_units: Whether the task allows asynchronous units. + Default is False. If True, the task may be queried for the next unit even if it + has other units in progress. """ self.name = name self.done = False @@ -79,11 +79,7 @@ def __init__( self.prepare_graph_store() def prepare_graph_store(self): - """Description - - Returns: - - """ + """Prepare the graph store for storing tasks and their units.""" if isinstance(self.graph_store, MotleyKuzuGraphStore): self.graph_store.ensure_node_table(self.NODE_CLASS) self.graph_store.ensure_node_table(self.task_unit_class) @@ -95,6 +91,10 @@ def prepare_graph_store(self): @property def graph_store(self) -> MotleyGraphStore: + """The graph store where the task is stored. + + This is an alias for the graph store of the crew that the task belongs to. + """ if self.crew is None: raise ValueError("Task must be registered with a crew for accessing graph store") return self.crew.graph_store @@ -106,13 +106,13 @@ def __str__(self) -> str: return self.__repr__() def set_upstream(self, task: Task) -> Task: - """Description + """Set a task as an upstream task for the current task. - Args: - task (Task): + This means that the current task will not be queried for task units + until the upstream task is marked as done. - Returns: - Task: + Args: + task: Upstream task. """ if self.crew is None or task.crew is None: raise ValueError("Both tasks must be registered with a crew") @@ -125,6 +125,11 @@ def set_upstream(self, task: Task) -> Task: return self def __rshift__(self, other: Task | Sequence[Task]) -> Task: + """Syntactic sugar for setting tasks order with the ``>>`` operator. + + Args: + other: Task or sequence of tasks to set as downstream. + """ if isinstance(other, Task): tasks = {other} else: @@ -136,19 +141,25 @@ def __rshift__(self, other: Task | Sequence[Task]) -> Task: return self def __rrshift__(self, other: Sequence[Task]) -> Sequence[Task]: + """Syntactic sugar for setting tasks order with the ``>>`` operator. + + Args: + other: Task or sequence of tasks to set as upstream. + """ for task in other: self.set_upstream(task) return other def get_units(self, status: Optional[str] = None) -> List[TaskUnitType]: - """ - Description + """Get the units of the task that are already inserted in the graph. + + This method should be used for fetching the existing task units. Args: - status (str | None): if provided, return only units with this status + status: Status of the task units to filter by. Returns: - :obj:`list` of :obj:`TaskUnitType`: + List of task units. """ assert self.crew is not None, "Task must be registered with a crew for accessing task units" @@ -172,10 +183,10 @@ def get_units(self, status: Optional[str] = None) -> List[TaskUnitType]: return task_units def get_upstream_tasks(self) -> List[Task]: - """Description + """Get the upstream tasks of the current task. Returns: - :obj:`list` of :obj:`Task` + List of upstream tasks. """ assert ( self.crew is not None and self.node.is_inserted @@ -195,10 +206,10 @@ def get_upstream_tasks(self) -> List[Task]: return [task for task in self.crew.tasks if task.node in upstream_task_nodes] def get_downstream_tasks(self) -> List[Task]: - """Description + """Get the downstream tasks of the current task. Returns: - :obj:`list` of :obj:`Task` + List of downstream tasks. """ assert ( self.crew is not None and self.node.is_inserted @@ -218,56 +229,69 @@ def get_downstream_tasks(self) -> List[Task]: return [task for task in self.crew.tasks if task.node in downstream_task_nodes] def set_done(self, value: bool = True): - """Description + """Set the done status of the task. Args: - value (bool): - - Returns: - + value: Value to set the done status to. """ self.done = value self.node.done = value - def register_started_unit(self, unit: TaskUnitType) -> None: - """Description + def on_unit_dispatch(self, unit: TaskUnitType) -> None: + """Method that is called by the crew when a unit of the task is dispatched. - Args: - unit (TaskUnitType): - - Returns: + Should be implemented by the subclass if needed. + Args: + unit: Task unit that is dispatched. """ pass - def register_completed_unit(self, unit: TaskUnitType) -> None: - """Description + def on_unit_completion(self, unit: TaskUnitType) -> None: + """Method that is called by the crew when a unit of the task is completed. - Args: - unit (TaskUnitType): - - Returns: + Should be implemented by the subclass if needed. + Args: + unit: Task unit that is completed. """ pass @abstractmethod def get_next_unit(self) -> TaskUnitType | None: - """Description + """Get the next unit of the task to run. Must be implemented by the subclass. + + This method is called in the crew's main loop repeatedly while the task is not done + and there are units in progress. + + **Note that returning a unit does not guarantee that it will be dispatched.** + Because of this, any state changes are strongly discouraged in this method. + If you need to perform some actions when the unit is dispatched or completed, + you should implement the ``on_unit_dispatch`` and/or ``on_unit_completion`` methods. + + If you need to find which units already exist in order to generate the next one, + you can use the ``get_units`` method. Returns: - :obj:`TaskUnitType` | None: + Next unit to run, or None if there are no units to run at the moment. """ + pass @abstractmethod def get_worker(self, tools: Optional[List[MotleyTool]]) -> Runnable: - """Description + """Get the worker that will run the task units. + + This method is called by the crew when a unit of the task is dispatched. + The unit will be converted to a dictionary and passed to the worker's ``invoke`` method. + + Typically, the worker is an agent, but it can be any object + that implements the Langchain Runnable interface. Args: - tools (:obj:`List[MotleyTool]`, None): + tools: Tools to be used by the worker. Returns: - Runnable: + Worker that will run the task units. """ pass diff --git a/motleycrew/tasks/task_unit.py b/motleycrew/tasks/task_unit.py index 3d4eab98..8d532a26 100644 --- a/motleycrew/tasks/task_unit.py +++ b/motleycrew/tasks/task_unit.py @@ -1,9 +1,3 @@ -""" Module description - -Attributes: - TaskUnitType (TypeVar): - -""" from __future__ import annotations from abc import ABC @@ -14,13 +8,17 @@ class TaskUnit(MotleyGraphNode, ABC): - """ Description + """Base class for describing task units. + A task unit should contain all the input data for the worker (usually an agent). + When a task unit is dispatched by the crew, it is converted to a dictionary + and passed to the worker's ``invoke()`` method. Attributes: - status (:obj:`str`, optional): - output (:obj:`Any`, optional): + status: Status of the task unit. + output: Output of the task unit. """ + status: str = TaskUnitStatus.PENDING output: Optional[Any] = None @@ -35,23 +33,29 @@ def __eq__(self, other: TaskUnit): @property def pending(self): + """Whether the task unit is pending.""" return self.status == TaskUnitStatus.PENDING @property def running(self): + """Whether the task unit is running.""" return self.status == TaskUnitStatus.RUNNING @property def done(self): + """Whether the task unit is done.""" return self.status == TaskUnitStatus.DONE def set_pending(self): + """Set the task unit status to pending.""" self.status = TaskUnitStatus.PENDING def set_running(self): + """Set the task unit status to running.""" self.status = TaskUnitStatus.RUNNING def set_done(self): + """Set the task unit status to done.""" self.status = TaskUnitStatus.DONE def as_dict(self): diff --git a/motleycrew/tools/autogen_chat_tool.py b/motleycrew/tools/autogen_chat_tool.py index f3e46f49..a551b12b 100644 --- a/motleycrew/tools/autogen_chat_tool.py +++ b/motleycrew/tools/autogen_chat_tool.py @@ -1,10 +1,9 @@ -""" Module description """ from typing import Optional, Type, Callable, Any -from langchain_core.tools import StructuredTool from langchain_core.prompts import PromptTemplate from langchain_core.prompts.base import BasePromptTemplate from langchain_core.pydantic_v1 import BaseModel, Field, create_model +from langchain_core.tools import StructuredTool try: from autogen import ConversableAgent, ChatResult @@ -17,20 +16,14 @@ def get_last_message(chat_result: ChatResult) -> str: - """ Description - - Args: - chat_result (ChatResult): - - Returns: - str: - """ for message in reversed(chat_result.chat_history): if message.get("content") and "TERMINATE" not in message["content"]: return message["content"] class AutoGenChatTool(MotleyTool): + """A tool for incorporating AutoGen chats into MotleyCrew.""" + def __init__( self, name: str, @@ -41,16 +34,19 @@ def __init__( result_extractor: Callable[[ChatResult], Any] = get_last_message, input_schema: Optional[Type[BaseModel]] = None, ): - """ Description - + """ Args: - name (str): - description (str): - prompt (:obj:`str`, :obj:`BasePromptTemplate`): - initiator (ConversableAgent): - recipient (ConversableAgent): - result_extractor (:obj:`Callable[[ChatResult]`, :obj:`Any`, optional): - input_schema (:obj:`Type[BaseModel]`, optional): + name: Name of the tool. + description: Description of the tool. + prompt: Prompt to use for the tool. Can be a string or a PromptTemplate object. + initiator: The agent initiating the chat. + recipient: The first recipient agent. + This is the agent that you would specify in ``initiate_chat`` arguments. + result_extractor: Function to extract the result from the chat result. + input_schema: Input schema for the tool. + The input variables should match the variables in the prompt. + If not provided, a schema will be generated based on the input variables + in the prompt, if any, with string fields. """ ensure_module_is_installed("autogen") langchain_tool = create_autogen_chat_tool( @@ -74,20 +70,6 @@ def create_autogen_chat_tool( result_extractor: Callable[[ChatResult], Any], input_schema: Optional[Type[BaseModel]] = None, ): - """ Description - - Args: - name (str): - description (str): - prompt (:obj:`str`, :obj:`BasePromptTemplate`): - initiator (ConversableAgent): - recipient (ConversableAgent): - result_extractor (:obj:`Callable[[ChatResult]`, :obj:`Any`, optional): - input_schema (:obj:`Type[BaseModel]`, optional): - - Returns: - - """ if not isinstance(prompt, BasePromptTemplate): prompt = PromptTemplate.from_template(prompt) diff --git a/motleycrew/tools/code/aider_tool.py b/motleycrew/tools/code/aider_tool.py index 590ddc7d..efcacb65 100644 --- a/motleycrew/tools/code/aider_tool.py +++ b/motleycrew/tools/code/aider_tool.py @@ -15,14 +15,9 @@ class AiderTool(MotleyTool): + """Tool for code generation using Aider.""" def __init__(self, model: str = None, **kwargs): - """Tool for code generation using Aider. - - Args: - model (str): model name - **kwargs: - """ ensure_module_is_installed("aider") model = model or Defaults.DEFAULT_LLM_NAME @@ -34,21 +29,12 @@ def __init__(self, model: str = None, **kwargs): class AiderToolInput(BaseModel): - """Input for the Aider tool. - - Attributes: - with_message (str): - """ + """Input for the Aider tool.""" with_message: str = Field(description="instructions for code generation") def create_aider_tool(coder: Coder): - """Create langchain tool from Aider Coder.run() method - - Returns: - Tool: - """ return Tool.from_function( func=coder.run, name="aider tool", diff --git a/motleycrew/tools/code/postgresql_linter.py b/motleycrew/tools/code/postgresql_linter.py index 7d534f23..ac033978 100644 --- a/motleycrew/tools/code/postgresql_linter.py +++ b/motleycrew/tools/code/postgresql_linter.py @@ -1,5 +1,5 @@ -from langchain_core.tools import Tool from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.tools import Tool try: from pglast import parse_sql, prettify @@ -14,10 +14,9 @@ class PostgreSQLLinterTool(MotleyTool): + """PostgreSQL code verification tool.""" def __init__(self): - """PostgreSQL code verification tool - """ ensure_module_is_installed("pglast") langchain_tool = create_pgsql_linter_tool() @@ -25,21 +24,12 @@ def __init__(self): class PostgreSQLLinterInput(BaseModel): - """Input for the PostgreSQLLinterTool. - - Attributes: - query (str): - """ + """Input for the PostgreSQLLinterTool.""" query: str = Field(description="SQL code for validation") def create_pgsql_linter_tool() -> Tool: - """Create the underlying langchain tool for PostgreSQLLinterTool - - Returns: - Tool: - """ def parse_func(query: str) -> str: try: parse_sql(query) diff --git a/motleycrew/tools/code/python_linter.py b/motleycrew/tools/code/python_linter.py index b70c967c..f2af73ec 100644 --- a/motleycrew/tools/code/python_linter.py +++ b/motleycrew/tools/code/python_linter.py @@ -1,8 +1,8 @@ import os from typing import Union -from langchain_core.tools import StructuredTool from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.tools import StructuredTool try: from aider.linter import Linter @@ -14,10 +14,9 @@ class PythonLinterTool(MotleyTool): + """Python code verification tool""" def __init__(self): - """Python code verification tool - """ ensure_module_is_installed("aider") langchain_tool = create_python_linter_tool() @@ -25,24 +24,13 @@ def __init__(self): class PythonLinterInput(BaseModel): - """Input for the PythonLinterTool. - - Attributes: - code (str): - file_name (str): - """ + """Input for the PythonLinterTool.""" code: str = Field(description="Python code for linting") file_name: str = Field(description="file name for the code", default="code.py") def create_python_linter_tool() -> StructuredTool: - """Create the underlying langchain tool for PythonLinterTool - - Returns: - Tool: - """ - def lint(code: str, file_name: str = None) -> Union[str, None]: # create temp python file temp_file_name = file_name or "code.py" @@ -50,7 +38,7 @@ def lint(code: str, file_name: str = None) -> Union[str, None]: if file_ext != ".py": raise ValueError("The file extension must be .py") - with open(temp_file_name, 'w') as f: + with open(temp_file_name, "w") as f: f.write(code) # lint code diff --git a/motleycrew/tools/html_render_tool.py b/motleycrew/tools/html_render_tool.py index f730b318..873a97af 100644 --- a/motleycrew/tools/html_render_tool.py +++ b/motleycrew/tools/html_render_tool.py @@ -19,14 +19,15 @@ class HTMLRenderer: + """Helper for rendering HTML code as an image.""" + def __init__( self, work_dir: str, - executable_path: str | None = None, + chromedriver_path: str | None = None, headless: bool = True, window_size: Optional[Tuple[int, int]] = None, ): - """Helper for rendering HTML code as an image""" ensure_module_is_installed( "selenium", "see documentation: https://pypi.org/project/selenium/, ChromeDriver is also required", @@ -39,21 +40,21 @@ def __init__( self.options = webdriver.ChromeOptions() if headless: self.options.add_argument("--headless") - self.service = Service(executable_path=executable_path) + self.service = Service(executable_path=chromedriver_path) self.window_size = window_size def render_image(self, html: str, file_name: str | None = None): - """Create image with png extension from html code + """Create a PNG image from HTML code. Args: - html (str): html code for rendering image - file_name (str): file name with not extension + html (str): HTML code for rendering image. + file_name (str): File name without extension. Returns: - file path to created image + Path to the rendered image. """ logger.info("Trying to render image from HTML code") - html_path, image_path = self.build_save_file_paths(file_name) + html_path, image_path = self.build_file_paths(file_name) browser = webdriver.Chrome(options=self.options, service=self.service) try: if self.window_size: @@ -80,15 +81,8 @@ def render_image(self, html: str, file_name: str | None = None): return image_path - def build_save_file_paths(self, file_name: str | None = None) -> Tuple[str, str]: - """Builds paths to html and image files - - Args: - file_name (str): file name with not extension - - Returns: - tuple[str, str]: html file path and image file path - """ + def build_file_paths(self, file_name: str | None = None) -> Tuple[str, str]: + """Builds paths to html and image files""" # check exists dirs: for _dir in (self.work_dir, self.html_dir, self.images_dir): @@ -103,22 +97,23 @@ def build_save_file_paths(self, file_name: str | None = None) -> Tuple[str, str] class HTMLRenderTool(MotleyTool): + """Tool for rendering HTML as image.""" def __init__( self, work_dir: str, - executable_path: str | None = None, + chromedriver_path: str | None = None, headless: bool = True, window_size: Optional[Tuple[int, int]] = None, ): - """Tool for rendering HTML as image - + """ Args: - work_dir (str): Directory for saving images and html files + work_dir: Directory for saving images and HTML files. + chromedriver_path: Path to the ChromeDriver executable. """ renderer = HTMLRenderer( work_dir=work_dir, - executable_path=executable_path, + chromedriver_path=chromedriver_path, headless=headless, window_size=window_size, ) @@ -127,21 +122,12 @@ def __init__( class HTMLRenderToolInput(BaseModel): - """Input for the HTMLRenderTool. - - Attributes: - html (str): - """ + """Input for the HTMLRenderTool.""" html: str = Field(description="HTML code for rendering") def create_render_tool(renderer: HTMLRenderer): - """Create langchain tool from HTMLRenderer.render_image method - - Returns: - Tool: - """ return Tool.from_function( func=renderer.render_image, name="HTML rendering tool", diff --git a/motleycrew/tools/image/dall_e.py b/motleycrew/tools/image/dall_e.py index 064785ec..da133d33 100644 --- a/motleycrew/tools/image/dall_e.py +++ b/motleycrew/tools/image/dall_e.py @@ -1,37 +1,21 @@ -""" Module description - -Attributes: - prompt_template (str): - dall_e_template (str): -""" +import mimetypes +import os from typing import Optional -import os import requests -import mimetypes - from langchain.agents import Tool -from langchain_core.pydantic_v1 import BaseModel, Field +from langchain.prompts import PromptTemplate from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper +from langchain_core.pydantic_v1 import BaseModel, Field -from motleycrew.tools.tool import MotleyTool import motleycrew.common.utils as motley_utils from motleycrew.common import LLMFramework -from motleycrew.common.llms import init_llm from motleycrew.common import logger -from langchain.prompts import PromptTemplate +from motleycrew.common.llms import init_llm +from motleycrew.tools.tool import MotleyTool def download_image(url: str, file_path: str) -> Optional[str]: - """ Description - - Args: - url (str): - file_path (str): - - Returns: - :obj:`str`, None: - """ response = requests.get(url, stream=True) if response.status_code == requests.codes.ok: content_type = response.headers.get("content-type") @@ -51,29 +35,48 @@ def download_image(url: str, file_path: str) -> Optional[str]: logger.error("Failed to download image. Status code: %s", response.status_code) +DEFAULT_REFINE_PROMPT = """Generate a detailed DALL-E prompt to generate an image +based on the following description: +```{text}``` +Your output MUST NOT exceed 3500 characters""" + +DEFAULT_DALL_E_PROMPT = """{text} +Note: Do not include any text in the images. +""" + + class DallEImageGeneratorTool(MotleyTool): + """A tool for generating images using the OpenAI DALL-E API. + + See the OpenAI API reference for more information: + https://platform.openai.com/docs/guides/images/usage + """ + def __init__( self, images_directory: Optional[str] = None, refine_prompt_with_llm: bool = True, + dall_e_prompt_template: str | PromptTemplate = DEFAULT_DALL_E_PROMPT, + refine_prompt_template: str | PromptTemplate = DEFAULT_REFINE_PROMPT, model: str = "dall-e-3", quality: str = "standard", size: str = "1024x1024", style: Optional[str] = None, ): - """ Description - + """ Args: - images_directory (:obj:`str`, optional): - refine_prompt_with_llm (:obj:`bool`, optional): - model (:obj:`str`, optional): - quality (:obj:`str`, optional): - size (:obj:`str`, optional): - style (:obj:`str`, optional): + images_directory: Directory to save the generated images. + refine_prompt_with_llm: Whether to refine the prompt using a language model. + model: DALL-E model to use. + quality: Image quality. Can be "standard" or "hd". + size: Image size. + style: Style to use for the model. """ langchain_tool = create_dalle_image_generator_langchain_tool( images_directory=images_directory, refine_prompt_with_llm=refine_prompt_with_llm, + dall_e_prompt_template=dall_e_prompt_template, + refine_prompt_template=refine_prompt_template, model=model, quality=quality, size=size, @@ -83,55 +86,27 @@ def __init__( class DallEToolInput(BaseModel): - """Input for the Dall-E tool. - - Attributes: - description (str): - """ + """Input for the Dall-E tool.""" description: str = Field(description="image description") -prompt_template = """Generate a detailed DALL-E prompt to generate an image -based on the following description: -```{text}``` -Your output MUST NOT exceed 3500 characters""" - -dall_e_template = """{text} -Note: Do not include any text in the images. -""" - - def run_dalle_and_save_images( description: str, images_directory: Optional[str] = None, refine_prompt_with_llm: bool = True, + dall_e_prompt_template: str | PromptTemplate = DEFAULT_DALL_E_PROMPT, + refine_prompt_template: str | PromptTemplate = DEFAULT_REFINE_PROMPT, model: str = "dall-e-3", quality: str = "standard", size: str = "1024x1024", style: Optional[str] = None, file_name_length: int = 8, ) -> Optional[list[str]]: - """ Description - - Args: - description (str): - images_directory (:obj:`str`, optional): - refine_prompt_with_llm(:obj:`bool`, optional): - model (:obj:`str`, optional): - quality (:obj:`str`, optional): - size (:obj:`str`, optional): - style (:obj:`str`, optional): - file_name_length (:obj:`int`, optional): - - Returns: - :obj:`list` of :obj:`str`: - """ - - dall_e_prompt = PromptTemplate.from_template(dall_e_template) + dall_e_prompt = PromptTemplate.from_template(dall_e_prompt_template) if refine_prompt_with_llm: - prompt = PromptTemplate.from_template(template=prompt_template) + prompt = PromptTemplate.from_template(refine_prompt_template) llm = init_llm(llm_framework=LLMFramework.LANGCHAIN) dall_e_prompt = prompt | llm | (lambda x: {"text": x.content}) | dall_e_prompt @@ -172,29 +147,20 @@ def run_dalle_and_save_images( def create_dalle_image_generator_langchain_tool( images_directory: Optional[str] = None, refine_prompt_with_llm: bool = True, + dall_e_prompt_template: str | PromptTemplate = DEFAULT_DALL_E_PROMPT, + refine_prompt_template: str | PromptTemplate = DEFAULT_REFINE_PROMPT, model: str = "dall-e-3", quality: str = "standard", size: str = "1024x1024", style: Optional[str] = None, ): - """ Description - - Args: - images_directory (:obj:`str`, optional): - refine_prompt_with_llm (:obj:`bool`, optional): - model (:obj:`str`, optional): - quality (:obj:`str`, optional): - size (:obj:`str`, optional): - style (:obj:`str`, optional): - - Returns: - Tool: - """ def run_dalle_and_save_images_partial(description: str): return run_dalle_and_save_images( description=description, images_directory=images_directory, refine_prompt_with_llm=refine_prompt_with_llm, + dall_e_prompt_template=dall_e_prompt_template, + refine_prompt_template=refine_prompt_template, model=model, quality=quality, size=size, @@ -208,9 +174,3 @@ def run_dalle_and_save_images_partial(description: str): "Input should be an image description.", args_schema=DallEToolInput, ) - - -if __name__ == "__main__": - tool = DallEImageGeneratorTool() - out = tool.invoke("A beautiful castle on top of a hill at sunset") - logger.info(out) diff --git a/motleycrew/tools/llm_tool.py b/motleycrew/tools/llm_tool.py index cabfc40e..b208c0e7 100644 --- a/motleycrew/tools/llm_tool.py +++ b/motleycrew/tools/llm_tool.py @@ -1,18 +1,19 @@ -""" Module description""" from typing import Optional, Type -from langchain_core.tools import StructuredTool +from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import PromptTemplate from langchain_core.prompts.base import BasePromptTemplate -from langchain_core.language_models import BaseLanguageModel from langchain_core.pydantic_v1 import BaseModel, Field, create_model +from langchain_core.tools import StructuredTool -from motleycrew.tools import MotleyTool from motleycrew.common import LLMFramework from motleycrew.common.llms import init_llm +from motleycrew.tools import MotleyTool class LLMTool(MotleyTool): + """A tool that uses a language model to generate output based on a prompt.""" + def __init__( self, name: str, @@ -21,14 +22,16 @@ def __init__( llm: Optional[BaseLanguageModel] = None, input_schema: Optional[Type[BaseModel]] = None, ): - """ Description - + """ Args: - name (str): - description (str): - prompt (:obj:`str`, :obj:`BasePromptTemplate`): - llm (:obj:`BaseLanguageModel`, optional): - input_schema (:obj:`Type[BaseModel]`, optional): + name: Name of the tool. + description: Description of the tool. + prompt: Prompt to use for the tool. Can be a string or a PromptTemplate object. + llm: Language model to use for the tool. + input_schema: Input schema for the tool. + The input variables should match the variables in the prompt. + If not provided, a schema will be generated based on the input variables + in the prompt, if any, with string fields. """ langchain_tool = create_llm_langchain_tool( name=name, @@ -47,18 +50,6 @@ def create_llm_langchain_tool( llm: Optional[BaseLanguageModel] = None, input_schema: Optional[Type[BaseModel]] = None, ): - """ Description - - Args: - name (str): - description (str): - prompt (:obj:`str`, :obj:`BasePromptTemplate`): - llm (:obj:`BaseLanguageModel`, optional): - input_schema (:obj:`Type[BaseModel]`, optional): - - Returns: - - """ if llm is None: llm = init_llm(llm_framework=LLMFramework.LANGCHAIN) diff --git a/motleycrew/tools/mermaid_evaluator_tool.py b/motleycrew/tools/mermaid_evaluator_tool.py index b25f70fa..3e8540bd 100644 --- a/motleycrew/tools/mermaid_evaluator_tool.py +++ b/motleycrew/tools/mermaid_evaluator_tool.py @@ -1,24 +1,19 @@ -""" Module description """ # https://nodejs.org/en/download # npm install -g @mermaid-js/mermaid-cli +import io import os.path import subprocess -import io import tempfile from typing import Optional from langchain_core.pydantic_v1 import create_model, Field from langchain_core.tools import Tool + from motleycrew.tools import MotleyTool class MermaidEvaluatorTool(MotleyTool): def __init__(self, format: Optional[str] = "svg"): - """ Description - - Args: - format (:obj:`str`, None): - """ def eval_mermaid_partial(mermaid_code: str): return eval_mermaid(mermaid_code, format) @@ -35,15 +30,6 @@ def eval_mermaid_partial(mermaid_code: str): def eval_mermaid(mermaid_code: str, format: Optional[str] = "svg") -> io.BytesIO: - """ Description - - Args: - mermaid_code (str): - format (:obj:`str`, optional): - - Returns: - io.BytesIO: - """ with tempfile.NamedTemporaryFile(delete=True, mode="w+", suffix=".mmd") as temp_in: temp_in.write(mermaid_code) temp_in.flush() # Ensure all data is written to disk diff --git a/motleycrew/tools/python_repl.py b/motleycrew/tools/python_repl.py index c03c1aac..eee7b432 100644 --- a/motleycrew/tools/python_repl.py +++ b/motleycrew/tools/python_repl.py @@ -1,4 +1,3 @@ -""" Module description """ from langchain.agents import Tool from langchain_experimental.utilities import PythonREPL from langchain_core.pydantic_v1 import BaseModel, Field @@ -7,31 +6,24 @@ class PythonREPLTool(MotleyTool): - def __init__(self): - """ Description + """Python REPL tool. Use this to execute python commands. + + Note that the tool's output is the content printed to stdout by the executed code. + Because of this, any data you want to be in the output should be printed using `print(...)`. + """ - """ + def __init__(self): langchain_tool = create_repl_tool() super().__init__(langchain_tool) class REPLToolInput(BaseModel): - """Input for the REPL tool. - - Attributes: - command (str): - """ + """Input for the REPL tool.""" command: str = Field(description="code to execute") -# You can create the tool to pass to an agent def create_repl_tool(): - """ Description - - Returns: - Tool: - """ return Tool.from_function( func=PythonREPL().run, name="python_repl", diff --git a/motleycrew/tools/simple_retriever_tool.py b/motleycrew/tools/simple_retriever_tool.py index 03976294..cba1be25 100644 --- a/motleycrew/tools/simple_retriever_tool.py +++ b/motleycrew/tools/simple_retriever_tool.py @@ -1,76 +1,59 @@ -""" Module description """ import os.path from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.tools import StructuredTool - -from llama_index.core.node_parser import SentenceSplitter -from llama_index.embeddings.openai import OpenAIEmbedding - from llama_index.core import ( VectorStoreIndex, SimpleDirectoryReader, StorageContext, load_index_from_storage, ) +from llama_index.core.node_parser import SentenceSplitter +from llama_index.embeddings.openai import OpenAIEmbedding -from motleycrew.tools import MotleyTool from motleycrew.applications.research_agent.question import Question +from motleycrew.tools import MotleyTool class SimpleRetrieverTool(MotleyTool): - def __init__(self, DATA_DIR, PERSIST_DIR, return_strings_only: bool = False): - """ Description + """A simple retriever tool that retrieves relevant documents from a local knowledge base.""" + def __init__(self, data_dir: str, persist_dir: str, return_strings_only: bool = False): + """ Args: - DATA_DIR (str): - PERSIST_DIR (str): - return_strings_only (:obj:`bool`, optional): + data_dir: Path to the directory containing the documents. + persist_dir: Path to the directory to store the index. + return_strings_only: Whether to return only the text of the retrieved documents. """ tool = make_retriever_langchain_tool( - DATA_DIR, PERSIST_DIR, return_strings_only=return_strings_only + data_dir, persist_dir, return_strings_only=return_strings_only ) super().__init__(tool) class RetrieverToolInput(BaseModel, arbitrary_types_allowed=True): - """Input for the Retriever Tool. - - Attributes: - question (Question): - - """ + """Input for the retriever tool.""" question: Question = Field( description="The input question for which to retrieve relevant data." ) -def make_retriever_langchain_tool(DATA_DIR, PERSIST_DIR, return_strings_only: bool = False): - """ Description - - Args: - DATA_DIR (str): - PERSIST_DIR (str): - return_strings_only (:obj:`bool`, optional): - - Returns: - - """ +def make_retriever_langchain_tool(data_dir, persist_dir, return_strings_only: bool = False): text_embedding_model = "text-embedding-ada-002" embeddings = OpenAIEmbedding(model=text_embedding_model) - if not os.path.exists(PERSIST_DIR): + if not os.path.exists(persist_dir): # load the documents and create the index - documents = SimpleDirectoryReader(DATA_DIR).load_data() + documents = SimpleDirectoryReader(data_dir).load_data() index = VectorStoreIndex.from_documents( documents, transformations=[SentenceSplitter(chunk_size=512), embeddings] ) # store it for later - index.storage_context.persist(persist_dir=PERSIST_DIR) + index.storage_context.persist(persist_dir=persist_dir) else: # load the existing index - storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR) + storage_context = StorageContext.from_defaults(persist_dir=persist_dir) index = load_index_from_storage(storage_context) retriever = index.as_retriever( @@ -92,18 +75,3 @@ def call_retriever(question: Question) -> list: args_schema=RetrieverToolInput, ) return retriever_tool - - -if __name__ == "__main__": - - # check if storage already exists - here = os.path.dirname(os.path.abspath(__file__)) - DATA_DIR = os.path.join(here, "mahabharata/text/TinyTales") - - PERSIST_DIR = "../../examples/research_agent/storage" - - retriever_tool = SimpleRetrieverTool(DATA_DIR, PERSIST_DIR) - response2 = retriever_tool.invoke( - {"question": Question(question="What are the most interesting facts about Arjuna?")} - ) - print(response2) diff --git a/motleycrew/tools/tool.py b/motleycrew/tools/tool.py index ec03b3c7..1d19d384 100644 --- a/motleycrew/tools/tool.py +++ b/motleycrew/tools/tool.py @@ -1,7 +1,5 @@ -""" Module description """ - -from typing import Union, Optional, Dict, Any from typing import Callable +from typing import Union, Optional, Dict, Any from langchain.tools import BaseTool from langchain_core.runnables import Runnable, RunnableConfig @@ -19,13 +17,16 @@ class MotleyTool(Runnable): + """Base tool class compatible with MotleyAgents. + + It is a wrapper for Langchain BaseTool, containing all necessary adapters and converters. + """ def __init__(self, tool: BaseTool): - """Base tool class compatible with MotleyAgents. - It is a wrapper for LangChain's BaseTool, containing all necessary adapters and converters. + """Initialize the MotleyTool. Args: - tool (BaseTool): + tool: Langchain BaseTool to wrap. """ self.tool = tool @@ -37,15 +38,17 @@ def __str__(self): @property def name(self): - # TODO: do we really want to make a thin wrapper in this fashion? + """Name of the tool.""" return self.tool.name @property def description(self): + """Description of the tool.""" return self.tool.description @property def args_schema(self): + """Schema of the tool arguments.""" return self.tool.args_schema def invoke( @@ -61,39 +64,43 @@ def _run(self, *args: tuple, **kwargs: Dict[str, Any]) -> Any: @staticmethod def from_langchain_tool(langchain_tool: BaseTool) -> "MotleyTool": - """Description + """Create a MotleyTool from a Langchain tool. Args: - langchain_tool (BaseTool): + langchain_tool: Langchain tool to convert. Returns: - MotleyTool: + MotleyTool instance. """ + return MotleyTool(tool=langchain_tool) @staticmethod def from_llama_index_tool(llama_index_tool: LlamaIndex__BaseTool) -> "MotleyTool": - """Description + """Create a MotleyTool from a LlamaIndex tool. Args: - llama_index_tool (LlamaIndex__BaseTool): + llama_index_tool: LlamaIndex tool to convert. Returns: - MotleyTool: + MotleyTool instance. """ + ensure_module_is_installed("llama_index") langchain_tool = llama_index_tool.to_langchain_tool() return MotleyTool.from_langchain_tool(langchain_tool=langchain_tool) @staticmethod def from_supported_tool(tool: MotleySupportedTool) -> "MotleyTool": - """Description + """Create a MotleyTool from any supported tool type. Args: - tool (:obj:`MotleyTool`, :obj:`BaseTool`, :obj:`LlamaIndex__BaseTool`, :obj:`MotleyAgentAbstractParent`): + tool: Tool of any supported type. + Currently, we support tools from Langchain, LlamaIndex, + as well as motleycrew agents. Returns: - + MotleyTool instance. """ if isinstance(tool, MotleyTool): return tool @@ -109,18 +116,18 @@ def from_supported_tool(tool: MotleySupportedTool) -> "MotleyTool": ) def to_langchain_tool(self) -> BaseTool: - """Description + """Convert the MotleyTool to a Langchain tool. Returns: - BaseTool: + Langchain tool. """ return self.tool def to_llama_index_tool(self) -> LlamaIndex__BaseTool: - """Description + """Convert the MotleyTool to a LlamaIndex tool. Returns: - LlamaIndex__BaseTool: + LlamaIndex tool. """ ensure_module_is_installed("llama_index") llama_index_tool = LlamaIndex__FunctionTool.from_defaults( @@ -132,10 +139,14 @@ def to_llama_index_tool(self) -> LlamaIndex__BaseTool: return llama_index_tool def to_autogen_tool(self) -> Callable: - """Description + """Convert the MotleyTool to an AutoGen tool. + + An AutoGen tool is basically a function. AutoGen infers the tool input schema + from the function signature. For this reason, because we can't generate the signature + dynamically, we can only convert tools with a single input field. Returns: - Callable: + AutoGen tool function. """ fields = list(self.tool.args_schema.__fields__.values()) if len(fields) != 1: diff --git a/tests/test_agents/test_llama_index_output_handler.py b/tests/test_agents/test_llama_index_output_handler.py index fd451617..7ba66af5 100644 --- a/tests/test_agents/test_llama_index_output_handler.py +++ b/tests/test_agents/test_llama_index_output_handler.py @@ -57,7 +57,7 @@ def agent(): ) agent.materialize() agent._agent._run_step = fake_run_step - agent._agent._run_step = agent.run_step_decorator()(agent._agent._run_step) + agent._agent._run_step = agent._run_step_decorator()(agent._agent._run_step) return agent diff --git a/tests/test_crew/__init__.py b/tests/test_crew/__init__.py index a10b94c4..d8114faf 100644 --- a/tests/test_crew/__init__.py +++ b/tests/test_crew/__init__.py @@ -1,6 +1,7 @@ import pytest from motleycrew.crew import MotleyCrew +from motleycrew.tasks import SimpleTask class AgentMock: @@ -39,6 +40,6 @@ def tasks(self, request, crew, agent): tasks = [] for i in range(num_tasks): description = "task{} description".format(self.num_task) - tasks.append(crew.create_simple_task(description=description, agent=agent)) + tasks.append(SimpleTask(description=description, agent=agent, crew=crew)) CrewFixtures.num_task += 1 return tasks diff --git a/tests/test_crew/test_crew_threads.py b/tests/test_crew/test_crew_threads.py index acfa5325..fda0b622 100644 --- a/tests/test_crew/test_crew_threads.py +++ b/tests/test_crew/test_crew_threads.py @@ -27,7 +27,7 @@ def test_init_thread_pool(self, thread_pool): assert all([t.is_alive() for t in thread_pool._threads]) assert thread_pool.input_queue.empty() assert thread_pool.output_queue.empty() - assert thread_pool.is_completed() + assert thread_pool.is_completed @pytest.mark.parametrize("tasks", [4], indirect=True) def test_put(self, thread_pool, agent, tasks): @@ -35,7 +35,7 @@ def test_put(self, thread_pool, agent, tasks): unit = task.get_next_unit() thread_pool.add_task_unit(agent, task, unit) - assert not thread_pool.is_completed() + assert not thread_pool.is_completed assert len(thread_pool._task_units_in_progress) == 4 @pytest.mark.parametrize("tasks", [4], indirect=True) @@ -49,7 +49,7 @@ def test_get_completed_tasks(self, thread_pool, agent, tasks): assert len(completed_tasks) == 4 assert len(thread_pool._task_units_in_progress) == 0 - assert thread_pool.is_completed() + assert thread_pool.is_completed assert all([t.state == TaskUnitThreadState.EXITED for t in thread_pool._threads]) @pytest.mark.parametrize("tasks", [1], indirect=True) @@ -61,7 +61,7 @@ def test_get_completed_task_exception(self, thread_pool, agent, tasks): with pytest.raises(AttributeError): thread_pool.get_completed_task_units() - assert not thread_pool.is_completed() + assert not thread_pool.is_completed def test_close(self, thread_pool): thread_pool.wait_and_close() diff --git a/tests/test_tasks/test_simple_task.py b/tests/test_tasks/test_simple_task.py index b4850111..f5732986 100644 --- a/tests/test_tasks/test_simple_task.py +++ b/tests/test_tasks/test_simple_task.py @@ -1,9 +1,8 @@ import pytest - from langchain_community.tools import DuckDuckGoSearchRun -from motleycrew.crew import MotleyCrew from motleycrew.agents.langchain.tool_calling_react import ReActToolCallingAgent +from motleycrew.crew import MotleyCrew from motleycrew.storage.graph_store_utils import init_graph_store from motleycrew.tasks.simple import ( SimpleTask, @@ -50,9 +49,9 @@ def test_register_completed_unit(self, tasks, crew): unit.output = task1.description with pytest.raises(AssertionError): - task1.register_completed_unit(unit) + task1.on_unit_completion(unit) unit.set_done() - task1.register_completed_unit(unit) + task1.on_unit_completion(unit) assert task1.done assert task1.output == unit.output assert task1.node.done @@ -61,7 +60,11 @@ def test_get_next_unit(self, tasks, crew): task1, task2 = tasks crew.add_dependency(task1, task2) assert task2.get_next_unit() is None - prompt = compose_simple_task_prompt_with_dependencies(task1.description, task1.get_units()) + prompt = compose_simple_task_prompt_with_dependencies( + description=task1.description, + upstream_task_units=task1.get_units(), + prompt_template_with_upstreams=task1.prompt_template_with_upstreams, + ) expected_unit = SimpleTaskUnit( name=task1.name, prompt=prompt,