Skip to content

Commit

Permalink
add PythonLinterTool
Browse files Browse the repository at this point in the history
  • Loading branch information
User committed Jun 18, 2024
1 parent 7d6921a commit 1f846be
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 4 deletions.
2 changes: 1 addition & 1 deletion motleycrew/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
from .llm_tool import LLMTool
from .mermaid_evaluator_tool import MermaidEvaluatorTool
from .python_repl import PythonREPLTool
from .linter_tools import PgSqlLinterTool
from .linter_tools import PgSqlLinterTool, PythonLinterTool
66 changes: 64 additions & 2 deletions motleycrew/tools/linter_tools.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from typing import Callable
import os
from typing import Callable, Union

from pglast import parse_sql, prettify
from pglast.parser import ParseError
from langchain.tools import Tool
from langchain.tools import Tool, StructuredTool
from langchain_core.pydantic_v1 import BaseModel, Field

try:
from aider.linter import Linter
except ImportError:
Linter = None

from motleycrew.tools import MotleyTool
from motleycrew.common.utils import ensure_module_is_installed


class PgSqlLinterTool(MotleyTool):
Expand Down Expand Up @@ -45,3 +52,58 @@ def create_pgsql_linter_tool(parse_func: Callable) -> Tool:
description="Tool for checking the health of the sql code of the postgresql database",
args_schema=PgSqlLinterInput,
)


class PythonLinterTool(MotleyTool):

def __init__(self):
"""Python code verification tool
"""
ensure_module_is_installed("aider", "pip install aider-chat")

def lint(code: str, file_name: str = None) -> Union[str, None]:
# create temp python file
temp_file_name = file_name or "code.py"
_, file_ext = os.path.splitext(temp_file_name)
if file_ext != ".py":
raise ValueError("The file extension must be py")

with open(temp_file_name, 'w') as f:
f.write(code)

# lint code
try:
linter = Linter()
return linter.lint(temp_file_name)
except Exception as e:
return str(e)
finally:
os.remove(temp_file_name)

langchain_tool = create_python_linter_tool(lint)
super().__init__(langchain_tool)


class PythonLinterInput(BaseModel):
"""Input for the PgSqlLinterTool.
Attributes:
code (str): python code
file_name (str): name temp python file
"""

code: str = Field(description="python code for verification")
file_name: str = Field(description="file name python code", default="code.py")

def create_python_linter_tool(lint_func: Callable) -> StructuredTool:
"""Create langchain tool from lint_func for PythonLinterTool
Returns:
Tool:
"""
return StructuredTool.from_function(
func=lint_func,
name="python linter tool",
description="Tool for checking the health of python code",
args_schema=PythonLinterInput,
)
38 changes: 37 additions & 1 deletion tests/test_tools/test_linter_tool.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import pytest

from motleycrew.tools import PgSqlLinterTool
from motleycrew.tools import PgSqlLinterTool, PythonLinterTool
from motleycrew.common.exceptions import ModuleNotInstalledException

@pytest.fixture
def pgsql_linter_tool():
tool = PgSqlLinterTool()
return tool

@pytest.fixture
def python_linter_tool():
try:
tool = PythonLinterTool()
except ModuleNotInstalledException:
tool = None
return tool

@pytest.mark.parametrize(
"query, expected",
[
Expand All @@ -18,3 +27,30 @@ def pgsql_linter_tool():
def test_pgsql_tool(pgsql_linter_tool, query, expected):
parse_result = pgsql_linter_tool.invoke({"query": query})
assert expected == parse_result

@pytest.mark.parametrize(
"code, file_name, valid_code, raises",
[
("def plus(a, b):\n\treturn a + b", None, True, False),
("def plus(a):\n\treturn a + b", "test_code.py", False, False),
("def plus(a, b):\nreturn a + b", "test_code.py", False, False),
("def plus(a, b):\n\treturn a + b", "code.js", True, True),
]
)
def test_python_tool(python_linter_tool, code, file_name, valid_code, raises):
if python_linter_tool is None:
return

params = {"code": code}
if file_name:
params["file_name"] = file_name

if raises:
with pytest.raises(ValueError):
python_linter_tool.invoke(params)
else:
linter_result = python_linter_tool.invoke(params)
if valid_code:
assert linter_result is None
else:
assert isinstance(linter_result, str)

0 comments on commit 1f846be

Please sign in to comment.