diff --git a/datafusion/tests/test_dataframe.py b/datafusion/tests/test_dataframe.py index 9944abee..c8c74fa2 100644 --- a/datafusion/tests/test_dataframe.py +++ b/datafusion/tests/test_dataframe.py @@ -623,6 +623,36 @@ def test_to_arrow_table(df): assert set(pyarrow_table.column_names) == {"a", "b", "c"} +def test_execute_stream(df): + stream = df.execute_stream() + assert all(batch is not None for batch in stream) + assert not list(stream) # after one iteration the generator must be exhausted + + +@pytest.mark.parametrize("schema", [True, False]) +def test_execute_stream_to_arrow_table(df, schema): + stream = df.execute_stream() + + if schema: + pyarrow_table = pa.Table.from_batches( + (batch.to_pyarrow() for batch in stream), schema=df.schema() + ) + else: + pyarrow_table = pa.Table.from_batches((batch.to_pyarrow() for batch in stream)) + + assert isinstance(pyarrow_table, pa.Table) + assert pyarrow_table.shape == (3, 3) + assert set(pyarrow_table.column_names) == {"a", "b", "c"} + + +def test_execute_stream_partitioned(df): + streams = df.execute_stream_partitioned() + assert all(batch is not None for stream in streams for batch in stream) + assert all( + not list(stream) for stream in streams + ) # after one iteration all generators must be exhausted + + def test_empty_to_arrow_table(df): # Convert empty datafusion dataframe to pyarrow Table pyarrow_table = df.limit(0).to_arrow_table() diff --git a/src/dataframe.rs b/src/dataframe.rs index 1e879099..92a25188 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -15,21 +15,27 @@ // specific language governing permissions and limitations // under the License. -use crate::physical_plan::PyExecutionPlan; -use crate::sql::logical::PyLogicalPlan; -use crate::utils::wait_for_future; -use crate::{errors::DataFusionError, expr::PyExpr}; +use std::sync::Arc; + use datafusion::arrow::datatypes::Schema; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; use datafusion::arrow::util::pretty; use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; +use datafusion::execution::SendableRecordBatchStream; use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; use datafusion::parquet::file::properties::WriterProperties; use datafusion::prelude::*; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::PyTuple; -use std::sync::Arc; +use tokio::task::JoinHandle; + +use crate::errors::py_datafusion_err; +use crate::physical_plan::PyExecutionPlan; +use crate::record_batch::PyRecordBatchStream; +use crate::sql::logical::PyLogicalPlan; +use crate::utils::{get_tokio_runtime, wait_for_future}; +use crate::{errors::DataFusionError, expr::PyExpr}; /// A PyDataFrame is a representation of a logical plan and an API to compose statements. /// Use it to build a plan and `.collect()` to execute the plan and collect the result. @@ -399,6 +405,35 @@ impl PyDataFrame { }) } + fn execute_stream(&self, py: Python) -> PyResult { + // create a Tokio runtime to run the async code + let rt = &get_tokio_runtime(py).0; + let df = self.df.as_ref().clone(); + let fut: JoinHandle> = + rt.spawn(async move { df.execute_stream().await }); + let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?; + Ok(PyRecordBatchStream::new(stream?)) + } + + fn execute_stream_partitioned(&self, py: Python) -> PyResult> { + // create a Tokio runtime to run the async code + let rt = &get_tokio_runtime(py).0; + let df = self.df.as_ref().clone(); + let fut: JoinHandle>> = + rt.spawn(async move { df.execute_stream_partitioned().await }); + let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?; + + match stream { + Ok(batches) => Ok(batches + .into_iter() + .map(|batch_stream| PyRecordBatchStream::new(batch_stream)) + .collect()), + _ => Err(PyValueError::new_err( + "Unable to execute stream partitioned", + )), + } + } + /// Convert to pandas dataframe with pyarrow /// Collect the batches, pass to Arrow Table & then convert to Pandas DataFrame fn to_pandas(&self, py: Python) -> PyResult { diff --git a/src/record_batch.rs b/src/record_batch.rs index aa3a392d..427807f2 100644 --- a/src/record_batch.rs +++ b/src/record_batch.rs @@ -20,6 +20,7 @@ use datafusion::arrow::pyarrow::ToPyArrow; use datafusion::arrow::record_batch::RecordBatch; use datafusion::physical_plan::SendableRecordBatchStream; use futures::StreamExt; +use pyo3::prelude::*; use pyo3::{pyclass, pymethods, PyObject, PyResult, Python}; #[pyclass(name = "RecordBatch", module = "datafusion", subclass)] @@ -61,4 +62,12 @@ impl PyRecordBatchStream { Some(Err(e)) => Err(e.into()), } } + + fn __next__(&mut self, py: Python) -> PyResult> { + self.next(py) + } + + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } }