diff --git a/src/agents/exceptions.py b/src/agents/exceptions.py index 78898f01..57cbef70 100644 --- a/src/agents/exceptions.py +++ b/src/agents/exceptions.py @@ -61,3 +61,7 @@ def __init__(self, guardrail_result: "OutputGuardrailResult"): super().__init__( f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire" ) + + +class ProviderError(AgentsException): + """Exception raised when the provider fails.""" diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 89619f83..9d1ee7d0 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -12,6 +12,7 @@ from .. import _debug from ..agent_output import AgentOutputSchemaBase +from ..exceptions import ProviderError from ..handoffs import Handoff from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent from ..logger import logger @@ -70,6 +71,9 @@ async def get_response( stream=False, ) + if not getattr(response, "choices", None): + raise ProviderError(f"LLM provider error: {getattr(response, 'error', 'unknown')}") + if _debug.DONT_LOG_MODEL_DATA: logger.debug("Received model response") else: @@ -252,7 +256,7 @@ async def _fetch_response( stream_options=self._non_null_or_not_given(stream_options), store=self._non_null_or_not_given(store), reasoning_effort=self._non_null_or_not_given(reasoning_effort), - extra_headers={ **HEADERS, **(model_settings.extra_headers or {}) }, + extra_headers={**HEADERS, **(model_settings.extra_headers or {})}, extra_query=model_settings.extra_query, extra_body=model_settings.extra_body, metadata=self._non_null_or_not_given(model_settings.metadata), diff --git a/tests/test_openai_chatcompletions.py b/tests/test_openai_chatcompletions.py index ba3ec68d..84ffe992 100644 --- a/tests/test_openai_chatcompletions.py +++ b/tests/test_openai_chatcompletions.py @@ -1,5 +1,6 @@ from __future__ import annotations +import importlib from collections.abc import AsyncIterator from typing import Any @@ -30,6 +31,7 @@ OpenAIProvider, generation_span, ) +from agents.exceptions import ProviderError from agents.models.chatcmpl_helpers import ChatCmplHelpers from agents.models.fake_id import FAKE_RESPONSES_ID @@ -330,3 +332,41 @@ def test_store_param(): assert ChatCmplHelpers.get_store_param(client, model_settings) is True, ( "Should respect explicitly set store=True" ) + + +@pytest.mark.asyncio +async def test_get_response_raises_provider_error_if_no_choices(monkeypatch): + # Import the class under test _inside_ the function so + # pytest’s conftest autouse fixtures don’t stomp it out. + import agents.models.openai_chatcompletions as chatmod + + chatmod = importlib.reload(chatmod) + + ModelClass = chatmod.OpenAIChatCompletionsModel + + dummy_client = AsyncOpenAI(api_key="fake", base_url="http://localhost") + model = ModelClass(model="test-model", openai_client=dummy_client) + + class FakeResponse: + choices = [] + error = "service unavailable" + + async def fake_fetch_response(*args, **kwargs): + return FakeResponse() + + monkeypatch.setattr(ModelClass, "_fetch_response", fake_fetch_response) + + settings = ModelSettings(temperature=0.0, max_tokens=1) + with pytest.raises(ProviderError) as exc: + await model.get_response( + system_instructions="", + input="Hello?", + model_settings=settings, + tools=[], + output_schema=None, + handoffs=[], + tracing=ModelTracing.DISABLED, + previous_response_id=None, + ) + + assert "service unavailable" in str(exc.value)