From e24dc75f2fe60efb5bc888fd70d2aede80027c25 Mon Sep 17 00:00:00 2001 From: Daniel Mesejo Date: Tue, 22 Aug 2023 15:45:06 +0200 Subject: [PATCH] feat: add register_avro and read_table (#461) --- datafusion/tests/test_context.py | 10 ++++++++ datafusion/tests/test_sql.py | 26 ++++++++++++++++++++ src/context.rs | 41 ++++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+) diff --git a/datafusion/tests/test_context.py b/datafusion/tests/test_context.py index 55a324ae..97bff9bb 100644 --- a/datafusion/tests/test_context.py +++ b/datafusion/tests/test_context.py @@ -214,6 +214,16 @@ def test_register_table(ctx, database): assert public.names() == {"csv", "csv1", "csv2", "csv3"} +def test_read_table(ctx, database): + default = ctx.catalog() + public = default.database("public") + assert public.names() == {"csv", "csv1", "csv2"} + + table = public.table("csv") + table_df = ctx.read_table(table) + table_df.show() + + def test_deregister_table(ctx, database): default = ctx.catalog() public = default.database("public") diff --git a/datafusion/tests/test_sql.py b/datafusion/tests/test_sql.py index 9d42a1f9..19a2ad2c 100644 --- a/datafusion/tests/test_sql.py +++ b/datafusion/tests/test_sql.py @@ -205,6 +205,32 @@ def test_register_json(ctx, tmp_path): ctx.register_json("json4", gzip_path, file_compression_type="rar") +def test_register_avro(ctx): + path = "testing/data/avro/alltypes_plain.avro" + ctx.register_avro("alltypes_plain", path) + result = ctx.sql( + "SELECT SUM(tinyint_col) as tinyint_sum FROM alltypes_plain" + ).collect() + result = pa.Table.from_batches(result).to_pydict() + assert result["tinyint_sum"][0] > 0 + + alternative_schema = pa.schema( + [ + pa.field("id", pa.int64()), + ] + ) + + ctx.register_avro( + "alltypes_plain_schema", + path, + schema=alternative_schema, + infinite=False, + ) + result = ctx.sql("SELECT * FROM alltypes_plain_schema").collect() + result = pa.Table.from_batches(result) + assert result.schema == alternative_schema + + def test_execute(ctx, tmp_path): data = [1, 1, 2, 2, 3, 11, 12] diff --git a/src/context.rs b/src/context.rs index 317ab785..c7f89f2e 100644 --- a/src/context.rs +++ b/src/context.rs @@ -545,6 +545,39 @@ impl PySessionContext { Ok(()) } + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = (name, + path, + schema=None, + file_extension=".avro", + table_partition_cols=vec![], + infinite=false))] + fn register_avro( + &mut self, + name: &str, + path: PathBuf, + schema: Option>, + file_extension: &str, + table_partition_cols: Vec<(String, String)>, + infinite: bool, + py: Python, + ) -> PyResult<()> { + let path = path + .to_str() + .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?; + + let mut options = AvroReadOptions::default() + .table_partition_cols(convert_table_partition_cols(table_partition_cols)?) + .mark_infinite(infinite); + options.file_extension = file_extension; + options.schema = schema.as_ref().map(|x| &x.0); + + let result = self.ctx.register_avro(name, path, options); + wait_for_future(py, result).map_err(DataFusionError::from)?; + + Ok(()) + } + // Registers a PyArrow.Dataset fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) -> PyResult<()> { let table: Arc = Arc::new(Dataset::new(dataset, py)?); @@ -734,6 +767,14 @@ impl PySessionContext { Ok(PyDataFrame::new(df)) } + fn read_table(&self, table: &PyTable) -> PyResult { + let df = self + .ctx + .read_table(table.table()) + .map_err(DataFusionError::from)?; + Ok(PyDataFrame::new(df)) + } + fn __repr__(&self) -> PyResult { let config = self.ctx.copied_config(); let mut config_entries = config