Skip to content

Commit ca62ac9

Browse files
committed
Casting between REE arrays with different run types and from primitive values to primitive arrays
Added copyright header
1 parent b1f5c25 commit ca62ac9

File tree

1 file changed

+313
-0
lines changed

1 file changed

+313
-0
lines changed

arrow-cast/src/runend.rs

Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
20+
use arrow_array::{
21+
types::RunEndIndexType, Array, ArrayRef, ArrowPrimitiveType, Date32Array, Date64Array,
22+
Decimal128Array, Decimal256Array, DurationMicrosecondArray, DurationMillisecondArray,
23+
DurationNanosecondArray, DurationSecondArray, Float16Array, Float32Array, Float64Array,
24+
Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, IntervalYearMonthArray,
25+
PrimitiveArray, RunArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray,
26+
Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
27+
TimestampNanosecondArray, TimestampSecondArray, TypedRunArray, UInt16Array, UInt32Array,
28+
UInt64Array, UInt8Array,
29+
};
30+
use arrow_buffer::ArrowNativeType;
31+
use arrow_schema::{ArrowError, DataType};
32+
33+
use crate::cast_with_options;
34+
35+
use super::CastOptions;
36+
37+
/// Attempt to cast a run-encoded array into a new type.
38+
///
39+
/// `K` is the *current* run end index type
40+
pub(crate) fn run_end_cast<K: RunEndIndexType>(
41+
array: &dyn Array,
42+
to_type: &DataType,
43+
cast_options: &CastOptions,
44+
) -> Result<ArrayRef, ArrowError> {
45+
let ree_array = array
46+
.as_any()
47+
.downcast_ref::<RunArray<K>>()
48+
.ok_or_else(|| {
49+
ArrowError::ComputeError(
50+
"Internal Error: Cannot cast run end array to RunArray of the expected type"
51+
.to_string(),
52+
)
53+
})?;
54+
55+
match to_type {
56+
// Potentially convert to a new value or run end type
57+
DataType::RunEndEncoded(re_t, dt) => {
58+
let values = cast_with_options(ree_array.values(), dt.data_type(), cast_options)?;
59+
let re = PrimitiveArray::<K>::new(ree_array.run_ends().inner().clone(), None);
60+
let re = cast_with_options(&re, re_t.data_type(), cast_options)?;
61+
62+
// TODO: we shouldn't need to validate the new run length array
63+
// since we can assume we are converting from a valid one, but
64+
// there's no "unchecked" variant yet
65+
let result: Arc<dyn Array> = match re.data_type() {
66+
DataType::Int16 => Arc::new(RunArray::try_new(
67+
re.as_any().downcast_ref::<Int16Array>().unwrap(),
68+
&values,
69+
)?),
70+
DataType::Int32 => Arc::new(RunArray::try_new(
71+
re.as_any().downcast_ref::<Int32Array>().unwrap(),
72+
&values,
73+
)?),
74+
DataType::Int64 => Arc::new(RunArray::try_new(
75+
re.as_any().downcast_ref::<Int64Array>().unwrap(),
76+
&values,
77+
)?),
78+
_ => Err(ArrowError::ComputeError(format!(
79+
"Invalid run end type requested during cast: {:?}",
80+
re.data_type()
81+
)))?,
82+
};
83+
84+
Ok(result.slice(ree_array.run_ends().offset(), ree_array.run_ends().len()))
85+
}
86+
// Convert to a primitive value
87+
DataType::Date32
88+
| DataType::Date64
89+
| DataType::Time32(_)
90+
| DataType::Time64(_)
91+
| DataType::Decimal128(_, _)
92+
| DataType::Decimal256(_, _)
93+
| DataType::Timestamp(_, _)
94+
| DataType::Duration(_)
95+
| DataType::Interval(_)
96+
| DataType::Int8
97+
| DataType::Int16
98+
| DataType::Int32
99+
| DataType::Int64
100+
| DataType::UInt8
101+
| DataType::UInt16
102+
| DataType::UInt32
103+
| DataType::UInt64
104+
| DataType::Float16
105+
| DataType::Float32
106+
| DataType::Float64 => {
107+
// TODO this could be somewhat inefficent, since the run encoded
108+
// array is initially transformed into a primitive array of the same
109+
// type, then casted to the (potentially) new type. For example,
110+
// casting a run encoded array of Float32 to Float64 will first
111+
// create a primitive array of Float32s, then convert that primitive
112+
// array to Float64.
113+
cast_with_options(&run_array_to_primitive(ree_array)?, to_type, cast_options)
114+
}
115+
_ => todo!(),
116+
}
117+
}
118+
119+
/// Converts a run array of primitive values into a primitive array, without changing the type
120+
fn run_array_to_primitive<R: RunEndIndexType>(ra: &RunArray<R>) -> Result<ArrayRef, ArrowError> {
121+
let prim = match ra.values().data_type() {
122+
DataType::Int8 => typed_run_array_to_primitive(ra.downcast::<Int8Array>().unwrap()),
123+
DataType::Int16 => typed_run_array_to_primitive(ra.downcast::<Int16Array>().unwrap()),
124+
DataType::Int32 => typed_run_array_to_primitive(ra.downcast::<Int32Array>().unwrap()),
125+
DataType::Int64 => typed_run_array_to_primitive(ra.downcast::<Int64Array>().unwrap()),
126+
DataType::UInt8 => typed_run_array_to_primitive(ra.downcast::<UInt8Array>().unwrap()),
127+
DataType::UInt16 => typed_run_array_to_primitive(ra.downcast::<UInt16Array>().unwrap()),
128+
DataType::UInt32 => typed_run_array_to_primitive(ra.downcast::<UInt32Array>().unwrap()),
129+
DataType::UInt64 => typed_run_array_to_primitive(ra.downcast::<UInt64Array>().unwrap()),
130+
DataType::Float16 => typed_run_array_to_primitive(ra.downcast::<Float16Array>().unwrap()),
131+
DataType::Float32 => typed_run_array_to_primitive(ra.downcast::<Float32Array>().unwrap()),
132+
DataType::Float64 => typed_run_array_to_primitive(ra.downcast::<Float64Array>().unwrap()),
133+
DataType::Date32 => typed_run_array_to_primitive(ra.downcast::<Date32Array>().unwrap()),
134+
DataType::Date64 => typed_run_array_to_primitive(ra.downcast::<Date64Array>().unwrap()),
135+
DataType::Time32(arrow_schema::TimeUnit::Second) => {
136+
typed_run_array_to_primitive(ra.downcast::<Time32SecondArray>().unwrap())
137+
}
138+
DataType::Time32(arrow_schema::TimeUnit::Millisecond) => {
139+
typed_run_array_to_primitive(ra.downcast::<Time32MillisecondArray>().unwrap())
140+
}
141+
DataType::Time64(arrow_schema::TimeUnit::Microsecond) => {
142+
typed_run_array_to_primitive(ra.downcast::<Time64MicrosecondArray>().unwrap())
143+
}
144+
DataType::Time64(arrow_schema::TimeUnit::Nanosecond) => {
145+
typed_run_array_to_primitive(ra.downcast::<Time64NanosecondArray>().unwrap())
146+
}
147+
DataType::Decimal128(_, _) => {
148+
typed_run_array_to_primitive(ra.downcast::<Decimal128Array>().unwrap())
149+
}
150+
DataType::Decimal256(_, _) => {
151+
typed_run_array_to_primitive(ra.downcast::<Decimal256Array>().unwrap())
152+
}
153+
DataType::Timestamp(arrow_schema::TimeUnit::Second, _) => {
154+
typed_run_array_to_primitive(ra.downcast::<TimestampSecondArray>().unwrap())
155+
}
156+
DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, _) => {
157+
typed_run_array_to_primitive(ra.downcast::<TimestampMillisecondArray>().unwrap())
158+
}
159+
DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, _) => {
160+
typed_run_array_to_primitive(ra.downcast::<TimestampMicrosecondArray>().unwrap())
161+
}
162+
163+
DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, _) => {
164+
typed_run_array_to_primitive(ra.downcast::<TimestampNanosecondArray>().unwrap())
165+
}
166+
DataType::Duration(arrow_schema::TimeUnit::Second) => {
167+
typed_run_array_to_primitive(ra.downcast::<DurationSecondArray>().unwrap())
168+
}
169+
DataType::Duration(arrow_schema::TimeUnit::Millisecond) => {
170+
typed_run_array_to_primitive(ra.downcast::<DurationMillisecondArray>().unwrap())
171+
}
172+
DataType::Duration(arrow_schema::TimeUnit::Microsecond) => {
173+
typed_run_array_to_primitive(ra.downcast::<DurationMicrosecondArray>().unwrap())
174+
}
175+
DataType::Duration(arrow_schema::TimeUnit::Nanosecond) => {
176+
typed_run_array_to_primitive(ra.downcast::<DurationNanosecondArray>().unwrap())
177+
}
178+
DataType::Interval(arrow_schema::IntervalUnit::YearMonth) => {
179+
typed_run_array_to_primitive(ra.downcast::<IntervalYearMonthArray>().unwrap())
180+
}
181+
DataType::Interval(arrow_schema::IntervalUnit::DayTime) => {
182+
typed_run_array_to_primitive(ra.downcast::<IntervalDayTimeArray>().unwrap())
183+
}
184+
DataType::Interval(arrow_schema::IntervalUnit::MonthDayNano) => {
185+
typed_run_array_to_primitive(ra.downcast::<IntervalYearMonthArray>().unwrap())
186+
}
187+
_ => {
188+
return Err(ArrowError::ComputeError(format!(
189+
"Cannot convert run-end encoded array of type {:?} to primitive type",
190+
ra.values().data_type()
191+
)))
192+
}
193+
};
194+
195+
Ok(prim)
196+
}
197+
198+
/// "Unroll" a run-end encoded array of primitive values into a primitive array.
199+
/// This function should be efficient for long run lenghts due to the use of
200+
/// Builder's `append_value_n`
201+
fn typed_run_array_to_primitive<R: RunEndIndexType, T: ArrowPrimitiveType>(
202+
arr: TypedRunArray<R, PrimitiveArray<T>>,
203+
) -> ArrayRef {
204+
let mut builder = PrimitiveArray::<T>::builder(
205+
arr.run_ends()
206+
.values()
207+
.last()
208+
.map(|end| end.as_usize())
209+
.unwrap_or(0),
210+
);
211+
212+
let mut last = 0;
213+
for (run_end, val) in arr
214+
.run_ends()
215+
.values()
216+
.iter()
217+
.zip(arr.values().values().iter().copied())
218+
{
219+
let run_end = run_end.as_usize();
220+
let run_length = run_end - last;
221+
builder.append_value_n(val, run_length);
222+
last = run_end;
223+
}
224+
225+
// TODO: this slice could be optimized by only copying the relevant parts of
226+
// the array, but this might be tricky to get right because a slice can
227+
// start or end in the middle of a run.
228+
Arc::new(builder.finish().slice(arr.offset(), arr.len()))
229+
}
230+
231+
#[cfg(test)]
232+
mod tests {
233+
use arrow_schema::Field;
234+
235+
use crate::can_cast_types;
236+
237+
use super::*;
238+
239+
#[test]
240+
fn test_can_cast_run_ends() {
241+
let re_i64 = Arc::new(Field::new("run ends", DataType::Int64, false));
242+
let re_i32 = Arc::new(Field::new("run ends", DataType::Int64, false));
243+
let va_f64 = Arc::new(Field::new("values", DataType::Float64, true));
244+
let va_str = Arc::new(Field::new("values", DataType::Utf8, true));
245+
246+
// can change run end type of non-primitive
247+
assert!(can_cast_types(
248+
&DataType::RunEndEncoded(re_i32.clone(), va_str.clone()),
249+
&DataType::RunEndEncoded(re_i64.clone(), va_str.clone())
250+
));
251+
252+
// can cast from primitive type to primitive
253+
assert!(can_cast_types(
254+
&DataType::RunEndEncoded(re_i32.clone(), va_f64.clone()),
255+
&DataType::Float64
256+
));
257+
258+
// cannot cast from non-primitive to flat array
259+
assert!(!can_cast_types(
260+
&DataType::RunEndEncoded(re_i32.clone(), va_str.clone()),
261+
&DataType::Utf8
262+
));
263+
}
264+
265+
#[test]
266+
fn test_run_end_to_primitive() {
267+
let run_ends = vec![2, 4, 5];
268+
let values = vec![10, 20, 30];
269+
let ree =
270+
RunArray::try_new(&Int32Array::from(run_ends), &Int32Array::from(values)).unwrap();
271+
272+
let result = cast_with_options(&ree, &DataType::Int32, &CastOptions::default()).unwrap();
273+
274+
let result = result.as_any().downcast_ref::<Int32Array>().unwrap();
275+
assert_eq!(result.values(), &[10, 10, 20, 20, 30]);
276+
}
277+
278+
#[test]
279+
fn test_run_end_sliced_to_primitive() {
280+
let run_ends = vec![2, 4, 5];
281+
let values = vec![10, 20, 30];
282+
let ree = RunArray::try_new(&Int32Array::from(run_ends), &Int32Array::from(values))
283+
.unwrap()
284+
.slice(1, 3);
285+
286+
let result = cast_with_options(&ree, &DataType::Int32, &CastOptions::default()).unwrap();
287+
288+
let result = result.as_any().downcast_ref::<Int32Array>().unwrap();
289+
assert_eq!(result.values(), &[10, 20, 20]);
290+
}
291+
292+
#[test]
293+
fn test_run_end_to_run_end() {
294+
let run_ends = vec![2, 4, 5];
295+
let values = vec![10, 20, 30];
296+
let ree =
297+
RunArray::try_new(&Int32Array::from(run_ends), &Int32Array::from(values)).unwrap();
298+
299+
let new_re_type = Field::new("run ends", DataType::Int64, false);
300+
let new_va_type = Field::new("values", DataType::Float64, true);
301+
let result = cast_with_options(
302+
&ree,
303+
&DataType::RunEndEncoded(Arc::new(new_re_type), Arc::new(new_va_type)),
304+
&CastOptions::default(),
305+
)
306+
.unwrap();
307+
308+
let result =
309+
cast_with_options(&result, &DataType::Float64, &CastOptions::default()).unwrap();
310+
let result = result.as_any().downcast_ref::<Float64Array>().unwrap();
311+
assert_eq!(result.values(), &[10.0, 10.0, 20.0, 20.0, 30.0]);
312+
}
313+
}

0 commit comments

Comments
 (0)