Skip to content

Commit

Permalink
Catch and format prediction result errors
Browse files Browse the repository at this point in the history
Catch and format as a JSON response, exceptions that occur when
converting the result of model.predict() to the expected Response type
(as documented in the mlflow model metadata)

Fixes #12
  • Loading branch information
bloomonkey committed Apr 17, 2023
1 parent f1fd20b commit be7521c
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 23 deletions.
6 changes: 4 additions & 2 deletions fastapi_mlflow/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from fastapi.responses import JSONResponse

from mlflow.pyfunc import PyFuncModel # type: ignore
from fastapi_mlflow.predictors import build_predictor, PyFuncModelPredictError

from fastapi_mlflow.exceptions import DictSerialisableException
from fastapi_mlflow.predictors import build_predictor


def build_app(pyfunc_model: PyFuncModel) -> FastAPI:
Expand All @@ -29,7 +31,7 @@ def build_app(pyfunc_model: PyFuncModel) -> FastAPI:
)

@app.exception_handler(Exception)
def handle_exception(_: Request, exc: PyFuncModelPredictError):
def handle_exception(_: Request, exc: DictSerialisableException):
return JSONResponse(
status_code=500,
content=exc.to_dict(),
Expand Down
22 changes: 22 additions & 0 deletions fastapi_mlflow/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
"""Exceptions.
Copyright (C) 2023, Auto Trader UK
Created 17. Apr 2023
"""


class DictSerialisableException(Exception):
"""An Exception wrapper for formatting."""

def __init__(self, name: str, message: str):
self.name = name
self.message = message

@classmethod
def from_exception(cls, exc: Exception):
return cls(name=exc.__class__.__name__, message=str(exc))

def to_dict(self):
return {"name": self.name, "message": self.message}
18 changes: 4 additions & 14 deletions fastapi_mlflow/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,7 @@
build_model_fields,
MLFLOW_SIGNATURE_TO_PYTHON_TYPE_MAP,
)


class PyFuncModelPredictError(Exception):
def __init__(self, exc: Exception):
super().__init__()
self.error_type_name = exc.__class__.__name__
self.message = str(exc)

def to_dict(self):
return {"name": self.error_type_name, "message": self.message}
from fastapi_mlflow.exceptions import DictSerialisableException


@no_type_check # Some types here are too dynamic for type checking
Expand Down Expand Up @@ -80,11 +71,10 @@ def request_to_dataframe(request: Request) -> pd.DataFrame:
def predictor(request: Request) -> Response:
try:
predictions = model.predict(request_to_dataframe(request))
response_data = convert_predictions_to_python(predictions)
return Response(data=response_data)
except Exception as exc:
raise PyFuncModelPredictError(exc) from exc

response_data = convert_predictions_to_python(predictions)
return Response(data=response_data)
raise DictSerialisableException.from_exception(exc) from exc

return predictor # type: ignore

Expand Down
27 changes: 20 additions & 7 deletions tests/test_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from datetime import datetime
from inspect import signature
from typing import Union
from unittest.mock import patch

import numpy as np
import numpy.typing as npt
Expand All @@ -15,11 +16,8 @@
import pytest
from mlflow.pyfunc import PyFuncModel # type: ignore

from fastapi_mlflow.predictors import (
build_predictor,
convert_predictions_to_python,
PyFuncModelPredictError,
)
from fastapi_mlflow.exceptions import DictSerialisableException
from fastapi_mlflow.predictors import build_predictor, convert_predictions_to_python


def test_build_predictor_returns_callable(
Expand Down Expand Up @@ -152,18 +150,33 @@ def test_predictor_handles_model_returning_nan(
assert item.b is None


def test_predictor_exception_from_model_raised_from(
def test_predictor_raises_custom_wrapped_exception_on_model_error(
model_input: pd.DataFrame, pyfunc_model_value_error: PyFuncModel
):
predictor = build_predictor(pyfunc_model_value_error)

request_type = signature(predictor).parameters["request"].annotation
request_obj = request_type(data=model_input.to_dict(orient="records"))

with pytest.raises(PyFuncModelPredictError):
with pytest.raises(DictSerialisableException):
predictor(request_obj)


def test_predictor_raises_custom_wrapped_exception_on_model_output_conversion(
model_input: pd.DataFrame,
pyfunc_model_ndarray: PyFuncModel,
):
predictor = build_predictor(pyfunc_model_ndarray)

request_type = signature(predictor).parameters["request"].annotation
request_obj = request_type(data=model_input.to_dict(orient="records"))

with patch.object(pyfunc_model_ndarray, "predict") as predict:
predict.return_value = ["Fail!"]
with pytest.raises(DictSerialisableException):
predictor(request_obj) * len(model_input)


def test_convert_predictions_to_python_ndarray():
predictions = np.array([1, 2, 3, np.nan])
response_data = convert_predictions_to_python(predictions)
Expand Down

0 comments on commit be7521c

Please sign in to comment.