17
17
18
18
//! Defines physical expressions that can evaluated at runtime during query execution
19
19
20
- use std:: any:: Any ;
21
- use std:: fmt:: Debug ;
22
-
23
20
use arrow:: array:: Float64Array ;
24
21
use arrow:: {
25
22
array:: { ArrayRef , UInt64Array } ,
@@ -29,10 +26,17 @@ use arrow::{
29
26
} ;
30
27
use datafusion_common:: { downcast_value, plan_err, unwrap_or_internal_err, ScalarValue } ;
31
28
use datafusion_common:: { DataFusionError , Result } ;
29
+ use datafusion_expr:: aggregate_doc_sections:: DOC_SECTION_STATISTICAL ;
32
30
use datafusion_expr:: function:: { AccumulatorArgs , StateFieldsArgs } ;
33
31
use datafusion_expr:: type_coercion:: aggregates:: NUMERICS ;
34
32
use datafusion_expr:: utils:: format_state_name;
35
- use datafusion_expr:: { Accumulator , AggregateUDFImpl , Signature , Volatility } ;
33
+ use datafusion_expr:: {
34
+ Accumulator , AggregateUDFImpl , Documentation , Signature , Volatility ,
35
+ } ;
36
+ use std:: any:: Any ;
37
+ use std:: collections:: HashMap ;
38
+ use std:: fmt:: Debug ;
39
+ use std:: sync:: OnceLock ;
36
40
37
41
macro_rules! make_regr_udaf_expr_and_func {
38
42
( $EXPR_FN: ident, $AGGREGATE_UDF_FN: ident, $REGR_TYPE: expr) => {
@@ -76,23 +80,7 @@ impl Regr {
76
80
}
77
81
}
78
82
79
- /*
80
- #[derive(Debug)]
81
- pub struct Regr {
82
- name: String,
83
- regr_type: RegrType,
84
- expr_y: Arc<dyn PhysicalExpr>,
85
- expr_x: Arc<dyn PhysicalExpr>,
86
- }
87
-
88
- impl Regr {
89
- pub fn get_regr_type(&self) -> RegrType {
90
- self.regr_type.clone()
91
- }
92
- }
93
- */
94
-
95
- #[ derive( Debug , Clone ) ]
83
+ #[ derive( Debug , Clone , PartialEq , Hash , Eq ) ]
96
84
#[ allow( clippy:: upper_case_acronyms) ]
97
85
pub enum RegrType {
98
86
/// Variant for `regr_slope` aggregate expression
@@ -135,6 +123,148 @@ pub enum RegrType {
135
123
SXY ,
136
124
}
137
125
126
+ impl RegrType {
127
+ /// return the documentation for the `RegrType`
128
+ fn documentation ( & self ) -> Option < & Documentation > {
129
+ get_regr_docs ( ) . get ( self )
130
+ }
131
+ }
132
+
133
+ static DOCUMENTATION : OnceLock < HashMap < RegrType , Documentation > > = OnceLock :: new ( ) ;
134
+ fn get_regr_docs ( ) -> & ' static HashMap < RegrType , Documentation > {
135
+ DOCUMENTATION . get_or_init ( || {
136
+ let mut hash_map = HashMap :: new ( ) ;
137
+ hash_map. insert (
138
+ RegrType :: Slope ,
139
+ Documentation :: builder ( )
140
+ . with_doc_section ( DOC_SECTION_STATISTICAL )
141
+ . with_description (
142
+ "Returns the slope of the linear regression line for non-null pairs in aggregate columns. \
143
+ Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k*X + b) using minimal RSS fitting.",
144
+ )
145
+ . with_syntax_example ( "regr_slope(expression_y, expression_x)" )
146
+ . with_standard_argument ( "expression_y" , "Expression" )
147
+ . with_standard_argument ( "expression_x" , "Expression" )
148
+ . build ( )
149
+ . unwrap ( )
150
+ ) ;
151
+
152
+ hash_map. insert (
153
+ RegrType :: Intercept ,
154
+ Documentation :: builder ( )
155
+ . with_doc_section ( DOC_SECTION_STATISTICAL )
156
+ . with_description (
157
+ "Computes the y-intercept of the linear regression line. For the equation (y = kx + b), \
158
+ this function returns b.",
159
+ )
160
+ . with_syntax_example ( "regr_intercept(expression_y, expression_x)" )
161
+ . with_standard_argument ( "expression_y" , "Dependent variable" )
162
+ . with_standard_argument ( "expression_x" , "Independent variable" )
163
+ . build ( )
164
+ . unwrap ( )
165
+ ) ;
166
+
167
+ hash_map. insert (
168
+ RegrType :: Count ,
169
+ Documentation :: builder ( )
170
+ . with_doc_section ( DOC_SECTION_STATISTICAL )
171
+ . with_description (
172
+ "Counts the number of non-null paired data points." ,
173
+ )
174
+ . with_syntax_example ( "regr_count(expression_y, expression_x)" )
175
+ . with_standard_argument ( "expression_y" , "Dependent variable" )
176
+ . with_standard_argument ( "expression_x" , "Independent variable" )
177
+ . build ( )
178
+ . unwrap ( )
179
+ ) ;
180
+
181
+ hash_map. insert (
182
+ RegrType :: R2 ,
183
+ Documentation :: builder ( )
184
+ . with_doc_section ( DOC_SECTION_STATISTICAL )
185
+ . with_description (
186
+ "Computes the square of the correlation coefficient between the independent and dependent variables." ,
187
+ )
188
+ . with_syntax_example ( "regr_r2(expression_y, expression_x)" )
189
+ . with_standard_argument ( "expression_y" , "Dependent variable" )
190
+ . with_standard_argument ( "expression_x" , "Independent variable" )
191
+ . build ( )
192
+ . unwrap ( )
193
+ ) ;
194
+
195
+ hash_map. insert (
196
+ RegrType :: AvgX ,
197
+ Documentation :: builder ( )
198
+ . with_doc_section ( DOC_SECTION_STATISTICAL )
199
+ . with_description (
200
+ "Computes the average of the independent variable (input) expression_x for the non-null paired data points." ,
201
+ )
202
+ . with_syntax_example ( "regr_avgx(expression_y, expression_x)" )
203
+ . with_standard_argument ( "expression_y" , "Dependent variable" )
204
+ . with_standard_argument ( "expression_x" , "Independent variable" )
205
+ . build ( )
206
+ . unwrap ( )
207
+ ) ;
208
+
209
+ hash_map. insert (
210
+ RegrType :: AvgY ,
211
+ Documentation :: builder ( )
212
+ . with_doc_section ( DOC_SECTION_STATISTICAL )
213
+ . with_description (
214
+ "Computes the average of the dependent variable (output) expression_y for the non-null paired data points." ,
215
+ )
216
+ . with_syntax_example ( "regr_avgy(expression_y, expression_x)" )
217
+ . with_standard_argument ( "expression_y" , "Dependent variable" )
218
+ . with_standard_argument ( "expression_x" , "Independent variable" )
219
+ . build ( )
220
+ . unwrap ( )
221
+ ) ;
222
+
223
+ hash_map. insert (
224
+ RegrType :: SXX ,
225
+ Documentation :: builder ( )
226
+ . with_doc_section ( DOC_SECTION_STATISTICAL )
227
+ . with_description (
228
+ "Computes the sum of squares of the independent variable." ,
229
+ )
230
+ . with_syntax_example ( "regr_sxx(expression_y, expression_x)" )
231
+ . with_standard_argument ( "expression_y" , "Dependent variable" )
232
+ . with_standard_argument ( "expression_x" , "Independent variable" )
233
+ . build ( )
234
+ . unwrap ( )
235
+ ) ;
236
+
237
+ hash_map. insert (
238
+ RegrType :: SYY ,
239
+ Documentation :: builder ( )
240
+ . with_doc_section ( DOC_SECTION_STATISTICAL )
241
+ . with_description (
242
+ "Computes the sum of squares of the dependent variable." ,
243
+ )
244
+ . with_syntax_example ( "regr_syy(expression_y, expression_x)" )
245
+ . with_standard_argument ( "expression_y" , "Dependent variable" )
246
+ . with_standard_argument ( "expression_x" , "Independent variable" )
247
+ . build ( )
248
+ . unwrap ( )
249
+ ) ;
250
+
251
+ hash_map. insert (
252
+ RegrType :: SXY ,
253
+ Documentation :: builder ( )
254
+ . with_doc_section ( DOC_SECTION_STATISTICAL )
255
+ . with_description (
256
+ "Computes the sum of products of paired data points." ,
257
+ )
258
+ . with_syntax_example ( "regr_sxy(expression_y, expression_x)" )
259
+ . with_standard_argument ( "expression_y" , "Dependent variable" )
260
+ . with_standard_argument ( "expression_x" , "Independent variable" )
261
+ . build ( )
262
+ . unwrap ( )
263
+ ) ;
264
+ hash_map
265
+ } )
266
+ }
267
+
138
268
impl AggregateUDFImpl for Regr {
139
269
fn as_any ( & self ) -> & dyn Any {
140
270
self
@@ -198,22 +328,11 @@ impl AggregateUDFImpl for Regr {
198
328
) ,
199
329
] )
200
330
}
201
- }
202
331
203
- /*
204
- impl PartialEq<dyn Any> for Regr {
205
- fn eq(&self, other: &dyn Any) -> bool {
206
- down_cast_any_ref(other)
207
- .downcast_ref::<Self>()
208
- .map(|x| {
209
- self.name == x.name
210
- && self.expr_y.eq(&x.expr_y)
211
- && self.expr_x.eq(&x.expr_x)
212
- })
213
- .unwrap_or(false)
332
+ fn documentation ( & self ) -> Option < & Documentation > {
333
+ self . regr_type . documentation ( )
214
334
}
215
335
}
216
- */
217
336
218
337
/// `RegrAccumulator` is used to compute linear regression aggregate functions
219
338
/// by maintaining statistics needed to compute them in an online fashion.
0 commit comments