Skip to content

Commit

Permalink
Fix predictors for pydantic 2 API changes
Browse files Browse the repository at this point in the history
Pydantic 2.x BaseModel has a slightly different API, and generates
slightly different JSON schemas.
https://docs.pydantic.dev/latest/migration/#changes-to-pydanticbasemodel
https://docs.pydantic.dev/latest/migration/#changes-to-json-schema-generation
  • Loading branch information
bloomonkey committed Aug 23, 2023
1 parent 4a3819b commit 9ab9628
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Fixed
- Adapt to pydantic 2.x API changes

## [0.6.1] - 2023-08-23
### Fixed
Expand Down
2 changes: 1 addition & 1 deletion fastapi_mlflow/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class Response(BaseModel):
data: List[output_model]

def request_to_dataframe(request: Request) -> pd.DataFrame:
df = pd.DataFrame([row.dict() for row in request.data], dtype=object)
df = pd.DataFrame([row.model_dump() for row in request.data], dtype=object)
for item in input_schema.to_dict():
if item["type"] in ("integer", "int32"):
df[item["name"]] = df[item["name"]].astype(np.int32)
Expand Down
12 changes: 6 additions & 6 deletions tests/test_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def test_predictor_has_correct_signature_for_input(
"type in predictor function parameter `request` is not a"
"subclass of pydantic.BaseModel"
)
schema = request_type.schema()
schema = request_type.model_json_schema()
assert "data" in schema["required"]
assert schema["properties"]["data"]["type"] == "array"
assert "RequestRow" in schema["definitions"]
assert schema["definitions"]["RequestRow"]["required"] == list(model_input.columns)
assert "RequestRow" in schema["$defs"]
assert schema["$defs"]["RequestRow"]["required"] == list(model_input.columns)


def test_predictor_signature_type_can_be_constructed(
Expand Down Expand Up @@ -90,10 +90,10 @@ def test_predictor_has_correct_return_type(
"type in predictor function parameter `request` is not a"
"subclass of pydantic.BaseModel"
)
schema = return_type.schema()
schema = return_type.model_json_schema()
assert "data" in schema["required"]
assert schema["properties"]["data"]["type"] == "array"
assert "ResponseRow" in schema["definitions"]
assert "ResponseRow" in schema["$defs"]


@pytest.mark.asyncio
Expand All @@ -119,7 +119,7 @@ async def test_predictor_correctly_applies_model(
request_obj = request_type(data=model_input.to_dict(orient="records"))
response = await predictor(request_obj)
try:
assert response.data == model_output.to_dict(orient="records") # type: ignore
assert [row.model_dump() for row in response.data] == model_output.to_dict(orient="records") # type: ignore
except (AttributeError, TypeError):
predictions = [item.prediction for item in response.data]
assert predictions == model_output.tolist() # type: ignore
Expand Down

0 comments on commit 9ab9628

Please sign in to comment.