Skip to content

Commit

Permalink
feat: optmize diff format
Browse files Browse the repository at this point in the history
  • Loading branch information
RaoHai committed Aug 28, 2024
1 parent ec43ec7 commit 755f195
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 101 deletions.
6 changes: 3 additions & 3 deletions server/agent/prompts/pull_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def get_role_prompt(repo_name: str, ref: str):
def get_pr_summary(repo_name: str, pull_number: int, title: str, description: str, file_diff: str):
return PULL_REQUEST_SUMMARY.format(
repo_name=repo_name,
pull_number={pull_number},
title={title},
description={description},
pull_number=pull_number,
title=title,
description=description,
file_diff=file_diff
)
1 change: 1 addition & 0 deletions server/agent/tools/pull_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def create_pr_summary(repo_name: str, pull_number: int, summary: str):
g = Github(auth=token)
repo = g.get_repo(repo_name)
pull_request = repo.get_pull(pull_number)
# print(f"create_pr_summary, pull_request={pull_request}, summary={summary}")
pull_request.create_issue_comment(summary)
return json.dumps([])
@tool
Expand Down
8 changes: 7 additions & 1 deletion server/event_handler/pull_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,13 @@ async def execute(self):
role_prompt = get_role_prompt(repo.full_name, pr.head.ref)
prompt = get_pr_summary(repo.full_name, pr.number, pr.title, pr.body, file_diff)

pr_content = f"{pr.title}:{pr.body}"
print(f"file_diff={file_diff}")
pr_content = f'''
### Pr Title
{pr.title}
### Pr Description
{pr.body}
'''

