Skip to content

Commit

Permalink
Feat/better tracing rendering (#4067)
Browse files Browse the repository at this point in the history
  • Loading branch information
wwzeng1 authored Jun 20, 2024
1 parent 06b92a8 commit e85e8c2
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 68 deletions.
8 changes: 5 additions & 3 deletions sweepai/core/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,8 @@ def llm_stream():
try:
save_messages_for_visualization(
self.messages + [Message(role="assistant", content=response)],
use_openai=use_openai
use_openai=use_openai,
model_name=model
)
except Exception as e:
logger.warning(f"Error saving messages for visualization: {e}")
Expand Down Expand Up @@ -479,7 +480,8 @@ def llm_stream():
try:
save_messages_for_visualization(
self.messages + [Message(role="assistant", content=response)],
use_openai=use_openai
use_openai=use_openai,
model_name=model
)
except Exception as e:
logger.warning(f"Error saving messages for visualization: {e}")
Expand Down Expand Up @@ -621,7 +623,7 @@ def call_anthropic(
if verbose:
logger.debug(f'{"Openai" if use_openai else "Anthropic"} response: {self.messages[-1].content}')
try:
save_messages_for_visualization(messages=self.messages, use_openai=use_openai)
save_messages_for_visualization(messages=self.messages, use_openai=use_openai, model_name=model)
except Exception as e:
logger.exception(f"Failed to save messages for visualization due to {e}")
self.prev_message_states.append(self.messages)
Expand Down
144 changes: 94 additions & 50 deletions sweepai/core/viz_utils.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,53 @@
import os
from datetime import datetime
from inspect import stack
import os
import re

from loguru import logger
from pytz import timezone
from loguru import logger
from sweepai.core.entities import Message
import html
import re

pst_timezone = timezone("US/Pacific")

def print_bar_chart(data: dict[str, list]):
total_length = sum(len(v) for v in data.values())
max_bar_length = 50
def wrap_xml_tags_with_details(text: str) -> str:
def process_tag(match):
full_tag = match.group(0)
is_closing = full_tag.startswith('</')

if is_closing:
return '</details>'
else:
escaped_tag = html.escape(full_tag)
return f'<details><summary>{escaped_tag}</summary>'

processed_text = re.sub(r'<[^>]+>', process_tag, text)

# Sort the data based on the values in descending order
sorted_data = sorted(data.items(), key=lambda x: len(x[1]), reverse=True)
lines = processed_text.split('\n')
for i, line in enumerate(lines):
if not (line.strip().startswith('<details') or line.strip().startswith('</details') or line.strip().startswith('<summary')):
lines[i] = html.escape(line)

# Find the length of the longest category name
max_category_length = max(len(key) for key in data.keys())
processed_text = '\n'.join(lines)

for key, value in sorted_data:
value = len(value)
ratio = value / total_length
bar_length = int(ratio * max_bar_length)
bar = '█' * bar_length
print(f"{key.ljust(max_category_length)} | {bar} {value}")

def wrap_xml_tags_with_details(text: str) -> str:
def replace_tag_pair(match):
tag = match.group(1)
content = match.group(2)
return f"<details><summary>&lt;{tag}&gt;</summary>\n\n```xml\n{content}\n```\n\n</details>"
return re.sub(r'<([^>]+)>(.*?)</\1>', replace_tag_pair, text, flags=re.DOTALL)
return processed_text

functions_to_unique_f_locals_string_getter = {
"on_ticket": lambda x: "issue_" + str(x["issue_number"]),
"review_pr": lambda x: "pr_" + str(x["pr"].number),
"on_failing_github_actions": lambda x: "pr_" + str(x["pull_request"].number),
} # just need to add the function name and the lambda to get the unique f_locals
}

# these are common wrappers that we don't want to use as our caller_function_name
llm_call_wrappers = ["continuous_llm_calls", "call_llm", "_bootstrap_inner"]

def save_messages_for_visualization(messages: list[Message], use_openai: bool):
def save_messages_for_visualization(messages: list[Message], use_openai: bool, model_name: str):
current_datetime = datetime.now(pst_timezone)
current_year_month_day = current_datetime.strftime("%Y_%h_%d")
current_year_month_day = current_datetime.strftime("%Y_%m_%d")
current_hour_minute_second = current_datetime.strftime("%I:%M:%S%p")
subfolder = f"sweepai_messages/{current_year_month_day}"
llm_type = "openai" if use_openai else "anthropic"

os.makedirs(subfolder, exist_ok=True)

# goes up the stack to unify shared logs
frames = stack()
function_names = [frame.function for frame in frames]
for i, function_name in enumerate(function_names):
Expand All @@ -61,12 +57,10 @@ def save_messages_for_visualization(messages: list[Message], use_openai: bool):
os.makedirs(subfolder, exist_ok=True)
break
else:
# terminate on the second to last item
if i == len(function_names) - 2:
subfolder = os.path.join(subfolder, f"{function_name}_{current_hour_minute_second}")
subfolder = os.path.join(subfolder, f"{current_hour_minute_second}_{function_name}")
os.makedirs(subfolder, exist_ok=True)
# finished going up the stack


caller_function_name = "unknown"
if len(function_names) < 2:
caller_function_name = "unknown"
Expand All @@ -75,29 +69,79 @@ def save_messages_for_visualization(messages: list[Message], use_openai: bool):
caller_function_name = function_names[i]
break

# add the current hour and minute to the caller function name
caller_function_name = f"{current_hour_minute_second}_{caller_function_name}"

raw_file = os.path.join(subfolder, f'{caller_function_name}.xml')
md_file = os.path.join(subfolder, f'{caller_function_name}.md')
# if the md/raw files exist, append _1, _2, etc. to the filename
html_file = os.path.join(subfolder, f'{caller_function_name}.html')
for i in range(1, 1000):
if not os.path.exists(raw_file) and not os.path.exists(md_file):
break # we can safely use the current filename
if not os.path.exists(raw_file) and not os.path.exists(html_file):
break
else:
raw_file = os.path.join(subfolder, f'{caller_function_name}_{i}.xml')
md_file = os.path.join(subfolder, f'{caller_function_name}_{i}.md')
html_file = os.path.join(subfolder, f'{caller_function_name}_{i}.html')

with open(raw_file, 'w') as f_raw, open(md_file, 'w') as f_md:
with open(raw_file, 'w') as f_raw, open(html_file, 'w') as f_html:
f_html.write('''
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Message Visualization</title>
<style>
body {
font-family: Arial, sans-serif;
line-height: 1.6;
padding: 20px;
background-color: #1e1e1e;
color: #e0e0e0;
}
h2 {
margin-top: 20px;
background-color: #0050a0;
color: #ffffff;
padding: 10px;
border-radius: 5px;
}
details {
margin-bottom: 10px;
background-color: #2a2a2a;
border: 1px solid #444;
border-radius: 5px;
padding: 10px;
}
summary {
cursor: pointer;
font-weight: bold;
background-color: #333333;
color: #ffffff;
padding: 5px;
border-radius: 3px;
margin-bottom: 10px;
}
</style>
</head>
<body>
<h2>Message Visualization</h2>
<div>
''')
total_length = 0
for message in messages:
content_raw = message.content
total_length += len(content_raw)
content_md = wrap_xml_tags_with_details(content_raw)
token_estimate_factor = 4 if use_openai else 3.5
message_tokens = int(len(content_raw) // token_estimate_factor)
message_header = f"{llm_type} {message.role} - {message_tokens} tokens - {int(total_length // token_estimate_factor)} total tokens"
f_raw.write(f"{message_header}\n{content_raw}\n\n")
f_md.write(f"## {message_header}\n\n{content_md}\n\n")
for i, message in enumerate(messages):
try:
content_raw = message.content
total_length += len(content_raw)
content_html = wrap_xml_tags_with_details(content_raw)
token_estimate_factor = 4 if use_openai else 3.5
message_tokens = int(len(content_raw) // token_estimate_factor)
message_header = f"{llm_type} {model_name} {message.role} - {message_tokens} tokens - {int(total_length // token_estimate_factor)} total tokens"
f_raw.write(f"{message_header}\n{content_raw}\n\n")
f_html.write(f'<details><summary>{html.escape(message_header)}</summary>\n<div>{content_html}</div>\n</details>\n\n')
except Exception as e:
logger.error(f"Error processing message: {e}")
f_raw.write(f"Error in message processing: {e}\nRaw content: {content_raw}\n\n")
f_html.write(f'<details><summary>Error in message processing</summary>\n<div><pre>{html.escape(str(e))}\n{html.escape(content_raw)}</pre></div>\n</details>\n\n')

f_html.write('</div></body></html>')

cwd = os.getcwd()
logger.info(f"Messages saved to {os.path.join(cwd, raw_file)} and {os.path.join(cwd, md_file)}")
logger.info(f"Messages saved to {os.path.join(cwd, raw_file)} and {os.path.join(cwd, html_file)}")
32 changes: 17 additions & 15 deletions tests/rerun_chat_modify_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,28 @@
import os

from sweepai.agents.modify import modify
from sweepai.config.server import GITHUB_APP_ID, GITHUB_APP_PEM
from sweepai.core.entities import FileChangeRequest
from sweepai.dataclasses.code_suggestions import CodeSuggestion
from sweepai.utils.github_utils import ClonedRepo, get_github_client, get_installation_id
from sweepai.utils.github_utils import MockClonedRepo


repo_name = os.environ["REPO_FULL_NAME"]
repo_full_name = os.environ["REPO_FULL_NAME"]
branch = os.environ["BRANCH"]
code_suggestions_path = os.environ["CODE_SUGGESTIONS_PATH"]

org_name, repo = repo_name.split("/")
installation_id = get_installation_id(org_name, GITHUB_APP_PEM, GITHUB_APP_ID)
user_token, g = get_github_client(installation_id=installation_id)
cloned_repo = ClonedRepo(
repo_name,
installation_id=installation_id,
token=user_token,
branch=branch
REPO_DIR = os.environ["REPO_DIR"]

# org_name, repo = repo_full_name.split("/")
# installation_id = get_installation_id(org_name, GITHUB_APP_PEM, GITHUB_APP_ID)
# user_token, g = get_github_client(installation_id=installation_id)
# cloned_repo = ClonedRepo(
# repo_full_name,
# installation_id=installation_id,
# token=user_token,
# branch=branch
# )
cloned_repo = MockClonedRepo(
_repo_dir=REPO_DIR,
repo_full_name=repo_full_name,
)

file_change_requests = []
Expand All @@ -44,7 +48,6 @@
)

try:
breakpoint()
for stateful_code_suggestions in modify.stream(
fcrs=file_change_requests,
request="",
Expand All @@ -53,5 +56,4 @@
):
pass
except Exception as e:
raise e

raise e

0 comments on commit e85e8c2

Please sign in to comment.