Skip to content

Commit

Permalink
Add debug printing support for better debugging and tracing
Browse files Browse the repository at this point in the history
  • Loading branch information
nicovank committed Feb 6, 2024
1 parent 5455c0b commit f7c4b0b
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 47 deletions.
5 changes: 5 additions & 0 deletions src/cwhy/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ def main() -> None:
action="store_true",
help="when enabled, only print prompt and exit (for debugging purposes)",
)
parser.add_argument(
"--debug",
action="store_true",
help=argparse.SUPPRESS,
)
parser.add_argument(
"--wrapper",
action="store_true",
Expand Down
7 changes: 4 additions & 3 deletions src/cwhy/conversation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from . import utils
from .diff_functions import DiffFunctions
from .explain_functions import ExplainFunctions
from ..print_debug import dprint


def converse(args, diagnostic):
Expand Down Expand Up @@ -47,12 +48,12 @@ def converse(args, diagnostic):
"content": function_response,
}
)
print()
dprint()
elif choice.finish_reason == "stop":
text = completion.choices[0].message.content
return llm_utils.word_wrap_except_code_blocks(text)
else:
print(f"Not found: {choice.finish_reason}.")
dprint(f"Not found: {choice.finish_reason}.")


def diff_converse(args, diagnostic):
Expand Down Expand Up @@ -128,4 +129,4 @@ def diff_converse(args, diagnostic):
}
)

print()
dprint()
11 changes: 6 additions & 5 deletions src/cwhy/conversation/diff_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from . import utils
from .explain_functions import ExplainFunctions
from ..print_debug import dprint


class DiffFunctions:
Expand All @@ -28,15 +29,15 @@ def dispatch(self, function_call) -> Optional[str]:
arguments = json.loads(function_call.arguments)
try:
if function_call.name == "apply_modification":
print("Calling: apply_modification(...)")
dprint("Calling: apply_modification(...)")
return self.apply_modification(
arguments["filename"],
arguments["start-line-number"],
arguments["number-lines-remove"],
arguments["replacement"],
)
elif function_call.name == "try_compiling":
print("Calling: try_compiling()")
dprint("Calling: try_compiling()")
return self.try_compiling()
else:
return self.explain_functions.dispatch(function_call)
Expand Down Expand Up @@ -106,9 +107,9 @@ def apply_modification(
whitespace = replaced_line[:n]
replacement_lines[0] = whitespace + replacement_lines[0]

print("CWhy wants to do the following modification:")
dprint("CWhy wants to do the following modification:")
for line in difflib.unified_diff(replaced_lines, replacement_lines):
print(line)
dprint(line)
if not input("Is this modification okay? (y/n) ") == "y":
return "The user declined this modification, it is probably wrong."

Expand All @@ -132,7 +133,7 @@ def try_compiling(self) -> Optional[str]:
)

if process.returncode == 0:
print("Compilation successful!")
dprint("Compilation successful!")
sys.exit(0)

return utils.get_truncated_error_message(self.args, process.stderr)
10 changes: 6 additions & 4 deletions src/cwhy/conversation/explain_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import llm_utils

from ..print_debug import dprint


class ExplainFunctions:
def __init__(self, args: argparse.Namespace):
Expand All @@ -22,7 +24,7 @@ def as_tools(self):

def dispatch(self, function_call) -> Optional[str]:
arguments = json.loads(function_call.arguments)
print(
dprint(
f"Calling: {function_call.name}({', '.join([f'{k}={v}' for k, v in arguments.items()])})"
)
try:
Expand All @@ -35,7 +37,7 @@ def dispatch(self, function_call) -> Optional[str]:
elif function_call.name == "list_directory":
return self.list_directory(arguments["path"])
except Exception as e:
print(e)
dprint(e)
return None

def get_compile_or_run_command_schema(self):
Expand All @@ -46,7 +48,7 @@ def get_compile_or_run_command_schema(self):

def get_compile_or_run_command(self) -> str:
result = " ".join(self.args.command)
print(result)
dprint(result)
return result

def get_code_surrounding_schema(self):
Expand All @@ -72,7 +74,7 @@ def get_code_surrounding_schema(self):
def get_code_surrounding(self, filename: str, lineno: int) -> str:
(lines, first) = llm_utils.read_lines(filename, lineno - 7, lineno + 3)
result = llm_utils.number_group_of_lines(lines, first)
print(result)
dprint(result)
return result

def list_directory_schema(self):
Expand Down
74 changes: 40 additions & 34 deletions src/cwhy/cwhy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)

from . import conversation, prompts
from .print_debug import dprint, enable_debug_printing


# Turn off most logging
Expand All @@ -23,21 +24,23 @@


