diff --git a/datafusion/tests/test_dataframe.py b/datafusion/tests/test_dataframe.py index 78cb50f1..ce7d89e7 100644 --- a/datafusion/tests/test_dataframe.py +++ b/datafusion/tests/test_dataframe.py @@ -14,8 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import os import pyarrow as pa +import pyarrow.parquet as pq import pytest from datafusion import functions as f @@ -645,3 +647,68 @@ def test_describe(df): "b": [3.0, 3.0, 5.0, 1.0, 4.0, 6.0, 5.0], "c": [3.0, 3.0, 7.0, 1.7320508075688772, 5.0, 8.0, 8.0], } + + +def test_write_parquet(df, tmp_path): + path = tmp_path + + df.write_parquet(str(path)) + result = pq.read_table(str(path)).to_pydict() + expected = df.to_pydict() + + assert result == expected + + +@pytest.mark.parametrize( + "compression, compression_level", + [("gzip", 6), ("brotli", 7), ("zstd", 15)], +) +def test_write_compressed_parquet( + df, tmp_path, compression, compression_level +): + path = tmp_path + + df.write_parquet( + str(path), compression=compression, compression_level=compression_level + ) + + # test that the actual compression scheme is the one written + for root, dirs, files in os.walk(path): + for file in files: + if file.endswith(".parquet"): + metadata = pq.ParquetFile(tmp_path / file).metadata.to_dict() + for row_group in metadata["row_groups"]: + for columns in row_group["columns"]: + assert columns["compression"].lower() == compression + + result = pq.read_table(str(path)).to_pydict() + expected = df.to_pydict() + + assert result == expected + + +@pytest.mark.parametrize( + "compression, compression_level", + [("gzip", 12), ("brotli", 15), ("zstd", 23), ("wrong", 12)], +) +def test_write_compressed_parquet_wrong_compression_level( + df, tmp_path, compression, compression_level +): + path = tmp_path + + with pytest.raises(ValueError): + df.write_parquet( + str(path), + compression=compression, + compression_level=compression_level, + ) + + +@pytest.mark.parametrize("compression", ["brotli", "zstd", "wrong"]) +def test_write_compressed_parquet_missing_compression_level( + df, tmp_path, compression +): + path = tmp_path + + with pytest.raises(ValueError): + df.write_parquet(str(path), compression=compression) diff --git a/src/dataframe.rs b/src/dataframe.rs index b8d8ddc3..61a44484 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -23,8 +23,10 @@ use datafusion::arrow::datatypes::Schema; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; use datafusion::arrow::util::pretty; use datafusion::dataframe::DataFrame; +use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; +use datafusion::parquet::file::properties::WriterProperties; use datafusion::prelude::*; -use pyo3::exceptions::PyTypeError; +use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::PyTuple; use std::sync::Arc; @@ -308,8 +310,58 @@ impl PyDataFrame { } /// Write a `DataFrame` to a Parquet file. - fn write_parquet(&self, path: &str, py: Python) -> PyResult<()> { - wait_for_future(py, self.df.as_ref().clone().write_parquet(path, None))?; + #[pyo3(signature = ( + path, + compression="uncompressed", + compression_level=None + ))] + fn write_parquet( + &self, + path: &str, + compression: &str, + compression_level: Option, + py: Python, + ) -> PyResult<()> { + fn verify_compression_level(cl: Option) -> Result { + cl.ok_or(PyValueError::new_err("compression_level is not defined")) + } + + let compression_type = match compression.to_lowercase().as_str() { + "snappy" => Compression::SNAPPY, + "gzip" => Compression::GZIP( + GzipLevel::try_new(compression_level.unwrap_or(6)) + .map_err(|e| PyValueError::new_err(format!("{e}")))?, + ), + "brotli" => Compression::BROTLI( + BrotliLevel::try_new(verify_compression_level(compression_level)?) + .map_err(|e| PyValueError::new_err(format!("{e}")))?, + ), + "zstd" => Compression::ZSTD( + ZstdLevel::try_new(verify_compression_level(compression_level)? as i32) + .map_err(|e| PyValueError::new_err(format!("{e}")))?, + ), + "lz0" => Compression::LZO, + "lz4" => Compression::LZ4, + "lz4_raw" => Compression::LZ4_RAW, + "uncompressed" => Compression::UNCOMPRESSED, + _ => { + return Err(PyValueError::new_err(format!( + "Unrecognized compression type {compression}" + ))); + } + }; + + let writer_properties = WriterProperties::builder() + .set_compression(compression_type) + .build(); + + wait_for_future( + py, + self.df + .as_ref() + .clone() + .write_parquet(path, Option::from(writer_properties)), + )?; Ok(()) }