From 7d93ec4f076c39565bc8237f3a1b7cc51964e9db Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Sun, 3 Nov 2024 14:36:24 +0100 Subject: [PATCH] fix: Parquet decoding of nested dictionary values (#19605) --- .../arrow/read/deserialize/nested_utils.rs | 7 +- .../read/deserialize/utils/dict_encoded.rs | 156 +++++++----------- .../src/parquet/encoding/bitpacked/decode.rs | 24 +-- .../encoding/delta_bitpacked/decoder.rs | 4 +- py-polars/polars/io/spreadsheet/_utils.py | 15 +- .../tests/unit/io/test_lazy_count_star.py | 14 +- py-polars/tests/unit/io/test_parquet.py | 11 +- .../unit/operations/namespaces/test_meta.py | 1 + py-polars/tests/unit/test_cse.py | 20 +-- 9 files changed, 116 insertions(+), 136 deletions(-) diff --git a/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs b/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs index 78b4a813693d..add07af2ce45 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs @@ -686,16 +686,19 @@ impl PageNestedDecoder { batched_collector.finalize()?; + let leaf_validity = leaf_validity.freeze(); + let leaf_filter = leaf_filter.freeze(); + let state = utils::State::new_nested( &self.decoder, &page, self.dict.as_ref(), - Some(leaf_validity.freeze()), + Some(leaf_validity), )?; state.decode( &mut self.decoder, &mut target, - Some(Filter::Mask(leaf_filter.freeze())), + Some(Filter::Mask(leaf_filter)), )?; self.iter.reuse_page_buffer(page); diff --git a/crates/polars-parquet/src/arrow/read/deserialize/utils/dict_encoded.rs b/crates/polars-parquet/src/arrow/read/deserialize/utils/dict_encoded.rs index 69c1cd6c549d..fdb6a7dab1d2 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/utils/dict_encoded.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/utils/dict_encoded.rs @@ -81,7 +81,7 @@ pub(crate) fn constrain_page_validity( #[inline(never)] pub fn decode_dict_dispatch( - values: HybridRleDecoder<'_>, + mut values: HybridRleDecoder<'_>, dict: &[B], is_optional: bool, page_validity: Option<&Bitmap>, @@ -100,9 +100,10 @@ pub fn decode_dict_dispatch( let page_validity = constrain_page_validity(values.len(), page_validity, filter.as_ref()); match (filter, page_validity) { - (None, None) => decode_required_dict(values, dict, None, target), + (None, None) => decode_required_dict(values, dict, target), (Some(Filter::Range(rng)), None) if rng.start == 0 => { - decode_required_dict(values, dict, Some(rng.end), target) + values.limit_to(rng.end); + decode_required_dict(values, dict, target) }, (None, Some(page_validity)) => decode_optional_dict(values, dict, &page_validity, target), (Some(Filter::Range(rng)), Some(page_validity)) if rng.start == 0 => { @@ -156,85 +157,67 @@ fn verify_dict_indices(indices: &[u32; 32], dict_size: usize) -> ParquetResult<( pub fn decode_required_dict( mut values: HybridRleDecoder<'_>, dict: &[B], - limit: Option, target: &mut Vec, ) -> ParquetResult<()> { if dict.is_empty() && values.len() > 0 { return Err(oob_dict_idx()); } - let mut limit = limit.unwrap_or(values.len()); - assert!(limit <= values.len()); let start_length = target.len(); - let end_length = start_length + limit; + let end_length = start_length + values.len(); - target.reserve(limit); + target.reserve(values.len()); let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; - while limit > 0 { + while values.len() > 0 { let chunk = values.next_chunk()?.unwrap(); match chunk { HybridRleChunk::Rle(value, length) => { - let length = length.min(limit); + if length == 0 { + continue; + } - let Some(&value) = dict.get(value as usize) else { - return Err(oob_dict_idx()); - }; let target_slice; // SAFETY: - // 1. `target_ptr..target_ptr + limit` is allocated + // 1. `target_ptr..target_ptr + values.len()` is allocated // 2. `length <= limit` unsafe { target_slice = std::slice::from_raw_parts_mut(target_ptr, length); target_ptr = target_ptr.add(length); } + let Some(&value) = dict.get(value as usize) else { + return Err(oob_dict_idx()); + }; + target_slice.fill(value); - limit -= length; }, HybridRleChunk::Bitpacked(mut decoder) => { let mut chunked = decoder.chunked(); - loop { - if limit < 32 { - break; - } - - let Some(chunk) = chunked.next() else { - break; - }; - + for chunk in chunked.by_ref() { verify_dict_indices(&chunk, dict.len())?; for (i, &idx) in chunk.iter().enumerate() { - let value = unsafe { dict.get_unchecked(idx as usize) }; - let value = *value; - unsafe { target_ptr.add(i).write(value) }; + unsafe { target_ptr.add(i).write(*dict.get_unchecked(idx as usize)) }; } - unsafe { target_ptr = target_ptr.add(32); } - limit -= 32; } - if let Some((chunk, chunk_size)) = chunked.next_inexact() { - let chunk_size = chunk_size.min(limit); - + if let Some((chunk, chunk_size)) = chunked.remainder() { let highest_idx = chunk[..chunk_size].iter().copied().max().unwrap(); - assert!((highest_idx as usize) < dict.len()); + if highest_idx as usize >= dict.len() { + return Err(oob_dict_idx()); + } for (i, &idx) in chunk[..chunk_size].iter().enumerate() { - let value = unsafe { dict.get_unchecked(idx as usize) }; - let value = *value; - unsafe { target_ptr.add(i).write(value) }; + unsafe { target_ptr.add(i).write(*dict.get_unchecked(idx as usize)) }; } - unsafe { target_ptr = target_ptr.add(chunk_size); } - - limit -= chunk_size; } }, } @@ -254,12 +237,12 @@ pub fn decode_optional_dict( validity: &Bitmap, target: &mut Vec, ) -> ParquetResult<()> { - let mut limit = validity.len(); let num_valid_values = validity.set_bits(); // Dispatch to the required kernel if all rows are valid anyway. if num_valid_values == validity.len() { - return decode_required_dict(values, dict, Some(validity.len()), target); + values.limit_to(validity.len()); + return decode_required_dict(values, dict, target); } if dict.is_empty() && num_valid_values > 0 { @@ -273,21 +256,18 @@ pub fn decode_optional_dict( target.reserve(validity.len()); let mut target_ptr = unsafe { target.as_mut_ptr().add(start_length) }; + values.limit_to(num_valid_values); let mut validity = BitMask::from_bitmap(validity); let mut values_buffer = [0u32; 128]; let values_buffer = &mut values_buffer; - loop { - if limit == 0 { - break; - } - - let Some(chunk) = values.next_chunk()? else { - break; - }; - - match chunk { + for chunk in values.into_chunk_iter() { + match chunk? { HybridRleChunk::Rle(value, size) => { + if size == 0 { + continue; + } + // If we know that we have `size` times `value` that we can append, but there might // be nulls in between those values. // @@ -303,6 +283,7 @@ pub fn decode_optional_dict( let Some(&value) = dict.get(value as usize) else { return Err(oob_dict_idx()); }; + let target_slice; // SAFETY: // Given `validity_iter` before the `advance_by_bits` @@ -315,7 +296,6 @@ pub fn decode_optional_dict( } target_slice.fill(value); - limit -= num_chunk_rows; }, HybridRleChunk::Bitpacked(mut decoder) => { let mut chunked = decoder.chunked(); @@ -328,9 +308,7 @@ pub fn decode_optional_dict( let mut num_done = 0; let mut validity_iter = validity.fast_iter_u56(); - 'outer: while limit >= 64 { - let v = validity_iter.next().unwrap(); - + 'outer: for v in validity_iter.by_ref() { while num_buffered < v.count_ones() as usize { let buffer_part = <&mut [u32; 32]>::try_from( &mut values_buffer[buffer_part_idx * 32..][..32], @@ -371,7 +349,6 @@ pub fn decode_optional_dict( target_ptr = target_ptr.add(56); } num_done += 56; - limit -= 56; } (_, validity) = unsafe { validity.split_at_unchecked(num_done) }; @@ -382,10 +359,9 @@ pub fn decode_optional_dict( .nth_set_bit_idx(num_decoder_remaining, 0) .unwrap_or(validity.len()); - let num_remaining = limit.min(decoder_limit); let current_validity; (current_validity, validity) = - unsafe { validity.split_at_unchecked(num_remaining) }; + unsafe { validity.split_at_unchecked(decoder_limit) }; let (v, _) = current_validity.fast_iter_u56().remainder(); while num_buffered < v.count_ones() as usize { @@ -405,7 +381,7 @@ pub fn decode_optional_dict( let mut num_read = 0; - for i in 0..num_remaining { + for i in 0..decoder_limit { let idx = values_buffer[(values_offset + num_read) % 128]; let value = unsafe { dict.get_unchecked(idx as usize) }; let value = *value; @@ -414,9 +390,8 @@ pub fn decode_optional_dict( } unsafe { - target_ptr = target_ptr.add(num_remaining); + target_ptr = target_ptr.add(decoder_limit); } - limit -= num_remaining; }, } } @@ -425,13 +400,8 @@ pub fn decode_optional_dict( assert_eq!(validity.set_bits(), 0); } - let target_slice; - unsafe { - target_slice = std::slice::from_raw_parts_mut(target_ptr, limit); - } - + let target_slice = unsafe { std::slice::from_raw_parts_mut(target_ptr, validity.len()) }; target_slice.fill(B::zeroed()); - unsafe { target.set_len(end_length); } @@ -474,25 +444,22 @@ pub fn decode_masked_optional_dict( let mut filter = BitMask::from_bitmap(filter); let mut validity = BitMask::from_bitmap(validity); + values.limit_to(num_valid_values); let mut values_buffer = [0u32; 128]; let values_buffer = &mut values_buffer; let mut num_rows_left = num_rows; - loop { + for chunk in values.into_chunk_iter() { // Early stop if we have no more rows to load. if num_rows_left == 0 { break; } - let Some(chunk) = values.next_chunk()? else { - break; - }; - - match chunk { + match chunk? { HybridRleChunk::Rle(value, size) => { - if value as usize >= dict.len() { - return Err(oob_dict_idx()); + if size == 0 { + continue; } // If we know that we have `size` times `value` that we can append, but there might @@ -512,9 +479,6 @@ pub fn decode_masked_optional_dict( let num_chunk_rows = current_filter.set_bits(); if num_chunk_rows > 0 { - // SAFETY: Bounds check done before. - let value = unsafe { dict.get_unchecked(value as usize) }; - let target_slice; // SAFETY: // Given `filter_iter` before the `advance_by_bits`. @@ -526,6 +490,10 @@ pub fn decode_masked_optional_dict( target_ptr = target_ptr.add(num_chunk_rows); } + let Some(value) = dict.get(value as usize) else { + return Err(oob_dict_idx()); + }; + target_slice.fill(*value); num_rows_left -= num_chunk_rows; } @@ -655,13 +623,8 @@ pub fn decode_masked_optional_dict( assert_eq!(validity.set_bits(), 0); } - let target_slice; - unsafe { - target_slice = std::slice::from_raw_parts_mut(target_ptr, num_rows_left); - } - + let target_slice = unsafe { std::slice::from_raw_parts_mut(target_ptr, num_rows_left) }; target_slice.fill(B::zeroed()); - unsafe { target.set_len(start_length + num_rows); } @@ -680,10 +643,11 @@ pub fn decode_masked_required_dict( // Dispatch to the non-filter kernel if all rows are needed anyway. if num_rows == filter.len() { - return decode_required_dict(values, dict, Some(filter.len()), target); + values.limit_to(filter.len()); + return decode_required_dict(values, dict, target); } - if dict.is_empty() && values.len() > 0 { + if dict.is_empty() && !filter.is_empty() { return Err(oob_dict_idx()); } @@ -694,24 +658,21 @@ pub fn decode_masked_required_dict( let mut filter = BitMask::from_bitmap(filter); + values.limit_to(filter.len()); let mut values_buffer = [0u32; 128]; let values_buffer = &mut values_buffer; let mut num_rows_left = num_rows; - loop { + for chunk in values.into_chunk_iter() { if num_rows_left == 0 { break; } - let Some(chunk) = values.next_chunk()? else { - break; - }; - - match chunk { + match chunk? { HybridRleChunk::Rle(value, size) => { - if value as usize >= dict.len() { - return Err(oob_dict_idx()); + if size == 0 { + continue; } let size = size.min(filter.len()); @@ -730,9 +691,6 @@ pub fn decode_masked_required_dict( let num_chunk_rows = current_filter.set_bits(); if num_chunk_rows > 0 { - // SAFETY: Bounds check done before. - let value = unsafe { dict.get_unchecked(value as usize) }; - let target_slice; // SAFETY: // Given `filter_iter` before the `advance_by_bits`. @@ -744,6 +702,10 @@ pub fn decode_masked_required_dict( target_ptr = target_ptr.add(num_chunk_rows); } + let Some(value) = dict.get(value as usize) else { + return Err(oob_dict_idx()); + }; + target_slice.fill(*value); num_rows_left -= num_chunk_rows; } diff --git a/crates/polars-parquet/src/parquet/encoding/bitpacked/decode.rs b/crates/polars-parquet/src/parquet/encoding/bitpacked/decode.rs index 8c43d2694590..44a30ff57245 100644 --- a/crates/polars-parquet/src/parquet/encoding/bitpacked/decode.rs +++ b/crates/polars-parquet/src/parquet/encoding/bitpacked/decode.rs @@ -149,19 +149,22 @@ impl Iterator for ChunkedDecoder<'_, '_, T> { impl ExactSizeIterator for ChunkedDecoder<'_, '_, T> {} impl ChunkedDecoder<'_, '_, T> { - /// Get and consume the remainder chunk if it exists + /// Get and consume the remainder chunk if it exists. + /// + /// This should only be called after all the chunks full are consumed. pub fn remainder(&mut self) -> Option<(T::Unpacked, usize)> { - let remainder_len = self.decoder.len() % T::Unpacked::LENGTH; - - if remainder_len > 0 { - let mut unpacked = T::Unpacked::zero(); - let packed = self.decoder.packed.next_back().unwrap(); - decode_pack::(packed, self.decoder.num_bits, &mut unpacked); - self.decoder.length -= remainder_len; - return Some((unpacked, remainder_len)); + if self.decoder.len() == 0 { + return None; } - None + debug_assert!(self.decoder.len() < T::Unpacked::LENGTH); + let remainder_len = self.decoder.len() % T::Unpacked::LENGTH; + + let mut unpacked = T::Unpacked::zero(); + let packed = self.decoder.packed.next()?; + decode_pack::(packed, self.decoder.num_bits, &mut unpacked); + self.decoder.length -= remainder_len; + Some((unpacked, remainder_len)) } /// Get the next (possibly partial) chunk and its filled length @@ -173,6 +176,7 @@ impl ChunkedDecoder<'_, '_, T> { } } + /// Consume the next chunk into `unpacked`. pub fn next_into(&mut self, unpacked: &mut T::Unpacked) -> Option { if self.decoder.len() == 0 { return None; diff --git a/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/decoder.rs b/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/decoder.rs index f176ef9862d4..a89d824bdb6a 100644 --- a/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/decoder.rs +++ b/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/decoder.rs @@ -176,7 +176,7 @@ fn gather_bitpacked( gatherer: &mut G, ) -> ParquetResult<()> { let mut chunked = decoder.chunked(); - for mut chunk in &mut chunked { + for mut chunk in chunked.by_ref() { for value in &mut chunk { *last_value = last_value .wrapping_add(*value as i64) @@ -188,7 +188,7 @@ fn gather_bitpacked( gatherer.gather_chunk(target, chunk)?; } - if let Some((mut chunk, length)) = chunked.next_inexact() { + if let Some((mut chunk, length)) = chunked.remainder() { let slice = &mut chunk[..length]; for value in slice.iter_mut() { diff --git a/py-polars/polars/io/spreadsheet/_utils.py b/py-polars/polars/io/spreadsheet/_utils.py index 090effbe8ce6..a19299dd47f5 100644 --- a/py-polars/polars/io/spreadsheet/_utils.py +++ b/py-polars/polars/io/spreadsheet/_utils.py @@ -42,10 +42,11 @@ def PortableTemporaryFile( "errors": errors, }, ) - tmp = NamedTemporaryFile(**params) # noqa: SIM115 - try: - yield tmp - finally: - tmp.close() - if delete: - Path(tmp.name).unlink(missing_ok=True) + + with NamedTemporaryFile(**params) as tmp: + try: + yield tmp + finally: + tmp.close() + if delete: + Path(tmp.name).unlink(missing_ok=True) diff --git a/py-polars/tests/unit/io/test_lazy_count_star.py b/py-polars/tests/unit/io/test_lazy_count_star.py index 013b649f4b2f..a3723f2ad939 100644 --- a/py-polars/tests/unit/io/test_lazy_count_star.py +++ b/py-polars/tests/unit/io/test_lazy_count_star.py @@ -29,15 +29,15 @@ def test_count_csv(io_files_path: Path, path: str, n_rows: int) -> None: @pytest.mark.write_disk def test_commented_csv() -> None: - csv_a = NamedTemporaryFile() # noqa: SIM115 - csv_a.write(b"A,B\nGr1,A\nGr1,B\n# comment line\n") - csv_a.seek(0) + with NamedTemporaryFile() as csv_a: + csv_a.write(b"A,B\nGr1,A\nGr1,B\n# comment line\n") + csv_a.seek(0) - expected = pl.DataFrame(pl.Series("len", [2], dtype=pl.UInt32)) - lf = pl.scan_csv(csv_a.name, comment_prefix="#").select(pl.len()) + expected = pl.DataFrame(pl.Series("len", [2], dtype=pl.UInt32)) + lf = pl.scan_csv(csv_a.name, comment_prefix="#").select(pl.len()) - assert "FAST COUNT" in lf.explain() - assert_frame_equal(lf.collect(), expected) + assert "FAST COUNT" in lf.explain() + assert_frame_equal(lf.collect(), expected) @pytest.mark.parametrize( diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index 84e93eb0c5e4..fe75fccd1ad0 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -1099,7 +1099,6 @@ def test_hybrid_rle() -> None: @pytest.mark.slow def test_roundtrip_parametric(df: pl.DataFrame) -> None: f = io.BytesIO() - print(df) df.write_parquet(f) f.seek(0) result = pl.read_parquet(f) @@ -2263,3 +2262,13 @@ def test_nested_nulls() -> None: f.seek(0) out = pl.read_parquet(f) assert_frame_equal(out, df) + + +@pytest.mark.parametrize("content", [[], [None], [None, 0.0]]) +def test_nested_dicts(content: list[float | None]) -> None: + df = pl.Series("a", [content], pl.List(pl.Float64)).to_frame() + + f = io.BytesIO() + df.write_parquet(f, use_pyarrow=True) + f.seek(0) + assert_frame_equal(df, pl.read_parquet(f)) diff --git a/py-polars/tests/unit/operations/namespaces/test_meta.py b/py-polars/tests/unit/operations/namespaces/test_meta.py index f038dd27fdbf..38835244557e 100644 --- a/py-polars/tests/unit/operations/namespaces/test_meta.py +++ b/py-polars/tests/unit/operations/namespaces/test_meta.py @@ -149,6 +149,7 @@ def test_meta_tree_format(namespace_files_path: Path) -> None: def test_meta_show_graph(namespace_files_path: Path) -> None: e = (pl.col("foo") * pl.col("bar")).sum().over(pl.col("ham")) / 2 dot = e.meta.show_graph(show=False, raw_output=True) + assert dot is not None assert len(dot) > 0 # Don't check output contents since this creates a maintenance burden # Assume output check in test_meta_tree_format is enough diff --git a/py-polars/tests/unit/test_cse.py b/py-polars/tests/unit/test_cse.py index fc44699b4e3e..06efd675c84f 100644 --- a/py-polars/tests/unit/test_cse.py +++ b/py-polars/tests/unit/test_cse.py @@ -147,18 +147,18 @@ def test_cse_9630() -> None: @pytest.mark.write_disk def test_schema_row_index_cse() -> None: - csv_a = NamedTemporaryFile() # noqa: SIM115 - csv_a.write(b"A,B\nGr1,A\nGr1,B") - csv_a.seek(0) + with NamedTemporaryFile() as csv_a: + csv_a.write(b"A,B\nGr1,A\nGr1,B") + csv_a.seek(0) - df_a = pl.scan_csv(csv_a.name).with_row_index("Idx") + df_a = pl.scan_csv(csv_a.name).with_row_index("Idx") - result = ( - df_a.join(df_a, on="B") - .group_by("A", maintain_order=True) - .all() - .collect(comm_subexpr_elim=True) - ) + result = ( + df_a.join(df_a, on="B") + .group_by("A", maintain_order=True) + .all() + .collect(comm_subexpr_elim=True) + ) expected = pl.DataFrame( {