Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add basic compression configuration to write_parquet #459

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions datafusion/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os

import pyarrow as pa
import pyarrow.parquet as pq
import pytest

from datafusion import functions as f
Expand Down Expand Up @@ -645,3 +647,68 @@ def test_describe(df):
"b": [3.0, 3.0, 5.0, 1.0, 4.0, 6.0, 5.0],
"c": [3.0, 3.0, 7.0, 1.7320508075688772, 5.0, 8.0, 8.0],
}


def test_write_parquet(df, tmp_path):
path = tmp_path

df.write_parquet(str(path))
result = pq.read_table(str(path)).to_pydict()
expected = df.to_pydict()

assert result == expected


@pytest.mark.parametrize(
"compression, compression_level",
[("gzip", 6), ("brotli", 7), ("zstd", 15)],
)
def test_write_compressed_parquet(
df, tmp_path, compression, compression_level
):
path = tmp_path

df.write_parquet(
str(path), compression=compression, compression_level=compression_level
)

# test that the actual compression scheme is the one written
for root, dirs, files in os.walk(path):
for file in files:
if file.endswith(".parquet"):
metadata = pq.ParquetFile(tmp_path / file).metadata.to_dict()
for row_group in metadata["row_groups"]:
for columns in row_group["columns"]:
assert columns["compression"].lower() == compression

result = pq.read_table(str(path)).to_pydict()
expected = df.to_pydict()

assert result == expected


@pytest.mark.parametrize(
"compression, compression_level",
[("gzip", 12), ("brotli", 15), ("zstd", 23), ("wrong", 12)],
)
def test_write_compressed_parquet_wrong_compression_level(
df, tmp_path, compression, compression_level
):
path = tmp_path

with pytest.raises(ValueError):
df.write_parquet(
str(path),
compression=compression,
compression_level=compression_level,
)


@pytest.mark.parametrize("compression", ["brotli", "zstd", "wrong"])
def test_write_compressed_parquet_missing_compression_level(
df, tmp_path, compression
):
path = tmp_path

with pytest.raises(ValueError):
df.write_parquet(str(path), compression=compression)
58 changes: 55 additions & 3 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
use datafusion::arrow::util::pretty;
use datafusion::dataframe::DataFrame;
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
use datafusion::parquet::file::properties::WriterProperties;
use datafusion::prelude::*;
use pyo3::exceptions::PyTypeError;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::PyTuple;
use std::sync::Arc;
Expand Down Expand Up @@ -308,8 +310,58 @@ impl PyDataFrame {
}

/// Write a `DataFrame` to a Parquet file.
fn write_parquet(&self, path: &str, py: Python) -> PyResult<()> {
wait_for_future(py, self.df.as_ref().clone().write_parquet(path, None))?;
#[pyo3(signature = (
path,
compression="uncompressed",
compression_level=None
))]
fn write_parquet(
&self,
path: &str,
compression: &str,
compression_level: Option<u32>,
py: Python,
) -> PyResult<()> {
fn verify_compression_level(cl: Option<u32>) -> Result<u32, PyErr> {
cl.ok_or(PyValueError::new_err("compression_level is not defined"))
}

let compression_type = match compression.to_lowercase().as_str() {
"snappy" => Compression::SNAPPY,
"gzip" => Compression::GZIP(
GzipLevel::try_new(compression_level.unwrap_or(6))
.map_err(|e| PyValueError::new_err(format!("{e}")))?,
),
"brotli" => Compression::BROTLI(
BrotliLevel::try_new(verify_compression_level(compression_level)?)
.map_err(|e| PyValueError::new_err(format!("{e}")))?,
),
"zstd" => Compression::ZSTD(
ZstdLevel::try_new(verify_compression_level(compression_level)? as i32)
.map_err(|e| PyValueError::new_err(format!("{e}")))?,
),
"lz0" => Compression::LZO,
"lz4" => Compression::LZ4,
"lz4_raw" => Compression::LZ4_RAW,
"uncompressed" => Compression::UNCOMPRESSED,
_ => {
return Err(PyValueError::new_err(format!(
"Unrecognized compression type {compression}"
)));
}
};

let writer_properties = WriterProperties::builder()
.set_compression(compression_type)
.build();

wait_for_future(
py,
self.df
.as_ref()
.clone()
.write_parquet(path, Option::from(writer_properties)),
)?;
Ok(())
}

Expand Down
Loading