Skip to content

Commit

Permalink
feat: add compression options (#456)
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo committed Aug 11, 2023
1 parent 37c91f4 commit e1b3740
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 9 deletions.
38 changes: 37 additions & 1 deletion datafusion/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import gzip
import os

import pyarrow as pa
Expand Down Expand Up @@ -336,11 +336,47 @@ def test_read_json(ctx):
assert result[0].column(1) == pa.array([1, 2, 3])


def test_read_json_compressed(ctx, tmp_path):
path = os.path.dirname(os.path.abspath(__file__))
test_data_path = os.path.join(path, "data_test_context", "data.json")

# File compression type
gzip_path = tmp_path / "data.json.gz"

with open(test_data_path, "rb") as csv_file:
with gzip.open(gzip_path, "wb") as gzipped_file:
gzipped_file.writelines(csv_file)

df = ctx.read_json(
gzip_path, file_extension=".gz", file_compression_type="gz"
)
result = df.collect()

assert result[0].column(0) == pa.array(["a", "b", "c"])
assert result[0].column(1) == pa.array([1, 2, 3])


def test_read_csv(ctx):
csv_df = ctx.read_csv(path="testing/data/csv/aggregate_test_100.csv")
csv_df.select(column("c1")).show()


def test_read_csv_compressed(ctx, tmp_path):
test_data_path = "testing/data/csv/aggregate_test_100.csv"

# File compression type
gzip_path = tmp_path / "aggregate_test_100.csv.gz"

with open(test_data_path, "rb") as csv_file:
with gzip.open(gzip_path, "wb") as gzipped_file:
gzipped_file.writelines(csv_file)

csv_df = ctx.read_csv(
gzip_path, file_extension=".gz", file_compression_type="gz"
)
csv_df.select(column("c1")).show()


def test_read_parquet(ctx):
csv_df = ctx.read_parquet(path="parquet/data/alltypes_plain.parquet")
csv_df.show()
Expand Down
23 changes: 21 additions & 2 deletions datafusion/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pyarrow as pa
import pyarrow.dataset as ds
import pytest
import gzip

from datafusion import udf

Expand All @@ -32,6 +33,7 @@ def test_no_table(ctx):

def test_register_csv(ctx, tmp_path):
path = tmp_path / "test.csv"
gzip_path = tmp_path / "test.csv.gz"

table = pa.Table.from_arrays(
[
Expand All @@ -43,6 +45,10 @@ def test_register_csv(ctx, tmp_path):
)
pa.csv.write_csv(table, path)

with open(path, "rb") as csv_file:
with gzip.open(gzip_path, "wb") as gzipped_file:
gzipped_file.writelines(csv_file)

ctx.register_csv("csv", path)
ctx.register_csv("csv1", str(path))
ctx.register_csv(
Expand All @@ -52,6 +58,13 @@ def test_register_csv(ctx, tmp_path):
delimiter=",",
schema_infer_max_records=10,
)
ctx.register_csv(
"csv_gzip",
gzip_path,
file_extension="gz",
file_compression_type="gzip",
)

alternative_schema = pa.schema(
[
("some_int", pa.int16()),
Expand All @@ -61,9 +74,9 @@ def test_register_csv(ctx, tmp_path):
)
ctx.register_csv("csv3", path, schema=alternative_schema)

assert ctx.tables() == {"csv", "csv1", "csv2", "csv3"}
assert ctx.tables() == {"csv", "csv1", "csv2", "csv3", "csv_gzip"}

for table in ["csv", "csv1", "csv2"]:
for table in ["csv", "csv1", "csv2", "csv_gzip"]:
result = ctx.sql(f"SELECT COUNT(int) AS cnt FROM {table}").collect()
result = pa.Table.from_batches(result)
assert result.to_pydict() == {"cnt": [4]}
Expand All @@ -77,6 +90,12 @@ def test_register_csv(ctx, tmp_path):
):
ctx.register_csv("csv4", path, delimiter="wrong")

with pytest.raises(
ValueError,
match="file_compression_type must one of: gzip, bz2, xz, zstd",
):
ctx.register_csv("csv4", path, file_compression_type="rar")


def test_register_parquet(ctx, tmp_path):
path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data())
Expand Down
31 changes: 25 additions & 6 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

use std::collections::{HashMap, HashSet};
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;

use object_store::ObjectStore;
Expand All @@ -40,6 +41,7 @@ use crate::utils::{get_tokio_runtime, wait_for_future};
use datafusion::arrow::datatypes::{DataType, Schema};
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::file_format::file_type::FileCompressionType;
use datafusion::datasource::MemTable;
use datafusion::datasource::TableProvider;
use datafusion::execution::context::{SessionConfig, SessionContext, TaskContext};
Expand Down Expand Up @@ -469,7 +471,8 @@ impl PySessionContext {
has_header=true,
delimiter=",",
schema_infer_max_records=1000,
file_extension=".csv"))]
file_extension=".csv",
file_compression_type=None))]
fn register_csv(
&mut self,
name: &str,
Expand All @@ -479,6 +482,7 @@ impl PySessionContext {
delimiter: &str,
schema_infer_max_records: usize,
file_extension: &str,
file_compression_type: Option<String>,
py: Python,
) -> PyResult<()> {
let path = path
Expand All @@ -495,7 +499,8 @@ impl PySessionContext {
.has_header(has_header)
.delimiter(delimiter[0])
.schema_infer_max_records(schema_infer_max_records)
.file_extension(file_extension);
.file_extension(file_extension)
.file_compression_type(parse_file_compression_type(file_compression_type)?);
options.schema = schema.as_ref().map(|x| &x.0);

let result = self.ctx.register_csv(name, path, options);
Expand Down Expand Up @@ -559,21 +564,23 @@ impl PySessionContext {
}

#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![]))]
#[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![], file_compression_type=None))]
fn read_json(
&mut self,
path: PathBuf,
schema: Option<PyArrowType<Schema>>,
schema_infer_max_records: usize,
file_extension: &str,
table_partition_cols: Vec<(String, String)>,
file_compression_type: Option<String>,
py: Python,
) -> PyResult<PyDataFrame> {
let path = path
.to_str()
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
let mut options = NdJsonReadOptions::default()
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
.file_compression_type(parse_file_compression_type(file_compression_type)?);
options.schema_infer_max_records = schema_infer_max_records;
options.file_extension = file_extension;
let df = if let Some(schema) = schema {
Expand All @@ -595,7 +602,8 @@ impl PySessionContext {
delimiter=",",
schema_infer_max_records=1000,
file_extension=".csv",
table_partition_cols=vec![]))]
table_partition_cols=vec![],
file_compression_type=None))]
fn read_csv(
&self,
path: PathBuf,
Expand All @@ -605,6 +613,7 @@ impl PySessionContext {
schema_infer_max_records: usize,
file_extension: &str,
table_partition_cols: Vec<(String, String)>,
file_compression_type: Option<String>,
py: Python,
) -> PyResult<PyDataFrame> {
let path = path
Expand All @@ -623,7 +632,8 @@ impl PySessionContext {
.delimiter(delimiter[0])
.schema_infer_max_records(schema_infer_max_records)
.file_extension(file_extension)
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
.file_compression_type(parse_file_compression_type(file_compression_type)?);

if let Some(py_schema) = schema {
options.schema = Some(&py_schema.0);
Expand Down Expand Up @@ -743,6 +753,15 @@ fn convert_table_partition_cols(
.collect::<Result<Vec<_>, _>>()
}

fn parse_file_compression_type(
file_compression_type: Option<String>,
) -> Result<FileCompressionType, PyErr> {
FileCompressionType::from_str(&*file_compression_type.unwrap_or("".to_string()).as_str())
.map_err(|_| {
PyValueError::new_err("file_compression_type must one of: gzip, bz2, xz, zstd")
})
}

impl From<PySessionContext> for SessionContext {
fn from(ctx: PySessionContext) -> SessionContext {
ctx.ctx
Expand Down

0 comments on commit e1b3740

Please sign in to comment.