Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix streaming endpoint failure handling #314

Merged
merged 11 commits into from
Oct 11, 2023
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ repos:
hooks:
- id: mypy
name: mypy-clients-python
files: clients/python/.*
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't be using clients/python mypy checking server files

entry: mypy --config-file clients/python/mypy.ini
language: system
- repo: https://github.com/pre-commit/mirrors-mypy
Expand Down
68 changes: 32 additions & 36 deletions model-engine/model_engine_server/api/llms_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,42 +226,38 @@ async def create_completion_stream_task(
logger.info(
f"POST /completion_stream with {request} to endpoint {model_endpoint_name} for {auth}"
)
try:
use_case = CompletionStreamV1UseCase(
model_endpoint_service=external_interfaces.model_endpoint_service,
llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service,
)
response = use_case.execute(
user=auth, model_endpoint_name=model_endpoint_name, request=request
)
use_case = CompletionStreamV1UseCase(
model_endpoint_service=external_interfaces.model_endpoint_service,
llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service,
)
response = use_case.execute(user=auth, model_endpoint_name=model_endpoint_name, request=request)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here use_case.execute doesn't actually execute until L237


async def event_generator():
try:
async for message in response:
yield {"data": message.json()}
except InvalidRequestException as exc:
yield {"data": {"error": {"status_code": 400, "detail": str(exc)}}}
return
async def event_generator():
try:
async for message in response:
yield {"data": message.json()}
except InvalidRequestException as exc:
yield {"data": {"error": {"status_code": 400, "detail": str(exc)}}}
except UpstreamServiceError as exc:
request_id = get_request_id()
logger.exception(f"Upstream service error for request {request_id}")
yield {"data": {"error": {"status_code": 500, "detail": str(exc)}}}
except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc:
print(str(exc))
yield {"data": {"error": {"status_code": 404, "detail": str(exc)}}}
except ObjectHasInvalidValueException as exc:
yield {"data": {"error": {"status_code": 400, "detail": str(exc)}}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a helper function to generate these payloads?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this snippet is short enough so okay to not have a helper function

except EndpointUnsupportedInferenceTypeException as exc:
yunfeng-scale marked this conversation as resolved.
Show resolved Hide resolved
yield {
"data": {
"error": {
"status_code": 400,
"detail": f"Unsupported inference type: {str(exc)}",
}
}
}

return EventSourceResponse(event_generator())
except UpstreamServiceError:
request_id = get_request_id()
logger.exception(f"Upstream service error for request {request_id}")
return EventSourceResponse(
iter((CompletionStreamV1Response(request_id=request_id).json(),)) # type: ignore
)
except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc:
raise HTTPException(
status_code=404,
detail="The specified endpoint could not be found.",
) from exc
except ObjectHasInvalidValueException as exc:
raise HTTPException(status_code=400, detail=str(exc))
except EndpointUnsupportedInferenceTypeException as exc:
raise HTTPException(
status_code=400,
detail=f"Unsupported inference type: {str(exc)}",
) from exc
return EventSourceResponse(event_generator())


@llm_router_v1.post("/fine-tunes", response_model=CreateFineTuneResponse)
Expand Down Expand Up @@ -405,12 +401,12 @@ async def delete_llm_model_endpoint(
model_endpoint_service=external_interfaces.model_endpoint_service,
)
return await use_case.execute(user=auth, model_endpoint_name=model_endpoint_name)
except (ObjectNotFoundException) as exc:
except ObjectNotFoundException as exc:
raise HTTPException(
status_code=404,
detail="The requested model endpoint could not be found.",
) from exc
except (ObjectNotAuthorizedException) as exc:
except ObjectNotAuthorizedException as exc:
raise HTTPException(
status_code=403,
detail="You don't have permission to delete the requested model endpoint.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1296,7 +1296,7 @@ async def execute(
)

if len(model_endpoints) == 0:
raise ObjectNotFoundException
raise ObjectNotFoundException(f"Model endpoint {model_endpoint_name} not found.")

if len(model_endpoints) > 1:
raise ObjectHasInvalidValueException(
Expand Down
58 changes: 58 additions & 0 deletions model-engine/tests/unit/api/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,32 @@ def test_completion_sync_success(
assert response_1.json().keys() == {"output", "request_id"}


def test_completion_sync_endpoint_not_found_returns_404(
llm_model_endpoint_sync: Tuple[ModelEndpoint, Any],
completion_sync_request: Dict[str, Any],
get_test_client_wrapper,
):
client = get_test_client_wrapper(
fake_docker_repository_image_always_exists=True,
fake_model_bundle_repository_contents={},
fake_model_endpoint_record_repository_contents={},
fake_model_endpoint_infra_gateway_contents={
llm_model_endpoint_sync[0]
.infra_state.deployment_name: llm_model_endpoint_sync[0]
.infra_state,
},
fake_batch_job_record_repository_contents={},
fake_batch_job_progress_gateway_contents={},
fake_docker_image_batch_job_bundle_repository_contents={},
)
response_1 = client.post(
f"/v1/llm/completions-sync?model_endpoint_name={llm_model_endpoint_sync[0].record.name}",
auth=("no_user", ""),
json=completion_sync_request,
)
assert response_1.status_code == 404


@pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still want to skip this test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, unfortunately i still haven't figured out how to run two streaming tests in a row. there's something wrong about how test client uses event loop that i wasn't able to fix: more context in encode/starlette#1315

def test_completion_stream_success(
llm_model_endpoint_streaming: ModelEndpoint,
Expand All @@ -136,6 +162,7 @@ def test_completion_stream_success(
f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}",
auth=("no_user", ""),
json=completion_stream_request,
stream=True,
)
assert response_1.status_code == 200
count = 0
Expand All @@ -146,3 +173,34 @@ def test_completion_stream_success(
)
count += 1
assert count == 1


@pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness")
def test_completion_stream_endpoint_not_found_returns_404(
llm_model_endpoint_streaming: ModelEndpoint,
completion_stream_request: Dict[str, Any],
get_test_client_wrapper,
):
client = get_test_client_wrapper(
fake_docker_repository_image_always_exists=True,
fake_model_bundle_repository_contents={},
fake_model_endpoint_record_repository_contents={},
fake_model_endpoint_infra_gateway_contents={
llm_model_endpoint_streaming.infra_state.deployment_name: llm_model_endpoint_streaming.infra_state,
},
fake_batch_job_record_repository_contents={},
fake_batch_job_progress_gateway_contents={},
fake_docker_image_batch_job_bundle_repository_contents={},
)
response_1 = client.post(
f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}",
auth=("no_user", ""),
json=completion_stream_request,
stream=True,
)

assert response_1.status_code == 200

for message in response_1:
print(message)
yunfeng-scale marked this conversation as resolved.
Show resolved Hide resolved
assert "404" in message.decode("utf-8")
1 change: 1 addition & 0 deletions model-engine/tests/unit/api/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def test_create_streaming_task_success(
f"/v1/streaming-tasks?model_endpoint_id={model_endpoint_streaming.record.id}",
auth=(test_api_key, ""),
json=endpoint_predict_request_1[1],
stream=True,
)
assert response.status_code == 200
count = 0
Expand Down