Skip to content

Commit

Permalink
feat: add execute_stream and execute_stream_partitioned
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo committed Mar 12, 2024
1 parent 18ac182 commit da30bd4
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 5 deletions.
30 changes: 30 additions & 0 deletions datafusion/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,36 @@ def test_to_arrow_table(df):
assert set(pyarrow_table.column_names) == {"a", "b", "c"}


def test_execute_stream(df):
stream = df.execute_stream()
assert all(batch is not None for batch in stream)
assert not list(stream) # after one iteration the generator must be exhausted


@pytest.mark.parametrize("schema", [True, False])
def test_execute_stream_to_arrow_table(df, schema):
stream = df.execute_stream()

if schema:
pyarrow_table = pa.Table.from_batches(
(batch.to_pyarrow() for batch in stream), schema=df.schema()
)
else:
pyarrow_table = pa.Table.from_batches((batch.to_pyarrow() for batch in stream))

assert isinstance(pyarrow_table, pa.Table)
assert pyarrow_table.shape == (3, 3)
assert set(pyarrow_table.column_names) == {"a", "b", "c"}


def test_execute_stream_partitioned(df):
streams = df.execute_stream_partitioned()
assert all(batch is not None for stream in streams for batch in stream)
assert all(
not list(stream) for stream in streams
) # after one iteration all generators must be exhausted


def test_empty_to_arrow_table(df):
# Convert empty datafusion dataframe to pyarrow Table
pyarrow_table = df.limit(0).to_arrow_table()
Expand Down
45 changes: 40 additions & 5 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,27 @@
// specific language governing permissions and limitations
// under the License.

use crate::physical_plan::PyExecutionPlan;
use crate::sql::logical::PyLogicalPlan;
use crate::utils::wait_for_future;
use crate::{errors::DataFusionError, expr::PyExpr};
use std::sync::Arc;

use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
use datafusion::arrow::util::pretty;
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
use datafusion::execution::SendableRecordBatchStream;
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
use datafusion::parquet::file::properties::WriterProperties;
use datafusion::prelude::*;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::PyTuple;
use std::sync::Arc;
use tokio::task::JoinHandle;

use crate::errors::py_datafusion_err;
use crate::physical_plan::PyExecutionPlan;
use crate::record_batch::PyRecordBatchStream;
use crate::sql::logical::PyLogicalPlan;
use crate::utils::{get_tokio_runtime, wait_for_future};
use crate::{errors::DataFusionError, expr::PyExpr};

/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
Expand Down Expand Up @@ -399,6 +405,35 @@ impl PyDataFrame {
})
}

fn execute_stream(&self, py: Python) -> PyResult<PyRecordBatchStream> {
// create a Tokio runtime to run the async code
let rt = &get_tokio_runtime(py).0;
let df = self.df.as_ref().clone();
let fut: JoinHandle<datafusion_common::Result<SendableRecordBatchStream>> =
rt.spawn(async move { df.execute_stream().await });
let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
Ok(PyRecordBatchStream::new(stream?))
}

fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
// create a Tokio runtime to run the async code
let rt = &get_tokio_runtime(py).0;
let df = self.df.as_ref().clone();
let fut: JoinHandle<datafusion_common::Result<Vec<SendableRecordBatchStream>>> =
rt.spawn(async move { df.execute_stream_partitioned().await });
let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;

match stream {
Ok(batches) => Ok(batches
.into_iter()
.map(|batch_stream| PyRecordBatchStream::new(batch_stream))
.collect()),
_ => Err(PyValueError::new_err(
"Unable to execute stream partitioned",
)),
}
}

/// Convert to pandas dataframe with pyarrow
/// Collect the batches, pass to Arrow Table & then convert to Pandas DataFrame
fn to_pandas(&self, py: Python) -> PyResult<PyObject> {
Expand Down
9 changes: 9 additions & 0 deletions src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use datafusion::arrow::pyarrow::ToPyArrow;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::physical_plan::SendableRecordBatchStream;
use futures::StreamExt;
use pyo3::prelude::*;
use pyo3::{pyclass, pymethods, PyObject, PyResult, Python};

#[pyclass(name = "RecordBatch", module = "datafusion", subclass)]
Expand Down Expand Up @@ -61,4 +62,12 @@ impl PyRecordBatchStream {
Some(Err(e)) => Err(e.into()),
}
}

fn __next__(&mut self, py: Python) -> PyResult<Option<PyRecordBatch>> {
self.next(py)
}

fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
}

0 comments on commit da30bd4

Please sign in to comment.