@@ -246,3 +246,218 @@ def udaf(
246
246
state_type = state_type ,
247
247
volatility = volatility ,
248
248
)
249
+
250
+
251
+ class WindowEvaluator (metaclass = ABCMeta ):
252
+ """Evaluator class for user defined window functions (UDWF).
253
+
254
+ Users should inherit from this class and implement ``evaluate``, ``evaluate_all``,
255
+ and/or ``evaluate_all_with_rank``. If using `evaluate` only you will need to
256
+ override ``supports_bounded_execution``.
257
+ """
258
+
259
+ def memoize (self ) -> None :
260
+ """Perform a memoize operation to improve performance.
261
+
262
+ When the window frame has a fixed beginning (e.g UNBOUNDED
263
+ PRECEDING), some functions such as FIRST_VALUE, LAST_VALUE and
264
+ NTH_VALUE do not need the (unbounded) input once they have
265
+ seen a certain amount of input.
266
+
267
+ `memoize` is called after each input batch is processed, and
268
+ such functions can save whatever they need
269
+ """
270
+ pass
271
+
272
+ def get_range (self , idx : int , n_rows : int ) -> tuple [int , int ]:
273
+ """Return the range for the window fuction.
274
+
275
+ If `uses_window_frame` flag is `false`. This method is used to
276
+ calculate required range for the window function during
277
+ stateful execution.
278
+
279
+ Generally there is no required range, hence by default this
280
+ returns smallest range(current row). e.g seeing current row is
281
+ enough to calculate window result (such as row_number, rank,
282
+ etc)
283
+
284
+ Args:
285
+ idx:: Current index
286
+ n_rows: Number of rows.
287
+ """
288
+ return (idx , idx + 1 )
289
+
290
+ def is_causal (self ) -> bool :
291
+ """Get whether evaluator needs future data for its result."""
292
+ return False
293
+
294
+ def evaluate_all (self , values : pyarrow .Array , num_rows : int ) -> pyarrow .Array :
295
+ """Evaluate a window function on an entire input partition.
296
+
297
+ This function is called once per input *partition* for window
298
+ functions that *do not use* values from the window frame,
299
+ such as `ROW_NUMBER`, `RANK`, `DENSE_RANK`, `PERCENT_RANK`,
300
+ `CUME_DIST`, `LEAD`, `LAG`).
301
+
302
+ It produces the result of all rows in a single pass. It
303
+ expects to receive the entire partition as the `value` and
304
+ must produce an output column with one output row for every
305
+ input row.
306
+
307
+ `num_rows` is required to correctly compute the output in case
308
+ `values.len() == 0`
309
+
310
+ Implementing this function is an optimization: certain window
311
+ functions are not affected by the window frame definition or
312
+ the query doesn't have a frame, and `evaluate` skips the
313
+ (costly) window frame boundary calculation and the overhead of
314
+ calling `evaluate` for each output row.
315
+
316
+ For example, the `LAG` built in window function does not use
317
+ the values of its window frame (it can be computed in one shot
318
+ on the entire partition with `Self::evaluate_all` regardless of the
319
+ window defined in the `OVER` clause)
320
+
321
+ ```sql
322
+ lag(x, 1) OVER (ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING)
323
+ ```
324
+
325
+ However, `avg()` computes the average in the window and thus
326
+ does use its window frame
327
+
328
+ ```sql
329
+ avg(x) OVER (PARTITION BY y ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING)
330
+ ```
331
+ """
332
+ if self .supports_bounded_execution () and not self .uses_window_frame ():
333
+ res = []
334
+ for idx in range (0 , num_rows ):
335
+ res .append (self .evaluate (values , self .get_range (idx , num_rows )))
336
+ return pyarrow .array (res )
337
+ else :
338
+ raise
339
+
340
+ @abstractmethod
341
+ def evaluate (self , values : pyarrow .Array , range : tuple [int , int ]) -> pyarrow .Scalar :
342
+ """Evaluate window function on a range of rows in an input partition.
343
+
344
+ This is the simplest and most general function to implement
345
+ but also the least performant as it creates output one row at
346
+ a time. It is typically much faster to implement stateful
347
+ evaluation using one of the other specialized methods on this
348
+ trait.
349
+
350
+ Returns a [`ScalarValue`] that is the value of the window
351
+ function within `range` for the entire partition. Argument
352
+ `values` contains the evaluation result of function arguments
353
+ and evaluation results of ORDER BY expressions. If function has a
354
+ single argument, `values[1..]` will contain ORDER BY expression results.
355
+ """
356
+ pass
357
+
358
+ @abstractmethod
359
+ def evaluate_all_with_rank (
360
+ self , num_rows : int , ranks_in_partition : list [tuple [int , int ]]
361
+ ) -> pyarrow .Array :
362
+ """Called for window functions that only need the rank of a row.
363
+
364
+ Evaluate the partition evaluator against the partition using
365
+ the row ranks. For example, `RANK(col)` produces
366
+
367
+ ```text
368
+ col | rank
369
+ --- + ----
370
+ A | 1
371
+ A | 1
372
+ C | 3
373
+ D | 4
374
+ D | 5
375
+ ```
376
+
377
+ For this case, `num_rows` would be `5` and the
378
+ `ranks_in_partition` would be called with
379
+
380
+ ```text
381
+ [
382
+ (0,1),
383
+ (2,2),
384
+ (3,4),
385
+ ]
386
+ """
387
+ pass
388
+
389
+ def supports_bounded_execution (self ) -> bool :
390
+ """Can the window function be incrementally computed using bounded memory?"""
391
+ return False
392
+
393
+ def uses_window_frame (self ) -> bool :
394
+ """Does the window function use the values from the window frame?"""
395
+ return False
396
+
397
+ def include_rank (self ) -> bool :
398
+ """Can this function be evaluated with (only) rank?"""
399
+ return False
400
+
401
+
402
+ class WindowUDF :
403
+ """Class for performing window user defined functions (UDF).
404
+
405
+ Window UDFs operate on a partition of rows. See
406
+ also :py:class:`ScalarUDF` for operating on a row by row basis.
407
+ """
408
+
409
+ def __init__ (
410
+ self ,
411
+ name : str | None ,
412
+ func : WindowEvaluator ,
413
+ input_type : pyarrow .DataType ,
414
+ return_type : _R ,
415
+ volatility : Volatility | str ,
416
+ ) -> None :
417
+ """Instantiate a user defined window function (UDWF).
418
+
419
+ See :py:func:`udwf` for a convenience function and argument
420
+ descriptions.
421
+ """
422
+ self ._udwf = df_internal .WindowUDF (
423
+ name , func , input_type , return_type , str (volatility )
424
+ )
425
+
426
+ def __call__ (self , * args : Expr ) -> Expr :
427
+ """Execute the UDWF.
428
+
429
+ This function is not typically called by an end user. These calls will
430
+ occur during the evaluation of the dataframe.
431
+ """
432
+ args_raw = [arg .expr for arg in args ]
433
+ return Expr (self ._udwf .__call__ (* args_raw ))
434
+
435
+ @staticmethod
436
+ def udwf (
437
+ func : Callable [..., _R ],
438
+ input_type : pyarrow .DataType ,
439
+ return_type : _R ,
440
+ volatility : Volatility | str ,
441
+ name : str | None = None ,
442
+ ) -> WindowUDF :
443
+ """Create a new User Defined Window Function.
444
+
445
+ Args:
446
+ func: The python function.
447
+ input_type: The data type of the arguments to ``func``.
448
+ return_type: The data type of the return value.
449
+ volatility: See :py:class:`Volatility` for allowed values.
450
+ name: A descriptive name for the function.
451
+
452
+ Returns:
453
+ A user defined window function.
454
+ """
455
+ if name is None :
456
+ name = func .__qualname__ .lower ()
457
+ return WindowUDF (
458
+ name = name ,
459
+ func = func ,
460
+ input_type = input_type ,
461
+ return_type = return_type ,
462
+ volatility = volatility ,
463
+ )
0 commit comments