Skip to content

Commit

Permalink
diff-converse: Add conversation + interactive mode (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicovank authored Feb 5, 2024
1 parent 8b2fd37 commit 9626a3d
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 55 deletions.
9 changes: 5 additions & 4 deletions src/cwhy/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,14 @@ def main() -> None:
"subcommand",
nargs="?",
default="explain",
choices=["explain", "diff", "converse"],
choices=["explain", "diff", "converse", "diff-converse"],
metavar="subcommand",
help=textwrap.dedent(
"""
explain: explain the diagnostic (default)
diff: \[experimental] generate a diff to fix the diagnostic
converse: \[experimental] interactively converse with CWhy
explain: explain the diagnostic (default)
diff: \[experimental] generate a diff to fix the diagnostic
converse: \[experimental] interactively converse with CWhy
diff-converse: \[experimental] interactively fix errors with CWhy
"""
).strip(),
)
Expand Down
119 changes: 86 additions & 33 deletions src/cwhy/conversation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,16 @@
import argparse
import json
import textwrap

import litellm # type: ignore
import llm_utils

from . import functions
from . import utils
from .diff_functions import DiffFunctions
from .explain_functions import ExplainFunctions


def get_truncated_error_message(args: argparse.Namespace, diagnostic: str) -> str:
"""
Alternate taking front and back lines until the maximum number of tokens.
"""
front: list[str] = []
back: list[str] = []
diagnostic_lines = diagnostic.splitlines()
n = len(diagnostic_lines)

def build_diagnostic_string() -> str:
return "\n".join(front) + "\n\n[...]\n\n" + "\n".join(reversed(back)) + "\n"

for i in range(n):
if i % 2 == 0:
line = diagnostic_lines[i // 2]
list = front
else:
line = diagnostic_lines[n - i // 2 - 1]
list = back
list.append(line)
count = llm_utils.count_tokens(args.llm, build_diagnostic_string())
if count > args.max_error_tokens:
list.pop()
break
return build_diagnostic_string()


def converse(client, args, diagnostic):
fns = functions.Functions(args, diagnostic)
def converse(args, diagnostic):
fns = ExplainFunctions(args)
available_functions_names = [fn["function"]["name"] for fn in fns.as_tools()]
system_message = textwrap.dedent(
f"""
Expand All @@ -44,14 +20,14 @@ def converse(client, args, diagnostic):
Once you have identified the problem, explain the diagnostic and provide a way to fix the issue if you can.
"""
).strip()
user_message = f"Here is my error message:\n\n```\n{get_truncated_error_message(args, diagnostic)}\n```\n\nWhat's the problem?"
user_message = f"Here is my error message:\n\n```\n{utils.get_truncated_error_message(args, diagnostic)}\n```\n\nWhat's the problem?"
conversation = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
]

while True:
completion = client.chat.completions.create(
completion = litellm.completion(
model=args.llm,
messages=conversation,
tools=fns.as_tools(),
Expand All @@ -71,8 +47,85 @@ def converse(client, args, diagnostic):
"content": function_response,
}
)
print()
elif choice.finish_reason == "stop":
text = completion.choices[0].message.content
return llm_utils.word_wrap_except_code_blocks(text)
else:
print(f"Not found: {choice.finish_reason}.")


def diff_converse(args, diagnostic):
fns = DiffFunctions(args)
tools = fns.as_tools()
tool_names = [fn["function"]["name"] for fn in tools]
system_message = textwrap.dedent(
f"""
You are an assistant programmer. The user is having an issue with their code, and you are trying to help them fix the code.
You may only call the following available functions: {", ".join(tool_names)}.
Your task is done only when the program can successfully compile and/or run, call as many functions as needed to reach this goal.
"""
).strip()
user_message = f"Here is my error message:\n\n```\n{utils.get_truncated_error_message(args, diagnostic)}\n```\n\nPlease help me fix it."
conversation = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
]

pick_action_schema = {
"name": "pick_action",
"description": "Picks an action to take to get more information about or fix the code.",
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": tool_names,
},
},
"required": ["action"],
},
}

while True:
# 1. Pick an action.
completion = litellm.completion(
model=args.llm,
messages=conversation,
tools=[{"type": "function", "function": pick_action_schema}],
tool_choice={
"type": "function",
"function": {"name": "pick_action"},
},
)

fn = completion.choices[0].message.tool_calls[0].function
arguments = json.loads(fn.arguments)
action = arguments["action"]

tool = [t for t in tools if t["function"]["name"] == action][0]
completion = litellm.completion(
model=args.llm,
messages=conversation,
tools=[tool],
tool_choice={
"type": "function",
"function": {"name": tool["function"]["name"]},
},
)

choice = completion.choices[0]
tool_call = choice.message.tool_calls[0]
function_response = fns.dispatch(tool_call.function)
if function_response:
conversation.append(choice.message)
conversation.append(
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": tool_call.function.name,
"content": function_response,
}
)

print()
138 changes: 138 additions & 0 deletions src/cwhy/conversation/diff_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import argparse
import difflib
import json
import subprocess
import sys
import traceback
from typing import Optional

