Skip to content

Commit

Permalink
lint and add o1 support with structured outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
CTY-git committed Nov 15, 2024
1 parent df45086 commit ff6812a
Show file tree
Hide file tree
Showing 6 changed files with 738 additions and 697 deletions.
73 changes: 73 additions & 0 deletions patchwork/common/client/llm/openai_.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class OpenAiLlmClient(LlmClient):
"gpt-3.5-turbo": 16_385,
"gpt-4 ": 8_192,
"gpt-4-turbo": 8_192,
"o1-preview": 128_000,
"o1-mini": 128_000,
"gpt-4o-mini": 128_000,
"gpt-4o": 128_000,
Expand Down Expand Up @@ -114,4 +115,76 @@ def chat_completion(
top_p=top_p,
)

is_json_output_required = response_format is not NOT_GIVEN and response_format.get("type") in ['json_object', 'json_schema']
if model.startswith("o1") and is_json_output_required:
return self.__o1_chat_completion(**input_kwargs)

return self.client.chat.completions.create(**NotGiven.remove_not_given(input_kwargs))

def __o1_chat_completion(
self,
messages: Iterable[ChatCompletionMessageParam],
model: str,
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: dict | completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
):
o1_messages = list(messages)
if response_format.get("type") == 'json_schema':
last_msg_idx = len(o1_messages) - 1
last_msg = o1_messages[last_msg_idx]
last_msg["content"] = last_msg["content"] + f"""
Response with the following json schema in mind:
{response_format.get('json_schema')}
"""
o1_input_kwargs = dict(
messages=o1_messages,
model=model,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
stop=stop,
temperature=temperature,
top_logprobs=top_logprobs,
top_p=top_p,
)

o1_response = self.client.chat.completions.create(**NotGiven.remove_not_given(o1_input_kwargs))

o1_choices_parser_responses = []
for o1_choice in o1_response.choices:
parser_input_kwargs = dict(
messages=[
{
"role": "user",
"content": f"Given the following data, format it with the given response format: {o1_choice.message.content}"
}
],
model="gpt-4o-mini",
max_tokens=max_tokens,
n=1,
response_format=response_format,
)
parser_response = self.client.beta.chat.completions.parse(**NotGiven.remove_not_given(parser_input_kwargs))
o1_choices_parser_responses.append(parser_response)

reconstructed_response = o1_response.model_copy()
for i, o1_choices_parser_response in enumerate(o1_choices_parser_responses):
if reconstructed_response.usage is not None:
reconstructed_response.usage.completion_tokens += o1_choices_parser_response.usage.completion_tokens
reconstructed_response.usage.prompt_tokens += o1_choices_parser_response.usage.prompt_tokens
reconstructed_response.usage.total_tokens += o1_choices_parser_response.usage.total_tokens
reconstructed_response.choices[i].message.content = o1_choices_parser_response.choices[0].message.content

return reconstructed_response
20 changes: 8 additions & 12 deletions patchwork/patchflows/GenerateUnitTests/GenerateUnitTests.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,23 @@
import json
from pathlib import Path

import yaml

from patchwork.common.utils.step_typing import validate_steps_with_inputs
from patchwork.step import Step
from patchwork.steps import (
LLM,
CallCode2Prompt,
ModifyCode,
PR
)
from patchwork.steps import LLM, PR, CallCode2Prompt, ModifyCode

_DEFAULT_INPUT_FILE = Path(__file__).parent / "defaults.yml"
_DEFAULT_PROMPT_JSON = Path(__file__).parent / "default_prompt.json"


class GenerateUnitTests(Step):
def __init__(self, inputs):
super().__init__(inputs)

final_inputs = yaml.safe_load(_DEFAULT_INPUT_FILE.read_text())
if final_inputs is None:
final_inputs = {}

final_inputs.update(inputs)

final_inputs["prompt_id"] = "GenerateUnitTests"
Expand All @@ -37,16 +33,16 @@ def __init__(self, inputs):
final_inputs["branch_prefix"] = f"{self.__class__.__name__.lower()}-"

validate_steps_with_inputs(
set(final_inputs.keys()).union({"prompt_values","files_to_patch"}), LLM, CallCode2Prompt,ModifyCode,PR
set(final_inputs.keys()).union({"prompt_values", "files_to_patch"}), LLM, CallCode2Prompt, ModifyCode, PR
)
self.inputs = final_inputs

def run(self):
outputs = CallCode2Prompt(self.inputs).run()
new_file_name = f"test_file.{self.inputs['test_file_extension']}"
new_file_path = Path(outputs['uri']).with_name(new_file_name)
Path(outputs['uri']).rename(new_file_path)
outputs['uri'] = str(new_file_path)
new_file_path = Path(outputs["uri"]).with_name(new_file_name)
Path(outputs["uri"]).rename(new_file_path)
outputs["uri"] = str(new_file_path)
self.inputs["response_partitions"] = {"patch": ["```", "\n", "```"]}
self.inputs["files_to_patch"] = self.inputs["prompt_values"] = [outputs]
outputs = LLM(self.inputs).run()
Expand Down
12 changes: 10 additions & 2 deletions patchwork/patchflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@
from .DependencyUpgrade.DependencyUpgrade import DependencyUpgrade
from .GenerateDocstring.GenerateDocstring import GenerateDocstring
from .GenerateREADME.GenerateREADME import GenerateREADME
from .GenerateUnitTests.GenerateUnitTests import GenerateUnitTests
from .PRReview.PRReview import PRReview
from .ResolveIssue.ResolveIssue import ResolveIssue
from .GenerateUnitTests.GenerateUnitTests import GenerateUnitTests

__all__ = ["AutoFix", "DependencyUpgrade", "GenerateREADME", "PRReview", "ResolveIssue", "GenerateDocstring", "GenerateUnitTests"]
__all__ = [
"AutoFix",
"DependencyUpgrade",
"GenerateREADME",
"PRReview",
"ResolveIssue",
"GenerateDocstring",
"GenerateUnitTests",
]
2 changes: 1 addition & 1 deletion patchwork/steps/ModifyCode/ModifyCode.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def run(self) -> dict:

if new_code is None:
continue

replace_code_in_file(uri, start_line, end_line, new_code)
modified_code_file = dict(path=uri, start_line=start_line, end_line=end_line, **extracted_response)
modified_code_files.append(modified_code_file)
Expand Down
Loading

0 comments on commit ff6812a

Please sign in to comment.