@@ -5,16 +5,19 @@ mod compare;
5
5
mod invert;
6
6
mod search_sorted;
7
7
8
- use vortex_error:: VortexResult ;
8
+ use num_traits:: { CheckedMul , ToPrimitive } ;
9
+ use vortex_dtype:: { NativePType , PType , match_each_native_ptype} ;
10
+ use vortex_error:: { VortexExpect , VortexResult , vortex_err} ;
9
11
use vortex_mask:: Mask ;
10
- use vortex_scalar:: Scalar ;
12
+ use vortex_scalar:: { FromPrimitiveOrF16 , PrimitiveScalar , Scalar } ;
11
13
12
14
use crate :: arrays:: ConstantEncoding ;
13
15
use crate :: arrays:: constant:: ConstantArray ;
14
16
use crate :: compute:: {
15
17
BinaryBooleanFn , BinaryNumericFn , CastFn , CompareFn , FilterFn , InvertFn , ScalarAtFn ,
16
- SearchSortedFn , SliceFn , TakeFn , UncompressedSizeFn ,
18
+ SearchSortedFn , SliceFn , SumFn , TakeFn , UncompressedSizeFn ,
17
19
} ;
20
+ use crate :: stats:: Stat ;
18
21
use crate :: vtable:: ComputeVTable ;
19
22
use crate :: { Array , ArrayRef } ;
20
23
@@ -62,6 +65,10 @@ impl ComputeVTable for ConstantEncoding {
62
65
fn uncompressed_size_fn ( & self ) -> Option < & dyn UncompressedSizeFn < & dyn Array > > {
63
66
Some ( self )
64
67
}
68
+
69
+ fn sum_fn ( & self ) -> Option < & dyn SumFn < & dyn Array > > {
70
+ Some ( self )
71
+ }
65
72
}
66
73
67
74
impl ScalarAtFn < & ConstantArray > for ConstantEncoding {
@@ -100,6 +107,51 @@ impl UncompressedSizeFn<&ConstantArray> for ConstantEncoding {
100
107
}
101
108
}
102
109
110
+ impl SumFn < & ConstantArray > for ConstantEncoding {
111
+ fn sum ( & self , array : & ConstantArray ) -> VortexResult < Scalar > {
112
+ let sum_dtype = Stat :: Sum
113
+ . dtype ( array. dtype ( ) )
114
+ . ok_or_else ( || vortex_err ! ( "Sum not supported for dtype {}" , array. dtype( ) ) ) ?;
115
+ let sum_ptype = PType :: try_from ( & sum_dtype) . vortex_expect ( "sum dtype must be primitive" ) ;
116
+
117
+ let scalar = array. scalar ( ) ;
118
+
119
+ let scalar_value = match_each_native_ptype ! (
120
+ sum_ptype,
121
+ unsigned: |$T | { sum_integral:: <u64 >( scalar. as_primitive( ) , array. len( ) ) ?. into( ) }
122
+ signed: |$T | { sum_integral:: <i64 >( scalar. as_primitive( ) , array. len( ) ) ?. into( ) }
123
+ floating: |$T | { sum_float( scalar. as_primitive( ) , array. len( ) ) ?. into( ) }
124
+ ) ;
125
+
126
+ Ok ( Scalar :: new ( sum_dtype, scalar_value) )
127
+ }
128
+ }
129
+
130
+ fn sum_integral < T > (
131
+ primitive_scalar : PrimitiveScalar < ' _ > ,
132
+ array_len : usize ,
133
+ ) -> VortexResult < Option < T > >
134
+ where
135
+ T : FromPrimitiveOrF16 + NativePType + CheckedMul ,
136
+ Scalar : From < Option < T > > ,
137
+ {
138
+ let v = primitive_scalar. as_ :: < T > ( ) ?;
139
+ let array_len =
140
+ T :: from ( array_len) . ok_or_else ( || vortex_err ! ( "array_len must fit the sum type" ) ) ?;
141
+ let sum = v. and_then ( |v| v. checked_mul ( & array_len) ) ;
142
+
143
+ Ok ( sum)
144
+ }
145
+
146
+ fn sum_float ( primitive_scalar : PrimitiveScalar < ' _ > , array_len : usize ) -> VortexResult < Option < f64 > > {
147
+ let v = primitive_scalar. as_ :: < f64 > ( ) ?;
148
+ let array_len = array_len
149
+ . to_f64 ( )
150
+ . ok_or_else ( || vortex_err ! ( "array_len must fit the sum type" ) ) ?;
151
+
152
+ Ok ( v. map ( |v| v * array_len) )
153
+ }
154
+
103
155
#[ cfg( test) ]
104
156
mod test {
105
157
use vortex_dtype:: half:: f16;
0 commit comments