diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 87b0656d171d..48011a323764 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -2,7 +2,9 @@ use std::ops::Sub; use polars_core::chunked_array::ops::{SortMultipleOptions, SortOptions}; use polars_core::export::regex; -use polars_core::prelude::{polars_bail, polars_err, DataType, PolarsResult, Schema, TimeUnit}; +use polars_core::prelude::{ + polars_bail, polars_err, DataType, PolarsResult, QuantileInterpolOptions, Schema, TimeUnit, +}; use polars_lazy::dsl::Expr; #[cfg(feature = "list_eval")] use polars_lazy::dsl::ListNameSpaceExtension; @@ -504,6 +506,13 @@ pub(crate) enum PolarsSQLFunctions { /// SELECT MEDIAN(column_1) FROM df; /// ``` Median, + /// SQL 'quantile_cont' function + /// Returns the continuous quantile element from the grouping + /// (interpolated value between two closest values). + /// ```sql + /// SELECT QUANTILE_CONT(column_1) FROM df; + /// ``` + QuantileCont, /// SQL 'min' function /// Returns the smallest (minimum) of all the elements in the grouping. /// ```sql @@ -686,6 +695,7 @@ impl PolarsSQLFunctions { "pi", "pow", "power", + "quantile_cont", "radians", "regexp_like", "replace", @@ -818,6 +828,7 @@ impl PolarsSQLFunctions { "last" => Self::Last, "max" => Self::Max, "median" => Self::Median, + "quantile_cont" => Self::QuantileCont, "min" => Self::Min, "stdev" | "stddev" | "stdev_samp" | "stddev_samp" => Self::StdDev, "sum" => Self::Sum, @@ -1243,6 +1254,32 @@ impl SQLFunctionVisitor<'_> { Last => self.visit_unary(Expr::last), Max => self.visit_unary_with_opt_cumulative(Expr::max, Expr::cum_max), Median => self.visit_unary(Expr::median), + QuantileCont => { + let args = extract_args(function)?; + match args.len() { + 2 => self.try_visit_binary(|e, q| { + let value = match q { + Expr::Literal(LiteralValue::Float(f)) => { + if (0.0..=1.0).contains(&f) { + Expr::from(f) + } else { + polars_bail!(SQLSyntax: "QUANTILE_CONT value must be between 0 and 1 ({})", args[1]) + } + }, + Expr::Literal(LiteralValue::Int(n)) => { + if (0..=1).contains(&n) { + Expr::from(n as f64) + } else { + polars_bail!(SQLSyntax: "QUANTILE_CONT value must be between 0 and 1 ({})", args[1]) + } + }, + _ => polars_bail!(SQLSyntax: "invalid value for QUANTILE_CONT ({})", args[1]) + }; + Ok(e.quantile(value, QuantileInterpolOptions::Linear)) + }), + _ => polars_bail!(SQLSyntax: "QUANTILE_CONT expects 2 arguments (found {})", args.len()), + } + }, Min => self.visit_unary_with_opt_cumulative(Expr::min, Expr::cum_min), StdDev => self.visit_unary(|e| e.std(1)), Sum => self.visit_unary_with_opt_cumulative(Expr::sum, Expr::cum_sum), diff --git a/crates/polars-sql/tests/functions_aggregate.rs b/crates/polars-sql/tests/functions_aggregate.rs new file mode 100644 index 000000000000..621ca18bd355 --- /dev/null +++ b/crates/polars-sql/tests/functions_aggregate.rs @@ -0,0 +1,65 @@ +use polars_core::prelude::*; +use polars_lazy::prelude::*; +use polars_plan::dsl::Expr; +use polars_sql::*; + +fn create_df() -> LazyFrame { + df! { + "Year" => [2018, 2018, 2019, 2019, 2020, 2020], + "Country" => ["US", "UK", "US", "UK", "US", "UK"], + "Sales" => [1000, 2000, 3000, 4000, 5000, 6000] + } + .unwrap() + .lazy() +} + +fn create_expected(expr: Expr, sql: &str) -> (DataFrame, DataFrame) { + let df = create_df(); + let alias = "TEST"; + + let query = format!( + r#" + SELECT + {sql} as {alias} + FROM + df + "# + ); + + let expected = df + .clone() + .select(&[expr.alias(alias)]) + .sort([alias], Default::default()) + .collect() + .unwrap(); + let mut ctx = SQLContext::new(); + ctx.register("df", df); + + let actual = ctx.execute(&query).unwrap().collect().unwrap(); + (expected, actual) +} + +#[test] +fn test_median() { + let expr = col("Sales").median(); + + let sql_expr = "MEDIAN(Sales)"; + let (expected, actual) = create_expected(expr, sql_expr); + + assert!(expected.equals(&actual)) +} + +#[test] +fn test_quantile_cont() { + for &q in &[0.25, 0.5, 0.75] { + let expr = col("Sales").quantile(lit(q), QuantileInterpolOptions::Linear); + + let sql_expr = format!("QUANTILE_CONT(Sales, {})", q); + let (expected, actual) = create_expected(expr, &sql_expr); + + assert!( + expected.equals(&actual), + "q: {q}: expected {expected:?}, got {actual:?}" + ) + } +}