diff --git a/sweepai/core/chat.py b/sweepai/core/chat.py index 73a64d4dd9..a04053b828 100644 --- a/sweepai/core/chat.py +++ b/sweepai/core/chat.py @@ -422,7 +422,8 @@ def llm_stream(): try: save_messages_for_visualization( self.messages + [Message(role="assistant", content=response)], - use_openai=use_openai + use_openai=use_openai, + model_name=model ) except Exception as e: logger.warning(f"Error saving messages for visualization: {e}") @@ -479,7 +480,8 @@ def llm_stream(): try: save_messages_for_visualization( self.messages + [Message(role="assistant", content=response)], - use_openai=use_openai + use_openai=use_openai, + model_name=model ) except Exception as e: logger.warning(f"Error saving messages for visualization: {e}") @@ -621,7 +623,7 @@ def call_anthropic( if verbose: logger.debug(f'{"Openai" if use_openai else "Anthropic"} response: {self.messages[-1].content}') try: - save_messages_for_visualization(messages=self.messages, use_openai=use_openai) + save_messages_for_visualization(messages=self.messages, use_openai=use_openai, model_name=model) except Exception as e: logger.exception(f"Failed to save messages for visualization due to {e}") self.prev_message_states.append(self.messages) diff --git a/sweepai/core/viz_utils.py b/sweepai/core/viz_utils.py index c8b9d4ea10..d174eeea90 100644 --- a/sweepai/core/viz_utils.py +++ b/sweepai/core/viz_utils.py @@ -1,57 +1,53 @@ +import os from datetime import datetime from inspect import stack -import os -import re - -from loguru import logger from pytz import timezone +from loguru import logger from sweepai.core.entities import Message +import html +import re pst_timezone = timezone("US/Pacific") -def print_bar_chart(data: dict[str, list]): - total_length = sum(len(v) for v in data.values()) - max_bar_length = 50 +def wrap_xml_tags_with_details(text: str) -> str: + def process_tag(match): + full_tag = match.group(0) + is_closing = full_tag.startswith('' + else: + escaped_tag = html.escape(full_tag) + return f'
{escaped_tag}' + + processed_text = re.sub(r'<[^>]+>', process_tag, text) - # Sort the data based on the values in descending order - sorted_data = sorted(data.items(), key=lambda x: len(x[1]), reverse=True) + lines = processed_text.split('\n') + for i, line in enumerate(lines): + if not (line.strip().startswith(' str: - def replace_tag_pair(match): - tag = match.group(1) - content = match.group(2) - return f"
<{tag}>\n\n```xml\n{content}\n```\n\n
" - return re.sub(r'<([^>]+)>(.*?)', replace_tag_pair, text, flags=re.DOTALL) + return processed_text functions_to_unique_f_locals_string_getter = { "on_ticket": lambda x: "issue_" + str(x["issue_number"]), "review_pr": lambda x: "pr_" + str(x["pr"].number), "on_failing_github_actions": lambda x: "pr_" + str(x["pull_request"].number), -} # just need to add the function name and the lambda to get the unique f_locals +} -# these are common wrappers that we don't want to use as our caller_function_name llm_call_wrappers = ["continuous_llm_calls", "call_llm", "_bootstrap_inner"] -def save_messages_for_visualization(messages: list[Message], use_openai: bool): +def save_messages_for_visualization(messages: list[Message], use_openai: bool, model_name: str): current_datetime = datetime.now(pst_timezone) - current_year_month_day = current_datetime.strftime("%Y_%h_%d") + current_year_month_day = current_datetime.strftime("%Y_%m_%d") current_hour_minute_second = current_datetime.strftime("%I:%M:%S%p") subfolder = f"sweepai_messages/{current_year_month_day}" llm_type = "openai" if use_openai else "anthropic" os.makedirs(subfolder, exist_ok=True) - # goes up the stack to unify shared logs frames = stack() function_names = [frame.function for frame in frames] for i, function_name in enumerate(function_names): @@ -61,12 +57,10 @@ def save_messages_for_visualization(messages: list[Message], use_openai: bool): os.makedirs(subfolder, exist_ok=True) break else: - # terminate on the second to last item if i == len(function_names) - 2: - subfolder = os.path.join(subfolder, f"{function_name}_{current_hour_minute_second}") + subfolder = os.path.join(subfolder, f"{current_hour_minute_second}_{function_name}") os.makedirs(subfolder, exist_ok=True) - # finished going up the stack - + caller_function_name = "unknown" if len(function_names) < 2: caller_function_name = "unknown" @@ -75,29 +69,79 @@ def save_messages_for_visualization(messages: list[Message], use_openai: bool): caller_function_name = function_names[i] break - # add the current hour and minute to the caller function name caller_function_name = f"{current_hour_minute_second}_{caller_function_name}" raw_file = os.path.join(subfolder, f'{caller_function_name}.xml') - md_file = os.path.join(subfolder, f'{caller_function_name}.md') - # if the md/raw files exist, append _1, _2, etc. to the filename + html_file = os.path.join(subfolder, f'{caller_function_name}.html') for i in range(1, 1000): - if not os.path.exists(raw_file) and not os.path.exists(md_file): - break # we can safely use the current filename + if not os.path.exists(raw_file) and not os.path.exists(html_file): + break else: raw_file = os.path.join(subfolder, f'{caller_function_name}_{i}.xml') - md_file = os.path.join(subfolder, f'{caller_function_name}_{i}.md') + html_file = os.path.join(subfolder, f'{caller_function_name}_{i}.html') - with open(raw_file, 'w') as f_raw, open(md_file, 'w') as f_md: + with open(raw_file, 'w') as f_raw, open(html_file, 'w') as f_html: + f_html.write(''' + + + + + + Message Visualization + + + +

Message Visualization

+
+''') total_length = 0 - for message in messages: - content_raw = message.content - total_length += len(content_raw) - content_md = wrap_xml_tags_with_details(content_raw) - token_estimate_factor = 4 if use_openai else 3.5 - message_tokens = int(len(content_raw) // token_estimate_factor) - message_header = f"{llm_type} {message.role} - {message_tokens} tokens - {int(total_length // token_estimate_factor)} total tokens" - f_raw.write(f"{message_header}\n{content_raw}\n\n") - f_md.write(f"## {message_header}\n\n{content_md}\n\n") + for i, message in enumerate(messages): + try: + content_raw = message.content + total_length += len(content_raw) + content_html = wrap_xml_tags_with_details(content_raw) + token_estimate_factor = 4 if use_openai else 3.5 + message_tokens = int(len(content_raw) // token_estimate_factor) + message_header = f"{llm_type} {model_name} {message.role} - {message_tokens} tokens - {int(total_length // token_estimate_factor)} total tokens" + f_raw.write(f"{message_header}\n{content_raw}\n\n") + f_html.write(f'
{html.escape(message_header)}\n
{content_html}
\n
\n\n') + except Exception as e: + logger.error(f"Error processing message: {e}") + f_raw.write(f"Error in message processing: {e}\nRaw content: {content_raw}\n\n") + f_html.write(f'
Error in message processing\n
{html.escape(str(e))}\n{html.escape(content_raw)}
\n
\n\n') + + f_html.write('
') + cwd = os.getcwd() - logger.info(f"Messages saved to {os.path.join(cwd, raw_file)} and {os.path.join(cwd, md_file)}") \ No newline at end of file + logger.info(f"Messages saved to {os.path.join(cwd, raw_file)} and {os.path.join(cwd, html_file)}") diff --git a/tests/rerun_chat_modify_direct.py b/tests/rerun_chat_modify_direct.py index 81f62bb1a3..d381aff2ec 100644 --- a/tests/rerun_chat_modify_direct.py +++ b/tests/rerun_chat_modify_direct.py @@ -2,24 +2,28 @@ import os from sweepai.agents.modify import modify -from sweepai.config.server import GITHUB_APP_ID, GITHUB_APP_PEM from sweepai.core.entities import FileChangeRequest from sweepai.dataclasses.code_suggestions import CodeSuggestion -from sweepai.utils.github_utils import ClonedRepo, get_github_client, get_installation_id +from sweepai.utils.github_utils import MockClonedRepo -repo_name = os.environ["REPO_FULL_NAME"] +repo_full_name = os.environ["REPO_FULL_NAME"] branch = os.environ["BRANCH"] code_suggestions_path = os.environ["CODE_SUGGESTIONS_PATH"] - -org_name, repo = repo_name.split("/") -installation_id = get_installation_id(org_name, GITHUB_APP_PEM, GITHUB_APP_ID) -user_token, g = get_github_client(installation_id=installation_id) -cloned_repo = ClonedRepo( - repo_name, - installation_id=installation_id, - token=user_token, - branch=branch +REPO_DIR = os.environ["REPO_DIR"] + +# org_name, repo = repo_full_name.split("/") +# installation_id = get_installation_id(org_name, GITHUB_APP_PEM, GITHUB_APP_ID) +# user_token, g = get_github_client(installation_id=installation_id) +# cloned_repo = ClonedRepo( +# repo_full_name, +# installation_id=installation_id, +# token=user_token, +# branch=branch +# ) +cloned_repo = MockClonedRepo( + _repo_dir=REPO_DIR, + repo_full_name=repo_full_name, ) file_change_requests = [] @@ -44,7 +48,6 @@ ) try: - breakpoint() for stateful_code_suggestions in modify.stream( fcrs=file_change_requests, request="", @@ -53,5 +56,4 @@ ): pass except Exception as e: - raise e - + raise e \ No newline at end of file