@@ -115,36 +115,42 @@ def chat_completion(
115
115
top_p = top_p ,
116
116
)
117
117
118
- is_json_output_required = response_format is not NOT_GIVEN and response_format .get ("type" ) in ['json_object' , 'json_schema' ]
118
+ is_json_output_required = response_format is not NOT_GIVEN and response_format .get ("type" ) in [
119
+ "json_object" ,
120
+ "json_schema" ,
121
+ ]
119
122
if model .startswith ("o1" ) and is_json_output_required :
120
123
return self .__o1_chat_completion (** input_kwargs )
121
124
122
125
return self .client .chat .completions .create (** NotGiven .remove_not_given (input_kwargs ))
123
126
124
127
def __o1_chat_completion (
125
- self ,
126
- messages : Iterable [ChatCompletionMessageParam ],
127
- model : str ,
128
- frequency_penalty : Optional [float ] | NotGiven = NOT_GIVEN ,
129
- logit_bias : Optional [Dict [str , int ]] | NotGiven = NOT_GIVEN ,
130
- logprobs : Optional [bool ] | NotGiven = NOT_GIVEN ,
131
- max_tokens : Optional [int ] | NotGiven = NOT_GIVEN ,
132
- n : Optional [int ] | NotGiven = NOT_GIVEN ,
133
- presence_penalty : Optional [float ] | NotGiven = NOT_GIVEN ,
134
- response_format : dict | completion_create_params .ResponseFormat | NotGiven = NOT_GIVEN ,
135
- stop : Union [Optional [str ], List [str ]] | NotGiven = NOT_GIVEN ,
136
- temperature : Optional [float ] | NotGiven = NOT_GIVEN ,
137
- top_logprobs : Optional [int ] | NotGiven = NOT_GIVEN ,
138
- top_p : Optional [float ] | NotGiven = NOT_GIVEN ,
128
+ self ,
129
+ messages : Iterable [ChatCompletionMessageParam ],
130
+ model : str ,
131
+ frequency_penalty : Optional [float ] | NotGiven = NOT_GIVEN ,
132
+ logit_bias : Optional [Dict [str , int ]] | NotGiven = NOT_GIVEN ,
133
+ logprobs : Optional [bool ] | NotGiven = NOT_GIVEN ,
134
+ max_tokens : Optional [int ] | NotGiven = NOT_GIVEN ,
135
+ n : Optional [int ] | NotGiven = NOT_GIVEN ,
136
+ presence_penalty : Optional [float ] | NotGiven = NOT_GIVEN ,
137
+ response_format : dict | completion_create_params .ResponseFormat | NotGiven = NOT_GIVEN ,
138
+ stop : Union [Optional [str ], List [str ]] | NotGiven = NOT_GIVEN ,
139
+ temperature : Optional [float ] | NotGiven = NOT_GIVEN ,
140
+ top_logprobs : Optional [int ] | NotGiven = NOT_GIVEN ,
141
+ top_p : Optional [float ] | NotGiven = NOT_GIVEN ,
139
142
):
140
143
o1_messages = list (messages )
141
- if response_format .get ("type" ) == ' json_schema' :
144
+ if response_format .get ("type" ) == " json_schema" :
142
145
last_msg_idx = len (o1_messages ) - 1
143
146
last_msg = o1_messages [last_msg_idx ]
144
- last_msg ["content" ] = last_msg ["content" ] + f"""
145
- Response with the following json schema in mind:
147
+ last_msg ["content" ] = (
148
+ last_msg ["content" ]
149
+ + f"""
150
+ Respond with the following json schema in mind:
146
151
{ response_format .get ('json_schema' )}
147
152
"""
153
+ )
148
154
o1_input_kwargs = dict (
149
155
messages = o1_messages ,
150
156
model = model ,
@@ -168,7 +174,7 @@ def __o1_chat_completion(
168
174
messages = [
169
175
{
170
176
"role" : "user" ,
171
- "content" : f"Given the following data, format it with the given response format: { o1_choice .message .content } "
177
+ "content" : f"Given the following data, format it with the given response format: { o1_choice .message .content } " ,
172
178
}
173
179
],
174
180
model = "gpt-4o-mini" ,
@@ -187,4 +193,4 @@ def __o1_chat_completion(
187
193
reconstructed_response .usage .total_tokens += o1_choices_parser_response .usage .total_tokens
188
194
reconstructed_response .choices [i ].message .content = o1_choices_parser_response .choices [0 ].message .content
189
195
190
- return reconstructed_response
196
+ return reconstructed_response
0 commit comments