Skip to content

Commit

Permalink
c
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Jun 26, 2024
1 parent 332e40a commit 4facac0
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 38 deletions.
52 changes: 50 additions & 2 deletions crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use either::Either;
use expr_expansion::{is_regex_projection, rewrite_projections};
use hive::hive_partitions_from_paths;

Expand Down Expand Up @@ -86,7 +87,7 @@ pub fn to_alp_impl(
paths,
predicate,
mut scan_type,
file_options,
mut file_options,
} => {
let mut file_info = if let Some(file_info) = file_info {
file_info
Expand Down Expand Up @@ -137,19 +138,66 @@ pub fn to_alp_impl(
let hive_parts = if hive_parts.is_some() {
hive_parts
} else if file_options.hive_options.enabled.unwrap() {
#[allow(unused_assignments)]
let mut owned = None;

hive_partitions_from_paths(
paths.as_ref(),
file_options.hive_options.hive_start_idx,
file_options.hive_options.schema.clone(),
match file_info.reader_schema.as_ref().unwrap() {
Either::Left(v) => {
owned = Some(Schema::from(v));
owned.as_ref().unwrap()
},
Either::Right(v) => v.as_ref(),
},
)?
} else {
None
};

if let Some(ref hive_parts) = hive_parts {
file_info.update_schema_with_hive_schema(hive_parts[0].schema().clone())?;
let hive_schema = hive_parts[0].schema();
file_info.update_schema_with_hive_schema(hive_schema.clone());
}

(|| {
// Update `with_columns` with a projection so that hive columns aren't loaded from the
// file
let Some(ref hive_parts) = hive_parts else {
return;
};

let hive_schema = hive_parts[0].schema();

let Some((first_hive_name, _)) = hive_schema.get_at_index(0) else {
return;
};

let names = match file_info.reader_schema.as_ref().unwrap() {
Either::Left(ref v) => {
let names = v.get_names();
names.contains(&first_hive_name.as_str()).then_some(names)
},
Either::Right(ref v) => {
v.contains(first_hive_name.as_str()).then(|| v.get_names())
},
};

let Some(names) = names else {
return;
};

file_options.with_columns = Some(
names
.iter()
.filter(|x| !hive_schema.contains(x))
.map(ToString::to_string)
.collect::<Arc<[_]>>(),
);
})();

if let Some(row_index) = &file_options.row_index {
let schema = Arc::make_mut(&mut file_info.schema);
*schema = schema
Expand Down
25 changes: 20 additions & 5 deletions crates/polars-plan/src/plans/hive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@ pub struct HivePartitions {
}

impl HivePartitions {
pub fn get_projection_schema_and_indices<T: AsRef<str>>(
pub fn get_projection_schema_and_indices(
&self,
names: &[T],
names: &PlHashSet<String>,
) -> (SchemaRef, Vec<usize>) {
let names = names.iter().map(T::as_ref).collect::<PlHashSet<&str>>();
let mut out_schema = Schema::with_capacity(self.stats.schema().len());
let mut out_indices = Vec::with_capacity(self.stats.column_stats().len());

Expand Down Expand Up @@ -66,6 +65,7 @@ pub fn hive_partitions_from_paths(
paths: &[PathBuf],
hive_start_idx: usize,
schema: Option<SchemaRef>,
reader_schema: &Schema,
) -> PolarsResult<Option<Arc<[HivePartitions]>>> {
let Some(path) = paths.first() else {
return Ok(None);
Expand All @@ -88,14 +88,29 @@ pub fn hive_partitions_from_paths(
}};
}

let hive_schema = if let Some(v) = schema {
v
let hive_schema = if let Some(ref schema) = schema {
Arc::new(get_hive_parts_iter!(path_string).map(|(name, _)| {
let Some(dtype) = schema.get(name) else {
polars_bail!(
SchemaFieldNotFound:
"path contains column not present in the given Hive schema: {:?}, path = {:?}",
name,
path
)
};
Ok(Field::new(name, dtype.clone()))
}).collect::<PolarsResult<Schema>>()?)
} else {
let mut hive_schema = Schema::with_capacity(16);
let mut schema_inference_map: PlHashMap<&str, PlHashSet<DataType>> =
PlHashMap::with_capacity(16);

for (name, _) in get_hive_parts_iter!(path_string) {
if let Some(dtype) = reader_schema.get(name) {
hive_schema.insert_at_index(hive_schema.len(), name.into(), dtype.clone())?;
continue;
}

hive_schema.insert_at_index(hive_schema.len(), name.into(), DataType::String)?;
schema_inference_map.insert(name, PlHashSet::with_capacity(4));
}
Expand Down
57 changes: 51 additions & 6 deletions crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mod rename;
#[cfg(feature = "semi_anti_join")]
mod semi_anti_join;

use either::Either;
use polars_core::datatypes::PlHashSet;
use polars_core::prelude::*;
use polars_io::RowIndex;
Expand All @@ -35,7 +36,7 @@ fn init_set() -> PlHashSet<Arc<str>> {

/// utility function to get names of the columns needed in projection at scan level
fn get_scan_columns(
acc_projections: &mut Vec<ColumnNode>,
acc_projections: &Vec<ColumnNode>,
expr_arena: &Arena<AExpr>,
row_index: Option<&RowIndex>,
) -> Option<Arc<[String]>> {
Expand Down Expand Up @@ -378,7 +379,7 @@ impl ProjectionPushDown {
mut options,
predicate,
} => {
options.with_columns = get_scan_columns(&mut acc_projections, expr_arena, None);
options.with_columns = get_scan_columns(&acc_projections, expr_arena, None);

options.output_schema = if options.with_columns.is_none() {
None
Expand Down Expand Up @@ -417,7 +418,7 @@ impl ProjectionPushDown {

if do_optimization {
file_options.with_columns = get_scan_columns(
&mut acc_projections,
&acc_projections,
expr_arena,
file_options.row_index.as_ref(),
);
Expand All @@ -432,7 +433,9 @@ impl ProjectionPushDown {

hive_parts = if let Some(hive_parts) = hive_parts {
let (new_schema, projected_indices) = hive_parts[0]
.get_projection_schema_and_indices(with_columns.as_ref());
.get_projection_schema_and_indices(
&with_columns.iter().cloned().collect::<PlHashSet<_>>(),
);

Some(
hive_parts
Expand All @@ -448,15 +451,22 @@ impl ProjectionPushDown {
.collect::<Arc<[_]>>(),
)
} else {
hive_parts
None
};

// Hive partitions are created AFTER the projection, so the output
// schema is incorrect. Here we ensure the columns that are projected and hive
// parts are added at the proper place in the schema, which is at the end.
if let Some(ref mut hive_parts) = hive_parts {
if let Some(ref hive_parts) = hive_parts {
let partition_schema = hive_parts.first().unwrap().schema();

file_options.with_columns = file_options.with_columns.map(|x| {
x.iter()
.filter(|x| !partition_schema.contains(x))
.cloned()
.collect::<Arc<[_]>>()
});

for (name, _) in partition_schema.iter() {
if let Some(dt) = schema.shift_remove(name) {
schema.with_column(name.clone(), dt);
Expand All @@ -465,6 +475,41 @@ impl ProjectionPushDown {
}
Some(Arc::new(schema))
} else {
(|| {
// Update `with_columns` with a projection so that hive columns aren't loaded from the
// file
let Some(ref hive_parts) = hive_parts else {
return;
};

let hive_schema = hive_parts[0].schema();

let Some((first_hive_name, _)) = hive_schema.get_at_index(0) else {
return;
};

let names = match file_info.reader_schema.as_ref().unwrap() {
Either::Left(ref v) => {
let names = v.get_names();
names.contains(&first_hive_name.as_str()).then_some(names)
},
Either::Right(ref v) => {
v.contains(first_hive_name.as_str()).then(|| v.get_names())
},
};

let Some(names) = names else {
return;
};

file_options.with_columns = Some(
names
.iter()
.filter(|x| !hive_schema.contains(x))
.map(ToString::to_string)
.collect::<Arc<[_]>>(),
);
})();
None
};
}
Expand Down
24 changes: 11 additions & 13 deletions crates/polars-plan/src/plans/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,18 @@ impl FileInfo {
}

/// Merge the [`Schema`] of a [`HivePartitions`] with the schema of this [`FileInfo`].
///
/// Returns an `Err` if any of the columns in either schema overlap.
pub fn update_schema_with_hive_schema(&mut self, hive_schema: SchemaRef) -> PolarsResult<()> {
let expected_len = self.schema.len() + hive_schema.len();
pub fn update_schema_with_hive_schema(&mut self, hive_schema: SchemaRef) {
let schema = Arc::make_mut(&mut self.schema);

let file_schema = Arc::make_mut(&mut self.schema);
file_schema.merge(Arc::unwrap_or_clone(hive_schema));

polars_ensure!(
file_schema.len() == expected_len,
Duplicate: "invalid Hive partition schema\n\n\
Extending the schema with the Hive partition schema would create duplicate fields."
);
Ok(())
for field in hive_schema.iter_fields() {
if let Ok(existing) = schema.try_get_mut(&field.name) {
*existing = field.data_type().clone();
} else {
schema
.insert_at_index(schema.len(), field.name, field.dtype.clone())
.unwrap();
}
}
}
}

Expand Down
62 changes: 50 additions & 12 deletions py-polars/tests/unit/io/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest

import polars as pl
from polars.exceptions import DuplicateError, SchemaFieldNotFoundError
from polars.exceptions import SchemaFieldNotFoundError
from polars.testing import assert_frame_equal, assert_series_equal


Expand Down Expand Up @@ -247,17 +247,6 @@ def test_hive_partitioned_projection_pushdown(
assert_frame_equal(result, expected)


@pytest.mark.write_disk()
def test_hive_partitioned_err(io_files_path: Path, tmp_path: Path) -> None:
df = pl.read_ipc(io_files_path / "*.ipc")
root = tmp_path / "sugars_g=10"
root.mkdir()
df.write_parquet(root / "file.parquet")

with pytest.raises(DuplicateError, match="invalid Hive partition schema"):
pl.scan_parquet(tmp_path, hive_partitioning=True).collect()


@pytest.mark.write_disk()
def test_hive_partitioned_projection_skip_files(
io_files_path: Path, tmp_path: Path
Expand Down Expand Up @@ -538,3 +527,52 @@ def test_hive_partition_force_async_17155(tmp_path: Path, monkeypatch: Any) -> N
assert_frame_equal(
lf.collect(), pl.DataFrame({k: [1, 2, 3] for k in ["x", "a", "b"]})
)


@pytest.mark.parametrize("projection_pushdown", [True, False])
def test_hive_partition_columns_contained_in_file(
tmp_path: Path, projection_pushdown: bool
) -> None:
path = tmp_path / "a=1/b=2/data.bin"
path.parent.mkdir(exist_ok=True, parents=True)
df = pl.DataFrame(
{"x": 1, "y": 1, "a": 1, "b": 2},
schema={"x": pl.Int32, "y": pl.Int32, "a": pl.Int8, "b": pl.Int16},
)
df.write_parquet(path)

def assert_with_projections(lf: pl.LazyFrame, df: pl.DataFrame) -> None:
for projection in [
["a"],
["b"],
["x"],
["y"],
["a", "x"],
["b", "x"],
["a", "y"],
["b", "y"],
["x", "y"],
["a", "b", "x"],
["a", "b", "y"],
]:
assert_frame_equal(
lf.select(projection).collect(projection_pushdown=projection_pushdown),
df.select(projection),
)

lf = pl.scan_parquet(path, hive_partitioning=True)
rhs = df
assert_frame_equal(lf.collect(projection_pushdown=projection_pushdown), rhs)
assert_with_projections(lf, rhs)

lf = pl.scan_parquet(
path,
hive_schema={"a": pl.String, "b": pl.String},
hive_partitioning=True,
)
rhs = df.with_columns(pl.col("a", "b").cast(pl.String))
assert_frame_equal(
lf.collect(projection_pushdown=projection_pushdown),
rhs,
)
assert_with_projections(lf, rhs)

0 comments on commit 4facac0

Please sign in to comment.