Skip to content

Commit

Permalink
Return RecordBatchReader in CSV read API (#931)
Browse files Browse the repository at this point in the history
For #594, for
#596.
  • Loading branch information
kylebarron authored Dec 11, 2024
1 parent b4e5511 commit 398d9d9
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 101 deletions.
78 changes: 71 additions & 7 deletions python/geoarrow-io/python/geoarrow/rust/io/_csv.pyi
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from pathlib import Path
from typing import BinaryIO
from typing import BinaryIO, Literal, overload

from arro3.core import Table
from arro3.core import RecordBatchReader, Table
from arro3.core.types import ArrowStreamExportable
from geoarrow.rust.core.enums import CoordType
from geoarrow.rust.core.types import CoordTypeT

@overload
def read_csv(
file: str | Path | BinaryIO,
*,
geometry_name: str | None = None,
downcast_geometry: Literal[True] = True,
batch_size: int = 65536,
coord_type: CoordType | CoordTypeT | None = None,
has_header: bool = True,
Expand All @@ -19,15 +21,76 @@ def read_csv(
quote: str | None = None,
terminator: str | None = None,
comment: str | None = None,
) -> Table:
) -> Table: ...
@overload
def read_csv(
file: str | Path | BinaryIO,
*,
geometry_name: str | None = None,
downcast_geometry: Literal[False],
batch_size: int = 65536,
coord_type: CoordType | CoordTypeT | None = None,
has_header: bool = True,
max_records: int | None = None,
delimiter: str | None = None,
escape: str | None = None,
quote: str | None = None,
terminator: str | None = None,
comment: str | None = None,
) -> RecordBatchReader: ...
def read_csv(
file: str | Path | BinaryIO,
*,
geometry_name: str | None = None,
downcast_geometry: bool = True,
batch_size: int = 65536,
coord_type: CoordType | CoordTypeT | None = None,
has_header: bool = True,
max_records: int | None = None,
delimiter: str | None = None,
escape: str | None = None,
quote: str | None = None,
terminator: str | None = None,
comment: str | None = None,
) -> RecordBatchReader | Table:
'''
Read a CSV file with a WKT-encoded geometry column.
Example:
```py
csv_text = """
address,type,datetime,report location,incident number
904 7th Av,Car Fire,05/22/2019 12:55:00 PM,POINT (-122.329051 47.6069),F190051945
9610 53rd Av S,Aid Response,05/22/2019 12:55:00 PM,POINT (-122.266529 47.515984),F190051946"
"""
Read a CSV file from a path on disk into a Table.
table = read_csv(BytesIO(csv_text.encode()), geometry_name="report location")
```
Or, if you'd like to stream the data, you can pass `downcast_geometry=False`:
```py
record_batch_reader = read_csv(
path_to_csv,
geometry_name="report location",
downcast_geometry=False,
batch_size=100_000,
)
for record_batch in record_batch_reader:
# Use each record batch.
```
Args:
file: the path to the file or a Python file object in binary read mode.
Other args:
geometry_name: the name of the geometry column within the CSV.
geometry_name: the name of the geometry column within the CSV. By default, will look for a column named "geometry", case insensitive.
downcast_geometry: Whether to simplify the type of the geometry column. When `downcast_geometry` is `False`, the GeoArrow geometry column is of type "Geometry", which is fully generic. When `downcast_geometry` is `True`, the GeoArrow geometry column will be simplified to its most basic representation. That is, if the table only includes points, the GeoArrow geometry column will be converted to a Point-type array.
Downcasting is only possible when all chunks have been loaded into memory.
Use `downcast_geometry=False` if you would like to iterate over batches of
the table, without loading all of them into memory at once.
batch_size: the number of rows to include in each internal batch of the table.
coord_type: The coordinate type. Defaults to None.
has_header: Set whether the CSV file has a header. Defaults to `True`.
Expand All @@ -41,8 +104,9 @@ def read_csv(
comment: Set the comment character. Defaults to `None`.
Returns:
Table from CSV file.
"""
A `Table` if `downcast_geometry` is `True` (the default). If `downcast_geometry`
is `False`, returns a `RecordBatchReader`, enabling streaming processing.
'''

def write_csv(table: ArrowStreamExportable, file: str | Path | BinaryIO) -> None:
"""
Expand Down
49 changes: 27 additions & 22 deletions python/geoarrow-io/src/io/csv.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
use std::io::{Seek, SeekFrom};

use crate::error::PyGeoArrowResult;
use crate::io::input::sync::{FileReader, FileWriter};
use arrow::array::RecordBatchReader;
use geoarrow::algorithm::native::DowncastTable;
use geoarrow::io::csv;
use geoarrow::io::csv::CSVReaderOptions;
use geoarrow::io::csv::{CSVReader, CSVReaderOptions};
use geoarrow::table::Table;
use pyo3::prelude::*;
use pyo3_arrow::export::Arro3Table;
use pyo3_arrow::export::{Arro3RecordBatchReader, Arro3Table};
use pyo3_arrow::input::AnyRecordBatch;
use pyo3_arrow::PyTable;
use pyo3_arrow::{PyRecordBatchReader, PyTable};
use pyo3_geoarrow::PyCoordType;

#[pyfunction]
Expand All @@ -26,12 +25,14 @@ use pyo3_geoarrow::PyCoordType;
quote=None,
terminator=None,
comment=None,
downcast_geometry=true,
),
text_signature = "(file, *, geometry_name=None, batch_size=65536, coord_type='interleaved', has_header=True,max_records=None, delimiter=None, escape=None, quote=None, terminator=None, comment=None)"
text_signature = "(file, *, geometry_name=None, batch_size=65536, coord_type='interleaved', has_header=True,max_records=None, delimiter=None, escape=None, quote=None, terminator=None, comment=None, downcast_geometry=True)"
)]
#[allow(clippy::too_many_arguments)]
pub fn read_csv(
mut file: FileReader,
py: Python,
file: FileReader,
geometry_name: Option<String>,
batch_size: usize,
coord_type: PyCoordType,
Expand All @@ -42,8 +43,9 @@ pub fn read_csv(
quote: Option<char>,
terminator: Option<char>,
comment: Option<char>,
) -> PyGeoArrowResult<Arro3Table> {
let mut options = CSVReaderOptions {
downcast_geometry: bool,
) -> PyGeoArrowResult<PyObject> {
let options = CSVReaderOptions {
coord_type: coord_type.into(),
batch_size,
geometry_column_name: geometry_name,
Expand All @@ -55,19 +57,22 @@ pub fn read_csv(
terminator,
comment,
};
let reader = CSVReader::try_new(file, options)?;

let pos = file.stream_position()?;
let (schema, _rows_read, geometry_col_name) = csv::infer_csv_schema(&mut file, &options)?;

// So we don't have to search for the geometry column a second time if not provided
options.geometry_column_name = Some(geometry_col_name);

file.seek(SeekFrom::Start(pos))?;

let record_batch_reader = csv::read_csv(file, schema, options)?;
let schema = record_batch_reader.schema();
let batches = record_batch_reader.collect::<std::result::Result<Vec<_>, _>>()?;
Ok(PyTable::try_new(batches, schema)?.into())
if downcast_geometry {
// Load the file into a table and then downcast
let batch_reader = geoarrow::io::RecordBatchReader::new(Box::new(reader));
let table = Table::try_from(batch_reader)?;
let table = table.downcast()?;
let (batches, schema) = table.into_inner();
Ok(Arro3Table::from(PyTable::try_new(batches, schema)?)
.into_pyobject(py)?
.unbind())
} else {
let batch_reader = PyRecordBatchReader::new(Box::new(reader));
let batch_reader = Arro3RecordBatchReader::from(batch_reader);
Ok(batch_reader.into_pyobject(py)?.unbind())
}
}

#[pyfunction]
Expand Down
27 changes: 26 additions & 1 deletion python/tests/io/test_csv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from io import BytesIO

from geoarrow.rust.core import geometry_col
from geoarrow.rust.io import read_flatgeobuf, read_csv, write_csv
from arro3.core import DataType

from tests.utils import FIXTURES_DIR

Expand All @@ -13,7 +15,7 @@ def test_read_write_csv():
# Write to csv
buf = BytesIO()
write_csv(table, buf)
print(buf.getvalue().decode())
# print(buf.getvalue().decode())

# Read back from CSV
buf.seek(0)
Expand All @@ -23,3 +25,26 @@ def test_read_write_csv():
# assertions until we have a geometry equality operation.
assert len(table) == len(retour)
assert table.schema.names == retour.schema.names
assert geometry_col(table).type == geometry_col(retour).type


CSV_TEXT = """
address,type,datetime,report location,incident number
904 7th Av,Car Fire,05/22/2019 12:55:00 PM,POINT (-122.329051 47.6069),F190051945
9610 53rd Av S,Aid Response,05/22/2019 12:55:00 PM,POINT (-122.266529 47.515984),F190051946"
"""


def test_downcast():
table = read_csv(BytesIO(CSV_TEXT.encode()), geometry_name="report location")
assert DataType.is_fixed_size_list(table["geometry"].type)


def test_reader_no_downcast():
reader = read_csv(
BytesIO(CSV_TEXT.encode()),
geometry_name="report location",
downcast_geometry=False,
)
table = reader.read_all()
assert table.num_rows == 2
25 changes: 6 additions & 19 deletions rust/geoarrow/src/io/csv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,41 +8,28 @@
//! use arrow_array::RecordBatchReader;
//!
//! use geoarrow::array::CoordType;
//! use geoarrow::io::csv::{infer_csv_schema, read_csv, CSVReaderOptions};
//! use geoarrow::io::csv::{CSVReader, CSVReaderOptions};
//! use geoarrow::table::Table;
//!
//! let s = r#"
//! address,type,datetime,report location,incident number
//! 904 7th Av,Car Fire,05/22/2019 12:55:00 PM,POINT (-122.329051 47.6069),F190051945
//! 9610 53rd Av S,Aid Response,05/22/2019 12:55:00 PM,POINT (-122.266529 47.515984),F190051946"#;
//! let mut cursor = Cursor::new(s);
//!
//! let options = CSVReaderOptions {
//! coord_type: CoordType::Separated,
//! geometry_column_name: Some("report location".to_string()),
//! has_header: Some(true),
//! ..Default::default()
//! };
//! let reader = CSVReader::try_new(Cursor::new(s), options).unwrap();
//!
//! // Note: this initial schema currently represents the CSV data _on disk_. That is, the
//! // geometry column is represented as a string. This may change in the future.
//! let (schema, _read_records, _geometry_column_name) =
//! infer_csv_schema(&mut cursor, &options).unwrap();
//! cursor.rewind().unwrap();
//!
//! // `read_csv` returns a RecordBatchReader, which enables streaming the CSV without reading
//! // all of it.
//! let record_batch_reader = read_csv(cursor, schema, options).unwrap();
//! let geospatial_schema = record_batch_reader.schema();
//! let table = Table::try_new(
//! record_batch_reader.collect::<Result<_, _>>().unwrap(),
//! geospatial_schema,
//! )
//! .unwrap();
//! // Now `reader` implements `arrow_array::RecordBatchReader`, so we can use TryFrom to convert
//! // it to a geoarrow Table
//! let table = Table::try_from(Box::new(reader) as Box<dyn arrow_array::RecordBatchReader>).unwrap();
//! ```
//!
pub use reader::{infer_csv_schema, read_csv, CSVReaderOptions};
pub use reader::{CSVReader, CSVReaderOptions};
pub use writer::write_csv;

mod reader;
Expand Down
Loading

0 comments on commit 398d9d9

Please sign in to comment.