Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: impl quantile_over_time function #1287

Merged
merged 5 commits into from
Apr 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions src/common/function-macro/src/range_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ macro_rules! ok {
}

pub(crate) fn process_range_fn(args: TokenStream, input: TokenStream) -> TokenStream {
let mut result = TokenStream::new();

// extract arg map
let arg_pairs = parse_macro_input!(args as AttributeArgs);
let arg_span = arg_pairs[0].span();
Expand All @@ -59,12 +61,17 @@ pub(crate) fn process_range_fn(args: TokenStream, input: TokenStream) -> TokenSt
let arg_types = ok!(extract_input_types(inputs));

// build the struct and its impl block
let struct_code = build_struct(
attrs,
vis,
ok!(get_ident(&arg_map, "name", arg_span)),
ok!(get_ident(&arg_map, "display_name", arg_span)),
);
// only do this when `display_name` is specified
if let Ok(display_name) = get_ident(&arg_map, "display_name", arg_span) {
let struct_code = build_struct(
attrs,
vis,
ok!(get_ident(&arg_map, "name", arg_span)),
display_name,
);
result.extend(struct_code);
}

let calc_fn_code = build_calc_fn(
ok!(get_ident(&arg_map, "name", arg_span)),
arg_types,
Expand All @@ -77,8 +84,6 @@ pub(crate) fn process_range_fn(args: TokenStream, input: TokenStream) -> TokenSt
}
.into();

let mut result = TokenStream::new();
result.extend(struct_code);
result.extend(calc_fn_code);
result.extend(input_fn_code);
result
Expand Down
11 changes: 2 additions & 9 deletions src/promql/src/extension_plan/empty_metric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ use datafusion::physical_plan::{
DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream,
};
use datafusion::prelude::Expr;
use datafusion::sql::TableReference;
use datatypes::arrow::array::TimestampMillisecondArray;
use datatypes::arrow::datatypes::SchemaRef;
use datatypes::arrow::record_batch::RecordBatch;
Expand All @@ -57,17 +56,12 @@ impl EmptyMetric {
let schema = Arc::new(DFSchema::new_with_metadata(
vec![
DFField::new(
None::<TableReference>,
Some(""),
&time_index_column_name,
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
DFField::new(
None::<TableReference>,
&value_column_name,
DataType::Float64,
true,
),
DFField::new(Some(""), &value_column_name, DataType::Float64, true),
],
HashMap::new(),
)?);
Expand All @@ -81,7 +75,6 @@ impl EmptyMetric {
}

pub fn to_execution_plan(&self) -> Arc<dyn ExecutionPlan> {
// let schema = self.schema.to
Arc::new(EmptyMetricExec {
start: self.start,
end: self.end,
Expand Down
2 changes: 2 additions & 0 deletions src/promql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ mod changes;
mod deriv;
mod extrapolate_rate;
mod idelta;
mod quantile;
mod resets;
#[cfg(test)]
mod test_util;
Expand All @@ -30,6 +31,7 @@ use datafusion::error::DataFusionError;
use datafusion::physical_plan::ColumnarValue;
pub use extrapolate_rate::{Delta, Increase, Rate};
pub use idelta::IDelta;
pub use quantile::QuantileOverTime;

pub(crate) fn extract_array(columnar_value: &ColumnarValue) -> Result<ArrayRef, DataFusionError> {
if let ColumnarValue::Array(array) = columnar_value {
Expand Down
210 changes: 210 additions & 0 deletions src/promql/src/functions/quantile.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;

use datafusion::arrow::array::Float64Array;
use datafusion::arrow::datatypes::TimeUnit;
use datafusion::common::DataFusionError;
use datafusion::logical_expr::{ScalarUDF, Signature, TypeSignature, Volatility};
use datafusion::physical_plan::ColumnarValue;
use datatypes::arrow::array::Array;
use datatypes::arrow::datatypes::DataType;

use crate::error;
use crate::functions::extract_array;
use crate::range_array::RangeArray;

pub struct QuantileOverTime {
quantile: f64,
}

impl QuantileOverTime {
fn new(quantile: f64) -> Self {
Self { quantile }
}

pub const fn name() -> &'static str {
"prom_quantile_over_time"
}

pub fn scalar_udf(quantile: f64) -> ScalarUDF {
ScalarUDF {
name: Self::name().to_string(),
signature: Signature::new(
TypeSignature::Exact(Self::input_type()),
Volatility::Immutable,
),
return_type: Arc::new(|_| Ok(Arc::new(Self::return_type()))),
fun: Arc::new(move |input| Self::new(quantile).calc(input)),
}
}

// time index column and value column
fn input_type() -> Vec<DataType> {
vec![
RangeArray::convert_data_type(DataType::Timestamp(TimeUnit::Millisecond, None)),
RangeArray::convert_data_type(DataType::Float64),
]
}

fn return_type() -> DataType {
DataType::Float64
}

fn calc(&self, input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
// construct matrix from input.
// The third one is quantile param, which is included in fields.
assert_eq!(input.len(), 3);
let ts_array = extract_array(&input[0])?;
let value_array = extract_array(&input[1])?;

let ts_range: RangeArray = RangeArray::try_new(ts_array.data().clone().into())?;
let value_range: RangeArray = RangeArray::try_new(value_array.data().clone().into())?;
error::ensure(
ts_range.len() == value_range.len(),
DataFusionError::Execution(format!(
"{}: input arrays should have the same length, found {} and {}",
Self::name(),
ts_range.len(),
value_range.len()
)),
)?;
error::ensure(
ts_range.value_type() == DataType::Timestamp(TimeUnit::Millisecond, None),
DataFusionError::Execution(format!(
"{}: expect TimestampMillisecond as time index array's type, found {}",
Self::name(),
ts_range.value_type()
)),
)?;
error::ensure(
value_range.value_type() == DataType::Float64,
DataFusionError::Execution(format!(
"{}: expect Float64 as value array's type, found {}",
Self::name(),
value_range.value_type()
)),
)?;

// calculation
let mut result_array = Vec::with_capacity(ts_range.len());

for index in 0..ts_range.len() {
let timestamps = ts_range.get(index).unwrap();
let values = value_range.get(index).unwrap();
let values = values
.as_any()
.downcast_ref::<Float64Array>()
.unwrap()
.values();
error::ensure(
timestamps.len() == values.len(),
DataFusionError::Execution(format!(
"{}: input arrays should have the same length, found {} and {}",
Self::name(),
timestamps.len(),
values.len()
)),
)?;

let retule = quantile_impl(values, self.quantile);

result_array.push(retule);
}

let result = ColumnarValue::Array(Arc::new(Float64Array::from_iter(result_array)));
Ok(result)
}
}

/// Refer to https://github.com/prometheus/prometheus/blob/6e2905a4d4ff9b47b1f6d201333f5bd53633f921/promql/quantile.go#L357-L386
fn quantile_impl(values: &[f64], quantile: f64) -> Option<f64> {
if quantile.is_nan() || values.is_empty() {
return Some(f64::NAN);
}
if quantile < 0.0 {
return Some(f64::NEG_INFINITY);
}
if quantile > 1.0 {
return Some(f64::INFINITY);
}

let mut values = values.to_vec();
values.sort_unstable_by(f64::total_cmp);

let length = values.len();
let rank = quantile * (length - 1) as f64;

let lower_index = 0.max(rank.floor() as usize);
let upper_index = (length - 1).min(lower_index + 1);
let weight = rank - rank.floor();

let result = values[lower_index] * (1.0 - weight) + values[upper_index] * weight;
Some(result)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_quantile_impl_empty() {
let values = &[];
let q = 0.5;
assert!(quantile_impl(values, q).unwrap().is_nan());
}

#[test]
fn test_quantile_impl_nan() {
let values = &[1.0, 2.0, 3.0];
let q = f64::NAN;
assert!(quantile_impl(values, q).unwrap().is_nan());
}

#[test]
fn test_quantile_impl_negative_quantile() {
let values = &[1.0, 2.0, 3.0];
let q = -0.5;
assert_eq!(quantile_impl(values, q).unwrap(), f64::NEG_INFINITY);
}

#[test]
fn test_quantile_impl_greater_than_one_quantile() {
let values = &[1.0, 2.0, 3.0];
let q = 1.5;
assert_eq!(quantile_impl(values, q).unwrap(), f64::INFINITY);
}

#[test]
fn test_quantile_impl_single_element() {
let values = &[1.0];
let q = 0.8;
assert_eq!(quantile_impl(values, q).unwrap(), 1.0);
}

#[test]
fn test_quantile_impl_even_length() {
let values = &[3.0, 1.0, 5.0, 2.0];
let q = 0.5;
assert_eq!(quantile_impl(values, q).unwrap(), 2.5);
}

#[test]
fn test_quantile_impl_odd_length() {
let values = &[4.0, 1.0, 3.0, 2.0, 5.0];
let q = 0.25;
assert_eq!(quantile_impl(values, q).unwrap(), 2.0);
}
}
16 changes: 13 additions & 3 deletions src/promql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ use table::table::adapter::DfTableProviderAdapter;
use crate::error::{
CatalogSnafu, DataFusionPlanningSnafu, ExpectExprSnafu, ExpectRangeSelectorSnafu,
MultipleVectorSnafu, Result, TableNameNotFoundSnafu, TimeIndexNotFoundSnafu,
UnexpectedTokenSnafu, UnknownTableSnafu, UnsupportedExprSnafu, ValueNotFoundSnafu,
ZeroRangeSelectorSnafu,
UnexpectedPlanExprSnafu, UnexpectedTokenSnafu, UnknownTableSnafu, UnsupportedExprSnafu,
ValueNotFoundSnafu, ZeroRangeSelectorSnafu,
};
use crate::extension_plan::{
EmptyMetric, InstantManipulate, Millisecond, RangeManipulate, SeriesDivide, SeriesNormalize,
};
use crate::functions::{
AbsentOverTime, AvgOverTime, CountOverTime, Delta, IDelta, Increase, LastOverTime, MaxOverTime,
MinOverTime, PresentOverTime, Rate, SumOverTime,
MinOverTime, PresentOverTime, QuantileOverTime, Rate, SumOverTime,
};

const LEFT_PLAN_JOIN_ALIAS: &str = "lhs";
Expand Down Expand Up @@ -692,6 +692,16 @@ impl PromPlanner {
"last_over_time" => ScalarFunc::Udf(LastOverTime::scalar_udf()),
"absent_over_time" => ScalarFunc::Udf(AbsentOverTime::scalar_udf()),
"present_over_time" => ScalarFunc::Udf(PresentOverTime::scalar_udf()),
"quantile_over_time" => {
let quantile_expr = match other_input_exprs.get(0) {
Some(DfExpr::Literal(ScalarValue::Float64(Some(quantile)))) => *quantile,
other => UnexpectedPlanExprSnafu {
desc: format!("expect f64 literal as quantile, but found {:?}", other),
}
.fail()?,
};
ScalarFunc::Udf(QuantileOverTime::scalar_udf(quantile_expr))
}
_ => ScalarFunc::DataFusionBuiltin(
BuiltinScalarFunction::from_str(func.name).map_err(|_| {
UnsupportedExprSnafu {
Expand Down