diff --git a/datafusion/tests/test_sql.py b/datafusion/tests/test_sql.py index 9d42a1f9d..19a2ad2cf 100644 --- a/datafusion/tests/test_sql.py +++ b/datafusion/tests/test_sql.py @@ -205,6 +205,32 @@ def test_register_json(ctx, tmp_path): ctx.register_json("json4", gzip_path, file_compression_type="rar") +def test_register_avro(ctx): + path = "testing/data/avro/alltypes_plain.avro" + ctx.register_avro("alltypes_plain", path) + result = ctx.sql( + "SELECT SUM(tinyint_col) as tinyint_sum FROM alltypes_plain" + ).collect() + result = pa.Table.from_batches(result).to_pydict() + assert result["tinyint_sum"][0] > 0 + + alternative_schema = pa.schema( + [ + pa.field("id", pa.int64()), + ] + ) + + ctx.register_avro( + "alltypes_plain_schema", + path, + schema=alternative_schema, + infinite=False, + ) + result = ctx.sql("SELECT * FROM alltypes_plain_schema").collect() + result = pa.Table.from_batches(result) + assert result.schema == alternative_schema + + def test_execute(ctx, tmp_path): data = [1, 1, 2, 2, 3, 11, 12] diff --git a/src/context.rs b/src/context.rs index 317ab785e..87b2b5be9 100644 --- a/src/context.rs +++ b/src/context.rs @@ -545,6 +545,39 @@ impl PySessionContext { Ok(()) } + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = (name, + path, + schema=None, + file_extension=".avro", + table_partition_cols=vec![], + infinite=false))] + fn register_avro( + &mut self, + name: &str, + path: PathBuf, + schema: Option>, + file_extension: &str, + table_partition_cols: Vec<(String, String)>, + infinite: bool, + py: Python, + ) -> PyResult<()> { + let path = path + .to_str() + .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?; + + let mut options = AvroReadOptions::default() + .table_partition_cols(convert_table_partition_cols(table_partition_cols)?) + .mark_infinite(infinite); + options.file_extension = file_extension; + options.schema = schema.as_ref().map(|x| &x.0); + + let result = self.ctx.register_avro(name, path, options); + wait_for_future(py, result).map_err(DataFusionError::from)?; + + Ok(()) + } + // Registers a PyArrow.Dataset fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) -> PyResult<()> { let table: Arc = Arc::new(Dataset::new(dataset, py)?);