Skip to content

Commit

Permalink
fix(o1_transformation.py): return 'stream' param support if o1-mini/o…
Browse files Browse the repository at this point in the history
…1-preview

o1 currently doesn't support streaming, but the other model versions do

Fixes #7292
  • Loading branch information
krrishdholakia committed Dec 18, 2024
1 parent e6e3686 commit 9a0d6db
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 1 deletion.
4 changes: 3 additions & 1 deletion litellm/llms/openai/chat/o1_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def get_supported_openai_params(self, model: str) -> list:
"top_logprobs",
]

if "o1-mini" not in model:
supported_streaming_models = ["o1-preview", "o1-mini"]

if model not in supported_streaming_models:
non_supported_params.append("stream")
non_supported_params.append("stream_options")

Expand Down
55 changes: 55 additions & 0 deletions tests/llm_translation/test_openai_o1.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,61 @@ async def test_o1_handle_system_role(model):
]


@pytest.mark.parametrize(
"model, expected_streaming_support",
[("o1-preview", True), ("o1-mini", True), ("o1", False)],
)
@pytest.mark.asyncio
async def test_o1_handle_streaming_optional_params(model, expected_streaming_support):
"""
Tests that:
- max_tokens is translated to 'max_completion_tokens'
- role 'system' is translated to 'user'
"""
from openai import AsyncOpenAI
from litellm.utils import ProviderConfigManager
from litellm.types.utils import LlmProviders

os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")

config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=LlmProviders.OPENAI
)

supported_params = config.get_supported_openai_params(model=model)

assert expected_streaming_support == ("stream" in supported_params)


# @pytest.mark.parametrize(
# "model",
# ["o1"], # "o1-preview", "o1-mini",
# )
# @pytest.mark.asyncio
# async def test_o1_handle_streaming_e2e(model):
# """
# Tests that:
# - max_tokens is translated to 'max_completion_tokens'
# - role 'system' is translated to 'user'
# """
# from openai import AsyncOpenAI
# from litellm.utils import ProviderConfigManager
# from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
# from litellm.types.utils import LlmProviders

# resp = litellm.completion(
# model=model,
# messages=[{"role": "user", "content": "Hello!"}],
# stream=True,
# )
# assert isinstance(resp, CustomStreamWrapper)
# for chunk in resp:
# print("chunk: ", chunk)

# assert True


@pytest.mark.asyncio
@pytest.mark.parametrize("model", ["gpt-4", "gpt-4-0314", "gpt-4-32k", "o1-preview"])
async def test_o1_max_completion_tokens(model: str):
Expand Down

0 comments on commit 9a0d6db

Please sign in to comment.