Skip to content

Commit

Permalink
Use the PyDoc trick instead of separate schema functions
Browse files Browse the repository at this point in the history
  • Loading branch information
nicovank committed May 22, 2024
1 parent bd67121 commit 2dc1068
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 35 deletions.
36 changes: 18 additions & 18 deletions src/cwhy/conversation/diff_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ def __init__(self, args: argparse.Namespace):

def as_tools(self):
return self.explain_functions.as_tools() + [
{"type": "function", "function": schema}
for schema in [
self.apply_modification_schema(),
self.try_compiling_schema(),
{"type": "function", "function": json.loads(f.__doc__)}
for f in [
self.apply_modification,
self.try_compiling,
]
]

Expand All @@ -45,8 +45,15 @@ def dispatch(self, function_call) -> Optional[str]:
traceback.print_exc()
return None

def apply_modification_schema(self):
return {
def apply_modification(
self,
filename: str,
start_line_number: int,
number_lines_remove: int,
replacement: str,
) -> Optional[str]:
"""
{
"name": "apply_modification",
"description": "Applies a single modification to the source file with the goal of fixing any existing compilation errors.",
"parameters": {
Expand Down Expand Up @@ -77,14 +84,7 @@ def apply_modification_schema(self):
],
},
}

def apply_modification(
self,
filename: str,
start_line_number: int,
number_lines_remove: int,
replacement: str,
) -> Optional[str]:
"""
with open(filename, "r") as f:
lines = [line.rstrip() for line in f.readlines()]

Expand Down Expand Up @@ -118,13 +118,13 @@ def apply_modification(
f.write("\n".join(lines))
return "Modification applied."

def try_compiling_schema(self):
return {
def try_compiling(self) -> Optional[str]:
"""
{
"name": "try_compiling",
"description": "Attempts to compile the code again after the user has made changes. Returns the new error message if there is one.",
}

def try_compiling(self) -> Optional[str]:
"""
process = subprocess.run(
self.args.command,
stdout=subprocess.PIPE,
Expand Down
34 changes: 17 additions & 17 deletions src/cwhy/conversation/explain_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ def __init__(self, args: argparse.Namespace):

def as_tools(self):
return [
{"type": "function", "function": schema}
for schema in [
self.get_compile_or_run_command_schema(),
self.get_code_surrounding_schema(),
self.list_directory_schema(),
{"type": "function", "function": json.loads(f.__doc__)}
for f in [
self.get_compile_or_run_command,
self.get_code_surrounding,
self.list_directory,
]
]

Expand All @@ -40,19 +40,20 @@ def dispatch(self, function_call) -> Optional[str]:
dprint(e)
return None

def get_compile_or_run_command_schema(self):
return {
def get_compile_or_run_command(self) -> str:
"""
{
"name": "get_compile_or_run_command",
"description": "Returns the command used to compile or run the code. This will include any flags and options used.",
}

def get_compile_or_run_command(self) -> str:
"""
result = " ".join(self.args.command)
dprint(result)
return result

def get_code_surrounding_schema(self):
return {
def get_code_surrounding(self, filename: str, lineno: int) -> str:
"""
{
"name": "get_code_surrounding",
"description": "Returns the code in the given file surrounding and including the provided line number.",
"parameters": {
Expand All @@ -70,15 +71,15 @@ def get_code_surrounding_schema(self):
"required": ["filename", "lineno"],
},
}

def get_code_surrounding(self, filename: str, lineno: int) -> str:
"""
(lines, first) = llm_utils.read_lines(filename, lineno - 7, lineno + 3)
result = llm_utils.number_group_of_lines(lines, first)
dprint(result)
return result

def list_directory_schema(self):
return {
def list_directory(self, path: str) -> str:
"""
{
"name": "list_directory",
"description": "Returns a list of all files and directories in the given directory.",
"parameters": {
Expand All @@ -92,8 +93,7 @@ def list_directory_schema(self):
"required": ["path"],
},
}

def list_directory(self, path: str) -> str:
"""
entries = os.listdir(path)
for i in range(len(entries)):
if os.path.isdir(os.path.join(path, entries[i])):
Expand Down

0 comments on commit 2dc1068

Please sign in to comment.