Skip to content

Commit

Permalink
simplify tool convert
Browse files Browse the repository at this point in the history
  • Loading branch information
garylin2099 committed Mar 11, 2024
1 parent 0116de0 commit bf4b13e
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 113 deletions.
2 changes: 1 addition & 1 deletion metagpt/prompts/mi/write_analysis_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
- Always prioritize using pre-defined tools for the same functionality.
# Output
Output code in the following format:
While some concise thoughts are helpful, code is absolutely required. Always output one and only one code block in your response. Output code in the following format:
```python
your code
```
Expand Down
9 changes: 6 additions & 3 deletions metagpt/tools/libs/data_preprocess.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import json
from typing import Literal

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -90,14 +91,16 @@ class FillMissingValue(DataPreprocessTool):
Completing missing values with simple strategies.
"""

def __init__(self, features: list, strategy: str = "mean", fill_value=None):
def __init__(
self, features: list, strategy: Literal["mean", "median", "most_frequent", "constant"] = "mean", fill_value=None
):
"""
Initialize self.
Args:
features (list): Columns to be processed.
strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only
be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.
strategy (Literal["mean", "median", "most_frequent", "constant"], optional): The imputation strategy, notice 'mean' and 'median' can only
be used for numeric features. Defaults to 'mean'.
fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values.
Defaults to None.
"""
Expand Down
71 changes: 25 additions & 46 deletions metagpt/tools/tool_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from metagpt.utils.parse_docstring import GoogleDocstringParser, remove_spaces

PARSER = GoogleDocstringParser


def convert_code_to_tool_schema(obj, include: list[str] = None):
docstring = inspect.getdoc(obj)
Expand All @@ -23,54 +25,31 @@ def convert_code_to_tool_schema(obj, include: list[str] = None):
return schema


def function_docstring_to_schema(fn_obj, docstring):
def function_docstring_to_schema(fn_obj, docstring) -> dict:
"""
Converts a function's docstring into a schema dictionary.
Args:
fn_obj: The function object.
docstring: The docstring of the function.
Returns:
A dictionary representing the schema of the function's docstring.
The dictionary contains the following keys:
- 'type': The type of the function ('function' or 'async_function').
- 'description': The first section of the docstring describing the function overall. Provided to LLMs for both recommending and using the function.
- 'signature': The signature of the function, which helps LLMs understand how to call the function.
- 'parameters': Docstring section describing parameters including args and returns, served as extra details for LLM perception.
"""
signature = inspect.signature(fn_obj)

docstring = remove_spaces(docstring)

overall_desc, param_desc = PARSER.parse(docstring)

function_type = "function" if not inspect.iscoroutinefunction(fn_obj) else "async_function"
return {"type": function_type, **docstring_to_schema(docstring)}


def docstring_to_schema(docstring: str):
if docstring is None:
return {}

parser = GoogleDocstringParser(docstring=docstring)

# 匹配简介部分
description = parser.parse_desc()

# 匹配Args部分
params = parser.parse_params()
parameter_schema = {"properties": {}, "required": []}
for param in params:
param_name, param_type, param_desc = param
# check required or optional
is_optional, param_type = parser.check_and_parse_optional(param_type)
if not is_optional:
parameter_schema["required"].append(param_name)
# type and desc
param_dict = {"type": param_type, "description": remove_spaces(param_desc)}
# match Default for optional args
has_default_val, default_val = parser.check_and_parse_default_value(param_desc)
if has_default_val:
param_dict["default"] = default_val
# match Enum
has_enum, enum_vals = parser.check_and_parse_enum(param_desc)
if has_enum:
param_dict["enum"] = enum_vals
# add to parameter schema
parameter_schema["properties"].update({param_name: param_dict})

# 匹配Returns部分
returns = parser.parse_returns()

# 构建YAML字典
schema = {
"description": description,
"parameters": parameter_schema,
}
if returns:
schema["returns"] = [{"type": ret[0], "description": remove_spaces(ret[1])} for ret in returns]

return schema
return {"type": function_type, "description": overall_desc, "signature": str(signature), "parameters": param_desc}


def get_class_method_docstring(cls, method_name):
Expand Down
82 changes: 19 additions & 63 deletions metagpt/utils/parse_docstring.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,23 @@
import re
from typing import Tuple

from pydantic import BaseModel


def remove_spaces(text):
return re.sub(r"\s+", " ", text).strip()


class DocstringParser(BaseModel):
docstring: str

def parse_desc(self) -> str:
"""Parse and return the description from the docstring."""

def parse_params(self) -> list[Tuple[str, str, str]]:
"""Parse and return the parameters from the docstring.
Returns:
list[Tuple[str, str, str]]: A list of input paramter info. Each info is a triple of (param name, param type, param description)
"""
class DocstringParser:
@staticmethod
def parse(docstring: str) -> Tuple[str, str]:
"""Parse the docstring and return the overall description and the parameter description.
def parse_returns(self) -> list[Tuple[str, str]]:
"""Parse and return the output information from the docstring.
Args:
docstring (str): The docstring to be parsed.
Returns:
list[Tuple[str, str]]: A list of output info. Each info is a tuple of (return type, return description)
Tuple[str, str]: A tuple of (overall description, parameter description)
"""

