Skip to content

Commit

Permalink
Merge pull request #23 from ShoggothAI/autogen-integration
Browse files Browse the repository at this point in the history
AutoGen integration
  • Loading branch information
ZmeiGorynych authored May 21, 2024
2 parents e9da22a + 7e803d4 commit 36483db
Show file tree
Hide file tree
Showing 11 changed files with 1,243 additions and 36 deletions.
756 changes: 756 additions & 0 deletions examples/Using AutoGen chats with motleycrew.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions examples/math_crewai.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from dotenv import load_dotenv

from motleycrew import MotleyCrew, Task
from motleycrew import MotleyCrew
from motleycrew.agents.crewai import CrewAIMotleyAgent
from motleycrew.tools.python_repl import create_repl_tool
from motleycrew.tools import PythonREPLTool
from motleycrew.common.utils import configure_logging


def main():
"""Main function of running the example."""
repl_tool = create_repl_tool()
repl_tool = PythonREPLTool()

# Define your agents with roles and goals
solver1 = CrewAIMotleyAgent(
Expand Down
10 changes: 5 additions & 5 deletions motleycrew/common/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ class Defaults:

DEFAULT_GRAPH_STORE_TYPE = GraphStoreType.KUZU


defaults_module_install_commands = {
"crewai": "pip install crewai",
"llama_index": "pip install llama-index"
}
MODULE_INSTALL_COMMANDS = {
"crewai": "pip install crewai",
"llama_index": "pip install llama-index",
"autogen": "pip install autogen",
}
4 changes: 2 additions & 2 deletions motleycrew/common/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from motleycrew.common.defaults import defaults_module_install_commands
from motleycrew.common import Defaults


class LLMFamilyNotSupported(Exception):
Expand Down Expand Up @@ -55,7 +55,7 @@ class ModuleNotInstalledException(Exception):

def __init__(self, module_name: str, install_command: str = None):
self.module_name = module_name
self.install_command = install_command or defaults_module_install_commands.get(
self.install_command = install_command or Defaults.MODULE_INSTALL_COMMANDS.get(
module_name, None
)

Expand Down
6 changes: 5 additions & 1 deletion motleycrew/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from .tool import MotleyTool
from .llm_tool import LLMTool

from .autogen_chat_tool import AutoGenChatTool
from .image_generation import DallEImageGeneratorTool
from .llm_tool import LLMTool
from .mermaid_evaluator_tool import MermaidEvaluatorTool
from .python_repl import PythonREPLTool
79 changes: 79 additions & 0 deletions motleycrew/tools/autogen_chat_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Optional, Type, Callable, Any

from langchain_core.tools import StructuredTool
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field, create_model

try:
from autogen import ConversableAgent, ChatResult
except ImportError:
ConversableAgent = None
ChatResult = None

from motleycrew.tools import MotleyTool
from motleycrew.common.utils import ensure_module_is_installed


def get_last_message(chat_result: ChatResult) -> str:
for message in reversed(chat_result.chat_history):
if message.get("content") and "TERMINATE" not in message["content"]:
return message["content"]


class AutoGenChatTool(MotleyTool):
def __init__(
self,
name: str,
description: str,
prompt: str | BasePromptTemplate,
initiator: ConversableAgent,
recipient: ConversableAgent,
result_extractor: Callable[[ChatResult], Any] = get_last_message,
input_schema: Optional[Type[BaseModel]] = None,
):
ensure_module_is_installed("autogen")
langchain_tool = create_autogen_chat_tool(
name=name,
description=description,
prompt=prompt,
initiator=initiator,
recipient=recipient,
result_extractor=result_extractor,
input_schema=input_schema,
)
super().__init__(langchain_tool)


def create_autogen_chat_tool(
name: str,
description: str,
prompt: str | BasePromptTemplate,
initiator: ConversableAgent,
recipient: ConversableAgent,
result_extractor: Callable[[ChatResult], Any],
input_schema: Optional[Type[BaseModel]] = None,
):
if not isinstance(prompt, BasePromptTemplate):
prompt = PromptTemplate.from_template(prompt)

if input_schema is None:
fields = {
var: (str, Field(description=f"Input {var} for the tool."))
for var in prompt.input_variables
}

# Create the AutoGenChatToolInput class dynamically
input_schema = create_model("AutoGenChatToolInput", **fields)

def run_autogen_chat(**kwargs) -> Any:
message = prompt.format(**kwargs)
chat_result = initiator.initiate_chat(recipient, message=message)
return result_extractor(chat_result)

return StructuredTool.from_function(
func=run_autogen_chat,
name=name,
description=description,
args_schema=input_schema,
)
20 changes: 12 additions & 8 deletions motleycrew/tools/python_repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
from .tool import MotleyTool


class PythonREPLTool(MotleyTool):
def __init__(self):
langchain_tool = create_repl_tool()
super().__init__(langchain_tool)


class REPLToolInput(BaseModel):
"""Input for the REPL tool."""

Expand All @@ -13,12 +19,10 @@ class REPLToolInput(BaseModel):

# You can create the tool to pass to an agent
def create_repl_tool():
return MotleyTool.from_langchain_tool(
Tool.from_function(
func=PythonREPL().run,
name="python_repl",
description="A Python shell. Use this to execute python commands. Input should be a valid python command. "
"MAKE SURE TO PRINT OUT THE RESULTS YOU CARE ABOUT USING `print(...)`.",
args_schema=REPLToolInput,
)
return Tool.from_function(
func=PythonREPL().run,
name="python_repl",
description="A Python shell. Use this to execute python commands. Input should be a valid python command. "
"MAKE SURE TO PRINT OUT THE RESULTS YOU CARE ABOUT USING `print(...)`.",
args_schema=REPLToolInput,
)
15 changes: 14 additions & 1 deletion motleycrew/tools/tool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Union, Annotated

from langchain.tools import BaseTool
from langchain_core.runnables import Runnable
Expand Down Expand Up @@ -76,3 +76,16 @@ def to_llama_index_tool(self) -> LlamaIndex__BaseTool:
fn_schema=self.tool.args_schema,
)
return llama_index_tool

def to_autogen_tool(self):
fields = list(self.tool.args_schema.__fields__.values())
if len(fields) != 1:
raise Exception("Multiple input fields are not supported in to_autogen_tool")

field_name = fields[0].name
field_type = fields[0].annotation

def autogen_tool_fn(input: field_type) -> str:
return self.invoke({field_name: input})

return autogen_tool_fn
Loading

0 comments on commit 36483db

Please sign in to comment.