Skip to content

Commit cdad4d4

Browse files
AdamGSrobert3005
andauthored
Wire SumFn for chunked and constant (#2606)
Co-authored-by: Robert Kruszewski <[email protected]>
1 parent 5fe6803 commit cdad4d4

File tree

2 files changed

+61
-5
lines changed
  • vortex-array/src/arrays

2 files changed

+61
-5
lines changed

vortex-array/src/arrays/chunked/compute/mod.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ use crate::arrays::ChunkedEncoding;
55
use crate::arrays::chunked::ChunkedArray;
66
use crate::compute::{
77
BinaryBooleanFn, BinaryNumericFn, CastFn, CompareFn, FillNullFn, FilterFn, InvertFn,
8-
IsConstantFn, IsSortedFn, MaskFn, MinMaxFn, ScalarAtFn, SliceFn, TakeFn, UncompressedSizeFn,
9-
try_cast,
8+
IsConstantFn, IsSortedFn, MaskFn, MinMaxFn, ScalarAtFn, SliceFn, SumFn, TakeFn,
9+
UncompressedSizeFn, try_cast,
1010
};
1111
use crate::vtable::ComputeVTable;
1212
use crate::{Array, ArrayRef};
@@ -87,6 +87,10 @@ impl ComputeVTable for ChunkedEncoding {
8787
fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> {
8888
Some(self)
8989
}
90+
91+
fn sum_fn(&self) -> Option<&dyn SumFn<&dyn Array>> {
92+
Some(self)
93+
}
9094
}
9195

9296
impl CastFn<&ChunkedArray> for ChunkedEncoding {

vortex-array/src/arrays/constant/compute/mod.rs

+55-3
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,19 @@ mod compare;
55
mod invert;
66
mod search_sorted;
77

8-
use vortex_error::VortexResult;
8+
use num_traits::{CheckedMul, ToPrimitive};
9+
use vortex_dtype::{NativePType, PType, match_each_native_ptype};
10+
use vortex_error::{VortexExpect, VortexResult, vortex_err};
911
use vortex_mask::Mask;
10-
use vortex_scalar::Scalar;
12+
use vortex_scalar::{FromPrimitiveOrF16, PrimitiveScalar, Scalar};
1113

1214
use crate::arrays::ConstantEncoding;
1315
use crate::arrays::constant::ConstantArray;
1416
use crate::compute::{
1517
BinaryBooleanFn, BinaryNumericFn, CastFn, CompareFn, FilterFn, InvertFn, ScalarAtFn,
16-
SearchSortedFn, SliceFn, TakeFn, UncompressedSizeFn,
18+
SearchSortedFn, SliceFn, SumFn, TakeFn, UncompressedSizeFn,
1719
};
20+
use crate::stats::Stat;
1821
use crate::vtable::ComputeVTable;
1922
use crate::{Array, ArrayRef};
2023

@@ -62,6 +65,10 @@ impl ComputeVTable for ConstantEncoding {
6265
fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> {
6366
Some(self)
6467
}
68+
69+
fn sum_fn(&self) -> Option<&dyn SumFn<&dyn Array>> {
70+
Some(self)
71+
}
6572
}
6673

6774
impl ScalarAtFn<&ConstantArray> for ConstantEncoding {
@@ -100,6 +107,51 @@ impl UncompressedSizeFn<&ConstantArray> for ConstantEncoding {
100107
}
101108
}
102109

110+
impl SumFn<&ConstantArray> for ConstantEncoding {
111+
fn sum(&self, array: &ConstantArray) -> VortexResult<Scalar> {
112+
let sum_dtype = Stat::Sum
113+
.dtype(array.dtype())
114+
.ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
115+
let sum_ptype = PType::try_from(&sum_dtype).vortex_expect("sum dtype must be primitive");
116+
117+
let scalar = array.scalar();
118+
119+
let scalar_value = match_each_native_ptype!(
120+
sum_ptype,
121+
unsigned: |$T| { sum_integral::<u64>(scalar.as_primitive(), array.len())?.into() }
122+
signed: |$T| { sum_integral::<i64>(scalar.as_primitive(), array.len())?.into() }
123+
floating: |$T| { sum_float(scalar.as_primitive(), array.len())?.into() }
124+
);
125+
126+
Ok(Scalar::new(sum_dtype, scalar_value))
127+
}
128+
}
129+
130+
fn sum_integral<T>(
131+
primitive_scalar: PrimitiveScalar<'_>,
132+
array_len: usize,
133+
) -> VortexResult<Option<T>>
134+
where
135+
T: FromPrimitiveOrF16 + NativePType + CheckedMul,
136+
Scalar: From<Option<T>>,
137+
{
138+
let v = primitive_scalar.as_::<T>()?;
139+
let array_len =
140+
T::from(array_len).ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
141+
let sum = v.and_then(|v| v.checked_mul(&array_len));
142+
143+
Ok(sum)
144+
}
145+
146+
fn sum_float(primitive_scalar: PrimitiveScalar<'_>, array_len: usize) -> VortexResult<Option<f64>> {
147+
let v = primitive_scalar.as_::<f64>()?;
148+
let array_len = array_len
149+
.to_f64()
150+
.ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
151+
152+
Ok(v.map(|v| v * array_len))
153+
}
154+
103155
#[cfg(test)]
104156
mod test {
105157
use vortex_dtype::half::f16;

0 commit comments

Comments
 (0)