@staticmethod
def check_and_parse_optional(param_type: str) -> Tuple[bool, str]:
"""Check if a parameter is optional and return a processed param_type rid of the optionality info if so"""

@staticmethod
def check_and_parse_default_value(param_desc: str) -> Tuple[bool, str]:
"""Check if a parameter has a default value and return the default value if so"""

@staticmethod
def check_and_parse_enum(param_desc: str) -> Tuple[bool, str]:
"""Check if a parameter description includes an enum and return enum values if so"""


class reSTDocstringParser(DocstringParser):
"""A parser for reStructuredText (reST) docstring"""
Expand All @@ -48,40 +26,18 @@ class reSTDocstringParser(DocstringParser):
class GoogleDocstringParser(DocstringParser):
"""A parser for Google-stype docstring"""

docstring: str

def parse_desc(self) -> str:
description_match = re.search(r"^(.*?)(?:Args:|Returns:|Raises:|$)", self.docstring, re.DOTALL)
description = remove_spaces(description_match.group(1)) if description_match else ""
return description

def parse_params(self) -> list[Tuple[str, str, str]]:
args_match = re.search(r"Args:\s*(.*?)(?:Returns:|Raises:|$)", self.docstring, re.DOTALL)
_args = args_match.group(1).strip() if args_match else ""
# variable_pattern = re.compile(r"(\w+)\s*\((.*?)\):\s*(.*)")
variable_pattern = re.compile(
r"(\w+)\s*\((.*?)\):\s*(.*?)(?=\n\s*\w+\s*\(|\Z)", re.DOTALL
) # (?=\n\w+\s*\(|\Z) is to assert that what follows is either the start of the next parameter (indicated by a newline, some word characters, and an opening parenthesis) or the end of the string (\Z).
params = variable_pattern.findall(_args)
return params

def parse_returns(self) -> list[Tuple[str, str]]:
returns_match = re.search(r"Returns:\s*(.*?)(?:Raises:|$)", self.docstring, re.DOTALL)
returns = returns_match.group(1).strip() if returns_match else ""
return_pattern = re.compile(r"^(.*)\s*:\s*(.*)$")
returns = return_pattern.findall(returns)
return returns

@staticmethod
def check_and_parse_optional(param_type: str) -> Tuple[bool, str]:
return "optional" in param_type, param_type.replace(", optional", "")
def parse(docstring: str) -> Tuple[str, str]:
if not docstring:
return "", ""

@staticmethod
def check_and_parse_default_value(param_desc: str) -> Tuple[bool, str]:
default_val = re.search(r"Defaults to (.+?)\.", param_desc)
return (True, default_val.group(1)) if default_val else (False, "")
docstring = remove_spaces(docstring)

@staticmethod
def check_and_parse_enum(param_desc: str) -> Tuple[bool, str]:
enum_val = re.search(r"Enum: \[(.+?)\]", param_desc)
return (True, [e.strip() for e in enum_val.group(1).split(",")]) if enum_val else (False, [])
if "Args:" in docstring:
overall_desc, param_desc = docstring.split("Args:")
param_desc = "Args:" + param_desc
else:
overall_desc = docstring
param_desc = ""

return overall_desc, param_desc

0 comments on commit bf4b13e

Please sign in to comment.