From 1dd00ad56f9fc43b915a15e55f90ec696184f7bd Mon Sep 17 00:00:00 2001 From: Judah Rand <17158624+judahrand@users.noreply.github.com> Date: Thu, 6 Jun 2024 11:37:26 +0100 Subject: [PATCH] Add support for serializing `pd.DataFrame` in Arrow IPC formats --- .../_internal/io_descriptors/pandas.py | 52 +++++++++++++++++-- tests/e2e/bento_server_http/tests/test_io.py | 31 +++++++++++ 2 files changed, 80 insertions(+), 3 deletions(-) diff --git a/src/bentoml/_internal/io_descriptors/pandas.py b/src/bentoml/_internal/io_descriptors/pandas.py index 7c0c99a32c3..795ee79cff3 100644 --- a/src/bentoml/_internal/io_descriptors/pandas.py +++ b/src/bentoml/_internal/io_descriptors/pandas.py @@ -46,6 +46,7 @@ pb_v1alpha1, _ = import_generated_stubs("v1alpha1") pd = LazyLoader("pd", globals(), "pandas", exc_msg=EXC_MSG) np = LazyLoader("np", globals(), "numpy") + pyarrow = LazyLoader("pyarrow", globals(), "pyarrow") logger = logging.getLogger(__name__) @@ -144,6 +145,8 @@ def _series_openapi_schema( class SerializationFormat(Enum): JSON = "application/json" PARQUET = "application/octet-stream" + ARROW_FILE = "application/vnd.apache.arrow.file" + ARROW_STREAM = "application/vnd.apache.arrow.stream" CSV = "text/csv" def __init__(self, mime_type: str): @@ -156,6 +159,10 @@ def __str__(self) -> str: return "parquet" elif self == SerializationFormat.CSV: return "csv" + elif self == SerializationFormat.ARROW_FILE: + return "arrow_file" + elif self == SerializationFormat.ARROW_STREAM: + return "arrow_stream" else: raise ValueError(f"Unknown serialization format: {self}") @@ -173,6 +180,10 @@ def _infer_serialization_format_from_request( return SerializationFormat.PARQUET elif content_type == "text/csv": return SerializationFormat.CSV + elif content_type == "application/vnd.apache.arrow.file": + return SerializationFormat.ARROW_FILE + elif content_type == "application/vnd.apache.arrow.stream": + return SerializationFormat.ARROW_STREAM elif content_type: logger.debug( "Unknown Content-Type ('%s'), falling back to '%s' serialization format.", @@ -196,6 +207,13 @@ def _validate_serialization_format(serialization_format: SerializationFormat): raise MissingDependencyException( "Parquet serialization is not available. Try installing pyarrow or fastparquet first." ) + if ( + serialization_format is SerializationFormat.ARROW_FILE + or serialization_format is SerializationFormat.ARROW_STREAM + ) and find_spec("pyarrow") is None: + raise MissingDependencyException( + "Arrow serialization is not available. Try installing pyarrow first." + ) class PandasDataFrame( @@ -311,6 +329,8 @@ def predict(input_df: pd.DataFrame) -> pd.DataFrame: - :obj:`json` - JSON text format (inferred from content-type ``"application/json"``) - :obj:`parquet` - Parquet binary format (inferred from content-type ``"application/octet-stream"``) - :obj:`csv` - CSV text format (inferred from content-type ``"text/csv"``) + - :obj:`arrow_file` - Arrow file format (inferred from content-type ``"application/vnd.apache.arrow.file"``) + - :obj:`arrow_stream` - Arrow stream format (inferred from content-type ``"application/vnd.apache.arrow.stream"``) Returns: :obj:`PandasDataFrame`: IO Descriptor that represents a :code:`pd.DataFrame`. @@ -325,7 +345,13 @@ def __init__( enforce_dtype: bool = False, shape: tuple[int, ...] | None = None, enforce_shape: bool = False, - default_format: t.Literal["json", "parquet", "csv"] = "json", + default_format: t.Literal[ + "json", + "parquet", + "csv", + "arrow_file", + "arrow_stream", + ] = "json", ): self._orient: ext.DataFrameOrient = orient self._columns = columns @@ -371,6 +397,8 @@ def _from_sample(self, sample: ext.PdDataFrame) -> ext.PdDataFrame: - :obj:`json` - JSON text format (inferred from content-type ``"application/json"``) - :obj:`parquet` - Parquet binary format (inferred from content-type ``"application/octet-stream"``) - :obj:`csv` - CSV text format (inferred from content-type ``"text/csv"``) + - :obj:`arrow_file` - Arrow file format (inferred from content-type ``"application/vnd.apache.arrow.file"``) + - :obj:`arrow_stream` - Arrow stream format (inferred from content-type ``"application/vnd.apache.arrow.stream"``) Returns: :class:`~bentoml._internal.io_descriptors.pandas.PandasDataFrame`: IODescriptor from given users inputs. @@ -539,6 +567,12 @@ async def from_http_request(self, request: Request) -> ext.PdDataFrame: res = pd.read_parquet(io.BytesIO(obj), engine=get_parquet_engine()) elif serialization_format is SerializationFormat.CSV: res: ext.PdDataFrame = pd.read_csv(io.BytesIO(obj), dtype=dtype) + elif serialization_format is SerializationFormat.ARROW_FILE: + with pyarrow.ipc.open_file(obj) as reader: + res = reader.read_pandas() + elif serialization_format is SerializationFormat.ARROW_STREAM: + with pyarrow.ipc.open_stream(obj) as reader: + res = reader.read_pandas() else: raise InvalidArgument( f"Unknown serialization format ({serialization_format})." @@ -576,6 +610,18 @@ async def to_http_response( resp = obj.to_parquet(engine=get_parquet_engine()) elif serialization_format is SerializationFormat.CSV: resp = obj.to_csv() + elif serialization_format is SerializationFormat.ARROW_FILE: + sink = pyarrow.BufferOutputStream() + batch = self.to_arrow(obj) + with pyarrow.ipc.new_file(sink, batch.schema) as writer: + writer.write_batch(batch) + resp = sink.getvalue().to_pybytes() + elif serialization_format is SerializationFormat.ARROW_STREAM: + sink = pyarrow.BufferOutputStream() + batch = self.to_arrow(obj) + with pyarrow.ipc.new_stream(sink, batch.schema) as writer: + writer.write_batch(batch) + resp = sink.getvalue().to_pybytes() else: raise InvalidArgument( f"Unknown serialization format ({serialization_format})." @@ -743,7 +789,7 @@ def from_arrow(self, batch: pyarrow.RecordBatch) -> ext.PdDataFrame: def to_arrow(self, df: pd.Series[t.Any]) -> pyarrow.RecordBatch: import pyarrow - return pyarrow.RecordBatch.from_pandas(df) + return pyarrow.RecordBatch.from_pandas(df, preserve_index=True) def spark_schema(self) -> pyspark.sql.types.StructType: from pyspark.pandas.typedef import as_spark_type @@ -1201,7 +1247,7 @@ def to_arrow(self, series: pd.Series[t.Any]) -> pyarrow.RecordBatch: import pyarrow df = series.to_frame() - return pyarrow.RecordBatch.from_pandas(df) + return pyarrow.RecordBatch.from_pandas(df, preserve_index=True) def spark_schema(self) -> pyspark.sql.types.StructType: from pyspark.pandas.typedef import as_spark_type diff --git a/tests/e2e/bento_server_http/tests/test_io.py b/tests/e2e/bento_server_http/tests/test_io.py index 56ba22d4625..03fd0c88f0b 100644 --- a/tests/e2e/bento_server_http/tests/test_io.py +++ b/tests/e2e/bento_server_http/tests/test_io.py @@ -7,6 +7,7 @@ from typing import Tuple import numpy as np +import pyarrow import pytest from bentoml.client import AsyncHTTPClient @@ -144,6 +145,36 @@ async def test_pandas(host: str): assert response.status_code == 200 assert await response.aread() == b'[{"col1":202}]' + headers = { + "Content-Type": "application/vnd.apache.arrow.stream", + "Origin": ORIGIN, + } + sink = pyarrow.BufferOutputStream() + batch = pyarrow.RecordBatch.from_pandas(df, preserve_index=True) + with pyarrow.ipc.new_stream(sink, batch.schema) as writer: + writer.write_batch(batch) + data = sink.getvalue().to_pybytes() + response = await client.client.post( + "/predict_dataframe", headers=headers, data=data + ) + assert response.status_code == 200 + assert await response.aread() == b'[{"col1":202}]' + + headers = { + "Content-Type": "application/vnd.apache.arrow.file", + "Origin": ORIGIN, + } + sink = pyarrow.BufferOutputStream() + batch = pyarrow.RecordBatch.from_pandas(df, preserve_index=True) + with pyarrow.ipc.new_file(sink, batch.schema) as writer: + writer.write_batch(batch) + data = sink.getvalue().to_pybytes() + response = await client.client.post( + "/predict_dataframe", headers=headers, data=data + ) + assert response.status_code == 200 + assert await response.aread() == b'[{"col1":202}]' + @pytest.mark.asyncio async def test_file(host: str, bin_file: str):