Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
CTY-git committed Nov 15, 2024
1 parent ff6812a commit 5d684db
Showing 1 changed file with 26 additions and 20 deletions.
46 changes: 26 additions & 20 deletions patchwork/common/client/llm/openai_.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,36 +115,42 @@ def chat_completion(
top_p=top_p,
)

is_json_output_required = response_format is not NOT_GIVEN and response_format.get("type") in ['json_object', 'json_schema']
is_json_output_required = response_format is not NOT_GIVEN and response_format.get("type") in [
"json_object",
"json_schema",
]
if model.startswith("o1") and is_json_output_required:
return self.__o1_chat_completion(**input_kwargs)

return self.client.chat.completions.create(**NotGiven.remove_not_given(input_kwargs))

def __o1_chat_completion(
self,
messages: Iterable[ChatCompletionMessageParam],
model: str,
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: dict | completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
self,
messages: Iterable[ChatCompletionMessageParam],
model: str,
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: dict | completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
):
o1_messages = list(messages)
if response_format.get("type") == 'json_schema':
if response_format.get("type") == "json_schema":
last_msg_idx = len(o1_messages) - 1
last_msg = o1_messages[last_msg_idx]
last_msg["content"] = last_msg["content"] + f"""
Response with the following json schema in mind:
last_msg["content"] = (
last_msg["content"]
+ f"""
Respond with the following json schema in mind:
{response_format.get('json_schema')}
"""
)
o1_input_kwargs = dict(
messages=o1_messages,
model=model,
Expand All @@ -168,7 +174,7 @@ def __o1_chat_completion(
messages=[
{
"role": "user",
"content": f"Given the following data, format it with the given response format: {o1_choice.message.content}"
"content": f"Given the following data, format it with the given response format: {o1_choice.message.content}",
}
],
model="gpt-4o-mini",
Expand All @@ -187,4 +193,4 @@ def __o1_chat_completion(
reconstructed_response.usage.total_tokens += o1_choices_parser_response.usage.total_tokens
reconstructed_response.choices[i].message.content = o1_choices_parser_response.choices[0].message.content

return reconstructed_response
return reconstructed_response

0 comments on commit 5d684db

Please sign in to comment.