-
Notifications
You must be signed in to change notification settings - Fork 61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix streaming endpoint failure handling #314
Changes from 2 commits
1ebc95f
93e009b
0e46b27
ff37909
ff2083d
9143da2
2f744ec
4896762
b28baa2
21b10e9
9fb8d5d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -226,42 +226,38 @@ async def create_completion_stream_task( | |
logger.info( | ||
f"POST /completion_stream with {request} to endpoint {model_endpoint_name} for {auth}" | ||
) | ||
try: | ||
use_case = CompletionStreamV1UseCase( | ||
model_endpoint_service=external_interfaces.model_endpoint_service, | ||
llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, | ||
) | ||
response = use_case.execute( | ||
user=auth, model_endpoint_name=model_endpoint_name, request=request | ||
) | ||
use_case = CompletionStreamV1UseCase( | ||
model_endpoint_service=external_interfaces.model_endpoint_service, | ||
llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, | ||
) | ||
response = use_case.execute(user=auth, model_endpoint_name=model_endpoint_name, request=request) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here |
||
|
||
async def event_generator(): | ||
try: | ||
async for message in response: | ||
yield {"data": message.json()} | ||
except InvalidRequestException as exc: | ||
yield {"data": {"error": {"status_code": 400, "detail": str(exc)}}} | ||
return | ||
async def event_generator(): | ||
try: | ||
async for message in response: | ||
yield {"data": message.json()} | ||
except InvalidRequestException as exc: | ||
yield {"data": {"error": {"status_code": 400, "detail": str(exc)}}} | ||
except UpstreamServiceError as exc: | ||
request_id = get_request_id() | ||
logger.exception(f"Upstream service error for request {request_id}") | ||
yield {"data": {"error": {"status_code": 500, "detail": str(exc)}}} | ||
except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: | ||
print(str(exc)) | ||
yield {"data": {"error": {"status_code": 404, "detail": str(exc)}}} | ||
except ObjectHasInvalidValueException as exc: | ||
yield {"data": {"error": {"status_code": 400, "detail": str(exc)}}} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we have a helper function to generate these payloads? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think this snippet is short enough so okay to not have a helper function |
||
except EndpointUnsupportedInferenceTypeException as exc: | ||
yunfeng-scale marked this conversation as resolved.
Show resolved
Hide resolved
|
||
yield { | ||
"data": { | ||
"error": { | ||
"status_code": 400, | ||
"detail": f"Unsupported inference type: {str(exc)}", | ||
} | ||
} | ||
} | ||
|
||
return EventSourceResponse(event_generator()) | ||
except UpstreamServiceError: | ||
request_id = get_request_id() | ||
logger.exception(f"Upstream service error for request {request_id}") | ||
return EventSourceResponse( | ||
iter((CompletionStreamV1Response(request_id=request_id).json(),)) # type: ignore | ||
) | ||
except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: | ||
raise HTTPException( | ||
status_code=404, | ||
detail="The specified endpoint could not be found.", | ||
) from exc | ||
except ObjectHasInvalidValueException as exc: | ||
raise HTTPException(status_code=400, detail=str(exc)) | ||
except EndpointUnsupportedInferenceTypeException as exc: | ||
raise HTTPException( | ||
status_code=400, | ||
detail=f"Unsupported inference type: {str(exc)}", | ||
) from exc | ||
return EventSourceResponse(event_generator()) | ||
|
||
|
||
@llm_router_v1.post("/fine-tunes", response_model=CreateFineTuneResponse) | ||
|
@@ -405,12 +401,12 @@ async def delete_llm_model_endpoint( | |
model_endpoint_service=external_interfaces.model_endpoint_service, | ||
) | ||
return await use_case.execute(user=auth, model_endpoint_name=model_endpoint_name) | ||
except (ObjectNotFoundException) as exc: | ||
except ObjectNotFoundException as exc: | ||
raise HTTPException( | ||
status_code=404, | ||
detail="The requested model endpoint could not be found.", | ||
) from exc | ||
except (ObjectNotAuthorizedException) as exc: | ||
except ObjectNotAuthorizedException as exc: | ||
raise HTTPException( | ||
status_code=403, | ||
detail="You don't have permission to delete the requested model endpoint.", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -113,6 +113,32 @@ def test_completion_sync_success( | |
assert response_1.json().keys() == {"output", "request_id"} | ||
|
||
|
||
def test_completion_sync_endpoint_not_found_returns_404( | ||
llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], | ||
completion_sync_request: Dict[str, Any], | ||
get_test_client_wrapper, | ||
): | ||
client = get_test_client_wrapper( | ||
fake_docker_repository_image_always_exists=True, | ||
fake_model_bundle_repository_contents={}, | ||
fake_model_endpoint_record_repository_contents={}, | ||
fake_model_endpoint_infra_gateway_contents={ | ||
llm_model_endpoint_sync[0] | ||
.infra_state.deployment_name: llm_model_endpoint_sync[0] | ||
.infra_state, | ||
}, | ||
fake_batch_job_record_repository_contents={}, | ||
fake_batch_job_progress_gateway_contents={}, | ||
fake_docker_image_batch_job_bundle_repository_contents={}, | ||
) | ||
response_1 = client.post( | ||
f"/v1/llm/completions-sync?model_endpoint_name={llm_model_endpoint_sync[0].record.name}", | ||
auth=("no_user", ""), | ||
json=completion_sync_request, | ||
) | ||
assert response_1.status_code == 404 | ||
|
||
|
||
@pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we still want to skip this test? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, unfortunately i still haven't figured out how to run two streaming tests in a row. there's something wrong about how test client uses event loop that i wasn't able to fix: more context in encode/starlette#1315 |
||
def test_completion_stream_success( | ||
llm_model_endpoint_streaming: ModelEndpoint, | ||
|
@@ -136,6 +162,7 @@ def test_completion_stream_success( | |
f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}", | ||
auth=("no_user", ""), | ||
json=completion_stream_request, | ||
stream=True, | ||
) | ||
assert response_1.status_code == 200 | ||
count = 0 | ||
|
@@ -146,3 +173,34 @@ def test_completion_stream_success( | |
) | ||
count += 1 | ||
assert count == 1 | ||
|
||
|
||
@pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness") | ||
def test_completion_stream_endpoint_not_found_returns_404( | ||
llm_model_endpoint_streaming: ModelEndpoint, | ||
completion_stream_request: Dict[str, Any], | ||
get_test_client_wrapper, | ||
): | ||
client = get_test_client_wrapper( | ||
fake_docker_repository_image_always_exists=True, | ||
fake_model_bundle_repository_contents={}, | ||
fake_model_endpoint_record_repository_contents={}, | ||
fake_model_endpoint_infra_gateway_contents={ | ||
llm_model_endpoint_streaming.infra_state.deployment_name: llm_model_endpoint_streaming.infra_state, | ||
}, | ||
fake_batch_job_record_repository_contents={}, | ||
fake_batch_job_progress_gateway_contents={}, | ||
fake_docker_image_batch_job_bundle_repository_contents={}, | ||
) | ||
response_1 = client.post( | ||
f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}", | ||
auth=("no_user", ""), | ||
json=completion_stream_request, | ||
stream=True, | ||
) | ||
|
||
assert response_1.status_code == 200 | ||
|
||
for message in response_1: | ||
print(message) | ||
yunfeng-scale marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert "404" in message.decode("utf-8") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't be using clients/python mypy checking server files