bot = Bot(
id=random_str(),
Expand Down
100 changes: 41 additions & 59 deletions server/tests/utils/test_path_to_hunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,16 @@
}
}
'''
basic_hunk = '''---new_hunk---
3: function greet(name) {
4: if (name) {
5: return `Hello, ${name}`;
6: } else {
7: return 'Hello, world!';
8: }
9: }
---old_hunk---
function greet(name) {
if (name) {
return 'Hello, ' + name;
} else {
return 'Hello, world!';
}
}'''
basic_hunk = '''NewFile OldFile SourceCode
0 0
3 3 function greet(name) {
4 4 if (name) {
5 - return 'Hello, ' + name;
5 + return `Hello, ${name}`;
6 6 } else {
7 7 return 'Hello, world!';
8 8 }
9 9 }'''

long_patch = '''
@@ -6,7 +6,7 @@
Expand Down Expand Up @@ -65,49 +59,37 @@ def chat_history_transform(self, messages: list[Message]):
async def run_stream_chat(self, input_data: ChatData) -> AsyncIterator[str]:
'''

long_hunk = '''---new_hunk---
6: from langchain.agents.format_scratchpad.openai_tools import (
7: format_to_openai_tool_messages,
8: )
9: from langchain_core.messages import AIMessage, FunctionMessage, HumanMessage, SystemMessage
10: from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
11: from langchain.prompts import MessagesPlaceholder
12: from langchain_core.prompts import ChatPromptTemplate
92: def chat_history_transform(self, messages: list[Message]):
93: transformed_messages = []
94: for message in messages:
95: match message.role:
96: case "user":
97: transformed_messages.append(HumanMessage(self.chat_model.parse_content(content=message.content)))
98: case "assistant":
99: transformed_messages.append(AIMessage(content=message.content))
100: case "system":
101: transformed_messages.append(SystemMessage(content=message.content))
102: case _:
103: transformed_messages.append(FunctionMessage(content=message.content))
104: return transformed_messages
105:
106: async def run_stream_chat(self, input_data: ChatData) -> AsyncIterator[str]:
---old_hunk---
from langchain.agents.format_scratchpad.openai_tools import (
format_to_openai_tool_messages,
)
from langchain_core.messages import AIMessage, FunctionMessage, HumanMessage
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
from langchain.prompts import MessagesPlaceholder
from langchain_core.prompts import ChatPromptTemplate
def chat_history_transform(self, messages: list[Message]):
transformed_messages = []
for message in messages:
if message.role == "user":
transformed_messages.append(HumanMessage(self.chat_model.parse_content(content=message.content)))
elif message.role == "assistant":
transformed_messages.append(AIMessage(content=message.content))
else:
transformed_messages.append(FunctionMessage(content=message.content))
return transformed_messages
async def run_stream_chat(self, input_data: ChatData) -> AsyncIterator[str]:'''
long_hunk = '''NewFile OldFile SourceCode
0 0
6 6 from langchain.agents.format_scratchpad.openai_tools import (
7 7 format_to_openai_tool_messages,
8 8 )
9 -from langchain_core.messages import AIMessage, FunctionMessage, HumanMessage
9 +from langchain_core.messages import AIMessage, FunctionMessage, HumanMessage, SystemMessage
10 10 from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
11 11 from langchain.prompts import MessagesPlaceholder
12 12 from langchain_core.prompts import ChatPromptTemplate
92 92 def chat_history_transform(self, messages: list[Message]):
93 93 transformed_messages = []
94 94 for message in messages:
95 - if message.role == "user":
96 - transformed_messages.append(HumanMessage(self.chat_model.parse_content(content=message.content)))
97 - elif message.role == "assistant":
98 - transformed_messages.append(AIMessage(content=message.content))
99 - else:
100 - transformed_messages.append(FunctionMessage(content=message.content))
95 + match message.role:
96 + case "user":
97 + transformed_messages.append(HumanMessage(self.chat_model.parse_content(content=message.content)))
98 + case "assistant":
99 + transformed_messages.append(AIMessage(content=message.content))
100 + case "system":
101 + transformed_messages.append(SystemMessage(content=message.content))
102 + case _:
103 + transformed_messages.append(FunctionMessage(content=message.content))
104 101 return transformed_messages
105 102
106 103 async def run_stream_chat(self, input_data: ChatData) -> AsyncIterator[str]:'''
class TestPathToHunk(TestCase):
def test_basic_covnert(self):
result = convert_patch_to_hunk(basic_patch)
Expand Down
61 changes: 23 additions & 38 deletions server/utils/path_to_hunk.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,26 @@
def convert_patch_to_hunk(patch: str) -> str:
if patch is None:
return ""
# 将 patch 按行拆分
lines = patch.strip().split('\n')

new_hunk = []
old_hunk = []

new_line_num = None
old_line_num = None

for line in lines:
import re

def convert_patch_to_hunk(diff):
old_line, new_line = 0, 0
result = []

for line in diff.splitlines():
if line.startswith('@@'):
# 获取 old 和 new 的起始行号
parts = line.split()
old_line_num = int(parts[1].split(',')[0][1:])
new_line_num = int(parts[2].split(',')[0][1:])
elif line.startswith('+'):
# 新增行
new_hunk.append(f"{new_line_num}: {line[1:]}")
new_line_num += 1
# 使用正则表达式提取旧文件和新文件的起始行号
match = re.search(r'@@ -(\d+),?\d* \+(\d+),?\d* @@', line)
if match:
old_line = int(match.group(1))
new_line = int(match.group(2))
continue # 跳过 @@ 行的输出
elif line.startswith('-'):
# 删除行
old_hunk.append(line[1:])
old_line_num += 1
result.append(f" {old_line:<5} {line}") # 仅旧文件有内容
old_line += 1
elif line.startswith('+'):
result.append(f"{new_line:<5} {line}") # 仅新文件有内容
new_line += 1
else:
# 不变行
new_hunk.append(f"{new_line_num}: {line}")
old_hunk.append(line)
new_line_num += 1
old_line_num += 1

# 格式化输出
result = []
result.append('---new_hunk---')
result.extend(new_hunk)
result.append('---old_hunk---')
result.extend(old_hunk)

return '\n'.join(result)
result.append(f"{new_line:<5} {old_line:<5} {line}") # 两边都有的内容
old_line += 1
new_line += 1

return "NewFile OldFile SourceCode \n" + "\n".join(result)

0 comments on commit 755f195

Please sign in to comment.