def print_key_info():
print("You need a key (or keys) from an AI service to use CWhy.")
print()
print("OpenAI:")
print(" You can get a key here: https://platform.openai.com/api-keys")
print(" Set the environment variable OPENAI_API_KEY to your key value:")
print(" export OPENAI_API_KEY=<your key>")
print()
print("Bedrock:")
print(" To use Bedrock, you need an AWS account.")
print(" Set the following environment variables:")
print(" export AWS_ACCESS_KEY_ID=<your key id>")
print(" export AWS_SECRET_ACCESS_KEY=<your secret key>")
print(" export AWS_REGION_NAME=us-west-2")
print(" You also need to request access to Claude:")
print(" https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html#manage-model-access")
dprint("You need a key (or keys) from an AI service to use CWhy.")
dprint()
dprint("OpenAI:")
dprint(" You can get a key here: https://platform.openai.com/api-keys")
dprint(" Set the environment variable OPENAI_API_KEY to your key value:")
dprint(" export OPENAI_API_KEY=<your key>")
dprint()
dprint("Bedrock:")
dprint(" To use Bedrock, you need an AWS account.")
dprint(" Set the following environment variables:")
dprint(" export AWS_ACCESS_KEY_ID=<your key id>")
dprint(" export AWS_SECRET_ACCESS_KEY=<your secret key>")
dprint(" export AWS_REGION_NAME=us-west-2")
dprint(" You also need to request access to Claude:")
dprint(
" https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html#manage-model-access"
)


_DEFAULT_FALLBACK_MODELS = []
Expand All @@ -64,17 +67,17 @@ def complete(args, user_prompt, **kwargs):
)
return completion
except NotFoundError as e:
print(f"'{args.llm}' either does not exist or you do not have access to it.")
dprint(f"'{args.llm}' either does not exist or you do not have access to it.")
raise e
except BadRequestError as e:
print("Something is wrong with your prompt.")
dprint("Something is wrong with your prompt.")
raise e
except RateLimitError as e:
print("You have exceeded a rate limit or have no remaining funds.")
dprint("You have exceeded a rate limit or have no remaining funds.")
raise e
except APITimeoutError as e:
print("The API timed out.")
print("You can increase the timeout with the --timeout option.")
dprint("The API timed out.")
dprint("You can increase the timeout with the --timeout option.")
raise e


Expand Down Expand Up @@ -128,7 +131,7 @@ def evaluate_diff(args, stdin):
def evaluate_with_fallback(args, stdin):
for i, model in enumerate(_DEFAULT_FALLBACK_MODELS):
if i != 0:
print(f"Falling back to {model}...")
dprint(f"Falling back to {model}...")
args.llm = model
try:
return evaluate(args, stdin)
Expand Down Expand Up @@ -166,29 +169,32 @@ def main(args: argparse.Namespace) -> None:
if process.returncode == 0:
return

if args.debug:
enable_debug_printing()

if args.show_prompt:
print("===================== Prompt =====================")
dprint("===================== Prompt =====================")
if args.llm == "default":
args.llm = _DEFAULT_FALLBACK_MODELS[0]
if args.subcommand == "explain":
print(prompts.explain_prompt(args, process.stderr))
dprint(prompts.explain_prompt(args, process.stderr))
elif args.subcommand == "diff":
print(prompts.diff_prompt(args, process.stderr))
print("==================================================")
dprint(prompts.diff_prompt(args, process.stderr))
dprint("==================================================")
sys.exit(0)

print(process.stdout, end="")
print(process.stderr, file=sys.stderr, end="")
print("==================================================")
print("CWhy")
print("==================================================")
dprint(process.stdout)
dprint(process.stderr, file=sys.stderr)
dprint("==================================================")
dprint("CWhy")
dprint("==================================================")
try:
result = evaluate(args, process.stderr if process.stderr else process.stdout)
print(result)
dprint(result)
except OpenAIError:
print_key_info()
sys.exit(1)
print("==================================================")
dprint("==================================================")

sys.exit(process.returncode)

Expand All @@ -197,8 +203,8 @@ def evaluate_text_prompt(args, prompt, wrap=True, **kwargs):
completion = complete(args, prompt, **kwargs)

msg = f"Analysis from {args.llm}:"
print(msg)
print("-" * len(msg))
dprint(msg)
dprint("-" * len(msg))
text = completion.choices[0].message.content

if wrap:
Expand Down
30 changes: 30 additions & 0 deletions src/cwhy/print_debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import inspect

_debug = False

_INFO_WIDTH = 30


def dprint(*objects, sep=" ", end="\n", file=None, flush=False):
if not _debug:
return print(*objects, sep=sep, end=end, file=file, flush=flush)

frame = inspect.currentframe()
caller = inspect.getouterframes(frame, 2)
filename = caller[1].filename
lineno = caller[1].lineno

info = f"{filename}:{lineno}"
if len(info) > _INFO_WIDTH:
info = f"...{info[-_INFO_WIDTH + 3:]}"
info = info.ljust(_INFO_WIDTH - 2)
info = f"{info} |"

message = sep.join(map(str, objects)) + end
for line in message.splitlines():
print(info, line, file=file, flush=flush)


def enable_debug_printing():
global _debug
_debug = True
4 changes: 3 additions & 1 deletion src/cwhy/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import llm_utils

from .print_debug import dprint


# Define error patterns with associated information. The numbers
# correspond to the groups matching file name and line number.
Expand Down Expand Up @@ -72,7 +74,7 @@ def __init__(self, args: argparse.Namespace, diagnostic: str):
file_name, line_number - 7, line_number + 3
)
except FileNotFoundError:
print(
dprint(
f"Cwhy warning: file not found: {file_name.lstrip()}",
file=sys.stderr,
)
Expand Down

0 comments on commit f7c4b0b

Please sign in to comment.