Skip to content

Commit

Permalink
Prompting fixes (#4066)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinlu1248 authored Jun 21, 2024
2 parents 37e08d0 + e32adf3 commit 5d80b91
Show file tree
Hide file tree
Showing 7 changed files with 399 additions and 112 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,4 @@ disable=[
]

[tool.ruff]
select = ["T100"]
lint.select = ["T100"]
24 changes: 0 additions & 24 deletions sweepai/api_test.py

This file was deleted.

173 changes: 129 additions & 44 deletions sweepai/chat/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import wraps
import time
import traceback
from typing import Any, Callable
import uuid
Expand All @@ -17,19 +18,23 @@

from sweepai.agents.modify_utils import get_error_message_dict, validate_and_parse_function_call
from sweepai.agents.search_agent import extract_xml_tag
from sweepai.chat.search_prompts import relevant_snippets_message, relevant_snippet_template, anthropic_system_message, function_response, pr_format, relevant_snippets_message_for_pr, openai_system_message, query_optimizer_system_prompt, query_optimizer_user_prompt, anthropic_format_message
from sweepai.chat.search_prompts import relevant_snippets_message, relevant_snippet_template, anthropic_system_message, function_response, pr_format, relevant_snippets_message_for_pr, openai_system_message, query_optimizer_system_prompt, query_optimizer_user_prompt, openai_format_message, anthropic_format_message
from sweepai.config.client import SweepConfig
from sweepai.config.server import CACHE_DIRECTORY, GITHUB_APP_ID, GITHUB_APP_PEM
from sweepai.config.server import CACHE_DIRECTORY, DOCKER_ENABLED, GITHUB_APP_ID, GITHUB_APP_PEM
from sweepai.core.chat import ChatGPT, call_llm
from sweepai.core.entities import FileChangeRequest, Message, Snippet
from sweepai.core.pull_request_bot import get_pr_summary_for_chat
from sweepai.core.review_utils import split_diff_into_patches
from sweepai.dataclasses.check_status import CheckStatus, gha_to_check_status, gha_to_message
from sweepai.dataclasses.code_suggestions import CodeSuggestion
from sweepai.handlers.on_check_suite import get_failing_docker_logs
from sweepai.handlers.on_failing_github_actions import handle_failing_github_actions
from sweepai.utils.convert_openai_anthropic import AnthropicFunctionCall
from sweepai.utils.github_utils import ClonedRepo, CustomGithub, MockClonedRepo, clean_branch_name, commit_multi_file_changes, create_branch, get_github_client, get_installation_id
from sweepai.utils.event_logger import posthog
from sweepai.utils.str_utils import extract_objects_from_string, get_hash
from sweepai.utils.streamable_functions import streamable
from sweepai.utils.ticket_rendering_utils import get_failing_gha_logs
from sweepai.utils.ticket_utils import prep_snippets
from sweepai.utils.timer import Timer

Expand Down Expand Up @@ -534,12 +539,16 @@ def chat_codebase_stream(
role="user"
),
*messages[:-1],
Message(
content=anthropic_format_message,
role="user",
)
]

if len(messages) <= 2:
chat_gpt.messages.append(
Message(
content=openai_format_message if use_openai else anthropic_format_message,
role="user",
)
)

