From e1b37401a2d1af86ab16b899f1dda8237a0d3535 Mon Sep 17 00:00:00 2001 From: Daniel Mesejo Date: Fri, 11 Aug 2023 21:46:22 +0200 Subject: [PATCH] feat: add compression options (#456) --- datafusion/tests/test_context.py | 38 +++++++++++++++++++++++++++++++- datafusion/tests/test_sql.py | 23 +++++++++++++++++-- src/context.rs | 31 +++++++++++++++++++++----- 3 files changed, 83 insertions(+), 9 deletions(-) diff --git a/datafusion/tests/test_context.py b/datafusion/tests/test_context.py index 6b1223a1..55a324ae 100644 --- a/datafusion/tests/test_context.py +++ b/datafusion/tests/test_context.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +import gzip import os import pyarrow as pa @@ -336,11 +336,47 @@ def test_read_json(ctx): assert result[0].column(1) == pa.array([1, 2, 3]) +def test_read_json_compressed(ctx, tmp_path): + path = os.path.dirname(os.path.abspath(__file__)) + test_data_path = os.path.join(path, "data_test_context", "data.json") + + # File compression type + gzip_path = tmp_path / "data.json.gz" + + with open(test_data_path, "rb") as csv_file: + with gzip.open(gzip_path, "wb") as gzipped_file: + gzipped_file.writelines(csv_file) + + df = ctx.read_json( + gzip_path, file_extension=".gz", file_compression_type="gz" + ) + result = df.collect() + + assert result[0].column(0) == pa.array(["a", "b", "c"]) + assert result[0].column(1) == pa.array([1, 2, 3]) + + def test_read_csv(ctx): csv_df = ctx.read_csv(path="testing/data/csv/aggregate_test_100.csv") csv_df.select(column("c1")).show() +def test_read_csv_compressed(ctx, tmp_path): + test_data_path = "testing/data/csv/aggregate_test_100.csv" + + # File compression type + gzip_path = tmp_path / "aggregate_test_100.csv.gz" + + with open(test_data_path, "rb") as csv_file: + with gzip.open(gzip_path, "wb") as gzipped_file: + gzipped_file.writelines(csv_file) + + csv_df = ctx.read_csv( + gzip_path, file_extension=".gz", file_compression_type="gz" + ) + csv_df.select(column("c1")).show() + + def test_read_parquet(ctx): csv_df = ctx.read_parquet(path="parquet/data/alltypes_plain.parquet") csv_df.show() diff --git a/datafusion/tests/test_sql.py b/datafusion/tests/test_sql.py index 638a222d..608bb196 100644 --- a/datafusion/tests/test_sql.py +++ b/datafusion/tests/test_sql.py @@ -19,6 +19,7 @@ import pyarrow as pa import pyarrow.dataset as ds import pytest +import gzip from datafusion import udf @@ -32,6 +33,7 @@ def test_no_table(ctx): def test_register_csv(ctx, tmp_path): path = tmp_path / "test.csv" + gzip_path = tmp_path / "test.csv.gz" table = pa.Table.from_arrays( [ @@ -43,6 +45,10 @@ def test_register_csv(ctx, tmp_path): ) pa.csv.write_csv(table, path) + with open(path, "rb") as csv_file: + with gzip.open(gzip_path, "wb") as gzipped_file: + gzipped_file.writelines(csv_file) + ctx.register_csv("csv", path) ctx.register_csv("csv1", str(path)) ctx.register_csv( @@ -52,6 +58,13 @@ def test_register_csv(ctx, tmp_path): delimiter=",", schema_infer_max_records=10, ) + ctx.register_csv( + "csv_gzip", + gzip_path, + file_extension="gz", + file_compression_type="gzip", + ) + alternative_schema = pa.schema( [ ("some_int", pa.int16()), @@ -61,9 +74,9 @@ def test_register_csv(ctx, tmp_path): ) ctx.register_csv("csv3", path, schema=alternative_schema) - assert ctx.tables() == {"csv", "csv1", "csv2", "csv3"} + assert ctx.tables() == {"csv", "csv1", "csv2", "csv3", "csv_gzip"} - for table in ["csv", "csv1", "csv2"]: + for table in ["csv", "csv1", "csv2", "csv_gzip"]: result = ctx.sql(f"SELECT COUNT(int) AS cnt FROM {table}").collect() result = pa.Table.from_batches(result) assert result.to_pydict() == {"cnt": [4]} @@ -77,6 +90,12 @@ def test_register_csv(ctx, tmp_path): ): ctx.register_csv("csv4", path, delimiter="wrong") + with pytest.raises( + ValueError, + match="file_compression_type must one of: gzip, bz2, xz, zstd", + ): + ctx.register_csv("csv4", path, file_compression_type="rar") + def test_register_parquet(ctx, tmp_path): path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) diff --git a/src/context.rs b/src/context.rs index cf133d79..1dca8a79 100644 --- a/src/context.rs +++ b/src/context.rs @@ -17,6 +17,7 @@ use std::collections::{HashMap, HashSet}; use std::path::PathBuf; +use std::str::FromStr; use std::sync::Arc; use object_store::ObjectStore; @@ -40,6 +41,7 @@ use crate::utils::{get_tokio_runtime, wait_for_future}; use datafusion::arrow::datatypes::{DataType, Schema}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; +use datafusion::datasource::file_format::file_type::FileCompressionType; use datafusion::datasource::MemTable; use datafusion::datasource::TableProvider; use datafusion::execution::context::{SessionConfig, SessionContext, TaskContext}; @@ -469,7 +471,8 @@ impl PySessionContext { has_header=true, delimiter=",", schema_infer_max_records=1000, - file_extension=".csv"))] + file_extension=".csv", + file_compression_type=None))] fn register_csv( &mut self, name: &str, @@ -479,6 +482,7 @@ impl PySessionContext { delimiter: &str, schema_infer_max_records: usize, file_extension: &str, + file_compression_type: Option, py: Python, ) -> PyResult<()> { let path = path @@ -495,7 +499,8 @@ impl PySessionContext { .has_header(has_header) .delimiter(delimiter[0]) .schema_infer_max_records(schema_infer_max_records) - .file_extension(file_extension); + .file_extension(file_extension) + .file_compression_type(parse_file_compression_type(file_compression_type)?); options.schema = schema.as_ref().map(|x| &x.0); let result = self.ctx.register_csv(name, path, options); @@ -559,7 +564,7 @@ impl PySessionContext { } #[allow(clippy::too_many_arguments)] - #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![]))] + #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![], file_compression_type=None))] fn read_json( &mut self, path: PathBuf, @@ -567,13 +572,15 @@ impl PySessionContext { schema_infer_max_records: usize, file_extension: &str, table_partition_cols: Vec<(String, String)>, + file_compression_type: Option, py: Python, ) -> PyResult { let path = path .to_str() .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?; let mut options = NdJsonReadOptions::default() - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?); + .table_partition_cols(convert_table_partition_cols(table_partition_cols)?) + .file_compression_type(parse_file_compression_type(file_compression_type)?); options.schema_infer_max_records = schema_infer_max_records; options.file_extension = file_extension; let df = if let Some(schema) = schema { @@ -595,7 +602,8 @@ impl PySessionContext { delimiter=",", schema_infer_max_records=1000, file_extension=".csv", - table_partition_cols=vec![]))] + table_partition_cols=vec![], + file_compression_type=None))] fn read_csv( &self, path: PathBuf, @@ -605,6 +613,7 @@ impl PySessionContext { schema_infer_max_records: usize, file_extension: &str, table_partition_cols: Vec<(String, String)>, + file_compression_type: Option, py: Python, ) -> PyResult { let path = path @@ -623,7 +632,8 @@ impl PySessionContext { .delimiter(delimiter[0]) .schema_infer_max_records(schema_infer_max_records) .file_extension(file_extension) - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?); + .table_partition_cols(convert_table_partition_cols(table_partition_cols)?) + .file_compression_type(parse_file_compression_type(file_compression_type)?); if let Some(py_schema) = schema { options.schema = Some(&py_schema.0); @@ -743,6 +753,15 @@ fn convert_table_partition_cols( .collect::, _>>() } +fn parse_file_compression_type( + file_compression_type: Option, +) -> Result { + FileCompressionType::from_str(&*file_compression_type.unwrap_or("".to_string()).as_str()) + .map_err(|_| { + PyValueError::new_err("file_compression_type must one of: gzip, bz2, xz, zstd") + }) +} + impl From for SessionContext { fn from(ctx: PySessionContext) -> SessionContext { ctx.ctx