Skip to content

Commit 7590726

Browse files
authored
fix: Handle leading nulls when computing stats in compressor (#2608)
1 parent 8b9bb08 commit 7590726

File tree

3 files changed

+135
-55
lines changed

3 files changed

+135
-55
lines changed

vortex-btrblocks/src/float/stats.rs

+66-18
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::hash::Hash;
22

3+
use itertools::Itertools;
34
use num_traits::Float;
45
use rustc_hash::FxBuildHasher;
56
use vortex_array::aliases::hash_map::HashMap;
@@ -9,6 +10,7 @@ use vortex_array::{Array, ToCanonical};
910
use vortex_dtype::half::f16;
1011
use vortex_dtype::{NativePType, PType};
1112
use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
13+
use vortex_mask::AllOr;
1214

1315
use crate::sample::sample;
1416
use crate::{CompressorStats, GenerateStatsOptions};
@@ -138,33 +140,57 @@ where
138140
let value_count = validity.true_count();
139141
let mut min = T::max_value();
140142
let mut max = T::min_value();
141-
let mut distinct_values_count: u32 = if count_distinct_values { 0 } else { u32::MAX };
143+
// Keep a HashMap of T, then convert the keys into PValue afterward since value is
144+
// so much more efficient to hash and search for.
142145
let mut distinct_values = if count_distinct_values {
143146
HashMap::with_capacity_and_hasher(array.len() / 2, FxBuildHasher)
144147
} else {
145148
HashMap::with_hasher(FxBuildHasher)
146149
};
147150

148-
// Keep a HashMap of T, then convert the keys into PValue afterward since value is
149-
// so much more efficient to hash and search for.
150151
let mut runs = 1;
151-
let mut prev = array.as_slice::<T>()[0];
152-
153-
for (idx, &value) in array.buffer::<T>().iter().enumerate() {
154-
if validity.value(idx) {
155-
min = min.min(value);
156-
max = max.max(value);
157-
158-
if count_distinct_values {
159-
*distinct_values.entry(value.to_bits()).or_insert(0) += 1;
160-
distinct_values_count = distinct_values.len().try_into().vortex_unwrap();
161-
} else {
162-
distinct_values_count = u32::MAX;
152+
let head_idx = validity
153+
.first()
154+
.vortex_expect("All null masks have been handled before");
155+
let buff = array.buffer::<T>();
156+
let mut prev = buff[head_idx];
157+
158+
let first_valid_buff = buff.slice(head_idx..array.len());
159+
match validity.boolean_buffer() {
160+
AllOr::All => {
161+
for value in first_valid_buff {
162+
min = min.min(value);
163+
max = max.max(value);
164+
165+
if count_distinct_values {
166+
*distinct_values.entry(value.to_bits()).or_insert(0) += 1;
167+
}
168+
169+
if value != prev {
170+
prev = value;
171+
runs += 1;
172+
}
163173
}
174+
}
175+
AllOr::None => unreachable!("All invalid arrays have been handled earlier"),
176+
AllOr::Some(v) => {
177+
for (&value, valid) in first_valid_buff
178+
.iter()
179+
.zip_eq(v.slice(head_idx, array.len() - head_idx).iter())
180+
{
181+
if valid {
182+
min = min.min(value);
183+
max = max.max(value);
184+
185+
if count_distinct_values {
186+
*distinct_values.entry(value.to_bits()).or_insert(0) += 1;
187+
}
164188

165-
if value != prev {
166-
prev = value;
167-
runs += 1;
189+
if value != prev {
190+
prev = value;
191+
runs += 1;
192+
}
193+
}
168194
}
169195
}
170196
}
@@ -175,6 +201,11 @@ where
175201
let value_count = value_count
176202
.try_into()
177203
.vortex_expect("null_count must fit in u32");
204+
let distinct_values_count = if count_distinct_values {
205+
distinct_values.len().try_into().vortex_unwrap()
206+
} else {
207+
u32::MAX
208+
};
178209

179210
FloatStats {
180211
null_count,
@@ -191,6 +222,8 @@ where
191222

192223
#[cfg(test)]
193224
mod tests {
225+
use vortex_array::arrays::PrimitiveArray;
226+
use vortex_array::validity::Validity;
194227
use vortex_array::{IntoArray, ToCanonical};
195228
use vortex_buffer::buffer;
196229

@@ -209,4 +242,19 @@ mod tests {
209242
assert_eq!(stats.average_run_length, 1);
210243
assert_eq!(stats.distinct_values_count, 3);
211244
}
245+
246+
#[test]
247+
fn test_float_stats_leading_nulls() {
248+
let floats = PrimitiveArray::new(
249+
buffer![0.0f32, 1.0f32, 2.0f32],
250+
Validity::from_iter([false, true, true]),
251+
);
252+
253+
let stats = FloatStats::generate(&floats);
254+
255+
assert_eq!(stats.value_count, 2);
256+
assert_eq!(stats.null_count, 1);
257+
assert_eq!(stats.average_run_length, 1);
258+
assert_eq!(stats.distinct_values_count, 2);
259+
}
212260
}

vortex-btrblocks/src/integer/stats.rs

+67-37
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use vortex_array::variants::PrimitiveArrayTrait;
99
use vortex_array::{Array, ToCanonical};
1010
use vortex_dtype::{NativePType, match_each_integer_ptype};
1111
use vortex_error::{VortexExpect, VortexUnwrap};
12+
use vortex_mask::AllOr;
1213
use vortex_scalar::PValue;
1314

1415
use crate::sample::sample;
@@ -196,7 +197,12 @@ where
196197
let value_count = validity.true_count();
197198

198199
// Initialize loop state
199-
let head = array.as_slice::<T>()[0];
200+
let head_idx = validity
201+
.first()
202+
.vortex_expect("All null masks have been handled before");
203+
let buffer = array.buffer::<T>();
204+
let head = buffer[head_idx];
205+
200206
let mut loop_state = LoopState {
201207
min: head,
202208
max: head,
@@ -205,43 +211,55 @@ where
205211
} else {
206212
HashMap::with_hasher(FxBuildHasher)
207213
},
208-
distinct_values_count: if count_distinct_values { 0 } else { u32::MAX },
209214
prev: head,
210215
runs: 1,
211216
};
212217

213-
let values = array.buffer::<T>();
214-
let mask = validity.to_boolean_buffer();
215-
216-
let mut offset = 0;
217-
for chunk in values.as_slice().chunks(64) {
218-
let validity = mask.slice(offset, chunk.len());
219-
offset += chunk.len();
220-
221-
if chunk.len() < 64 {
222-
// Final iteration, run naive loop
223-
inner_loop_naive(chunk, count_distinct_values, &validity, &mut loop_state);
224-
break;
225-
}
226-
227-
let set_bits = validity.count_set_bits();
228-
229-
match set_bits {
230-
// All nulls -> no stats to update
231-
0 => continue,
232-
// Inner loop for when validity check can be elided
233-
64 => inner_loop_nonnull(
234-
chunk.try_into().vortex_unwrap(),
218+
let sliced = buffer.slice(head_idx..array.len());
219+
let mut chunks = sliced.as_slice().array_chunks::<64>();
220+
match validity.boolean_buffer() {
221+
AllOr::All => {
222+
for chunk in &mut chunks {
223+
inner_loop_nonnull(chunk, count_distinct_values, &mut loop_state)
224+
}
225+
let remainder = chunks.remainder();
226+
inner_loop_naive(
227+
remainder,
235228
count_distinct_values,
229+
&BooleanBuffer::new_set(remainder.len()),
236230
&mut loop_state,
237-
),
238-
// Inner loop for when we need to check validity
239-
_ => inner_loop_nullable(
240-
chunk.try_into().vortex_unwrap(),
231+
);
232+
}
233+
AllOr::None => unreachable!("All invalid arrays have been handled before"),
234+
AllOr::Some(v) => {
235+
let mask = v.slice(head_idx, array.len() - head_idx);
236+
let mut offset = 0;
237+
for chunk in &mut chunks {
238+
let validity = mask.slice(offset, 64);
239+
offset += 64;
240+
241+
match validity.count_set_bits() {
242+
// All nulls -> no stats to update
243+
0 => continue,
244+
// Inner loop for when validity check can be elided
245+
64 => inner_loop_nonnull(chunk, count_distinct_values, &mut loop_state),
246+
// Inner loop for when we need to check validity
247+
_ => inner_loop_nullable(
248+
chunk,
249+
count_distinct_values,
250+
&validity,
251+
&mut loop_state,
252+
),
253+
}
254+
}
255+
// Final iteration, run naive loop
256+
let remainder = chunks.remainder();
257+
inner_loop_naive(
258+
remainder,
241259
count_distinct_values,
242-
&validity,
260+
&mask.slice(offset, remainder.len()),
243261
&mut loop_state,
244-
),
262+
);
245263
}
246264
}
247265

@@ -257,7 +275,11 @@ where
257275
};
258276

259277
let runs = loop_state.runs;
260-
let distinct_values_count = loop_state.distinct_values_count;
278+
let distinct_values_count = if count_distinct_values {
279+
loop_state.distinct_values.len().try_into().vortex_unwrap()
280+
} else {
281+
u32::MAX
282+
};
261283

262284
let typed = TypedStats {
263285
min: loop_state.min,
@@ -289,7 +311,6 @@ struct LoopState<T> {
289311
max: T,
290312
prev: T,
291313
runs: u32,
292-
distinct_values_count: u32,
293314
distinct_values: HashMap<T, u32, FxBuildHasher>,
294315
}
295316

@@ -305,7 +326,6 @@ fn inner_loop_nonnull<T: PrimInt + Hash>(
305326

306327
if count_distinct_values {
307328
*state.distinct_values.entry(value).or_insert(0) += 1;
308-
state.distinct_values_count = state.distinct_values.len().try_into().vortex_unwrap();
309329
}
310330

311331
if value != state.prev {
@@ -329,8 +349,6 @@ fn inner_loop_nullable<T: PrimInt + Hash>(
329349

330350
if count_distinct_values {
331351
*state.distinct_values.entry(value).or_insert(0) += 1;
332-
state.distinct_values_count =
333-
state.distinct_values.len().try_into().vortex_unwrap();
334352
}
335353

336354
if value != state.prev {
@@ -355,8 +373,6 @@ fn inner_loop_naive<T: PrimInt + Hash>(
355373

356374
if count_distinct_values {
357375
*state.distinct_values.entry(value).or_insert(0) += 1;
358-
state.distinct_values_count =
359-
state.distinct_values.len().try_into().vortex_unwrap();
360376
}
361377

362378
if value != state.prev {
@@ -376,6 +392,8 @@ mod tests {
376392
use vortex_array::validity::Validity;
377393
use vortex_buffer::{Buffer, buffer};
378394

395+
use crate::CompressorStats;
396+
use crate::integer::IntegerStats;
379397
use crate::integer::stats::typed_int_stats;
380398

381399
#[test]
@@ -413,4 +431,16 @@ mod tests {
413431
let stats = typed_int_stats::<u8>(&array, true);
414432
assert_eq!(stats.distinct_values_count, 64);
415433
}
434+
435+
#[test]
436+
fn test_integer_stats_leading_nulls() {
437+
let ints = PrimitiveArray::new(buffer![0, 1, 2], Validity::from_iter([false, true, true]));
438+
439+
let stats = IntegerStats::generate(&ints);
440+
441+
assert_eq!(stats.value_count, 2);
442+
assert_eq!(stats.null_count, 1);
443+
assert_eq!(stats.average_run_length, 1);
444+
assert_eq!(stats.distinct_values_count, 2);
445+
}
416446
}

vortex-btrblocks/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#![feature(array_chunks)]
2+
13
use std::fmt::Debug;
24
use std::hash::Hash;
35

0 commit comments

Comments
 (0)