Skip to content

Commit

Permalink
Optionally downcast geometry in csv reader
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron committed Dec 11, 2024
1 parent 53cbf2c commit f72c3a0
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 16 deletions.
64 changes: 58 additions & 6 deletions python/geoarrow-io/python/geoarrow/rust/io/_csv.pyi
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from pathlib import Path
from typing import BinaryIO
from typing import BinaryIO, Literal, overload

from arro3.core import Table
from arro3.core import RecordBatchReader, Table
from arro3.core.types import ArrowStreamExportable
from geoarrow.rust.core.enums import CoordType
from geoarrow.rust.core.types import CoordTypeT

@overload
def read_csv(
file: str | Path | BinaryIO,
*,
geometry_name: str | None = None,
downcast_geometry: Literal[True] = True,
batch_size: int = 65536,
coord_type: CoordType | CoordTypeT | None = None,
has_header: bool = True,
Expand All @@ -19,15 +21,64 @@ def read_csv(
quote: str | None = None,
terminator: str | None = None,
comment: str | None = None,
) -> Table:
) -> Table: ...
@overload
def read_csv(
file: str | Path | BinaryIO,
*,
geometry_name: str | None = None,
downcast_geometry: Literal[False],
batch_size: int = 65536,
coord_type: CoordType | CoordTypeT | None = None,
has_header: bool = True,
max_records: int | None = None,
delimiter: str | None = None,
escape: str | None = None,
quote: str | None = None,
terminator: str | None = None,
comment: str | None = None,
) -> RecordBatchReader: ...
def read_csv(
file: str | Path | BinaryIO,
*,
geometry_name: str | None = None,
downcast_geometry: bool = True,
batch_size: int = 65536,
coord_type: CoordType | CoordTypeT | None = None,
has_header: bool = True,
max_records: int | None = None,
delimiter: str | None = None,
escape: str | None = None,
quote: str | None = None,
terminator: str | None = None,
comment: str | None = None,
) -> RecordBatchReader | Table:
'''
Read a CSV file with a WKT-encoded geometry column.
Example:
```
csv_text = """
address,type,datetime,report location,incident number
904 7th Av,Car Fire,05/22/2019 12:55:00 PM,POINT (-122.329051 47.6069),F190051945
9610 53rd Av S,Aid Response,05/22/2019 12:55:00 PM,POINT (-122.266529 47.515984),F190051946"
"""
Read a CSV file from a path on disk into a Table.
reader = read_csv(BytesIO(csv_text.encode()), geometry_name="report location")
reader.read_all()
```
Args:
file: the path to the file or a Python file object in binary read mode.
Other args:
geometry_name: the name of the geometry column within the CSV.
downcast_geometry: Whether to simplify the type of the geometry column. When `downcast_geometry` is `False`, the GeoArrow geometry column is of type "Geometry", which is fully generic. When `downcast_geometry` is `True`, the GeoArrow geometry column will be simplified to its most basic representation. That is, if the table only includes points, the GeoArrow geometry column will be converted to a Point-type array.
Downcasting is only possible when all chunks have been loaded into memory.
Use `downcast_geometry=False` if you would like to iterate over batches of
the table, without loading all of them into memory at once.
batch_size: the number of rows to include in each internal batch of the table.
coord_type: The coordinate type. Defaults to None.
has_header: Set whether the CSV file has a header. Defaults to `True`.
Expand All @@ -41,8 +92,9 @@ def read_csv(
comment: Set the comment character. Defaults to `None`.
Returns:
Table from CSV file.
"""
A Table if `downcast_geometry` is `True` (the default). If `downcast_geometry`
is `False`, returns a `RecordBatchReader`, enabling streaming processing.
'''

