From 755f1950db28b4c7f6bfbccee50c1a5dedc2b8d9 Mon Sep 17 00:00:00 2001 From: "raoha.rh" Date: Wed, 28 Aug 2024 14:20:20 +0800 Subject: [PATCH] feat: optmize diff format --- server/agent/prompts/pull_request.py | 6 +- server/agent/tools/pull_request.py | 1 + server/event_handler/pull_request.py | 8 +- server/tests/utils/test_path_to_hunk.py | 100 ++++++++++-------------- server/utils/path_to_hunk.py | 61 ++++++--------- 5 files changed, 75 insertions(+), 101 deletions(-) diff --git a/server/agent/prompts/pull_request.py b/server/agent/prompts/pull_request.py index c0995018..cb2e3722 100644 --- a/server/agent/prompts/pull_request.py +++ b/server/agent/prompts/pull_request.py @@ -49,8 +49,8 @@ def get_role_prompt(repo_name: str, ref: str): def get_pr_summary(repo_name: str, pull_number: int, title: str, description: str, file_diff: str): return PULL_REQUEST_SUMMARY.format( repo_name=repo_name, - pull_number={pull_number}, - title={title}, - description={description}, + pull_number=pull_number, + title=title, + description=description, file_diff=file_diff ) \ No newline at end of file diff --git a/server/agent/tools/pull_request.py b/server/agent/tools/pull_request.py index 6222f832..9c4c7ec3 100644 --- a/server/agent/tools/pull_request.py +++ b/server/agent/tools/pull_request.py @@ -43,6 +43,7 @@ def create_pr_summary(repo_name: str, pull_number: int, summary: str): g = Github(auth=token) repo = g.get_repo(repo_name) pull_request = repo.get_pull(pull_number) + # print(f"create_pr_summary, pull_request={pull_request}, summary={summary}") pull_request.create_issue_comment(summary) return json.dumps([]) @tool diff --git a/server/event_handler/pull_request.py b/server/event_handler/pull_request.py index 96fd9e49..7fb7ca9d 100644 --- a/server/event_handler/pull_request.py +++ b/server/event_handler/pull_request.py @@ -79,7 +79,13 @@ async def execute(self): role_prompt = get_role_prompt(repo.full_name, pr.head.ref) prompt = get_pr_summary(repo.full_name, pr.number, pr.title, pr.body, file_diff) - pr_content = f"{pr.title}:{pr.body}" + print(f"file_diff={file_diff}") + pr_content = f''' + ### Pr Title + {pr.title} + ### Pr Description + {pr.body} + ''' bot = Bot( id=random_str(), diff --git a/server/tests/utils/test_path_to_hunk.py b/server/tests/utils/test_path_to_hunk.py index 77c88f3a..33c3ed35 100644 --- a/server/tests/utils/test_path_to_hunk.py +++ b/server/tests/utils/test_path_to_hunk.py @@ -14,22 +14,16 @@ } } ''' -basic_hunk = '''---new_hunk--- -3: function greet(name) { -4: if (name) { -5: return `Hello, ${name}`; -6: } else { -7: return 'Hello, world!'; -8: } -9: } ----old_hunk--- - function greet(name) { - if (name) { - return 'Hello, ' + name; - } else { - return 'Hello, world!'; - } - }''' +basic_hunk = '''NewFile OldFile SourceCode +0 0 +3 3 function greet(name) { +4 4 if (name) { + 5 - return 'Hello, ' + name; +5 + return `Hello, ${name}`; +6 6 } else { +7 7 return 'Hello, world!'; +8 8 } +9 9 }''' long_patch = ''' @@ -6,7 +6,7 @@ @@ -65,49 +59,37 @@ def chat_history_transform(self, messages: list[Message]): async def run_stream_chat(self, input_data: ChatData) -> AsyncIterator[str]: ''' -long_hunk = '''---new_hunk--- -6: from langchain.agents.format_scratchpad.openai_tools import ( -7: format_to_openai_tool_messages, -8: ) -9: from langchain_core.messages import AIMessage, FunctionMessage, HumanMessage, SystemMessage -10: from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser -11: from langchain.prompts import MessagesPlaceholder -12: from langchain_core.prompts import ChatPromptTemplate -92: def chat_history_transform(self, messages: list[Message]): -93: transformed_messages = [] -94: for message in messages: -95: match message.role: -96: case "user": -97: transformed_messages.append(HumanMessage(self.chat_model.parse_content(content=message.content))) -98: case "assistant": -99: transformed_messages.append(AIMessage(content=message.content)) -100: case "system": -101: transformed_messages.append(SystemMessage(content=message.content)) -102: case _: -103: transformed_messages.append(FunctionMessage(content=message.content)) -104: return transformed_messages -105: -106: async def run_stream_chat(self, input_data: ChatData) -> AsyncIterator[str]: ----old_hunk--- - from langchain.agents.format_scratchpad.openai_tools import ( - format_to_openai_tool_messages, - ) -from langchain_core.messages import AIMessage, FunctionMessage, HumanMessage - from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser - from langchain.prompts import MessagesPlaceholder - from langchain_core.prompts import ChatPromptTemplate - def chat_history_transform(self, messages: list[Message]): - transformed_messages = [] - for message in messages: - if message.role == "user": - transformed_messages.append(HumanMessage(self.chat_model.parse_content(content=message.content))) - elif message.role == "assistant": - transformed_messages.append(AIMessage(content=message.content)) - else: - transformed_messages.append(FunctionMessage(content=message.content)) - return transformed_messages - - async def run_stream_chat(self, input_data: ChatData) -> AsyncIterator[str]:''' +long_hunk = '''NewFile OldFile SourceCode +0 0 +6 6 from langchain.agents.format_scratchpad.openai_tools import ( +7 7 format_to_openai_tool_messages, +8 8 ) + 9 -from langchain_core.messages import AIMessage, FunctionMessage, HumanMessage +9 +from langchain_core.messages import AIMessage, FunctionMessage, HumanMessage, SystemMessage +10 10 from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser +11 11 from langchain.prompts import MessagesPlaceholder +12 12 from langchain_core.prompts import ChatPromptTemplate +92 92 def chat_history_transform(self, messages: list[Message]): +93 93 transformed_messages = [] +94 94 for message in messages: + 95 - if message.role == "user": + 96 - transformed_messages.append(HumanMessage(self.chat_model.parse_content(content=message.content))) + 97 - elif message.role == "assistant": + 98 - transformed_messages.append(AIMessage(content=message.content)) + 99 - else: + 100 - transformed_messages.append(FunctionMessage(content=message.content)) +95 + match message.role: +96 + case "user": +97 + transformed_messages.append(HumanMessage(self.chat_model.parse_content(content=message.content))) +98 + case "assistant": +99 + transformed_messages.append(AIMessage(content=message.content)) +100 + case "system": +101 + transformed_messages.append(SystemMessage(content=message.content)) +102 + case _: +103 + transformed_messages.append(FunctionMessage(content=message.content)) +104 101 return transformed_messages +105 102 +106 103 async def run_stream_chat(self, input_data: ChatData) -> AsyncIterator[str]:''' class TestPathToHunk(TestCase): def test_basic_covnert(self): result = convert_patch_to_hunk(basic_patch) diff --git a/server/utils/path_to_hunk.py b/server/utils/path_to_hunk.py index 07fb088e..36695c2f 100644 --- a/server/utils/path_to_hunk.py +++ b/server/utils/path_to_hunk.py @@ -1,41 +1,26 @@ -def convert_patch_to_hunk(patch: str) -> str: - if patch is None: - return "" - # 将 patch 按行拆分 - lines = patch.strip().split('\n') - - new_hunk = [] - old_hunk = [] - - new_line_num = None - old_line_num = None - - for line in lines: +import re + +def convert_patch_to_hunk(diff): + old_line, new_line = 0, 0 + result = [] + + for line in diff.splitlines(): if line.startswith('@@'): - # 获取 old 和 new 的起始行号 - parts = line.split() - old_line_num = int(parts[1].split(',')[0][1:]) - new_line_num = int(parts[2].split(',')[0][1:]) - elif line.startswith('+'): - # 新增行 - new_hunk.append(f"{new_line_num}: {line[1:]}") - new_line_num += 1 + # 使用正则表达式提取旧文件和新文件的起始行号 + match = re.search(r'@@ -(\d+),?\d* \+(\d+),?\d* @@', line) + if match: + old_line = int(match.group(1)) + new_line = int(match.group(2)) + continue # 跳过 @@ 行的输出 elif line.startswith('-'): - # 删除行 - old_hunk.append(line[1:]) - old_line_num += 1 + result.append(f" {old_line:<5} {line}") # 仅旧文件有内容 + old_line += 1 + elif line.startswith('+'): + result.append(f"{new_line:<5} {line}") # 仅新文件有内容 + new_line += 1 else: - # 不变行 - new_hunk.append(f"{new_line_num}: {line}") - old_hunk.append(line) - new_line_num += 1 - old_line_num += 1 - - # 格式化输出 - result = [] - result.append('---new_hunk---') - result.extend(new_hunk) - result.append('---old_hunk---') - result.extend(old_hunk) - - return '\n'.join(result) \ No newline at end of file + result.append(f"{new_line:<5} {old_line:<5} {line}") # 两边都有的内容 + old_line += 1 + new_line += 1 + + return "NewFile OldFile SourceCode \n" + "\n".join(result) \ No newline at end of file