diff --git a/datafusion/tests/test_sql.py b/datafusion/tests/test_sql.py index 9d4c8f67..8ec2ffb1 100644 --- a/datafusion/tests/test_sql.py +++ b/datafusion/tests/test_sql.py @@ -21,8 +21,9 @@ import pyarrow as pa import pyarrow.dataset as ds import pytest +from datafusion.object_store import LocalFileSystem -from datafusion import udf +from datafusion import udf, col from . import generic as helpers @@ -374,3 +375,58 @@ def test_simple_select(ctx, tmp_path, arr): result = batches[0].column(0) np.testing.assert_equal(result, arr) + + +@pytest.mark.parametrize("file_sort_order", (None, [[col("int").sort(True, True)]])) +@pytest.mark.parametrize("pass_schema", (True, False)) +def test_register_listing_table(ctx, tmp_path, pass_schema, file_sort_order): + dir_root = tmp_path / "dataset_parquet_partitioned" + dir_root.mkdir(exist_ok=False) + (dir_root / "grp=a/date_id=20201005").mkdir(exist_ok=False, parents=True) + (dir_root / "grp=a/date_id=20211005").mkdir(exist_ok=False, parents=True) + (dir_root / "grp=b/date_id=20201005").mkdir(exist_ok=False, parents=True) + + table = pa.Table.from_arrays( + [ + [1, 2, 3, 4, 5, 6, 7], + ["a", "b", "c", "d", "e", "f", "g"], + [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7], + ], + names=["int", "str", "float"], + ) + pa.parquet.write_table( + table.slice(0, 3), dir_root / "grp=a/date_id=20201005/file.parquet" + ) + pa.parquet.write_table( + table.slice(3, 2), dir_root / "grp=a/date_id=20211005/file.parquet" + ) + pa.parquet.write_table( + table.slice(5, 10), dir_root / "grp=b/date_id=20201005/file.parquet" + ) + + ctx.register_object_store("file://local", LocalFileSystem(), None) + ctx.register_listing_table( + "my_table", + f"file://{dir_root}/", + table_partition_cols=[("grp", "string"), ("date_id", "int")], + file_extension=".parquet", + schema=table.schema if pass_schema else None, + file_sort_order=file_sort_order, + ) + assert ctx.tables() == {"my_table"} + + result = ctx.sql( + "SELECT grp, COUNT(*) AS count FROM my_table GROUP BY grp" + ).collect() + result = pa.Table.from_batches(result) + + rd = result.to_pydict() + assert dict(zip(rd["grp"], rd["count"])) == {"a": 5, "b": 2} + + result = ctx.sql( + "SELECT grp, COUNT(*) AS count FROM my_table WHERE date_id=20201005 GROUP BY grp" + ).collect() + result = pa.Table.from_batches(result) + + rd = result.to_pydict() + assert dict(zip(rd["grp"], rd["count"])) == {"a": 3, "b": 2} diff --git a/docs/source/index.rst b/docs/source/index.rst index 155b0cf9..6cec3a15 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -75,6 +75,7 @@ Example Github and Issue Tracker Rust's API Docs Code of conduct + Examples .. _toc.guide: .. toctree:: @@ -84,6 +85,7 @@ Example user-guide/introduction user-guide/basics + user-guide/configuration user-guide/common-operations/index user-guide/io/index user-guide/sql diff --git a/docs/source/user-guide/configuration.rst b/docs/source/user-guide/configuration.rst new file mode 100644 index 00000000..63b0dc3e --- /dev/null +++ b/docs/source/user-guide/configuration.rst @@ -0,0 +1,51 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +Configuration +======== + +Let's look at how we can configure DataFusion. When creating a :code:`SessionContext`, you can pass in +a :code:`SessionConfig` and :code:`RuntimeConfig` object. These two cover a wide range of options. + +.. code-block:: python + + from datafusion import RuntimeConfig, SessionConfig, SessionContext + + # create a session context with default settings + ctx = SessionContext() + print(ctx) + + # create a session context with explicit runtime and config settings + runtime = RuntimeConfig().with_disk_manager_os().with_fair_spill_pool(10000000) + config = ( + SessionConfig() + .with_create_default_catalog_and_schema(True) + .with_default_catalog_and_schema("foo", "bar") + .with_target_partitions(8) + .with_information_schema(True) + .with_repartition_joins(False) + .with_repartition_aggregations(False) + .with_repartition_windows(False) + .with_parquet_pruning(False) + .set("datafusion.execution.parquet.pushdown_filters", "true") + ) + ctx = SessionContext(config, runtime) + print(ctx) + + +You can read more about available :code:`SessionConfig` options `here `_, +and about :code:`RuntimeConfig` options `here https://docs.rs/datafusion/latest/datafusion/execution/runtime_env/struct.RuntimeConfig.html`_. diff --git a/src/context.rs b/src/context.rs index f34fbce8..da5f60cc 100644 --- a/src/context.rs +++ b/src/context.rs @@ -39,10 +39,14 @@ use crate::store::StorageContexts; use crate::udaf::PyAggregateUDF; use crate::udf::PyScalarUDF; use crate::utils::{get_tokio_runtime, wait_for_future}; -use datafusion::arrow::datatypes::{DataType, Schema}; +use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, +}; use datafusion::datasource::MemTable; use datafusion::datasource::TableProvider; use datafusion::execution::context::{SessionConfig, SessionContext, SessionState, TaskContext}; @@ -244,7 +248,7 @@ impl PySessionContext { }) } - /// Register a an object store with the given name + /// Register an object store with the given name pub fn register_object_store( &mut self, scheme: &str, @@ -278,6 +282,53 @@ impl PySessionContext { Ok(()) } + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = (name, path, table_partition_cols=vec![], + file_extension=".parquet", + schema=None, + file_sort_order=None))] + pub fn register_listing_table( + &mut self, + name: &str, + path: &str, + table_partition_cols: Vec<(String, String)>, + file_extension: &str, + schema: Option>, + file_sort_order: Option>>, + py: Python, + ) -> PyResult<()> { + let options = ListingOptions::new(Arc::new(ParquetFormat::new())) + .with_file_extension(file_extension) + .with_table_partition_cols(convert_table_partition_cols(table_partition_cols)?) + .with_file_sort_order( + file_sort_order + .unwrap_or_default() + .into_iter() + .map(|e| e.into_iter().map(|f| f.into()).collect()) + .collect(), + ); + let table_path = ListingTableUrl::parse(path)?; + let resolved_schema: SchemaRef = match schema { + Some(s) => Arc::new(s.0), + None => { + let state = self.ctx.state(); + let schema = options.infer_schema(&state, &table_path); + wait_for_future(py, schema).map_err(DataFusionError::from)? + } + }; + let config = ListingTableConfig::new(table_path) + .with_listing_options(options) + .with_schema(resolved_schema); + let table = ListingTable::try_new(config)?; + self.register_table( + name, + &PyTable { + table: Arc::new(table), + }, + )?; + Ok(()) + } + /// Returns a PyDataFrame whose plan corresponds to the SQL statement. pub fn sql(&mut self, query: &str, py: Python) -> PyResult { let result = self.ctx.sql(query); @@ -849,8 +900,9 @@ pub fn convert_table_partition_cols( .into_iter() .map(|(name, ty)| match ty.as_str() { "string" => Ok((name, DataType::Utf8)), + "int" => Ok((name, DataType::Int32)), _ => Err(DataFusionError::Common(format!( - "Unsupported data type '{ty}' for partition column" + "Unsupported data type '{ty}' for partition column. Supported types are 'string' and 'int'" ))), }) .collect::, _>>() diff --git a/src/dataframe.rs b/src/dataframe.rs index 92a25188..a239a35f 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -424,10 +424,7 @@ impl PyDataFrame { let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?; match stream { - Ok(batches) => Ok(batches - .into_iter() - .map(|batch_stream| PyRecordBatchStream::new(batch_stream)) - .collect()), + Ok(batches) => Ok(batches.into_iter().map(PyRecordBatchStream::new).collect()), _ => Err(PyValueError::new_err( "Unable to execute stream partitioned", )),