|
3 | 3 | import re
|
4 | 4 | import shutil
|
5 | 5 | from tempfile import TemporaryDirectory
|
6 |
| -from typing import Dict, List |
| 6 | +from typing import Dict, List, Optional |
7 | 7 |
|
8 | 8 | import jsonschema
|
9 | 9 | import openai # use the official client for correctness check
|
@@ -268,118 +268,27 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
|
268 | 268 | assert len(completion.choices[0].text) >= 0
|
269 | 269 |
|
270 | 270 |
|
271 |
| -@pytest.mark.asyncio |
272 |
| -@pytest.mark.parametrize( |
273 |
| - "model_name, prompt_logprobs", |
274 |
| - [(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)], |
275 |
| -) |
276 |
| -async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI, |
277 |
| - model_name: str, prompt_logprobs: int): |
278 |
| - params: Dict = { |
279 |
| - "messages": [{ |
280 |
| - "role": "system", |
281 |
| - "content": "You are a helpful assistant." |
282 |
| - }, { |
283 |
| - "role": "user", |
284 |
| - "content": "Who won the world series in 2020?" |
285 |
| - }, { |
286 |
| - "role": |
287 |
| - "assistant", |
288 |
| - "content": |
289 |
| - "The Los Angeles Dodgers won the World Series in 2020." |
290 |
| - }, { |
291 |
| - "role": "user", |
292 |
| - "content": "Where was it played?" |
293 |
| - }], |
294 |
| - "model": |
295 |
| - model_name |
296 |
| - } |
297 |
| - |
298 |
| - if prompt_logprobs is not None: |
299 |
| - params["extra_body"] = {"prompt_logprobs": prompt_logprobs} |
300 |
| - |
301 |
| - if prompt_logprobs and prompt_logprobs < 0: |
302 |
| - with pytest.raises(BadRequestError) as err_info: |
303 |
| - await client.chat.completions.create(**params) |
304 |
| - expected_err_string = ( |
305 |
| - "Error code: 400 - {'object': 'error', 'message': " |
306 |
| - "'Prompt_logprobs set to invalid negative value: -1'," |
307 |
| - " 'type': 'BadRequestError', 'param': None, 'code': 400}") |
308 |
| - assert str(err_info.value) == expected_err_string |
309 |
| - else: |
310 |
| - completion = await client.chat.completions.create(**params) |
311 |
| - if prompt_logprobs and prompt_logprobs > 0: |
312 |
| - assert completion.prompt_logprobs is not None |
313 |
| - assert len(completion.prompt_logprobs) > 0 |
314 |
| - else: |
315 |
| - assert completion.prompt_logprobs is None |
316 |
| - |
317 |
| - |
318 |
| -@pytest.mark.asyncio |
319 |
| -@pytest.mark.parametrize( |
320 |
| - "model_name", |
321 |
| - [MODEL_NAME], |
322 |
| -) |
323 |
| -async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI, |
324 |
| - model_name: str): |
325 |
| - params: Dict = { |
326 |
| - "messages": [{ |
327 |
| - "role": "system", |
328 |
| - "content": "You are a helpful assistant." |
329 |
| - }, { |
330 |
| - "role": "user", |
331 |
| - "content": "Who won the world series in 2020?" |
332 |
| - }, { |
333 |
| - "role": |
334 |
| - "assistant", |
335 |
| - "content": |
336 |
| - "The Los Angeles Dodgers won the World Series in 2020." |
337 |
| - }, { |
338 |
| - "role": "user", |
339 |
| - "content": "Where was it played?" |
340 |
| - }], |
341 |
| - "model": |
342 |
| - model_name, |
343 |
| - "extra_body": { |
344 |
| - "prompt_logprobs": 1 |
345 |
| - } |
346 |
| - } |
347 |
| - |
348 |
| - completion_1 = await client.chat.completions.create(**params) |
349 |
| - |
350 |
| - params["extra_body"] = {"prompt_logprobs": 2} |
351 |
| - completion_2 = await client.chat.completions.create(**params) |
352 |
| - |
353 |
| - assert len(completion_1.prompt_logprobs[3]) == 1 |
354 |
| - assert len(completion_2.prompt_logprobs[3]) == 2 |
355 |
| - |
356 |
| - |
357 | 271 | @pytest.mark.asyncio
|
358 | 272 | @pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1),
|
359 | 273 | (MODEL_NAME, 0),
|
360 | 274 | (MODEL_NAME, 1),
|
361 | 275 | (MODEL_NAME, None)])
|
362 | 276 | async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
|
363 | 277 | model_name: str,
|
364 |
| - prompt_logprobs: int): |
| 278 | + prompt_logprobs: Optional[int]): |
365 | 279 | params: Dict = {
|
366 | 280 | "prompt": ["A robot may not injure another robot", "My name is"],
|
367 | 281 | "model": model_name,
|
368 | 282 | }
|
369 | 283 | if prompt_logprobs is not None:
|
370 | 284 | params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
|
371 | 285 |
|
372 |
| - if prompt_logprobs and prompt_logprobs < 0: |
373 |
| - with pytest.raises(BadRequestError) as err_info: |
| 286 | + if prompt_logprobs is not None and prompt_logprobs < 0: |
| 287 | + with pytest.raises(BadRequestError): |
374 | 288 | await client.completions.create(**params)
|
375 |
| - expected_err_string = ( |
376 |
| - "Error code: 400 - {'object': 'error', 'message': " |
377 |
| - "'Prompt_logprobs set to invalid negative value: -1'," |
378 |
| - " 'type': 'BadRequestError', 'param': None, 'code': 400}") |
379 |
| - assert str(err_info.value) == expected_err_string |
380 | 289 | else:
|
381 | 290 | completion = await client.completions.create(**params)
|
382 |
| - if prompt_logprobs and prompt_logprobs > 0: |
| 291 | + if prompt_logprobs is not None: |
383 | 292 | assert completion.choices[0].prompt_logprobs is not None
|
384 | 293 | assert len(completion.choices[0].prompt_logprobs) > 0
|
385 | 294 |
|
|
0 commit comments