from . import utils
from .explain_functions import ExplainFunctions


class DiffFunctions:
def __init__(self, args: argparse.Namespace):
self.args = args
self.explain_functions = ExplainFunctions(args)

def as_tools(self):
return self.explain_functions.as_tools() + [
{"type": "function", "function": schema}
for schema in [
self.apply_modification_schema(),
self.try_compiling_schema(),
]
]

def dispatch(self, function_call) -> Optional[str]:
arguments = json.loads(function_call.arguments)
try:
if function_call.name == "apply_modification":
print("Calling: apply_modification(...)")
return self.apply_modification(
arguments["filename"],
arguments["start-line-number"],
arguments["number-lines-remove"],
arguments["replacement"],
)
elif function_call.name == "try_compiling":
print("Calling: try_compiling()")
return self.try_compiling()
else:
return self.explain_functions.dispatch(function_call)
except Exception:
traceback.print_exc()
return None

def apply_modification_schema(self):
return {
"name": "apply_modification",
"description": "Applies a single modification to the source file with the goal of fixing any existing compilation errors.",
"parameters": {
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "The filename to modify.",
},
"start-line-number": {
"type": "integer",
"description": "The line number to start replacing at.",
},
"number-lines-remove": {
"type": "integer",
"description": "The number of lines to remove, which can be zero to only add new code.",
},
"replacement": {
"type": "string",
"description": "The replacement code, which can be blank to simply remove lines.",
},
},
"required": [
"filename",
"start-line-number",
"number-lines-remove",
"replacement",
],
},
}

def apply_modification(
self,
filename: str,
start_line_number: int,
number_lines_remove: int,
replacement: str,
) -> Optional[str]:
with open(filename, "r") as f:
lines = [line.rstrip() for line in f.readlines()]

pre_lines = lines[: start_line_number - 1]
replacement_lines = replacement.splitlines()
replaced_lines = lines[
start_line_number - 1 : start_line_number + number_lines_remove - 1
]
post_lines = lines[start_line_number + number_lines_remove - 1 :]

# If replacing a single line, make sure we keep indentation.
if (
number_lines_remove == 1
and len(replacement_lines) == 1
and start_line_number >= 1
):
replaced_line = lines[start_line_number - 1]
replacement_lines[0] = replacement_lines[0].lstrip()
n = len(replaced_line) - len(replaced_line.lstrip())
whitespace = replaced_line[:n]
replacement_lines[0] = whitespace + replacement_lines[0]

print("CWhy wants to do the following modification:")
for line in difflib.unified_diff(replaced_lines, replacement_lines):
print(line)
if not input("Is this modification okay? (y/n) ") == "y":
return "The user declined this modification, it is probably wrong."

lines = pre_lines + replacement_lines + post_lines
with open(filename, "w") as f:
f.write("\n".join(lines))
return "Modification applied."

def try_compiling_schema(self):
return {
"name": "try_compiling",
"description": "Attempts to compile the code again after the user has made changes. Returns the new error message if there is one.",
}

def try_compiling(self) -> Optional[str]:
process = subprocess.run(
self.args.command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)

if process.returncode == 0:
print("Compilation successful!")
sys.exit(0)

return utils.get_truncated_error_message(self.args, process.stderr)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import llm_utils


class Functions:
class ExplainFunctions:
def __init__(self, args: argparse.Namespace):
self.args = args

Expand Down Expand Up @@ -45,7 +45,9 @@ def get_compile_or_run_command_schema(self):
}

def get_compile_or_run_command(self) -> str:
return " ".join(self.args.command)
result = " ".join(self.args.command)
print(result)
return result

def get_code_surrounding_schema(self):
return {
Expand All @@ -69,7 +71,9 @@ def get_code_surrounding_schema(self):

def get_code_surrounding(self, filename: str, lineno: int) -> str:
(lines, first) = llm_utils.read_lines(filename, lineno - 7, lineno + 3)
return llm_utils.number_group_of_lines(lines, first)
result = llm_utils.number_group_of_lines(lines, first)
print(result)
return result

def list_directory_schema(self):
return {
Expand Down
30 changes: 30 additions & 0 deletions src/cwhy/conversation/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import argparse

import llm_utils


def get_truncated_error_message(args: argparse.Namespace, diagnostic: str) -> str:
"""
Alternate taking front and back lines until the maximum number of tokens.
"""
front: list[str] = []
back: list[str] = []
diagnostic_lines = diagnostic.splitlines()
n = len(diagnostic_lines)

def build_diagnostic_string() -> str:
return "\n".join(front) + "\n\n[...]\n\n" + "\n".join(reversed(back)) + "\n"

for i in range(n):
if i % 2 == 0:
line = diagnostic_lines[i // 2]
list = front
else:
line = diagnostic_lines[n - i // 2 - 1]
list = back
list.append(line)
count = llm_utils.count_tokens(args.llm, build_diagnostic_string())
if count > args.max_error_tokens:
list.pop()
break
return build_diagnostic_string()
Loading

0 comments on commit 9626a3d

Please sign in to comment.