16
16
// under the License.
17
17
18
18
use arrow:: compute:: kernels:: numeric:: add;
19
- use arrow_array:: { Array , ArrayRef , Float64Array , Int32Array , RecordBatch , UInt8Array } ;
19
+ use arrow_array:: {
20
+ Array , ArrayRef , Float32Array , Float64Array , Int32Array , RecordBatch ,
21
+ UInt8Array ,
22
+ } ;
20
23
use arrow_schema:: DataType :: Float64 ;
21
24
use arrow_schema:: { DataType , Field , Schema } ;
22
25
use datafusion:: prelude:: * ;
@@ -26,12 +29,15 @@ use datafusion_common::{
26
29
assert_batches_eq, assert_batches_sorted_eq, cast:: as_int32_array, not_impl_err,
27
30
plan_err, DataFusionError , ExprSchema , Result , ScalarValue ,
28
31
} ;
32
+ use datafusion_common:: { DFField , DFSchema } ;
29
33
use datafusion_expr:: {
30
34
create_udaf, create_udf, Accumulator , ColumnarValue , ExprSchemable ,
31
- LogicalPlanBuilder , ScalarUDF , ScalarUDFImpl , Signature , Volatility ,
35
+ LogicalPlanBuilder , ScalarUDF , ScalarUDFImpl , Signature , Simplified , Volatility ,
32
36
} ;
37
+
33
38
use rand:: { thread_rng, Rng } ;
34
39
use std:: any:: Any ;
40
+ use std:: collections:: HashMap ;
35
41
use std:: iter;
36
42
use std:: sync:: Arc ;
37
43
@@ -498,6 +504,81 @@ async fn test_user_defined_functions_zero_argument() -> Result<()> {
498
504
Ok ( ( ) )
499
505
}
500
506
507
+ #[ derive( Debug ) ]
508
+ struct CastToI64UDF {
509
+ signature : Signature ,
510
+ }
511
+
512
+ impl CastToI64UDF {
513
+ fn new ( ) -> Self {
514
+ Self {
515
+ signature : Signature :: any ( 1 , Volatility :: Immutable ) ,
516
+ }
517
+ }
518
+ }
519
+
520
+ impl ScalarUDFImpl for CastToI64UDF {
521
+ fn as_any ( & self ) -> & dyn Any {
522
+ self
523
+ }
524
+ fn name ( & self ) -> & str {
525
+ "cast_to_i64"
526
+ }
527
+ fn signature ( & self ) -> & Signature {
528
+ & self . signature
529
+ }
530
+ fn return_type ( & self , _args : & [ DataType ] ) -> Result < DataType > {
531
+ Ok ( DataType :: Int64 )
532
+ }
533
+ // Wrap with Expr::Cast() to Int64
534
+ fn simplify ( & self , args : Vec < Expr > ) -> Result < Simplified > {
535
+ let dfs = DFSchema :: new_with_metadata (
536
+ vec ! [ DFField :: new( Some ( "t" ) , "x" , DataType :: Float32 , true ) ] ,
537
+ HashMap :: default ( ) ,
538
+ ) ?;
539
+ let e = args[ 0 ] . clone ( ) ;
540
+ let casted_expr = e. cast_to ( & DataType :: Int64 , & dfs) ?;
541
+ Ok ( Simplified :: Rewritten ( casted_expr) )
542
+ }
543
+ fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
544
+ Ok ( args. get ( 0 ) . unwrap ( ) . clone ( ) )
545
+ }
546
+ }
547
+
548
+ #[ tokio:: test]
549
+ async fn test_user_defined_functions_cast_to_i64 ( ) -> Result < ( ) > {
550
+ let ctx = SessionContext :: new ( ) ;
551
+
552
+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "x" , DataType :: Float32 , false ) ] ) ) ;
553
+
554
+ let batch = RecordBatch :: try_new (
555
+ schema,
556
+ vec ! [ Arc :: new( Float32Array :: from( vec![ 1.0 , 2.0 , 3.0 ] ) ) ] ,
557
+ ) ?;
558
+
559
+ ctx. register_batch ( "t" , batch) ?;
560
+
561
+ let cast_to_i64_udf = ScalarUDF :: from ( CastToI64UDF :: new ( ) ) ;
562
+ ctx. register_udf ( cast_to_i64_udf) ;
563
+
564
+ let result = plan_and_collect ( & ctx, "SELECT cast_to_i64(x) FROM t" ) . await ?;
565
+
566
+ assert_batches_eq ! (
567
+ & [
568
+ "+------------------+" ,
569
+ "| cast_to_i64(t.x) |" ,
570
+ "+------------------+" ,
571
+ "| 1 |" ,
572
+ "| 2 |" ,
573
+ "| 3 |" ,
574
+ "+------------------+"
575
+ ] ,
576
+ & result
577
+ ) ;
578
+
579
+ Ok ( ( ) )
580
+ }
581
+
501
582
#[ derive( Debug ) ]
502
583
struct TakeUDF {
503
584
signature : Signature ,
0 commit comments