diff --git a/src/cwhy/__main__.py b/src/cwhy/__main__.py index 11aaed6..04933cd 100755 --- a/src/cwhy/__main__.py +++ b/src/cwhy/__main__.py @@ -163,6 +163,11 @@ def main() -> None: action="store_true", help="when enabled, only print prompt and exit (for debugging purposes)", ) + parser.add_argument( + "--debug", + action="store_true", + help=argparse.SUPPRESS, + ) parser.add_argument( "--wrapper", action="store_true", diff --git a/src/cwhy/conversation/__init__.py b/src/cwhy/conversation/__init__.py index 636cfd5..67123aa 100644 --- a/src/cwhy/conversation/__init__.py +++ b/src/cwhy/conversation/__init__.py @@ -7,6 +7,7 @@ from . import utils from .diff_functions import DiffFunctions from .explain_functions import ExplainFunctions +from ..print_debug import dprint def converse(args, diagnostic): @@ -47,12 +48,12 @@ def converse(args, diagnostic): "content": function_response, } ) - print() + dprint() elif choice.finish_reason == "stop": text = completion.choices[0].message.content return llm_utils.word_wrap_except_code_blocks(text) else: - print(f"Not found: {choice.finish_reason}.") + dprint(f"Not found: {choice.finish_reason}.") def diff_converse(args, diagnostic): @@ -128,4 +129,4 @@ def diff_converse(args, diagnostic): } ) - print() + dprint() diff --git a/src/cwhy/conversation/diff_functions.py b/src/cwhy/conversation/diff_functions.py index 8bcf702..3ed77d8 100644 --- a/src/cwhy/conversation/diff_functions.py +++ b/src/cwhy/conversation/diff_functions.py @@ -8,6 +8,7 @@ from . import utils from .explain_functions import ExplainFunctions +from ..print_debug import dprint class DiffFunctions: @@ -28,7 +29,7 @@ def dispatch(self, function_call) -> Optional[str]: arguments = json.loads(function_call.arguments) try: if function_call.name == "apply_modification": - print("Calling: apply_modification(...)") + dprint("Calling: apply_modification(...)") return self.apply_modification( arguments["filename"], arguments["start-line-number"], @@ -36,7 +37,7 @@ def dispatch(self, function_call) -> Optional[str]: arguments["replacement"], ) elif function_call.name == "try_compiling": - print("Calling: try_compiling()") + dprint("Calling: try_compiling()") return self.try_compiling() else: return self.explain_functions.dispatch(function_call) @@ -106,9 +107,9 @@ def apply_modification( whitespace = replaced_line[:n] replacement_lines[0] = whitespace + replacement_lines[0] - print("CWhy wants to do the following modification:") + dprint("CWhy wants to do the following modification:") for line in difflib.unified_diff(replaced_lines, replacement_lines): - print(line) + dprint(line) if not input("Is this modification okay? (y/n) ") == "y": return "The user declined this modification, it is probably wrong." @@ -132,7 +133,7 @@ def try_compiling(self) -> Optional[str]: ) if process.returncode == 0: - print("Compilation successful!") + dprint("Compilation successful!") sys.exit(0) return utils.get_truncated_error_message(self.args, process.stderr) diff --git a/src/cwhy/conversation/explain_functions.py b/src/cwhy/conversation/explain_functions.py index 69b7596..c7098b6 100644 --- a/src/cwhy/conversation/explain_functions.py +++ b/src/cwhy/conversation/explain_functions.py @@ -5,6 +5,8 @@ import llm_utils +from ..print_debug import dprint + class ExplainFunctions: def __init__(self, args: argparse.Namespace): @@ -22,7 +24,7 @@ def as_tools(self): def dispatch(self, function_call) -> Optional[str]: arguments = json.loads(function_call.arguments) - print( + dprint( f"Calling: {function_call.name}({', '.join([f'{k}={v}' for k, v in arguments.items()])})" ) try: @@ -35,7 +37,7 @@ def dispatch(self, function_call) -> Optional[str]: elif function_call.name == "list_directory": return self.list_directory(arguments["path"]) except Exception as e: - print(e) + dprint(e) return None def get_compile_or_run_command_schema(self): @@ -46,7 +48,7 @@ def get_compile_or_run_command_schema(self): def get_compile_or_run_command(self) -> str: result = " ".join(self.args.command) - print(result) + dprint(result) return result def get_code_surrounding_schema(self): @@ -72,7 +74,7 @@ def get_code_surrounding_schema(self): def get_code_surrounding(self, filename: str, lineno: int) -> str: (lines, first) = llm_utils.read_lines(filename, lineno - 7, lineno + 3) result = llm_utils.number_group_of_lines(lines, first) - print(result) + dprint(result) return result def list_directory_schema(self): diff --git a/src/cwhy/cwhy.py b/src/cwhy/cwhy.py index 931262a..72d62f6 100755 --- a/src/cwhy/cwhy.py +++ b/src/cwhy/cwhy.py @@ -15,6 +15,7 @@ ) from . import conversation, prompts +from .print_debug import dprint, enable_debug_printing # Turn off most logging @@ -23,21 +24,23 @@ def print_key_info(): - print("You need a key (or keys) from an AI service to use CWhy.") - print() - print("OpenAI:") - print(" You can get a key here: https://platform.openai.com/api-keys") - print(" Set the environment variable OPENAI_API_KEY to your key value:") - print(" export OPENAI_API_KEY=") - print() - print("Bedrock:") - print(" To use Bedrock, you need an AWS account.") - print(" Set the following environment variables:") - print(" export AWS_ACCESS_KEY_ID=") - print(" export AWS_SECRET_ACCESS_KEY=") - print(" export AWS_REGION_NAME=us-west-2") - print(" You also need to request access to Claude:") - print(" https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html#manage-model-access") + dprint("You need a key (or keys) from an AI service to use CWhy.") + dprint() + dprint("OpenAI:") + dprint(" You can get a key here: https://platform.openai.com/api-keys") + dprint(" Set the environment variable OPENAI_API_KEY to your key value:") + dprint(" export OPENAI_API_KEY=") + dprint() + dprint("Bedrock:") + dprint(" To use Bedrock, you need an AWS account.") + dprint(" Set the following environment variables:") + dprint(" export AWS_ACCESS_KEY_ID=") + dprint(" export AWS_SECRET_ACCESS_KEY=") + dprint(" export AWS_REGION_NAME=us-west-2") + dprint(" You also need to request access to Claude:") + dprint( + " https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html#manage-model-access" + ) _DEFAULT_FALLBACK_MODELS = [] @@ -64,17 +67,17 @@ def complete(args, user_prompt, **kwargs): ) return completion except NotFoundError as e: - print(f"'{args.llm}' either does not exist or you do not have access to it.") + dprint(f"'{args.llm}' either does not exist or you do not have access to it.") raise e except BadRequestError as e: - print("Something is wrong with your prompt.") + dprint("Something is wrong with your prompt.") raise e except RateLimitError as e: - print("You have exceeded a rate limit or have no remaining funds.") + dprint("You have exceeded a rate limit or have no remaining funds.") raise e except APITimeoutError as e: - print("The API timed out.") - print("You can increase the timeout with the --timeout option.") + dprint("The API timed out.") + dprint("You can increase the timeout with the --timeout option.") raise e @@ -128,7 +131,7 @@ def evaluate_diff(args, stdin): def evaluate_with_fallback(args, stdin): for i, model in enumerate(_DEFAULT_FALLBACK_MODELS): if i != 0: - print(f"Falling back to {model}...") + dprint(f"Falling back to {model}...") args.llm = model try: return evaluate(args, stdin) @@ -166,29 +169,32 @@ def main(args: argparse.Namespace) -> None: if process.returncode == 0: return + if args.debug: + enable_debug_printing() + if args.show_prompt: - print("===================== Prompt =====================") + dprint("===================== Prompt =====================") if args.llm == "default": args.llm = _DEFAULT_FALLBACK_MODELS[0] if args.subcommand == "explain": - print(prompts.explain_prompt(args, process.stderr)) + dprint(prompts.explain_prompt(args, process.stderr)) elif args.subcommand == "diff": - print(prompts.diff_prompt(args, process.stderr)) - print("==================================================") + dprint(prompts.diff_prompt(args, process.stderr)) + dprint("==================================================") sys.exit(0) - print(process.stdout, end="") - print(process.stderr, file=sys.stderr, end="") - print("==================================================") - print("CWhy") - print("==================================================") + dprint(process.stdout) + dprint(process.stderr, file=sys.stderr) + dprint("==================================================") + dprint("CWhy") + dprint("==================================================") try: result = evaluate(args, process.stderr if process.stderr else process.stdout) - print(result) + dprint(result) except OpenAIError: print_key_info() sys.exit(1) - print("==================================================") + dprint("==================================================") sys.exit(process.returncode) @@ -197,8 +203,8 @@ def evaluate_text_prompt(args, prompt, wrap=True, **kwargs): completion = complete(args, prompt, **kwargs) msg = f"Analysis from {args.llm}:" - print(msg) - print("-" * len(msg)) + dprint(msg) + dprint("-" * len(msg)) text = completion.choices[0].message.content if wrap: diff --git a/src/cwhy/print_debug.py b/src/cwhy/print_debug.py new file mode 100644 index 0000000..58a557b --- /dev/null +++ b/src/cwhy/print_debug.py @@ -0,0 +1,30 @@ +import inspect + +_debug = False + +_INFO_WIDTH = 30 + + +def dprint(*objects, sep=" ", end="\n", file=None, flush=False): + if not _debug: + return print(*objects, sep=sep, end=end, file=file, flush=flush) + + frame = inspect.currentframe() + caller = inspect.getouterframes(frame, 2) + filename = caller[1].filename + lineno = caller[1].lineno + + info = f"{filename}:{lineno}" + if len(info) > _INFO_WIDTH: + info = f"...{info[-_INFO_WIDTH + 3:]}" + info = info.ljust(_INFO_WIDTH - 2) + info = f"{info} |" + + message = sep.join(map(str, objects)) + end + for line in message.splitlines(): + print(info, line, file=file, flush=flush) + + +def enable_debug_printing(): + global _debug + _debug = True diff --git a/src/cwhy/prompts.py b/src/cwhy/prompts.py index 832626f..0ab7e9e 100644 --- a/src/cwhy/prompts.py +++ b/src/cwhy/prompts.py @@ -6,6 +6,8 @@ import llm_utils +from .print_debug import dprint + # Define error patterns with associated information. The numbers # correspond to the groups matching file name and line number. @@ -72,7 +74,7 @@ def __init__(self, args: argparse.Namespace, diagnostic: str): file_name, line_number - 7, line_number + 3 ) except FileNotFoundError: - print( + dprint( f"Cwhy warning: file not found: {file_name.lstrip()}", file=sys.stderr, )