Skip to content

Commit a0ad376

Browse files
authored
[Minor] Refactor approx_percentile (#11769)
* Refactor approx_percentile * Refactor approx_percentile * Types * Types * Types
1 parent f044bc8 commit a0ad376

File tree

3 files changed

+41
-31
lines changed

3 files changed

+41
-31
lines changed

datafusion/functions-aggregate/src/approx_median.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ impl AggregateUDFImpl for ApproxMedian {
7878
Ok(vec![
7979
Field::new(format_state_name(args.name, "max_size"), UInt64, false),
8080
Field::new(format_state_name(args.name, "sum"), Float64, false),
81-
Field::new(format_state_name(args.name, "count"), Float64, false),
81+
Field::new(format_state_name(args.name, "count"), UInt64, false),
8282
Field::new(format_state_name(args.name, "max"), Float64, false),
8383
Field::new(format_state_name(args.name, "min"), Float64, false),
8484
Field::new_list(

datafusion/functions-aggregate/src/approx_percentile_cont.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ impl AggregateUDFImpl for ApproxPercentileCont {
214214
),
215215
Field::new(
216216
format_state_name(args.name, "count"),
217-
DataType::Float64,
217+
DataType::UInt64,
218218
false,
219219
),
220220
Field::new(
@@ -406,7 +406,7 @@ impl Accumulator for ApproxPercentileAccumulator {
406406
}
407407

408408
fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
409-
if self.digest.count() == 0.0 {
409+
if self.digest.count() == 0 {
410410
return ScalarValue::try_from(self.return_type.clone());
411411
}
412412
let q = self.digest.estimate_quantile(self.percentile);
@@ -487,8 +487,8 @@ mod tests {
487487
ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100);
488488

489489
accumulator.merge_digests(&[t1]);
490-
assert_eq!(accumulator.digest.count(), 50_000.0);
490+
assert_eq!(accumulator.digest.count(), 50_000);
491491
accumulator.merge_digests(&[t2]);
492-
assert_eq!(accumulator.digest.count(), 100_000.0);
492+
assert_eq!(accumulator.digest.count(), 100_000);
493493
}
494494
}

datafusion/physical-expr-common/src/aggregate/tdigest.rs

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@ macro_rules! cast_scalar_f64 {
4747
};
4848
}
4949

50+
// Cast a non-null [`ScalarValue::UInt64`] to an [`u64`], or
51+
// panic.
52+
macro_rules! cast_scalar_u64 {
53+
($value:expr ) => {
54+
match &$value {
55+
ScalarValue::UInt64(Some(v)) => *v,
56+
v => panic!("invalid type {:?}", v),
57+
}
58+
};
59+
}
60+
5061
/// This trait is implemented for each type a [`TDigest`] can operate on,
5162
/// allowing it to support both numerical rust types (obtained from
5263
/// `PrimitiveArray` instances), and [`ScalarValue`] instances.
@@ -142,7 +153,7 @@ pub struct TDigest {
142153
centroids: Vec<Centroid>,
143154
max_size: usize,
144155
sum: f64,
145-
count: f64,
156+
count: u64,
146157
max: f64,
147158
min: f64,
148159
}
@@ -153,7 +164,7 @@ impl TDigest {
153164
centroids: Vec::new(),
154165
max_size,
155166
sum: 0_f64,
156-
count: 0_f64,
167+
count: 0,
157168
max: f64::NAN,
158169
min: f64::NAN,
159170
}
@@ -164,14 +175,14 @@ impl TDigest {
164175
centroids: vec![centroid.clone()],
165176
max_size,
166177
sum: centroid.mean * centroid.weight,
167-
count: 1_f64,
178+
count: 1,
168179
max: centroid.mean,
169180
min: centroid.mean,
170181
}
171182
}
172183

173184
#[inline]
174-
pub fn count(&self) -> f64 {
185+
pub fn count(&self) -> u64 {
175186
self.count
176187
}
177188

@@ -203,16 +214,16 @@ impl Default for TDigest {
203214
centroids: Vec::new(),
204215
max_size: 100,
205216
sum: 0_f64,
206-
count: 0_f64,
217+
count: 0,
207218
max: f64::NAN,
208219
min: f64::NAN,
209220
}
210221
}
211222
}
212223

213224
impl TDigest {
214-
fn k_to_q(k: f64, d: f64) -> f64 {
215-
let k_div_d = k / d;
225+
fn k_to_q(k: u64, d: usize) -> f64 {
226+
let k_div_d = k as f64 / d as f64;
216227
if k_div_d >= 0.5 {
217228
let base = 1.0 - k_div_d;
218229
1.0 - 2.0 * base * base
@@ -244,12 +255,12 @@ impl TDigest {
244255
}
245256

246257
let mut result = TDigest::new(self.max_size());
247-
result.count = self.count() + (sorted_values.len() as f64);
258+
result.count = self.count() + sorted_values.len() as u64;
248259

249260
let maybe_min = *sorted_values.first().unwrap();
250261
let maybe_max = *sorted_values.last().unwrap();
251262

252-
if self.count() > 0.0 {
263+
if self.count() > 0 {
253264
result.min = self.min.min(maybe_min);
254265
result.max = self.max.max(maybe_max);
255266
} else {
@@ -259,10 +270,10 @@ impl TDigest {
259270

260271
let mut compressed: Vec<Centroid> = Vec::with_capacity(self.max_size);
261272

262-
let mut k_limit: f64 = 1.0;
273+
let mut k_limit: u64 = 1;
263274
let mut q_limit_times_count =
264-
Self::k_to_q(k_limit, self.max_size as f64) * result.count();
265-
k_limit += 1.0;
275+
Self::k_to_q(k_limit, self.max_size) * result.count() as f64;
276+
k_limit += 1;
266277

267278
let mut iter_centroids = self.centroids.iter().peekable();
268279
let mut iter_sorted_values = sorted_values.iter().peekable();
@@ -309,8 +320,8 @@ impl TDigest {
309320

310321
compressed.push(curr.clone());
311322
q_limit_times_count =
312-
Self::k_to_q(k_limit, self.max_size as f64) * result.count();
313-
k_limit += 1.0;
323+
Self::k_to_q(k_limit, self.max_size) * result.count() as f64;
324+
k_limit += 1;
314325
curr = next;
315326
}
316327
}
@@ -381,16 +392,16 @@ impl TDigest {
381392
let mut centroids: Vec<Centroid> = Vec::with_capacity(n_centroids);
382393
let mut starts: Vec<usize> = Vec::with_capacity(digests.len());
383394

384-
let mut count: f64 = 0.0;
395+
let mut count = 0;
385396
let mut min = f64::INFINITY;
386397
let mut max = f64::NEG_INFINITY;
387398

388399
let mut start: usize = 0;
389400
for digest in digests.iter() {
390401
starts.push(start);
391402

392-
let curr_count: f64 = digest.count();
393-
if curr_count > 0.0 {
403+
let curr_count = digest.count();
404+
if curr_count > 0 {
394405
min = min.min(digest.min);
395406
max = max.max(digest.max);
396407
count += curr_count;
@@ -424,8 +435,8 @@ impl TDigest {
424435
let mut result = TDigest::new(max_size);
425436
let mut compressed: Vec<Centroid> = Vec::with_capacity(max_size);
426437

427-
let mut k_limit: f64 = 1.0;
428-
let mut q_limit_times_count = Self::k_to_q(k_limit, max_size as f64) * (count);
438+
let mut k_limit = 1;
439+
let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64;
429440

430441
let mut iter_centroids = centroids.iter_mut();
431442
let mut curr = iter_centroids.next().unwrap();
@@ -444,8 +455,8 @@ impl TDigest {
444455
sums_to_merge = 0_f64;
445456
weights_to_merge = 0_f64;
446457
compressed.push(curr.clone());
447-
q_limit_times_count = Self::k_to_q(k_limit, max_size as f64) * (count);
448-
k_limit += 1.0;
458+
q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64;
459+
k_limit += 1;
449460
curr = centroid;
450461
}
451462
}
@@ -468,8 +479,7 @@ impl TDigest {
468479
return 0.0;
469480
}
470481

471-
let count_ = self.count;
472-
let rank = q * count_;
482+
let rank = q * self.count as f64;
473483

474484
let mut pos: usize;
475485
let mut t;
@@ -479,7 +489,7 @@ impl TDigest {
479489
}
480490

481491
pos = 0;
482-
t = count_;
492+
t = self.count as f64;
483493

484494
for (k, centroid) in self.centroids.iter().enumerate().rev() {
485495
t -= centroid.weight();
@@ -581,7 +591,7 @@ impl TDigest {
581591
vec![
582592
ScalarValue::UInt64(Some(self.max_size as u64)),
583593
ScalarValue::Float64(Some(self.sum)),
584-
ScalarValue::Float64(Some(self.count)),
594+
ScalarValue::UInt64(Some(self.count)),
585595
ScalarValue::Float64(Some(self.max)),
586596
ScalarValue::Float64(Some(self.min)),
587597
ScalarValue::List(arr),
@@ -627,7 +637,7 @@ impl TDigest {
627637
Self {
628638
max_size,
629639
sum: cast_scalar_f64!(state[1]),
630-
count: cast_scalar_f64!(&state[2]),
640+
count: cast_scalar_u64!(&state[2]),
631641
max,
632642
min,
633643
centroids,

0 commit comments

Comments
 (0)