Skip to content

Commit

Permalink
fix: remove invalidated content-length in error response
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik committed Jan 22, 2025
1 parent 6d21828 commit c081fbb
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 17 deletions.
13 changes: 10 additions & 3 deletions aidial_adapter_openai/exception_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ResponseWrapper,
parse_adapter_exception,
)
from aidial_adapter_openai.utils.log_config import logger


def to_adapter_exception(exc: Exception) -> AdapterException:
Expand All @@ -32,7 +33,7 @@ def to_adapter_exception(exc: Exception) -> AdapterException:

return parse_adapter_exception(
status_code=r.status_code,
headers=dict(httpx_headers.items()),
headers=httpx_headers,
content=r.text,
)

Expand Down Expand Up @@ -71,6 +72,12 @@ def to_adapter_exception(exc: Exception) -> AdapterException:


def adapter_exception_handler(
request: FastAPIRequest, exc: Exception
request: FastAPIRequest, e: Exception
) -> FastAPIResponse:
return to_adapter_exception(exc).to_fastapi_response()
adapter_exception = to_adapter_exception(e)

logger.error(
f"Caught exception: {type(e).__module__}.{type(e).__name__}. "
f"Converted to the adapter exception: {adapter_exception!r}"
)
return adapter_exception.to_fastapi_response()
17 changes: 11 additions & 6 deletions aidial_adapter_openai/utils/adapter_exception.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, Dict
from typing import Any, MutableMapping

from aidial_sdk.exceptions import HTTPException as DialException
from fastapi.responses import Response as FastAPIResponse
Expand All @@ -8,14 +8,14 @@
class ResponseWrapper(Exception):
content: Any
status_code: int
headers: Dict[str, str] | None
headers: MutableMapping[str, str] | None

def __init__(
self,
*,
content: Any,
status_code: int,
headers: Dict[str, str] | None,
headers: MutableMapping[str, str] | None,
) -> None:
super().__init__(str(content))
self.content = content
Expand Down Expand Up @@ -51,7 +51,7 @@ def json_error(self) -> dict:


def _parse_dial_exception(
*, status_code: int, headers: Dict[str, str], content: Any
*, status_code: int, headers: MutableMapping[str, str], content: Any
) -> DialException | None:
if isinstance(content, str):
try:
Expand All @@ -61,6 +61,11 @@ def _parse_dial_exception(
else:
obj = content

# The content length is invalidated as soon as
# the original content is lost
if "Content-Length" in headers:
del headers["Content-Length"]

if (
isinstance(obj, dict)
and (error := obj.get("error"))
Expand All @@ -79,14 +84,14 @@ def _parse_dial_exception(
param=param,
code=code,
display_message=display_message,
headers=headers,
headers=dict(headers.items()),
)

return None


def parse_adapter_exception(
*, status_code: int, headers: Dict[str, str], content: Any
*, status_code: int, headers: MutableMapping[str, str], content: Any
) -> AdapterException:
return _parse_dial_exception(
status_code=status_code, headers=headers, content=content
Expand Down
10 changes: 4 additions & 6 deletions aidial_adapter_openai/utils/sse_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,11 @@ async def to_openai_sse_stream(
async for chunk in stream:
yield format_chunk(chunk)
except Exception as e:
logger.exception(
f"caught exception while streaming: {type(e).__module__}.{type(e).__name__}"
)

adapter_exception = to_adapter_exception(e)
logger.error(
f"converted to the adapter exception: {adapter_exception!r}"

logger.exception(
f"Caught exception while streaming: {type(e).__module__}.{type(e).__name__}. "
f"Converted to the adapter exception: {adapter_exception!r}"
)

yield format_chunk(adapter_exception.json_error())
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def create_test_cases(
build_vision_common,
]
),
ids=lambda tc: tc.get_id(),
ids=lambda tc: tc.get_id() if isinstance(tc, TestCase) else "na",
)
async def test_chat_completion(
test_case: TestCase,
Expand Down
52 changes: 51 additions & 1 deletion tests/unit_tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,18 @@ def mock_response(
status_code: int,
content_type: str,
content: str,
*,
check_request: Callable[[httpx.Request], None] = lambda _: None,
extra_headers: dict[str, str] = {},
) -> SideEffectTypes:
def side_effect(request: httpx.Request):
check_request(request)
return httpx.Response(
status_code=status_code,
headers={"content-type": content_type},
headers={
"content-type": content_type,
**extra_headers,
},
content=content,
)

Expand Down Expand Up @@ -519,6 +524,51 @@ async def test_connection_error_from_upstream_non_streaming(
}


@respx.mock
async def test_content_length_of_response_error(test_app: httpx.AsyncClient):
upstream_response = """
{
"error": {
"message": "Bad request",
"code": "400"
}
}
"""
upstream_response_content_length = str(len(upstream_response))

respx.post(
"http://localhost:5001/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview"
).mock(
side_effect=mock_response(
400,
"application/json",
upstream_response,
extra_headers={"content-length": upstream_response_content_length},
)
)

response = await test_app.post(
"/openai/deployments/gpt-4/chat/completions?api-version=2023-03-15-preview",
json={"messages": [{"role": "user", "content": "Test content"}]},
headers={
"X-UPSTREAM-KEY": "TEST_API_KEY",
"X-UPSTREAM-ENDPOINT": "http://localhost:5001/openai/deployments/gpt-4/chat/completions",
},
)

expected_response = json.dumps(
json.loads(upstream_response), separators=(",", ":")
)
expected_content_length = str(len(expected_response))

assert response.status_code == 400
assert response.text == expected_response
assert response.headers["content-length"] == expected_content_length
assert upstream_response_content_length != expected_content_length


@respx.mock
async def test_connection_error_from_upstream_streaming(
test_app: httpx.AsyncClient,
Expand Down

0 comments on commit c081fbb

Please sign in to comment.