diff --git a/Cargo.lock b/Cargo.lock index 76ea5784..4ae3ec1f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -107,12 +107,12 @@ dependencies = [ "serde_json", "snap", "strum 0.25.0", - "strum_macros 0.25.2", + "strum_macros 0.25.3", "thiserror", "typed-builder", "uuid", "xz2", - "zstd", + "zstd 0.12.4", ] [[package]] @@ -312,7 +312,7 @@ version = "47.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d1d179c117b158853e0101bfbed5615e86fe97ee356b4af901f1c5001e1ce4b" dependencies = [ - "bitflags 2.4.0", + "bitflags 2.4.1", ] [[package]] @@ -347,9 +347,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb42b2197bf15ccb092b62c74515dbd8b86d0effd934795f6687c93b6e679a2c" +checksum = "f658e2baef915ba0f26f1f7c42bfb8e12f532a01f449a090ded75ae7a07e9ba2" dependencies = [ "bzip2", "flate2", @@ -359,8 +359,8 @@ dependencies = [ "pin-project-lite", "tokio", "xz2", - "zstd", - "zstd-safe", + "zstd 0.13.0", + "zstd-safe 7.0.0", ] [[package]] @@ -376,9 +376,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.73" +version = "0.1.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc00ceb34980c03614e35a3a4e218276a0a824e911d07651cd0d858a51e8c0f0" +checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" dependencies = [ "proc-macro2", "quote", @@ -429,9 +429,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635" +checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" [[package]] name = "blake2" @@ -760,7 +760,7 @@ dependencies = [ "url", "uuid", "xz2", - "zstd", + "zstd 0.12.4", ] [[package]] @@ -817,7 +817,7 @@ dependencies = [ "datafusion-common", "sqlparser", "strum 0.25.0", - "strum_macros 0.25.2", + "strum_macros 0.25.3", ] [[package]] @@ -923,7 +923,7 @@ dependencies = [ "pyo3", "pyo3-build-config 0.20.0", "rand", - "regex-syntax 0.8.1", + "regex-syntax 0.8.2", "syn 2.0.38", "tokio", "url", @@ -1039,9 +1039,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.27" +version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6c98ee8095e9d1dcbf2fcc6d95acccb90d1c81db1e44725c6a984b1dbdfb010" +checksum = "46303f565772937ffe1d394a4fac6f411c6013172fadde9dcdb1e147a086940e" dependencies = [ "crc32fast", "miniz_oxide", @@ -1184,7 +1184,7 @@ version = "0.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fbf97ba92db08df386e10c8ede66a2a0369bd277090afd8710e19e38de9ec0cd" dependencies = [ - "bitflags 2.4.0", + "bitflags 2.4.1", "libc", "libgit2-sys", "log", @@ -1359,16 +1359,16 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.57" +version = "0.1.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fad5b825842d2b38bd206f3e81d6957625fd7f0a361e345c30e01a0ae2dd613" +checksum = "8326b86b6cff230b97d0d312a6c40a60726df3332e721f72a1b035f451663b20" dependencies = [ "android_system_properties", "core-foundation-sys", "iana-time-zone-haiku", "js-sys", "wasm-bindgen", - "windows", + "windows-core", ] [[package]] @@ -1924,7 +1924,7 @@ dependencies = [ "thrift", "tokio", "twox-hash", - "zstd", + "zstd 0.12.4", ] [[package]] @@ -2242,32 +2242,32 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.0" +version = "1.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d119d7c7ca818f8a53c300863d4f87566aac09943aef5b355bb83969dae75d87" +checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" dependencies = [ "aho-corasick", "memchr", "regex-automata", - "regex-syntax 0.8.1", + "regex-syntax 0.8.2", ] [[package]] name = "regex-automata" -version = "0.4.1" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "465c6fc0621e4abc4187a2bda0937bfd4f722c2730b29562e19689ea796c9a4b" +checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.1", + "regex-syntax 0.8.2", ] [[package]] name = "regex-lite" -version = "0.1.3" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a6ebcd15653947e6140f59a9811a06ed061d18a5c35dfca2e2e4c5525696878" +checksum = "30b661b2f27137bdbc16f00eda72866a92bb28af1753ffbd56744fb6e2e9cd8e" [[package]] name = "regex-syntax" @@ -2277,9 +2277,9 @@ checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" [[package]] name = "regex-syntax" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56d84fdd47036b038fc80dd333d10b6aab10d5d31f4a366e20014def75328d33" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "regress" @@ -2371,11 +2371,11 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.18" +version = "0.38.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a74ee2d7c2581cd139b42447d7d9389b889bdaad3a73f1ebb16f2a3237bb19c" +checksum = "745ecfa778e66b2b63c88a61cb36e0eea109e803b0b86bf9879fbc77c70e86ed" dependencies = [ - "bitflags 2.4.0", + "bitflags 2.4.1", "errno", "libc", "linux-raw-sys", @@ -2488,18 +2488,18 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.188" +version = "1.0.189" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" +checksum = "8e422a44e74ad4001bdc8eede9a4570ab52f71190e9c076d14369f38b9200537" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.188" +version = "1.0.189" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" +checksum = "1e48d1f918009ce3145511378cf68d613e3b3d9137d67272562080d68a2b32d5" dependencies = [ "proc-macro2", "quote", @@ -2690,7 +2690,7 @@ version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" dependencies = [ - "strum_macros 0.25.2", + "strum_macros 0.25.3", ] [[package]] @@ -2708,9 +2708,9 @@ dependencies = [ [[package]] name = "strum_macros" -version = "0.25.2" +version = "0.25.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad8d03b598d3d0fff69bf533ee3ef19b8eeb342729596df84bcc7e1f96ec4059" +checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" dependencies = [ "heck", "proc-macro2", @@ -2926,11 +2926,10 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" [[package]] name = "tracing" -version = "0.1.37" +version = "0.1.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" +checksum = "ee2ef2af84856a50c1d430afce2fdded0a4ec7eda868db86409b4543df0797f9" dependencies = [ - "cfg-if", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -2938,9 +2937,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", @@ -2949,9 +2948,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.31" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", ] @@ -3290,10 +3289,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] -name = "windows" -version = "0.48.0" +name = "windows-core" +version = "0.51.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f" +checksum = "f1f8cf84f35d2db49a46868f947758c7a1138116f7fac3bc844f43ade1292e64" dependencies = [ "windows-targets", ] @@ -3389,7 +3388,16 @@ version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a27595e173641171fc74a1232b7b1c7a7cb6e18222c11e9dfb9888fa424c53c" dependencies = [ - "zstd-safe", + "zstd-safe 6.0.6", +] + +[[package]] +name = "zstd" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bffb3309596d527cfcba7dfc6ed6052f1d39dfbd7c867aa2e865e4a449c10110" +dependencies = [ + "zstd-safe 7.0.0", ] [[package]] @@ -3402,6 +3410,15 @@ dependencies = [ "zstd-sys", ] +[[package]] +name = "zstd-safe" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43747c7422e2924c11144d5229878b98180ef8b06cca4ab5af37afc8a8d8ea3e" +dependencies = [ + "zstd-sys", +] + [[package]] name = "zstd-sys" version = "2.0.9+zstd.1.5.5" diff --git a/datafusion/__init__.py b/datafusion/__init__.py index 4a495b46..c854f3f9 100644 --- a/datafusion/__init__.py +++ b/datafusion/__init__.py @@ -33,7 +33,6 @@ SessionConfig, RuntimeConfig, ScalarUDF, - WindowFrame, ) from .common import ( @@ -86,6 +85,8 @@ DropTable, Repartition, Partitioning, + Window, + WindowFrame, ) __version__ = importlib_metadata.version(__name__) @@ -99,6 +100,7 @@ "Expr", "AggregateUDF", "ScalarUDF", + "Window", "WindowFrame", "column", "literal", diff --git a/src/common/data_type.rs b/src/common/data_type.rs index 078b8c84..405a5632 100644 --- a/src/common/data_type.rs +++ b/src/common/data_type.rs @@ -329,6 +329,7 @@ impl DataTypeMap { } "float" => Ok(DataType::Float32), "double" => Ok(DataType::Float64), + "byte_array" => Ok(DataType::Utf8), _ => Err(PyValueError::new_err(format!( "Unable to determine Arrow Data Type from Parquet String type: {:?}", parquet_str_type @@ -604,13 +605,30 @@ impl PyDataType { /// is presented as a String rather than an actual DataType. This function is used to /// convert that String to a DataType for the Python side to use. pub fn py_map_from_arrow_type_str(arrow_str_type: String) -> PyResult { + // Certain string types contain "metadata" that should be trimmed here. Ex: "datetime64[ns, Europe/Berlin]" + let arrow_str_type = match arrow_str_type.find('[') { + Some(index) => arrow_str_type[0..index].to_string(), + None => arrow_str_type, // Return early if ',' is not found. + }; + let arrow_dtype = match arrow_str_type.to_lowercase().as_str() { + "bool" => Ok(DataType::Boolean), "boolean" => Ok(DataType::Boolean), + "uint8" => Ok(DataType::UInt8), + "uint16" => Ok(DataType::UInt16), + "uint32" => Ok(DataType::UInt32), + "uint64" => Ok(DataType::UInt64), + "int8" => Ok(DataType::Int8), + "int16" => Ok(DataType::Int16), "int32" => Ok(DataType::Int32), "int64" => Ok(DataType::Int64), "float" => Ok(DataType::Float32), "double" => Ok(DataType::Float64), + "float16" => Ok(DataType::Float16), + "float32" => Ok(DataType::Float32), "float64" => Ok(DataType::Float64), + "datetime64" => Ok(DataType::Date64), + "object" => Ok(DataType::Utf8), _ => Err(PyValueError::new_err(format!( "Unable to determine Arrow Data Type from Arrow String type: {:?}", arrow_str_type diff --git a/src/expr.rs b/src/expr.rs index ecf8fae3..e502edce 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -92,6 +92,7 @@ pub mod subquery; pub mod subquery_alias; pub mod table_scan; pub mod union; +pub mod window; /// A PyExpr that can be used on a DataFrame #[pyclass(name = "Expr", module = "datafusion.expr", subclass)] @@ -112,6 +113,11 @@ impl From for PyExpr { } } +/// Convert a list of DataFusion Expr to PyExpr +pub fn py_expr_list(expr: &[Expr]) -> PyResult> { + Ok(expr.iter().map(|e| PyExpr::from(e.clone())).collect()) +} + #[pymethods] impl PyExpr { /// Return the specific expression @@ -542,6 +548,10 @@ impl PyExpr { // appear in projections) so we just delegate to the contained expression instead Self::expr_to_field(expr, input_plan) } + Expr::Wildcard => { + // Since * could be any of the valid column names just return the first one + Ok(input_plan.schema().field(0).clone()) + } _ => { let fields = exprlist_to_fields(&[expr.clone()], input_plan).map_err(PyErr::from)?; @@ -652,5 +662,8 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/src/expr/window.rs b/src/expr/window.rs new file mode 100644 index 00000000..6583c97a --- /dev/null +++ b/src/expr/window.rs @@ -0,0 +1,294 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_expr::expr::WindowFunction; +use datafusion_expr::{Expr, Window, WindowFrame, WindowFrameBound, WindowFrameUnits}; +use pyo3::prelude::*; +use std::fmt::{self, Display, Formatter}; + +use crate::common::df_schema::PyDFSchema; +use crate::errors::py_type_err; +use crate::expr::logical_node::LogicalNode; +use crate::expr::PyExpr; +use crate::sql::logical::PyLogicalPlan; + +use super::py_expr_list; + +use crate::errors::py_datafusion_err; + +#[pyclass(name = "Window", module = "datafusion.expr", subclass)] +#[derive(Clone)] +pub struct PyWindow { + window: Window, +} + +#[pyclass(name = "WindowFrame", module = "datafusion.expr", subclass)] +#[derive(Clone)] +pub struct PyWindowFrame { + window_frame: WindowFrame, +} + +impl From for WindowFrame { + fn from(window_frame: PyWindowFrame) -> Self { + window_frame.window_frame + } +} + +impl From for PyWindowFrame { + fn from(window_frame: WindowFrame) -> PyWindowFrame { + PyWindowFrame { window_frame } + } +} + +#[pyclass(name = "WindowFrameBound", module = "datafusion.expr", subclass)] +#[derive(Clone)] +pub struct PyWindowFrameBound { + frame_bound: WindowFrameBound, +} + +impl From for Window { + fn from(window: PyWindow) -> Window { + window.window + } +} + +impl From for PyWindow { + fn from(window: Window) -> PyWindow { + PyWindow { window } + } +} + +impl From for PyWindowFrameBound { + fn from(frame_bound: WindowFrameBound) -> Self { + PyWindowFrameBound { frame_bound } + } +} + +impl Display for PyWindow { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!( + f, + "Over\n + Window Expr: {:?} + Schema: {:?}", + &self.window.window_expr, &self.window.schema + ) + } +} + +impl Display for PyWindowFrame { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!( + f, + "OVER ({} BETWEEN {} AND {})", + self.window_frame.units, self.window_frame.start_bound, self.window_frame.end_bound + ) + } +} + +#[pymethods] +impl PyWindow { + /// Returns the schema of the Window + pub fn schema(&self) -> PyResult { + Ok(self.window.schema.as_ref().clone().into()) + } + + /// Returns window expressions + pub fn get_window_expr(&self) -> PyResult> { + py_expr_list(&self.window.window_expr) + } + + /// Returns order by columns in a window function expression + pub fn get_sort_exprs(&self, expr: PyExpr) -> PyResult> { + match expr.expr.unalias() { + Expr::WindowFunction(WindowFunction { order_by, .. }) => py_expr_list(&order_by), + other => Err(not_window_function_err(other)), + } + } + + /// Return partition by columns in a window function expression + pub fn get_partition_exprs(&self, expr: PyExpr) -> PyResult> { + match expr.expr.unalias() { + Expr::WindowFunction(WindowFunction { partition_by, .. }) => { + py_expr_list(&partition_by) + } + other => Err(not_window_function_err(other)), + } + } + + /// Return input args for window function + pub fn get_args(&self, expr: PyExpr) -> PyResult> { + match expr.expr.unalias() { + Expr::WindowFunction(WindowFunction { args, .. }) => py_expr_list(&args), + other => Err(not_window_function_err(other)), + } + } + + /// Return window function name + pub fn window_func_name(&self, expr: PyExpr) -> PyResult { + match expr.expr.unalias() { + Expr::WindowFunction(WindowFunction { fun, .. }) => Ok(fun.to_string()), + other => Err(not_window_function_err(other)), + } + } + + /// Returns a Pywindow frame for a given window function expression + pub fn get_frame(&self, expr: PyExpr) -> Option { + match expr.expr.unalias() { + Expr::WindowFunction(WindowFunction { window_frame, .. }) => Some(window_frame.into()), + _ => None, + } + } +} + +fn not_window_function_err(expr: Expr) -> PyErr { + py_type_err(format!( + "Provided {} Expr {:?} is not a WindowFunction type", + expr.variant_name(), + expr + )) +} + +#[pymethods] +impl PyWindowFrame { + #[new(unit, start_bound, end_bound)] + pub fn new(units: &str, start_bound: Option, end_bound: Option) -> PyResult { + let units = units.to_ascii_lowercase(); + let units = match units.as_str() { + "rows" => WindowFrameUnits::Rows, + "range" => WindowFrameUnits::Range, + "groups" => WindowFrameUnits::Groups, + _ => { + return Err(py_datafusion_err(DataFusionError::NotImplemented(format!( + "{:?}", + units, + )))); + } + }; + let start_bound = match start_bound { + Some(start_bound) => { + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(start_bound))) + } + None => match units { + WindowFrameUnits::Range => WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + WindowFrameUnits::Groups => { + return Err(py_datafusion_err(DataFusionError::NotImplemented(format!( + "{:?}", + units, + )))); + } + }, + }; + let end_bound = match end_bound { + Some(end_bound) => WindowFrameBound::Following(ScalarValue::UInt64(Some(end_bound))), + None => match units { + WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)), + WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::UInt64(None)), + WindowFrameUnits::Groups => { + return Err(py_datafusion_err(DataFusionError::NotImplemented(format!( + "{:?}", + units, + )))); + } + }, + }; + Ok(PyWindowFrame { + window_frame: WindowFrame { + units, + start_bound, + end_bound, + }, + }) + } + + /// Returns the window frame units for the bounds + pub fn get_frame_units(&self) -> PyResult { + Ok(self.window_frame.units.to_string()) + } + /// Returns starting bound + pub fn get_lower_bound(&self) -> PyResult { + Ok(self.window_frame.start_bound.clone().into()) + } + /// Returns end bound + pub fn get_upper_bound(&self) -> PyResult { + Ok(self.window_frame.end_bound.clone().into()) + } + + /// Get a String representation of this window frame + fn __repr__(&self) -> String { + format!("{}", self) + } +} + +#[pymethods] +impl PyWindowFrameBound { + /// Returns if the frame bound is current row + pub fn is_current_row(&self) -> bool { + matches!(self.frame_bound, WindowFrameBound::CurrentRow) + } + + /// Returns if the frame bound is preceding + pub fn is_preceding(&self) -> bool { + matches!(self.frame_bound, WindowFrameBound::Preceding(_)) + } + + /// Returns if the frame bound is following + pub fn is_following(&self) -> bool { + matches!(self.frame_bound, WindowFrameBound::Following(_)) + } + /// Returns the offset of the window frame + pub fn get_offset(&self) -> PyResult> { + match &self.frame_bound { + WindowFrameBound::Preceding(val) | WindowFrameBound::Following(val) => match val { + x if x.is_null() => Ok(None), + ScalarValue::UInt64(v) => Ok(*v), + // The cast below is only safe because window bounds cannot be negative + ScalarValue::Int64(v) => Ok(v.map(|n| n as u64)), + ScalarValue::Utf8(Some(s)) => match s.parse::() { + Ok(s) => Ok(Some(s)), + Err(_e) => Err(DataFusionError::Plan(format!( + "Unable to parse u64 from Utf8 value '{s}'" + )) + .into()), + }, + ref x => { + Err(DataFusionError::Plan(format!("Unexpected window frame bound: {x}")).into()) + } + }, + WindowFrameBound::CurrentRow => Ok(None), + } + } + /// Returns if the frame bound is unbounded + pub fn is_unbounded(&self) -> PyResult { + match &self.frame_bound { + WindowFrameBound::Preceding(v) | WindowFrameBound::Following(v) => Ok(v.is_null()), + WindowFrameBound::CurrentRow => Ok(false), + } + } +} + +impl LogicalNode for PyWindow { + fn inputs(&self) -> Vec { + vec![self.window.input.as_ref().clone().into()] + } + + fn to_variant(&self, py: Python) -> PyResult { + Ok(self.clone().into_py(py)) + } +} diff --git a/src/functions.rs b/src/functions.rs index 42203d7b..be903609 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -20,8 +20,8 @@ use pyo3::{prelude::*, wrap_pyfunction}; use crate::context::PySessionContext; use crate::errors::DataFusionError; use crate::expr::conditional_expr::PyCaseBuilder; +use crate::expr::window::PyWindowFrame; use crate::expr::PyExpr; -use crate::window_frame::PyWindowFrame; use datafusion::execution::FunctionRegistry; use datafusion_common::Column; use datafusion_expr::expr::Alias; diff --git a/src/lib.rs b/src/lib.rs index b9bd5766..2512aefa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -54,7 +54,6 @@ mod udaf; #[allow(clippy::borrow_deref_ref)] mod udf; pub mod utils; -mod window_frame; #[cfg(feature = "mimalloc")] #[global_allocator] @@ -84,7 +83,6 @@ fn _internal(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/sql/logical.rs b/src/sql/logical.rs index 2183155b..3aa8a699 100644 --- a/src/sql/logical.rs +++ b/src/sql/logical.rs @@ -20,17 +20,20 @@ use std::sync::Arc; use crate::errors::py_unsupported_variant_err; use crate::expr::aggregate::PyAggregate; use crate::expr::analyze::PyAnalyze; +use crate::expr::cross_join::PyCrossJoin; use crate::expr::distinct::PyDistinct; use crate::expr::empty_relation::PyEmptyRelation; use crate::expr::explain::PyExplain; use crate::expr::extension::PyExtension; use crate::expr::filter::PyFilter; +use crate::expr::join::PyJoin; use crate::expr::limit::PyLimit; use crate::expr::projection::PyProjection; use crate::expr::sort::PySort; use crate::expr::subquery::PySubquery; use crate::expr::subquery_alias::PySubqueryAlias; use crate::expr::table_scan::PyTableScan; +use crate::expr::window::PyWindow; use datafusion_expr::LogicalPlan; use pyo3::prelude::*; @@ -62,17 +65,20 @@ impl PyLogicalPlan { Python::with_gil(|_| match self.plan.as_ref() { LogicalPlan::Aggregate(plan) => PyAggregate::from(plan.clone()).to_variant(py), LogicalPlan::Analyze(plan) => PyAnalyze::from(plan.clone()).to_variant(py), + LogicalPlan::CrossJoin(plan) => PyCrossJoin::from(plan.clone()).to_variant(py), + LogicalPlan::Distinct(plan) => PyDistinct::from(plan.clone()).to_variant(py), LogicalPlan::EmptyRelation(plan) => PyEmptyRelation::from(plan.clone()).to_variant(py), LogicalPlan::Explain(plan) => PyExplain::from(plan.clone()).to_variant(py), LogicalPlan::Extension(plan) => PyExtension::from(plan.clone()).to_variant(py), - LogicalPlan::Distinct(plan) => PyDistinct::from(plan.clone()).to_variant(py), LogicalPlan::Filter(plan) => PyFilter::from(plan.clone()).to_variant(py), + LogicalPlan::Join(plan) => PyJoin::from(plan.clone()).to_variant(py), LogicalPlan::Limit(plan) => PyLimit::from(plan.clone()).to_variant(py), LogicalPlan::Projection(plan) => PyProjection::from(plan.clone()).to_variant(py), LogicalPlan::Sort(plan) => PySort::from(plan.clone()).to_variant(py), LogicalPlan::TableScan(plan) => PyTableScan::from(plan.clone()).to_variant(py), LogicalPlan::Subquery(plan) => PySubquery::from(plan.clone()).to_variant(py), LogicalPlan::SubqueryAlias(plan) => PySubqueryAlias::from(plan.clone()).to_variant(py), + LogicalPlan::Window(plan) => PyWindow::from(plan.clone()).to_variant(py), other => Err(py_unsupported_variant_err(format!( "Cannot convert this plan to a LogicalNode: {:?}", other diff --git a/src/window_frame.rs b/src/window_frame.rs deleted file mode 100644 index b8f414e6..00000000 --- a/src/window_frame.rs +++ /dev/null @@ -1,110 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion_common::{DataFusionError, ScalarValue}; -use datafusion_expr::window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; -use pyo3::prelude::*; -use std::fmt::{Display, Formatter}; - -use crate::errors::py_datafusion_err; - -#[pyclass(name = "WindowFrame", module = "datafusion", subclass)] -#[derive(Clone)] -pub struct PyWindowFrame { - frame: WindowFrame, -} - -impl From for WindowFrame { - fn from(frame: PyWindowFrame) -> Self { - frame.frame - } -} - -impl From for PyWindowFrame { - fn from(frame: WindowFrame) -> PyWindowFrame { - PyWindowFrame { frame } - } -} - -impl Display for PyWindowFrame { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - write!( - f, - "OVER ({} BETWEEN {} AND {})", - self.frame.units, self.frame.start_bound, self.frame.end_bound - ) - } -} - -#[pymethods] -impl PyWindowFrame { - #[new(unit, start_bound, end_bound)] - pub fn new(units: &str, start_bound: Option, end_bound: Option) -> PyResult { - let units = units.to_ascii_lowercase(); - let units = match units.as_str() { - "rows" => WindowFrameUnits::Rows, - "range" => WindowFrameUnits::Range, - "groups" => WindowFrameUnits::Groups, - _ => { - return Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - units, - )))); - } - }; - let start_bound = match start_bound { - Some(start_bound) => { - WindowFrameBound::Preceding(ScalarValue::UInt64(Some(start_bound))) - } - None => match units { - WindowFrameUnits::Range => WindowFrameBound::Preceding(ScalarValue::UInt64(None)), - WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)), - WindowFrameUnits::Groups => { - return Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - units, - )))); - } - }, - }; - let end_bound = match end_bound { - Some(end_bound) => WindowFrameBound::Following(ScalarValue::UInt64(Some(end_bound))), - None => match units { - WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)), - WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::UInt64(None)), - WindowFrameUnits::Groups => { - return Err(py_datafusion_err(DataFusionError::NotImplemented(format!( - "{:?}", - units, - )))); - } - }, - }; - Ok(PyWindowFrame { - frame: WindowFrame { - units, - start_bound, - end_bound, - }, - }) - } - - /// Get a String representation of this window frame - fn __repr__(&self) -> String { - format!("{}", self) - } -}