|
20 | 20 | use arrow::array::ArrayRef;
|
21 | 21 | use arrow::array::NullArray;
|
22 | 22 | use arrow::datatypes::DataType;
|
23 |
| -use datafusion_common::{Result, ScalarValue}; |
| 23 | +use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; |
24 | 24 | use std::sync::Arc;
|
25 | 25 |
|
26 | 26 | /// Represents the result of evaluating an expression: either a single
|
@@ -75,4 +75,166 @@ impl ColumnarValue {
|
75 | 75 | pub fn create_null_array(num_rows: usize) -> Self {
|
76 | 76 | ColumnarValue::Array(Arc::new(NullArray::new(num_rows)))
|
77 | 77 | }
|
| 78 | + |
| 79 | + /// Converts [`ColumnarValue`]s to [`ArrayRef`]s with the same length. |
| 80 | + /// |
| 81 | + /// # Performance Note |
| 82 | + /// |
| 83 | + /// This function expands any [`ScalarValue`] to an array. This expansion |
| 84 | + /// permits using a single function in terms of arrays, but it can be |
| 85 | + /// inefficient compared to handling the scalar value directly. |
| 86 | + /// |
| 87 | + /// Thus, it is recommended to provide specialized implementations for |
| 88 | + /// scalar values if performance is a concern. |
| 89 | + /// |
| 90 | + /// # Errors |
| 91 | + /// |
| 92 | + /// If there are multiple array arguments that have different lengths |
| 93 | + pub fn values_to_arrays(args: &[ColumnarValue]) -> Result<Vec<ArrayRef>> { |
| 94 | + if args.is_empty() { |
| 95 | + return Ok(vec![]); |
| 96 | + } |
| 97 | + |
| 98 | + let mut array_len = None; |
| 99 | + for arg in args { |
| 100 | + array_len = match (arg, array_len) { |
| 101 | + (ColumnarValue::Array(a), None) => Some(a.len()), |
| 102 | + (ColumnarValue::Array(a), Some(array_len)) => { |
| 103 | + if array_len == a.len() { |
| 104 | + Some(array_len) |
| 105 | + } else { |
| 106 | + return internal_err!( |
| 107 | + "Arguments has mixed length. Expected length: {array_len}, found length: {}", a.len() |
| 108 | + ); |
| 109 | + } |
| 110 | + } |
| 111 | + (ColumnarValue::Scalar(_), array_len) => array_len, |
| 112 | + } |
| 113 | + } |
| 114 | + |
| 115 | + // If array_len is none, it means there are only scalars, so make a 1 element array |
| 116 | + let inferred_length = array_len.unwrap_or(1); |
| 117 | + |
| 118 | + let args = args |
| 119 | + .iter() |
| 120 | + .map(|arg| arg.clone().into_array(inferred_length)) |
| 121 | + .collect::<Result<Vec<_>>>()?; |
| 122 | + |
| 123 | + Ok(args) |
| 124 | + } |
| 125 | +} |
| 126 | + |
| 127 | +#[cfg(test)] |
| 128 | +mod tests { |
| 129 | + use super::*; |
| 130 | + |
| 131 | + #[test] |
| 132 | + fn values_to_arrays() { |
| 133 | + // (input, expected) |
| 134 | + let cases = vec![ |
| 135 | + // empty |
| 136 | + TestCase { |
| 137 | + input: vec![], |
| 138 | + expected: vec![], |
| 139 | + }, |
| 140 | + // one array of length 3 |
| 141 | + TestCase { |
| 142 | + input: vec![ColumnarValue::Array(make_array(1, 3))], |
| 143 | + expected: vec![make_array(1, 3)], |
| 144 | + }, |
| 145 | + // two arrays length 3 |
| 146 | + TestCase { |
| 147 | + input: vec![ |
| 148 | + ColumnarValue::Array(make_array(1, 3)), |
| 149 | + ColumnarValue::Array(make_array(2, 3)), |
| 150 | + ], |
| 151 | + expected: vec![make_array(1, 3), make_array(2, 3)], |
| 152 | + }, |
| 153 | + // array and scalar |
| 154 | + TestCase { |
| 155 | + input: vec![ |
| 156 | + ColumnarValue::Array(make_array(1, 3)), |
| 157 | + ColumnarValue::Scalar(ScalarValue::Int32(Some(100))), |
| 158 | + ], |
| 159 | + expected: vec![ |
| 160 | + make_array(1, 3), |
| 161 | + make_array(100, 3), // scalar is expanded |
| 162 | + ], |
| 163 | + }, |
| 164 | + // scalar and array |
| 165 | + TestCase { |
| 166 | + input: vec![ |
| 167 | + ColumnarValue::Scalar(ScalarValue::Int32(Some(100))), |
| 168 | + ColumnarValue::Array(make_array(1, 3)), |
| 169 | + ], |
| 170 | + expected: vec![ |
| 171 | + make_array(100, 3), // scalar is expanded |
| 172 | + make_array(1, 3), |
| 173 | + ], |
| 174 | + }, |
| 175 | + // multiple scalars and array |
| 176 | + TestCase { |
| 177 | + input: vec![ |
| 178 | + ColumnarValue::Scalar(ScalarValue::Int32(Some(100))), |
| 179 | + ColumnarValue::Array(make_array(1, 3)), |
| 180 | + ColumnarValue::Scalar(ScalarValue::Int32(Some(200))), |
| 181 | + ], |
| 182 | + expected: vec![ |
| 183 | + make_array(100, 3), // scalar is expanded |
| 184 | + make_array(1, 3), |
| 185 | + make_array(200, 3), // scalar is expanded |
| 186 | + ], |
| 187 | + }, |
| 188 | + ]; |
| 189 | + for case in cases { |
| 190 | + case.run(); |
| 191 | + } |
| 192 | + } |
| 193 | + |
| 194 | + #[test] |
| 195 | + #[should_panic( |
| 196 | + expected = "Arguments has mixed length. Expected length: 3, found length: 4" |
| 197 | + )] |
| 198 | + fn values_to_arrays_mixed_length() { |
| 199 | + ColumnarValue::values_to_arrays(&[ |
| 200 | + ColumnarValue::Array(make_array(1, 3)), |
| 201 | + ColumnarValue::Array(make_array(2, 4)), |
| 202 | + ]) |
| 203 | + .unwrap(); |
| 204 | + } |
| 205 | + |
| 206 | + #[test] |
| 207 | + #[should_panic( |
| 208 | + expected = "Arguments has mixed length. Expected length: 3, found length: 7" |
| 209 | + )] |
| 210 | + fn values_to_arrays_mixed_length_and_scalar() { |
| 211 | + ColumnarValue::values_to_arrays(&[ |
| 212 | + ColumnarValue::Array(make_array(1, 3)), |
| 213 | + ColumnarValue::Scalar(ScalarValue::Int32(Some(100))), |
| 214 | + ColumnarValue::Array(make_array(2, 7)), |
| 215 | + ]) |
| 216 | + .unwrap(); |
| 217 | + } |
| 218 | + |
| 219 | + struct TestCase { |
| 220 | + input: Vec<ColumnarValue>, |
| 221 | + expected: Vec<ArrayRef>, |
| 222 | + } |
| 223 | + |
| 224 | + impl TestCase { |
| 225 | + fn run(self) { |
| 226 | + let Self { input, expected } = self; |
| 227 | + |
| 228 | + assert_eq!( |
| 229 | + ColumnarValue::values_to_arrays(&input).unwrap(), |
| 230 | + expected, |
| 231 | + "\ninput: {input:?}\nexpected: {expected:?}" |
| 232 | + ); |
| 233 | + } |
| 234 | + } |
| 235 | + |
| 236 | + /// Makes an array of length `len` with all elements set to `val` |
| 237 | + fn make_array(val: i32, len: usize) -> ArrayRef { |
| 238 | + Arc::new(arrow::array::Int32Array::from(vec![val; len])) |
| 239 | + } |
78 | 240 | }
|
0 commit comments