Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement python record batch reader #637

Merged
merged 7 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions js/src/io/geojson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ pub fn read_geojson(file: &[u8], batch_size: Option<usize>) -> WasmResult<Table>
#[wasm_bindgen(js_name = writeGeoJSON)]
pub fn write_geojson(table: Table) -> WasmResult<Vec<u8>> {
let (schema, batches) = table.into_inner();
let mut rust_table = geoarrow::table::Table::try_new(schema, batches)?;
let rust_table = geoarrow::table::Table::try_new(schema, batches)?;
let mut output_file: Vec<u8> = vec![];
_write_geojson(&mut rust_table, &mut output_file)?;
_write_geojson(rust_table, &mut output_file)?;
Ok(output_file)
}
1 change: 1 addition & 0 deletions python/core/src/ffi/from_python/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod chunked;
pub mod ffi_stream;
pub mod input;
pub mod record_batch;
pub mod record_batch_reader;
pub mod scalar;
pub mod schema;
pub mod table;
Expand Down
15 changes: 15 additions & 0 deletions python/core/src/ffi/from_python/record_batch_reader.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use crate::ffi::from_python::utils::import_arrow_c_stream;
use crate::stream::PyRecordBatchReader;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::{PyAny, PyResult};

impl<'a> FromPyObject<'a> for PyRecordBatchReader {
fn extract(ob: &'a PyAny) -> PyResult<Self> {
let stream = import_arrow_c_stream(ob)?;
let stream_reader = arrow::ffi_stream::ArrowArrayStreamReader::try_new(stream)
.map_err(|err| PyValueError::new_err(err.to_string()))?;

Ok(Self(Some(Box::new(stream_reader))))
}
}
1 change: 1 addition & 0 deletions python/core/src/ffi/to_python/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod array;
pub mod chunked;
pub mod ffi_stream;
pub mod record_batch_reader;
pub mod scalar;
pub mod table;

Expand Down
35 changes: 35 additions & 0 deletions python/core/src/ffi/to_python/record_batch_reader.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use crate::error::PyGeoArrowResult;
use crate::stream::PyRecordBatchReader;
use arrow::ffi_stream::FFI_ArrowArrayStream;

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyCapsule;
use std::ffi::CString;

