Skip to content

Commit 5d684db

Browse files
committed
lint
1 parent ff6812a commit 5d684db

File tree

1 file changed

+26
-20
lines changed

1 file changed

+26
-20
lines changed

patchwork/common/client/llm/openai_.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -115,36 +115,42 @@ def chat_completion(
115115
top_p=top_p,
116116
)
117117

118-
is_json_output_required = response_format is not NOT_GIVEN and response_format.get("type") in ['json_object', 'json_schema']
118+
is_json_output_required = response_format is not NOT_GIVEN and response_format.get("type") in [
119+
"json_object",
120+
"json_schema",
121+
]
119122
if model.startswith("o1") and is_json_output_required:
120123
return self.__o1_chat_completion(**input_kwargs)
121124

122125
return self.client.chat.completions.create(**NotGiven.remove_not_given(input_kwargs))
123126

124127
def __o1_chat_completion(
125-
self,
126-
messages: Iterable[ChatCompletionMessageParam],
127-
model: str,
128-
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
129-
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
130-
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
131-
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
132-
n: Optional[int] | NotGiven = NOT_GIVEN,
133-
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
134-
response_format: dict | completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
135-
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
136-
temperature: Optional[float] | NotGiven = NOT_GIVEN,
137-
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
138-
top_p: Optional[float] | NotGiven = NOT_GIVEN,
128+
self,
129+
messages: Iterable[ChatCompletionMessageParam],
130+
model: str,
131+
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
132+
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
133+
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
134+
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
135+
n: Optional[int] | NotGiven = NOT_GIVEN,
136+
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
137+
response_format: dict | completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
138+
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
139+
temperature: Optional[float] | NotGiven = NOT_GIVEN,
140+
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
141+
top_p: Optional[float] | NotGiven = NOT_GIVEN,
139142
):
140143
o1_messages = list(messages)
141-
if response_format.get("type") == 'json_schema':
144+
if response_format.get("type") == "json_schema":
142145
last_msg_idx = len(o1_messages) - 1
143146
last_msg = o1_messages[last_msg_idx]
144-
last_msg["content"] = last_msg["content"] + f"""
145-
Response with the following json schema in mind:
147+
last_msg["content"] = (
148+
last_msg["content"]
149+
+ f"""
150+
Respond with the following json schema in mind:
146151
{response_format.get('json_schema')}
147152
"""
153+
)
148154
o1_input_kwargs = dict(
149155
messages=o1_messages,
150156
model=model,
@@ -168,7 +174,7 @@ def __o1_chat_completion(
168174
messages=[
169175
{
170176
"role": "user",
171-
"content": f"Given the following data, format it with the given response format: {o1_choice.message.content}"
177+
"content": f"Given the following data, format it with the given response format: {o1_choice.message.content}",
172178
}
173179
],
174180
model="gpt-4o-mini",
@@ -187,4 +193,4 @@ def __o1_chat_completion(
187193
reconstructed_response.usage.total_tokens += o1_choices_parser_response.usage.total_tokens
188194
reconstructed_response.choices[i].message.content = o1_choices_parser_response.choices[0].message.content
189195

190-
return reconstructed_response
196+
return reconstructed_response

0 commit comments

Comments
 (0)