diff --git a/optd-cost-model/Cargo.lock b/optd-cost-model/Cargo.lock index 8ebe52f..f749199 100644 --- a/optd-cost-model/Cargo.lock +++ b/optd-cost-model/Cargo.lock @@ -274,7 +274,7 @@ dependencies = [ "arrow-schema 47.0.0", "chrono", "half", - "indexmap", + "indexmap 2.6.0", "lexical-core", "num", "serde", @@ -721,6 +721,47 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" +[[package]] +name = "crossbeam" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33480d6946193aa8033910124896ca395333cae7e2d1113d1fef6c3272217df2" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-queue" version = "0.3.11" @@ -793,6 +834,7 @@ dependencies = [ "ident_case", "proc-macro2", "quote", + "strsim", "syn 2.0.87", ] @@ -1368,6 +1410,17 @@ dependencies = [ "icu_properties", ] +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", + "serde", +] + [[package]] name = "indexmap" version = "2.6.0" @@ -1376,6 +1429,7 @@ checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", "hashbrown 0.15.1", + "serde", ] [[package]] @@ -1404,6 +1458,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -1733,10 +1796,16 @@ version = "0.1.0" dependencies = [ "arrow-schema 53.2.0", "chrono", + "crossbeam", "datafusion-expr", + "itertools 0.13.0", + "lazy_static", "optd-persistent", "ordered-float 4.5.0", + "rand", "serde", + "serde_json", + "serde_with", ] [[package]] @@ -1789,7 +1858,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39b0deead1528fd0e5947a8546a9642a9777c25f6e1e26f34c97b204bbb465bd" dependencies = [ "heck 0.4.1", - "itertools", + "itertools 0.12.1", "proc-macro2", "proc-macro2-diagnostics", "quote", @@ -2514,6 +2583,36 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_with" +version = "3.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e28bdad6db2b8340e449f7108f020b3b092e8583a9e3fb82713e1d4e71fe817" +dependencies = [ + "base64", + "chrono", + "hex", + "indexmap 1.9.3", + "indexmap 2.6.0", + "serde", + "serde_derive", + "serde_json", + "serde_with_macros", + "time", +] + +[[package]] +name = "serde_with_macros" +version = "3.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d846214a9854ef724f3da161b426242d8de7c1fc7de2f89bb1efcb154dca79d" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "sha1" version = "0.10.6" @@ -2687,7 +2786,7 @@ dependencies = [ "hashbrown 0.14.5", "hashlink", "hex", - "indexmap", + "indexmap 2.6.0", "log", "memchr", "once_cell", @@ -3140,7 +3239,7 @@ version = "0.22.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" dependencies = [ - "indexmap", + "indexmap 2.6.0", "toml_datetime", "winnow", ] diff --git a/optd-cost-model/Cargo.toml b/optd-cost-model/Cargo.toml index cb3285f..d37c72a 100644 --- a/optd-cost-model/Cargo.toml +++ b/optd-cost-model/Cargo.toml @@ -6,8 +6,15 @@ edition = "2021" [dependencies] optd-persistent = { path = "../optd-persistent", version = "0.1" } serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +serde_with = { version = "3.7.0", features = ["json"] } arrow-schema = "53.2.0" datafusion-expr = "32.0.0" ordered-float = "4.0" chrono = "0.4" +itertools = "0.13" +lazy_static = "1.5" +[dev-dependencies] +crossbeam = "0.8" +rand = "0.8" diff --git a/optd-cost-model/src/lib.rs b/optd-cost-model/src/lib.rs index 6aeb476..f081d4c 100644 --- a/optd-cost-model/src/lib.rs +++ b/optd-cost-model/src/lib.rs @@ -7,6 +7,7 @@ use optd_persistent::cost_model::interface::{Stat, StatType}; pub mod common; pub mod cost; pub mod cost_model; +pub mod stats; pub mod storage; pub enum StatValue { diff --git a/optd-cost-model/src/stats/arith_encoder.rs b/optd-cost-model/src/stats/arith_encoder.rs new file mode 100644 index 0000000..4285939 --- /dev/null +++ b/optd-cost-model/src/stats/arith_encoder.rs @@ -0,0 +1,74 @@ +//! This module provides an encoder that converts alpha-numeric strings +//! into f64 values, designed to maintain the natural ordering of strings. +//! +//! While the encoding is theoretically lossless, in practice, it may suffer +//! from precision loss due to floating-point errors. +//! +//! Non-alpha-numeric characters are relegated to the end of the encoded value, +//! rendering them indistinguishable from one another in this context. + +use std::collections::HashMap; + +// TODO: Use lazy cell instead of lazy static. +use lazy_static::lazy_static; + +// The alphanumerical ordering. +const ALPHANUMERIC_ORDER: [char; 95] = [ + ' ', '!', '"', '#', '$', '%', '&', '\'', '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', + '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', '0', '1', '2', '3', '4', + '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', + 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', + 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', +]; + +const PMF: f64 = 1.0 / (ALPHANUMERIC_ORDER.len() as f64); + +lazy_static! { + static ref CDF: HashMap = { + let length = ALPHANUMERIC_ORDER.len() + 1; // To account for non-alpha-numeric characters. + let mut cdf = HashMap::with_capacity(length); + for (index, &char) in ALPHANUMERIC_ORDER.iter().enumerate() { + cdf.insert(char, (index as f64) / (length as f64)); + } + cdf + }; +} + +pub fn encode(string: &str) -> f64 { + let mut left = 0.0; + // 10_000.0 is fairly arbitrary. don't make it f64::MAX though because it causes overflow in + // other places of the code + let mut right = 10_000.0; + + for char in string.chars() { + let cdf = CDF.get(&char).unwrap_or(&1.0); + let distance = right - left; + right = left + distance * (cdf + PMF); + left += distance * cdf; + } + + left +} + +// Start of unit testing section. +#[cfg(test)] +mod tests { + use super::encode; + + #[test] + fn encode_tests() { + assert!(encode("") < encode("abc")); + assert!(encode("abc") < encode("bcd")); + + assert!(encode("a") < encode("aaa")); + assert!(encode("!a") < encode("a!")); + assert!(encode("Alexis") < encode("Schlomer")); + + assert!(encode("Gungnir Rules!") < encode("Schlomer")); + assert!(encode("Gungnir Rules!") < encode("Schlomer")); + + assert_eq!(encode(" "), encode(" ")); + assert_eq!(encode("Same"), encode("Same")); + assert!(encode("Nicolas ") < encode("Nicolas💰💼")); + } +} diff --git a/optd-cost-model/src/stats/counter.rs b/optd-cost-model/src/stats/counter.rs new file mode 100644 index 0000000..baa32ab --- /dev/null +++ b/optd-cost-model/src/stats/counter.rs @@ -0,0 +1,196 @@ +use std::collections::HashMap; +use std::hash::Hash; + +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; + +/// The Counter structure to track exact frequencies of fixed elements. +#[serde_with::serde_as] +#[derive(Serialize, Deserialize, Debug)] +pub struct Counter { + #[serde_as(as = "HashMap")] + counts: HashMap, // The exact counts of an element T. + total_count: i32, // The total number of elements. +} + +// Self-contained implementation of the Counter data structure. +impl Counter +where + T: PartialEq + Eq + Hash + Clone + Serialize + DeserializeOwned, +{ + /// Creates and initializes a new empty Counter with the frequency map sized + /// based on the number of unique elements in `to_track`. + pub fn new(to_track: &[T]) -> Self { + let mut counts: HashMap = HashMap::with_capacity(to_track.len()); + for item in to_track { + counts.insert(item.clone(), 0); + } + + Counter:: { + counts, + total_count: 0, + } + } + + // Inserts an element in the Counter if it is being tracked. + pub fn insert_element(&mut self, elem: T, occ: i32) { + if let Some(frequency) = self.counts.get_mut(&elem) { + *frequency += occ; + } + } + + /// Digests an array of data into the Counter structure. + pub fn aggregate(&mut self, data: &[T]) { + data.iter() + .for_each(|key| self.insert_element(key.clone(), 1)); + self.total_count += data.len() as i32; + } + + /// Merges another Counter into the current one. + /// Particularly useful for parallel execution. + pub fn merge(&mut self, other: &Counter) { + other + .counts + .iter() + .for_each(|(key, occ)| self.insert_element(key.clone(), *occ)); + self.total_count += other.total_count; + } + + /// Returns the frequencies of the most common values. + pub fn frequencies(&self) -> HashMap { + self.counts + .iter() + .map(|(key, &value)| (key.clone(), value as f64 / self.total_count as f64)) + .collect() + } + + /// Whether the counter tracks the given key. + pub fn is_tracking(&self, key: &T) -> bool { + self.counts.contains_key(key) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::sync::{Arc, Mutex}; + + use crossbeam::thread; + use rand::rngs::StdRng; + use rand::seq::SliceRandom; + use rand::SeedableRng; + + use super::Counter; + + // Generates hardcoded frequencies and returns them, + // along with a flattened randomized array containing those frequencies. + fn generate_frequencies() -> (HashMap, Vec) { + let mut frequencies = HashMap::new(); + + frequencies.insert(0, 2); + frequencies.insert(1, 4); + frequencies.insert(2, 9); + frequencies.insert(3, 8); + frequencies.insert(4, 50); + frequencies.insert(5, 6); + + let mut flattened = Vec::new(); + for (key, &value) in &frequencies { + for _ in 0..value { + flattened.push(*key); + } + } + + let mut rng = StdRng::seed_from_u64(0); + flattened.shuffle(&mut rng); + + (frequencies, flattened) + } + + #[test] + fn aggregate() { + let to_track = vec![0, 1, 2, 3]; + let mut mcv = Counter::::new(&to_track); + + let (frequencies, flattened) = generate_frequencies(); + + mcv.aggregate(&flattened); + + let mcv_freq = mcv.frequencies(); + assert_eq!(mcv_freq.len(), to_track.len()); + + to_track.iter().for_each(|item| { + assert!(mcv_freq.contains_key(item)); + assert_eq!( + mcv_freq.get(item), + frequencies + .get(item) + .map(|e| (*e as f64 / flattened.len() as f64)) + .as_ref() + ); + }); + } + + #[test] + fn merge() { + let to_track = vec![0, 1, 2, 3]; + let n_jobs = 16; + + let total_frequencies = Arc::new(Mutex::new(HashMap::::new())); + let total_count = Arc::new(Mutex::new(0)); + let result_mcv = Arc::new(Mutex::new(Counter::::new(&to_track))); + thread::scope(|s| { + for _ in 0..n_jobs { + s.spawn(|_| { + let mut local_mcv = Counter::::new(&to_track); + + let (local_frequencies, flattened) = generate_frequencies(); + let mut total_frequencies = total_frequencies.lock().unwrap(); + let mut total_count = total_count.lock().unwrap(); + for (&key, &value) in &local_frequencies { + *total_frequencies.entry(key).or_insert(0) += value; + *total_count += value; + } + + local_mcv.aggregate(&flattened); + + let mcv_local_freq = local_mcv.frequencies(); + assert_eq!(mcv_local_freq.len(), to_track.len()); + + to_track.iter().for_each(|item| { + assert!(mcv_local_freq.contains_key(item)); + assert_eq!( + mcv_local_freq.get(item), + local_frequencies + .get(item) + .map(|e| (*e as f64 / flattened.len() as f64)) + .as_ref() + ); + }); + + let mut result = result_mcv.lock().unwrap(); + result.merge(&local_mcv); + }); + } + }) + .unwrap(); + + let mcv = result_mcv.lock().unwrap(); + let total_count = total_count.lock().unwrap(); + let mcv_freq = mcv.frequencies(); + + assert_eq!(*total_count, mcv.total_count); + to_track.iter().for_each(|item| { + assert!(mcv_freq.contains_key(item)); + assert_eq!( + mcv_freq.get(item), + total_frequencies + .lock() + .unwrap() + .get(item) + .map(|e| (*e as f64 / *total_count as f64)) + .as_ref() + ); + }); + } +} diff --git a/optd-cost-model/src/stats/mod.rs b/optd-cost-model/src/stats/mod.rs new file mode 100644 index 0000000..287b20a --- /dev/null +++ b/optd-cost-model/src/stats/mod.rs @@ -0,0 +1,123 @@ +#![allow(unused)] + +mod arith_encoder; +pub mod counter; +pub mod tdigest; + +use crate::common::values::Value; +use counter::Counter; +use serde::{Deserialize, Serialize}; + +// Default n-distinct estimate for derived columns or columns lacking statistics +pub const DEFAULT_NUM_DISTINCT: u64 = 200; +// A placeholder for unimplemented!() for codepaths which are accessed by plannertest +pub const UNIMPLEMENTED_SEL: f64 = 0.01; +// Default statistics. All are from selfuncs.h in Postgres unless specified otherwise +// Default selectivity estimate for equalities such as "A = b" +pub const DEFAULT_EQ_SEL: f64 = 0.005; +// Default selectivity estimate for inequalities such as "A < b" +pub const DEFAULT_INEQ_SEL: f64 = 0.3333333333333333; +// Used for estimating pattern selectivity character-by-character. These numbers +// are not used on their own. Depending on the characters in the pattern, the +// selectivity is multiplied by these factors. +// +// See `FULL_WILDCARD_SEL` and `FIXED_CHAR_SEL` in Postgres. +pub const FULL_WILDCARD_SEL_FACTOR: f64 = 5.0; +pub const FIXED_CHAR_SEL_FACTOR: f64 = 0.2; + +pub type AttributeCombValue = Vec>; + +#[derive(Serialize, Deserialize, Debug)] +#[serde(tag = "type")] +pub enum MostCommonValues { + Counter(Counter), + // Add more types here... +} + +impl MostCommonValues { + // it is true that we could just expose freq_over_pred() and use that for freq() and + // total_freq() however, freq() and total_freq() each have potential optimizations (freq() + // is O(1) instead of O(n) and total_freq() can be cached) + // additionally, it makes sense to return an Option for freq() instead of just 0 if value + // doesn't exist thus, I expose three different functions + pub fn freq(&self, value: &AttributeCombValue) -> Option { + match self { + MostCommonValues::Counter(counter) => counter.frequencies().get(value).copied(), + } + } + + pub fn total_freq(&self) -> f64 { + match self { + MostCommonValues::Counter(counter) => counter.frequencies().values().sum(), + } + } + + pub fn freq_over_pred(&self, pred: Box bool>) -> f64 { + match self { + MostCommonValues::Counter(counter) => counter + .frequencies() + .iter() + .filter(|(val, _)| pred(val)) + .map(|(_, freq)| freq) + .sum(), + } + } + + // returns the # of entries (i.e. value + freq) in the most common values structure + pub fn cnt(&self) -> usize { + match self { + MostCommonValues::Counter(counter) => counter.frequencies().len(), + } + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub enum Distribution { + TDigest(tdigest::TDigest), + // Add more types here... +} + +impl Distribution { + pub fn cdf(&self, value: &Value) -> f64 { + match self { + Distribution::TDigest(tdigest) => { + let nb_rows = tdigest.norm_weight; + if nb_rows == 0 { + tdigest.cdf(value) + } else { + tdigest.centroids.len() as f64 * tdigest.cdf(value) / nb_rows as f64 + } + } + } + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct AttributeCombValueStats { + pub mcvs: MostCommonValues, // Does NOT contain full nulls. + pub distr: Option, // Does NOT contain mcvs; optional. + pub ndistinct: u64, // Does NOT contain full nulls. + pub null_frac: f64, // % of full nulls. +} + +impl AttributeCombValueStats { + pub fn new( + mcvs: MostCommonValues, + ndistinct: u64, + null_frac: f64, + distr: Option, + ) -> Self { + Self { + mcvs, + ndistinct, + null_frac, + distr, + } + } +} + +impl From for AttributeCombValueStats { + fn from(value: serde_json::Value) -> Self { + serde_json::from_value(value).unwrap() + } +} diff --git a/optd-cost-model/src/stats/tdigest.rs b/optd-cost-model/src/stats/tdigest.rs new file mode 100644 index 0000000..83dc9b5 --- /dev/null +++ b/optd-cost-model/src/stats/tdigest.rs @@ -0,0 +1,395 @@ +// Copyright (c) 2023-2024 CMU Database Group +// +// Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +//! Simplified implementation of the TDigest data structure as described in +//! Ted Dunning's paper: +//! "Computing Extremely Accurate Quantiles Using t-Digests" (2019). +//! For more details, refer to: https://arxiv.org/pdf/1902.04023.pdf + +use std::f64::consts::PI; +use std::hash::Hash; +use std::marker::PhantomData; + +use itertools::Itertools; +use serde::{Deserialize, Serialize}; + +use crate::common::values::Value; + +use super::arith_encoder; + +pub const DEFAULT_COMPRESSION: f64 = 200.0; + +/// Trait to transform any object into a stream of bytes. +pub trait IntoFloat { + fn to_float(&self) -> f64; +} + +/// The TDigest structure for the statistical aggregator to query quantiles. +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct TDigest { + /// A sorted array of Centroids, according to their mean. + pub centroids: Vec, /* TODO(Alexis): Temporary fix to normalize the stats in + * stats.rs [pub]. */ + /// Compression factor: higher is more precise, but has higher memory requirements. + compression: f64, + /// Number of values in the TDigest (sum of all centroids). + total_weight: usize, + + // TODO(Alexis): Temporary fix to normalize the stats in stats.rs [field]. + pub norm_weight: usize, + + data_type: PhantomData, // For type checker. +} + +/// A Centroid is a cluster of aggregated data points. +#[derive(PartialEq, PartialOrd, Clone, Serialize, Deserialize, Debug)] +pub struct Centroid { + // TODO(Alexis): Temporary fix to normalize the stats in stats.rs [pub]. + /// Mean of all aggregated points in this cluster. + mean: f64, + /// The number of points in this cluster. + weight: usize, +} + +// Utility functions defined on a Centroid. +impl Centroid { + // Merges an existing Centroid into itself. + fn merge(&mut self, other: &Centroid) { + let weight = self.weight + other.weight; + self.mean = + ((self.mean * self.weight as f64) + (other.mean * other.weight as f64)) / weight as f64; + self.weight = weight; + } +} + +// IntoFloat implementation of optd's Value. +impl IntoFloat for Value { + fn to_float(&self) -> f64 { + match self { + Value::UInt8(v) => *v as f64, + Value::UInt16(v) => *v as f64, + Value::UInt32(v) => *v as f64, + Value::UInt64(v) => *v as f64, + Value::Int8(v) => *v as f64, + Value::Int16(v) => *v as f64, + Value::Int32(v) => *v as f64, + Value::Int64(v) => *v as f64, + Value::Float(v) => *v.0, + Value::Bool(v) => *v as i64 as f64, + Value::String(v) => arith_encoder::encode(v), + Value::Date32(v) => *v as f64, + _ => unreachable!(), + } + } +} + +// Self-contained implementation of the TDigest data structure. +impl TDigest +where + T: IntoFloat + Eq + Hash + Clone, +{ + /// Creates and initializes a new empty TDigest. + pub fn new(compression: f64) -> Self { + TDigest { + centroids: Vec::new(), + compression, + total_weight: 0, + + norm_weight: 0, + data_type: PhantomData, + } + } + + /// Ingests an array of non-NaN f64 values into the TDigest. + pub fn merge_values(&mut self, values: &[T]) { + let centroids = values + .iter() + .map(|val| val.to_float()) + .sorted_by(|a, b| a.partial_cmp(b).unwrap()) + .map(|v| Centroid { mean: v, weight: 1 }) + .collect_vec(); + let compression = self.compression; + let total_weight = centroids.len(); + + // Create an ephemeral TDigest to reuse the same interface. + self.merge(&TDigest { + centroids, + compression, + total_weight, + + norm_weight: 0, + data_type: PhantomData, + }); + } + + /// Merges two TDigests together and returns a new one. + /// Particularly useful for parallel execution. + /// Note: self to_ignore set is *NOT* updated. + pub fn merge(&mut self, other: &TDigest) { + let mut sorted_centroids = self.centroids.iter().merge(other.centroids.iter()); + + let mut new_centroids = Vec::new(); + let total_weight = self.total_weight + other.total_weight; + + // Initialize the greedy merging (copy first Centroid as a starting point). + let mut q_curr = 0.0; + let mut q_limit = self.k_rev_scale(self.k_scale(q_curr) + 1.0); + + let mut tmp_centroid = match sorted_centroids.next() { + Some(centroid) => centroid.clone(), + None => { + return; + } + }; + + // Iterate over ordered and merged Centroids (starting from index 1). + for centroid in sorted_centroids { + let q_new = (tmp_centroid.weight + centroid.weight) as f64 / total_weight as f64; + if (q_curr + q_new) <= q_limit { + tmp_centroid.merge(centroid) + } else { + q_curr += tmp_centroid.weight as f64 / total_weight as f64; + q_limit = self.k_rev_scale(self.k_scale(q_curr) + 1.0); + new_centroids.push(tmp_centroid); + tmp_centroid = centroid.clone(); + } + } + new_centroids.push(tmp_centroid); + + self.centroids = new_centroids; + self.total_weight += other.total_weight; + } + + /// Obtains a given quantile from the TDigest. + /// Returns 0.0 if TDigest is empty. + /// Performs a linear interpollation between two neighboring Centroids if needed. + /// Note: This is *not* normalized with nb_ignored. + pub fn quantile(&self, q: f64) -> f64 { + let target_cum = q * (self.total_weight as f64); + let pos_cum = self // Finds the centroid whose *cumulative weight* exceeds or equals the quantile. + .centroids + .iter() + .map(|c| c.weight) + .scan(0, |acc, weight| { + *acc += weight; + Some(*acc) + }) + .enumerate() + .find(|&(_, cum)| target_cum < (cum as f64)); + + match pos_cum { + Some((pos, cum)) => { + // TODO: We ignore edge-cases where Centroid's weights are 1. + if (pos == 0) || (pos == self.centroids.len() - 1) { + self.centroids[pos].mean + } else { + // Quantile is somewhere between in (prev+curr)/2 and (curr+next)/2 means. + let (prev, curr, next) = ( + &self.centroids[pos - 1], + &self.centroids[pos], + &self.centroids[pos + 1], + ); + let (min_q, max_q) = + ((prev.mean + curr.mean) / 2.0, (curr.mean + next.mean) / 2.0); + lerp( + min_q, + max_q, + ((cum as f64) - target_cum) / (curr.weight as f64), + ) + } + } + None => self.centroids.last().map(|c| c.mean).unwrap_or(0.0), + } + } + + /// Obtains the CDF corresponding to a given value. + /// Returns 0.0 if the TDigest is empty. + /// Note: This *is* normalized with nb_ignored. + pub fn cdf(&self, v: &T) -> f64 { + let mut cum_sum = 0; + let pos_cum = self // Finds the centroid whose *mean* exceeds or equals the given value. + .centroids + .iter() + .enumerate() + .find(|(_, c)| { + cum_sum += c.weight; // Get the cum_sum as a side effect. + v.to_float() < c.mean + }) + .map(|(pos, _)| (pos, cum_sum)); + + let nb_total = self.total_weight as f64; + match pos_cum { + Some((_pos, cum)) => { + // TODO: Can do better with 2 lerps, left as future work. + // TODO: We ignore edge-cases where Centroid's weights are 1. + (cum as f64) / nb_total + } + None => self.centroids.last().map(|_| 1.0).unwrap_or(0.0), + } + } + + // Obtains the k-distance for a given quantile. + // Note: The scaling function implemented is k1 in Ted Dunning's paper. + fn k_scale(&self, quantile: f64) -> f64 { + (self.compression / (2.0 * PI)) * (2.0 * quantile - 1.0).asin() + } + + // Obtains the quantile associated to a k-distance. + // There are probably numerical optimizations to flatten the nested + // k_scale(k_rev_scale()) calls. But let's keep it simple. + fn k_rev_scale(&self, k_distance: f64) -> f64 { + ((2.0 * PI * k_distance / self.compression).sin() + 1.0) / 2.0 + } +} + +// Performs the linear interpolation between a and b, given a fraction f. +fn lerp(a: f64, b: f64, f: f64) -> f64 { + (a * (1.0 - f)) + (b * f) +} + +#[cfg(test)] +mod tests { + use std::sync::{Arc, Mutex}; + + use crossbeam::thread; + use ordered_float::OrderedFloat; + use rand::distributions::{Distribution, Uniform, WeightedIndex}; + use rand::rngs::StdRng; + use rand::SeedableRng; + + use super::{IntoFloat, TDigest}; + + impl IntoFloat for OrderedFloat { + fn to_float(&self) -> f64 { + self.0 + } + } + + // Whether obtained = expected +/- error + fn is_close(obtained: f64, expected: f64, error: f64) -> bool { + ((expected - error) < obtained) && (obtained < (expected + error)) + } + + // Checks whether the tdigest follows a uniform distribution. + fn check_tdigest_uniform( + tdigest: &TDigest>, + buckets: i32, + max: f64, + min: f64, + error: f64, + ) { + for k in 0..buckets { + let expected_cdf = (k as f64) / (buckets as f64); + let expected_quantile = (max - min) * expected_cdf + min; + + let obtained_cdf = tdigest.cdf(&OrderedFloat(expected_quantile)); + let obtained_quantile = tdigest.quantile(expected_cdf); + + assert!(is_close(obtained_cdf, expected_cdf, error)); + assert!(is_close( + obtained_quantile, + expected_quantile, + (max - min) * error, + )); + } + } + + #[test] + fn uniform_merge_sequential() { + let buckets = 200; + let error = 0.03; // 3% absolute error on each quantile; error gets worse near the median. + let mut tdigest = TDigest::new(buckets as f64); + + let (min, max) = (-1000.0, 1000.0); + let uniform_distr = Uniform::new(min, max); + let mut rng = StdRng::seed_from_u64(0); + + let batch_size = 1024; + let batch_numbers = 64; + + for _ in 0..batch_numbers { + let mut random_numbers = Vec::with_capacity(batch_size); + for _ in 0..batch_size { + let num: f64 = uniform_distr.sample(&mut rng); + random_numbers.push(OrderedFloat(num)); + } + tdigest.merge_values(&random_numbers); + } + + check_tdigest_uniform(&tdigest, buckets, max, min, error); + } + + #[test] + fn uniform_merge_parallel() { + let buckets = 200; + let error = 0.03; // 3% absolute error on each quantile, note error is worse near the median. + + let (min, max) = (-1000.0, 1000.0); + + let batch_size = 65536; + let batch_numbers = 64; + + let result_tdigest = Arc::new(Mutex::new(TDigest::new(buckets as f64))); + thread::scope(|s| { + for _ in 0..batch_numbers { + s.spawn(|_| { + let mut local_tdigest = TDigest::new(buckets as f64); + + let mut random_numbers = Vec::with_capacity(batch_size); + let uniform_distr = Uniform::new(min, max); + let mut rng = StdRng::seed_from_u64(0); + + for _ in 0..batch_size { + let num: f64 = uniform_distr.sample(&mut rng); + random_numbers.push(OrderedFloat(num)); + } + local_tdigest.merge_values(&random_numbers); + + let mut result = result_tdigest.lock().unwrap(); + result.merge(&local_tdigest); + }); + } + }) + .unwrap(); + + let tdigest = result_tdigest.lock().unwrap(); + check_tdigest_uniform(&tdigest, buckets, max, min, error); + } + + #[test] + fn weighted_merge() { + let buckets = 200; + let error = 0.05; // 5% absolute error on each quantile, note error is worse near the median. + + let mut tdigest = TDigest::new(buckets as f64); + + let choices = [9.0, 900.0, 990.0, 9990.0, 190000.0, 990000.0]; + let weights = [1, 2, 1, 3, 4, 5]; // Total of 16. + let total_weight: i32 = weights.iter().sum(); + + let weighted_distr = WeightedIndex::new(weights).unwrap(); + let mut rng = StdRng::seed_from_u64(0); + + let batch_size = 128; + let batch_numbers = 16; + + for _ in 0..batch_numbers { + let mut random_numbers = Vec::with_capacity(batch_size); + for _ in 0..batch_size { + let num: f64 = choices[weighted_distr.sample(&mut rng)]; + random_numbers.push(OrderedFloat(num)); + } + tdigest.merge_values(&random_numbers); + } + + let mut curr_weight = 0; + for (c, w) in choices.iter().zip(weights) { + curr_weight += w; + let estimate_cdf = tdigest.cdf(&OrderedFloat(*c)); + let obtained_cdf = (curr_weight as f64) / (total_weight as f64); + assert!(is_close(obtained_cdf, estimate_cdf, error)); + } + } +}