#[pymethods]
impl PyRecordBatchReader {
/// An implementation of the [Arrow PyCapsule
/// Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html).
/// This dunder method should not be called directly, but enables zero-copy
/// data transfer to other Python libraries that understand Arrow memory.
///
/// For example, you can call [`pyarrow.table()`][pyarrow.table] to convert this array
/// into a pyarrow table, without copying memory.
fn __arrow_c_stream__(
&mut self,
_requested_schema: Option<PyObject>,
) -> PyGeoArrowResult<PyObject> {
let reader = self.0.take().ok_or(PyValueError::new_err(
"Cannot read from closed RecordBatchReader",
))?;

let ffi_stream = FFI_ArrowArrayStream::new(reader);
let stream_capsule_name = CString::new("arrow_array_stream").unwrap();

Python::with_gil(|py| {
let stream_capsule = PyCapsule::new(py, ffi_stream, Some(stream_capsule_name))?;
Ok(stream_capsule.to_object(py))
})
}
}
5 changes: 3 additions & 2 deletions python/core/src/io/csv.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::error::PyGeoArrowResult;
use crate::io::input::sync::{BinaryFileReader, BinaryFileWriter};
use crate::stream::PyRecordBatchReader;
use crate::table::GeoTable;
use geoarrow::io::csv::read_csv as _read_csv;
use geoarrow::io::csv::write_csv as _write_csv;
Expand Down Expand Up @@ -39,8 +40,8 @@ pub fn read_csv(
/// None
#[pyfunction]
#[pyo3(signature = (table, file))]
pub fn write_csv(py: Python, mut table: GeoTable, file: PyObject) -> PyGeoArrowResult<()> {
pub fn write_csv(py: Python, table: PyRecordBatchReader, file: PyObject) -> PyGeoArrowResult<()> {
let writer = file.extract::<BinaryFileWriter>(py)?;
_write_csv(&mut table.0, writer)?;
_write_csv(table.into_reader()?, writer)?;
Ok(())
}
10 changes: 8 additions & 2 deletions python/core/src/io/flatgeobuf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::error::{PyGeoArrowError, PyGeoArrowResult};
use crate::io::input::sync::BinaryFileWriter;
use crate::io::input::{construct_reader, FileReader};
use crate::io::object_store::PyObjectStore;
use crate::stream::PyRecordBatchReader;
use crate::table::GeoTable;
use flatgeobuf::FgbWriterOptions;
use geoarrow::io::flatgeobuf::read_flatgeobuf_async as _read_flatgeobuf_async;
Expand Down Expand Up @@ -184,7 +185,7 @@ pub fn read_flatgeobuf_async(
#[pyo3(signature = (table, file, *, write_index=true))]
pub fn write_flatgeobuf(
py: Python,
mut table: GeoTable,
table: PyRecordBatchReader,
file: PyObject,
write_index: bool,
) -> PyGeoArrowResult<()> {
Expand All @@ -195,6 +196,11 @@ pub fn write_flatgeobuf(
write_index,
..Default::default()
};
_write_flatgeobuf(&mut table.0, writer, name.as_deref().unwrap_or(""), options)?;
_write_flatgeobuf(
table.into_reader()?,
writer,
name.as_deref().unwrap_or(""),
options,
)?;
Ok(())
}
9 changes: 7 additions & 2 deletions python/core/src/io/geojson.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::error::PyGeoArrowResult;
use crate::io::input::sync::{BinaryFileReader, BinaryFileWriter};
use crate::stream::PyRecordBatchReader;
use crate::table::GeoTable;
use geoarrow::io::geojson::read_geojson as _read_geojson;
use geoarrow::io::geojson::write_geojson as _write_geojson;
Expand Down Expand Up @@ -33,8 +34,12 @@ pub fn read_geojson(py: Python, file: PyObject, batch_size: usize) -> PyGeoArrow
/// Returns:
/// None
#[pyfunction]
pub fn write_geojson(py: Python, mut table: GeoTable, file: PyObject) -> PyGeoArrowResult<()> {
pub fn write_geojson(
py: Python,
table: PyRecordBatchReader,
file: PyObject,
) -> PyGeoArrowResult<()> {
let writer = file.extract::<BinaryFileWriter>(py)?;
_write_geojson(&mut table.0, writer)?;
_write_geojson(table.into_reader()?, writer)?;
Ok(())
}
5 changes: 3 additions & 2 deletions python/core/src/io/geojson_lines.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::error::PyGeoArrowResult;
use crate::io::input::sync::{BinaryFileReader, BinaryFileWriter};
use crate::stream::PyRecordBatchReader;
use crate::table::GeoTable;
use geoarrow::io::geojson_lines::read_geojson_lines as _read_geojson_lines;
use geoarrow::io::geojson_lines::write_geojson_lines as _write_geojson_lines;
Expand Down Expand Up @@ -41,10 +42,10 @@ pub fn read_geojson_lines(
#[pyfunction]
pub fn write_geojson_lines(
py: Python,
mut table: GeoTable,
table: PyRecordBatchReader,
file: PyObject,
) -> PyGeoArrowResult<()> {
let writer = file.extract::<BinaryFileWriter>(py)?;
_write_geojson_lines(&mut table.0, writer)?;
_write_geojson_lines(table.into_reader()?, writer)?;
Ok(())
}
13 changes: 9 additions & 4 deletions python/core/src/io/ipc.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::error::PyGeoArrowResult;
use crate::io::input::sync::{BinaryFileReader, BinaryFileWriter};
use crate::stream::PyRecordBatchReader;
use crate::table::GeoTable;
use geoarrow::io::ipc::read_ipc as _read_ipc;
use geoarrow::io::ipc::read_ipc_stream as _read_ipc_stream;
Expand Down Expand Up @@ -46,9 +47,9 @@ pub fn read_ipc_stream(py: Python, file: PyObject) -> PyGeoArrowResult<GeoTable>
/// Returns:
/// None
#[pyfunction]
pub fn write_ipc(py: Python, mut table: GeoTable, file: PyObject) -> PyGeoArrowResult<()> {
pub fn write_ipc(py: Python, table: PyRecordBatchReader, file: PyObject) -> PyGeoArrowResult<()> {
let writer = file.extract::<BinaryFileWriter>(py)?;
_write_ipc(&mut table.0, writer)?;
_write_ipc(table.into_reader()?, writer)?;
Ok(())
}

Expand All @@ -61,8 +62,12 @@ pub fn write_ipc(py: Python, mut table: GeoTable, file: PyObject) -> PyGeoArrowR
/// Returns:
/// None
#[pyfunction]
pub fn write_ipc_stream(py: Python, mut table: GeoTable, file: PyObject) -> PyGeoArrowResult<()> {
pub fn write_ipc_stream(
py: Python,
table: PyRecordBatchReader,
file: PyObject,
) -> PyGeoArrowResult<()> {
let writer = file.extract::<BinaryFileWriter>(py)?;
_write_ipc_stream(&mut table.0, writer)?;
_write_ipc_stream(table.into_reader()?, writer)?;
Ok(())
}
4 changes: 4 additions & 0 deletions python/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub mod io;
pub mod record_batch;
pub mod scalar;
pub mod schema;
pub mod stream;
pub mod table;

const VERSION: &str = env!("CARGO_PKG_VERSION");
Expand Down Expand Up @@ -92,6 +93,9 @@ fn _rust(_py: Python, m: &PyModule) -> PyResult<()> {
// m.add_class::<chunked_array::ChunkedUInt64Array>()?;
// m.add_class::<chunked_array::ChunkedUInt8Array>()?;

// RecordBatchReader
m.add_class::<stream::PyRecordBatchReader>()?;

// Table
m.add_class::<table::GeoTable>()?;

Expand Down
22 changes: 22 additions & 0 deletions python/core/src/stream/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use arrow_array::RecordBatchReader as _RecordBatchReader;
use geoarrow::error::GeoArrowError;
use pyo3::prelude::*;

use crate::error::PyGeoArrowResult;

/// A wrapper around an [arrow_array::RecordBatchReader]
#[pyclass(
module = "geoarrow.rust.core._rust",
name = "RecordBatchReader",
subclass
)]
pub struct PyRecordBatchReader(pub(crate) Option<Box<dyn _RecordBatchReader + Send>>);

impl PyRecordBatchReader {
pub fn into_reader(mut self) -> PyGeoArrowResult<Box<dyn _RecordBatchReader + Send>> {
let stream = self.0.take().ok_or(GeoArrowError::General(
"Cannot write from closed stream.".to_string(),
))?;
Ok(stream)
}
}
10 changes: 5 additions & 5 deletions src/io/csv/writer.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use crate::error::Result;
use crate::table::Table;
use crate::io::stream::RecordBatchReader;
use geozero::csv::CsvWriter;
use geozero::GeozeroDatasource;
use std::io::Write;

/// Write a Table to CSV
pub fn write_csv<W: Write>(table: &mut Table, writer: W) -> Result<()> {
pub fn write_csv<W: Write, S: Into<RecordBatchReader>>(stream: S, writer: W) -> Result<()> {
let mut csv_writer = CsvWriter::new(writer);
table.process(&mut csv_writer)?;
stream.into().process(&mut csv_writer)?;
Ok(())
}

Expand All @@ -19,11 +19,11 @@ mod test {

#[test]
fn test_write() {
let mut table = point::table();
let table = point::table();

let mut output_buffer = Vec::new();
let writer = BufWriter::new(&mut output_buffer);
write_csv(&mut table, writer).unwrap();
write_csv(&table, writer).unwrap();
let output_string = String::from_utf8(output_buffer).unwrap();
println!("{}", output_string);
}
Expand Down
36 changes: 19 additions & 17 deletions src/io/flatgeobuf/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,41 @@ use std::io::Write;
use flatgeobuf::{FgbWriter, FgbWriterOptions};
use geozero::GeozeroDatasource;

use crate::error::GeoArrowError;
use crate::error::Result;
use crate::io::stream::RecordBatchReader;
use crate::schema::GeoSchemaExt;
use crate::table::Table;

// TODO: always write CRS saved in Table metadata (you can do this by adding an option)
/// Write a Table to a FlatGeobuf file.
pub fn write_flatgeobuf<W: Write>(
table: &mut Table,
pub fn write_flatgeobuf<W: Write, S: Into<RecordBatchReader>>(
stream: S,
writer: W,
name: &str,
) -> Result<(), GeoArrowError> {
write_flatgeobuf_with_options(table, writer, name, Default::default())
) -> Result<()> {
write_flatgeobuf_with_options(stream, writer, name, Default::default())
}

/// Write a Table to a FlatGeobuf file with specific writer options.
///
/// Note: this `name` argument is what OGR observes as the layer name of the file.
pub fn write_flatgeobuf_with_options<W: Write>(
table: &mut Table,
pub fn write_flatgeobuf_with_options<W: Write, S: Into<RecordBatchReader>>(
stream: S,
writer: W,
name: &str,
options: FgbWriterOptions,
) -> Result<(), GeoArrowError> {
) -> Result<()> {
let mut stream = stream.into();
let mut fgb =
FgbWriter::create_with_options(name, infer_flatgeobuf_geometry_type(table), options)?;
table.process(&mut fgb)?;
FgbWriter::create_with_options(name, infer_flatgeobuf_geometry_type(&stream)?, options)?;
stream.process(&mut fgb)?;
fgb.write(writer)?;
Ok(())
}

fn infer_flatgeobuf_geometry_type(table: &Table) -> flatgeobuf::GeometryType {
let fields = &table.schema().fields;
let geom_col_idxs = table.schema().as_ref().geometry_columns();
fn infer_flatgeobuf_geometry_type(stream: &RecordBatchReader) -> Result<flatgeobuf::GeometryType> {
let schema = stream.schema()?;
let fields = &schema.fields;
let geom_col_idxs = schema.as_ref().geometry_columns();
if geom_col_idxs.len() != 1 {
panic!("Only one geometry column currently supported in FlatGeobuf writer");
}
Expand All @@ -53,7 +55,7 @@ fn infer_flatgeobuf_geometry_type(table: &Table) -> flatgeobuf::GeometryType {
"geoarrow.geometrycollection" => flatgeobuf::GeometryType::GeometryCollection,
_ => todo!(),
};
geometry_type
Ok(geometry_type)
} else {
todo!()
}
Expand All @@ -68,11 +70,11 @@ mod test {

#[test]
fn test_write() {
let mut table = point::table();
let table = point::table();

let mut output_buffer = Vec::new();
let writer = BufWriter::new(&mut output_buffer);
write_flatgeobuf(&mut table, writer, "name").unwrap();
write_flatgeobuf(&table, writer, "name").unwrap();

let mut reader = Cursor::new(output_buffer);
let new_table = read_flatgeobuf(&mut reader, Default::default()).unwrap();
Expand Down
10 changes: 5 additions & 5 deletions src/io/geojson/writer.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use super::geojson_writer::GeoJsonWriter;
use crate::error::Result;
use crate::table::Table;
use crate::io::stream::RecordBatchReader;
use geozero::GeozeroDatasource;
use std::io::Write;

/// Write a Table to GeoJSON
///
/// Note: Does not reproject to WGS84 for you
pub fn write_geojson<W: Write>(table: &mut Table, writer: W) -> Result<()> {
pub fn write_geojson<W: Write, S: Into<RecordBatchReader>>(stream: S, writer: W) -> Result<()> {
let mut geojson = GeoJsonWriter::new(writer);
table.process(&mut geojson)?;
stream.into().process(&mut geojson)?;
Ok(())
}

Expand All @@ -21,11 +21,11 @@ mod test {

#[test]
fn test_write() {
let mut table = point::table();
let table = point::table();

let mut output_buffer = Vec::new();
let writer = BufWriter::new(&mut output_buffer);
write_geojson(&mut table, writer).unwrap();
write_geojson(&table, writer).unwrap();
let output_string = String::from_utf8(output_buffer).unwrap();
println!("{}", output_string);
}
Expand Down
Loading
Loading