def stream_state(
initial_user_message: str,
snippets: list[Snippet],
Expand Down Expand Up @@ -585,7 +594,7 @@ def stream_state(
if not token:
continue
result_string += token
if len(result_string) < 30:
if len(result_string) < 50:
continue
current_string, *_ = result_string.split("<function_call>")
if "<analysis>" in current_string:
Expand Down Expand Up @@ -712,6 +721,15 @@ def stream_state(

message_content = new_messages[-1].content
code_suggestions_raw, _ = extract_objects_from_string(message_content, "code_change", ["file_path", "original_code", "new_code"])
# combine additions of the same file together
new_code_suggestions_raw = []
for code_suggestion in code_suggestions_raw:
fcr = next((fcr for fcr in new_code_suggestions_raw if fcr["file_path"] == code_suggestion["file_path"] and fcr["original_code"] == code_suggestion["original_code"] == ""), None)
if fcr:
fcr["new_code"] += "\n\n" + code_suggestion["new_code"].lstrip("\n")
else:
new_code_suggestions_raw.append(code_suggestion)
code_suggestions_raw = new_code_suggestions_raw
if code_suggestions_raw:
new_messages[-1].annotations = {
"codeSuggestions": [
Expand Down Expand Up @@ -812,43 +830,6 @@ def handle_function_call(function_call: AnthropicFunctionCall, repo_name: str, s
else:
return "ERROR\n\nTool not found.", []

# @app.post("/backend/autofix")
# async def autofix(
# repo_name: str = Body(...),
# code_suggestions: list[CodeSuggestion] = Body(...),
# access_token: str = Depends(get_token_header)
# ):
# with Timer() as timer:
# g = get_authenticated_github_client(repo_name, access_token)
# logger.debug(f"Getting authenticated GitHub client took {timer.time_elapsed} seconds")
# if not g:
# return {"success": False, "error": "The repository may not exist or you may not have access to this repository."}

# file_change_requests = []
# for code_suggestion in code_suggestions:
# file_change_requests.append(FileChangeRequest(
# filename=code_suggestion.file_path,
# instructions=f"<original_code>\n{code_suggestion.original_code}\n</original_code>\n<new_code>\n{code_suggestion.new_code}\n</new_code>",
# change_type="modify",
# ))

# org_name, repo_name_ = repo_name.split("/")
# cloned_repo = MockClonedRepo(
# f"{repo_cache}/{repo_name_}",
# repo_name,
# token=access_token
# )

# error_messages, error_indices = get_error_message_formatted(
# file_change_requests=file_change_requests,
# cloned_repo=cloned_repo,
# )

# return {
# "success": True,
# "error_messages": error_messages
# }

@app.post("/backend/autofix")
async def autofix(
repo_name: str = Body(...),
Expand Down Expand Up @@ -908,6 +889,9 @@ def stream():

return StreamingResponse(stream())

# TODO: refactor all the PR stuff together
# TODO: refactor all the github client stuff

@app.post("/backend/create_pull")
async def create_pull(
repo_name: str = Body(...),
Expand Down Expand Up @@ -1082,6 +1066,107 @@ async def create_pull_metadata(
"branch": clean_branch_name(title),
}

@app.post("/backend/validate_pull")
async def validate_pull(
repo_name: str = Body(...),
pull_request_number: int = Body(...),
access_token: str = Depends(get_token_header)
):
with Timer() as timer:
g = get_authenticated_github_client(repo_name, access_token)
logger.debug(f"Getting authenticated GitHub client took {timer.time_elapsed} seconds")
if not g:
return {"success": False, "error": "The repository may not exist or you may not have access to this repository."}

org_name, repo_name_ = repo_name.split("/")
repo = g.get_repo(repo_name)
pull_request = repo.get_pull(int(pull_request_number))

cloned_repo = get_cloned_repo(repo_name, access_token, pull_request.head.ref)
installation_id = get_installation_id(org_name, GITHUB_APP_PEM, GITHUB_APP_ID)
current_commit = pull_request.head.sha

def stream():
try:
all_statuses: list[CheckStatus] = []
docker_statuses: list[CheckStatus] = []
if DOCKER_ENABLED:
for docker_statuses in get_failing_docker_logs.stream(cloned_repo):
yield json.dumps(docker_statuses)
any_failed = not all_statuses or any(status["succeeded"] is False for status in docker_statuses)
if not any_failed:
for _ in range(60 * 6):
runs = list(repo.get_commit(current_commit).get_check_runs())
suite_runs = list(repo.get_workflow_runs(branch=pull_request.head.ref, head_sha=pull_request.head.sha))
suite_statuses: list[CheckStatus] = [
{
"message": gha_to_message[run.status],
"stdout": "", # TODO, fille this in
"succeeded": gha_to_check_status[run.status],
"status": gha_to_check_status[run.status],
"llm_message": "",
"container_name": run.name,
}
for run in sorted(suite_runs, key=lambda run: run.name)
]
yield json.dumps(docker_statuses + suite_statuses)
if all([run.conclusion in ["success", "skipped", None] and \
run.status not in ["in_progress", "waiting", "pending", "requested", "queued"] for run in runs]):
logger.info("All Github Actions have succeeded or have no result.")
break
if not any([run.conclusion == "failure" for run in runs]):
time.sleep(10)
continue
for i, run in enumerate(sorted(suite_runs, key=lambda run: run.name)):
if run.conclusion == "failure":
failed_logs = get_failing_gha_logs(
[run],
installation_id,
)
suite_statuses[i]["stdout"] = failed_logs
suite_statuses[i]["succeeded"] = False
suite_statuses[i]["status"] = "failure"
suite_statuses[i]["llm_message"] = failed_logs
yield json.dumps(docker_statuses + suite_statuses)
logger.info("Github Actions failed!")
break
except Exception as e:
yield json.dumps({"error": str(e)})
raise e

return StreamingResponse(stream())

@app.post("/backend/fix_pull")
async def fix_pull(
repo_name: str = Body(...),
pull_request_number: int = Body(...),
problem_statement: str = Body(...),
failing_logs: str = Body(...),
snippets: list[Snippet] = Body(...),
access_token: str = Depends(get_token_header)
):
"""
Temporarily disabled
"""
with Timer() as timer:
g = get_authenticated_github_client(repo_name, access_token)
logger.debug(f"Getting authenticated GitHub client took {timer.time_elapsed} seconds")
if not g:
return {"success": False, "error": "The repository may not exist or you may not have access to this repository."}

org_name, repo_name_ = repo_name.split("/")
commit = handle_failing_github_actions(
problem_statement=problem_statement,
failing_logs=failing_logs,
repo=g.get_repo(repo_name),
pull_request=g.get_repo(repo_name).get_pull(pull_request_number),
user_token=access_token,
username=Github(access_token).get_user().login,
installation_id=get_installation_id(org_name, GITHUB_APP_PEM, GITHUB_APP_ID),
)

return commit

@app.post("/backend/messages/save")
async def write_message_to_disk(
repo_name: str = Body(...),
Expand Down
7 changes: 4 additions & 3 deletions sweepai/chat/search_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
When showing relevant examples of code, only show MINIMAL excerpts of code that address the user's question. Do NOT copy the whole file, but only the lines that are relevant to the user's question.
When suggesting code changes, you add <code_change> blocks inside the <user_response></user_response> tags.
To suggest code changes, first list each section of each file that you would like to change. Then for each section, write a <code_change> block for that change. These changes should be atomic -- to change multiple parts of the file, you MUST write multiple separate <code_change> blocks.
</user_response>"""

openai_format_message = """You MUST follow the following XML-based format, including <user_response> and </user_respose> tags:
Expand Down Expand Up @@ -124,7 +124,8 @@
- Only show code as supplementary evidence or to enhance the explanations. When doing so, only show MINIMAL excerpts of code that address the user's question. Do NOT copy the whole file, but only the lines that are relevant to the user's question. Be concise, it's hard for a user to read entire files worth of content.
- Use markdown for your responses, using headers where applicable to improve clarity and lists to enumerate examples.
- Wherever possible, you should suggest code changes. To do so, you must add <code_change> blocks to the <user_response> block following the format provided below.
- Code changes must be atomic. Each code change must be in its own block, unless they are contiguous changes in the same file.
- Code changes must be atomic. Each code change must be in its own block, unless they are contiguous changes in the same file.
- To change multiple parts of the file, write separate <code_change> blocks.
# <code_change> Format
First, indicate whether you want to modify an existing file or create a new file, then write in the following format:
Expand All @@ -134,7 +135,7 @@
path/to/file.py
</file_path>
<original_code>
Copy the original section of code from path/to/file.py. This is the section of code that you will change. Paraphrasing, abbreviating the source code, or placeholder comments such as "# rest of code" are NEVER PERMITTED.
Copy the original section of code from path/to/file.py. This is the section of code that you will change. Paraphrasing, abbreviating the source code, or placeholder comments such as "# rest of code" are NEVER PERMITTED. Leave empty for creating new files.
</original_code>
<new_code>
New code to replace <original_code> with.
Expand Down
55 changes: 55 additions & 0 deletions sweepai/dataclasses/check_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Optional, Literal, TypedDict

class CheckStatus(TypedDict):
message: str
stdout: str
succeeded: Optional[bool]
status: Literal["pending", "running", "success", "failure", "cancelled"]
llm_message: str
container_name: str

# Status can be one of: completed, action_required, cancelled, failure, neutral, skipped, stale, success, timed_out, in_progress, queued, requested, waiting, pending.
gha_to_check_status = {
"completed": "success",
"action_required": "success",
"cancelled": "cancelled",
"failure": "failure",
"neutral": "success",
"skipped": "success",
"stale": "success",
"success": "success",
"timed_out": "failure",
"in_progress": "running",
"queued": "pending",
"requested": "pending",
"waiting": "pending",
"pending": "pending",
}

gha_to_succeeded = {
"completed": True,
"action_required": False,
"cancelled": False,
"failure": False,
"neutral": True,
"skipped": True,
"stale": True,
"success": True,
}

gha_to_message = {
"completed": "Github Action completed",
"action_required": "Github Action action required",
"cancelled": "Github Action cancelled",
"failure": "Github Action failed",
"neutral": "Github Action neutral",
"skipped": "Github Action skipped",
"stale": "Github Action stale",
"success": "Github Action succeeded",
"timed_out": "Github Action timed out",
"in_progress": "Github Action in progress",
"queued": "Github Action queued",
"requested": "Github Action requested",
"waiting": "Github Action waiting",
"pending": "Github Action pending",
}
Loading

0 comments on commit 5d80b91

Please sign in to comment.