def write_csv(table: ArrowStreamExportable, file: str | Path | BinaryIO) -> None:
"""
Expand Down
29 changes: 24 additions & 5 deletions python/geoarrow-io/src/io/csv.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use crate::error::PyGeoArrowResult;
use crate::io::input::sync::{FileReader, FileWriter};
use geoarrow::algorithm::native::DowncastTable;
use geoarrow::io::csv;
use geoarrow::io::csv::{CSVReader, CSVReaderOptions};
use geoarrow::table::Table;
use pyo3::prelude::*;
use pyo3_arrow::export::Arro3RecordBatchReader;
use pyo3_arrow::export::{Arro3RecordBatchReader, Arro3Table};
use pyo3_arrow::input::AnyRecordBatch;
use pyo3_arrow::PyRecordBatchReader;
use pyo3_arrow::{PyRecordBatchReader, PyTable};
use pyo3_geoarrow::PyCoordType;

#[pyfunction]
Expand All @@ -23,11 +25,13 @@ use pyo3_geoarrow::PyCoordType;
quote=None,
terminator=None,
comment=None,
downcast_geometry=true,
),
text_signature = "(file, *, geometry_name=None, batch_size=65536, coord_type='interleaved', has_header=True,max_records=None, delimiter=None, escape=None, quote=None, terminator=None, comment=None)"
text_signature = "(file, *, geometry_name=None, batch_size=65536, coord_type='interleaved', has_header=True,max_records=None, delimiter=None, escape=None, quote=None, terminator=None, comment=None, downcast_geometry=True)"
)]
#[allow(clippy::too_many_arguments)]
pub fn read_csv(
py: Python,
file: FileReader,
geometry_name: Option<String>,
batch_size: usize,
Expand All @@ -39,7 +43,8 @@ pub fn read_csv(
quote: Option<char>,
terminator: Option<char>,
comment: Option<char>,
) -> PyGeoArrowResult<Arro3RecordBatchReader> {
downcast_geometry: bool,
) -> PyGeoArrowResult<PyObject> {
let options = CSVReaderOptions {
coord_type: coord_type.into(),
batch_size,
Expand All @@ -53,7 +58,21 @@ pub fn read_csv(
comment,
};
let reader = CSVReader::try_new(file, options)?;
Ok(PyRecordBatchReader::new(Box::new(reader)).into())

if downcast_geometry {
// Load the file into a table and then downcast
let batch_reader = geoarrow::io::RecordBatchReader::new(Box::new(reader));
let table = Table::try_from(batch_reader)?;
let table = table.downcast()?;
let (batches, schema) = table.into_inner();
Ok(Arro3Table::from(PyTable::try_new(batches, schema)?)
.into_pyobject(py)?
.unbind())
} else {
let batch_reader = PyRecordBatchReader::new(Box::new(reader));
let batch_reader = Arro3RecordBatchReader::from(batch_reader);
Ok(batch_reader.into_pyobject(py)?.unbind())
}
}

#[pyfunction]
Expand Down
27 changes: 26 additions & 1 deletion python/tests/io/test_csv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from io import BytesIO

from geoarrow.rust.core import geometry_col
from geoarrow.rust.io import read_flatgeobuf, read_csv, write_csv
from arro3.core import DataType

from tests.utils import FIXTURES_DIR

Expand All @@ -13,7 +15,7 @@ def test_read_write_csv():
# Write to csv
buf = BytesIO()
write_csv(table, buf)
print(buf.getvalue().decode())
# print(buf.getvalue().decode())

# Read back from CSV
buf.seek(0)
Expand All @@ -23,3 +25,26 @@ def test_read_write_csv():
# assertions until we have a geometry equality operation.
assert len(table) == len(retour)
assert table.schema.names == retour.schema.names
assert geometry_col(table).type == geometry_col(retour).type


CSV_TEXT = """
address,type,datetime,report location,incident number
904 7th Av,Car Fire,05/22/2019 12:55:00 PM,POINT (-122.329051 47.6069),F190051945
9610 53rd Av S,Aid Response,05/22/2019 12:55:00 PM,POINT (-122.266529 47.515984),F190051946"
"""


def test_downcast():
table = read_csv(BytesIO(CSV_TEXT.encode()), geometry_name="report location")
assert DataType.is_fixed_size_list(table["geometry"].type)


def test_reader_no_downcast():
reader = read_csv(
BytesIO(CSV_TEXT.encode()),
geometry_name="report location",
downcast_geometry=False,
)
table = reader.read_all()
assert table.num_rows == 2
8 changes: 4 additions & 4 deletions rust/geoarrow/src/io/csv/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ impl Default for CSVReaderOptions {
}
}

/// Returns (Schema, records_read, geometry column name)
///
/// Note that the geometry column in the Schema is still left as a String.
fn infer_csv_schema(
reader: impl Read,
options: &CSVReaderOptions,
Expand All @@ -107,6 +110,7 @@ fn infer_csv_schema(
Ok((Arc::new(schema), records_read, geometry_col_name))
}

/// A CSV reader that parses a WKT-encoded geometry column
pub struct CSVReader<R> {
reader: arrow_csv::Reader<R>,
output_schema: SchemaRef,
Expand All @@ -126,10 +130,6 @@ impl<R: Read + Seek> CSVReader<R> {
/// By default, the reader will **scan the entire CSV file** to infer the data's
/// schema. If your data is large, you can limit the number of records scanned
/// with the [CSVReaderOptions].
///
/// Returns (Schema, records_read, geometry column name)
///
/// Note that the geometry column in the Schema is still left as a String.
pub fn try_new(mut reader: R, options: CSVReaderOptions) -> Result<Self> {
let (schema, _read_records, _geometry_column_name) =
infer_csv_schema(&mut reader, &options)?;
Expand Down

0 comments on commit f72c3a0

Please sign in to comment.