Skip to content

Commit 761e167

Browse files
authored
Improve Error Handling and Readibility for downcasting StructArray (#4061)
* improve error messages for StructArray * refactor newly added Date32Array downcasting and correct error string * beautify code * changes after code review * fix formatting
1 parent 61429f8 commit 761e167

File tree

4 files changed

+20
-17
lines changed

4 files changed

+20
-17
lines changed

benchmarks/src/tpch.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
// under the License.
1717

1818
use arrow::array::{
19-
Array, ArrayRef, Date32Array, Decimal128Array, Float64Array, Int32Array, Int64Array,
20-
StringArray,
19+
Array, ArrayRef, Decimal128Array, Float64Array, Int32Array, Int64Array, StringArray,
2120
};
2221
use arrow::datatypes::SchemaRef;
2322
use arrow::record_batch::RecordBatch;
@@ -27,6 +26,7 @@ use std::path::Path;
2726
use std::sync::Arc;
2827
use std::time::Instant;
2928

29+
use datafusion::common::cast::as_date32_array;
3030
use datafusion::common::ScalarValue;
3131
use datafusion::logical_expr::Cast;
3232
use datafusion::prelude::*;
@@ -440,7 +440,7 @@ fn col_to_scalar(column: &ArrayRef, row_index: usize) -> ScalarValue {
440440
ScalarValue::Decimal128(Some(array.value(row_index)), *p, *s)
441441
}
442442
DataType::Date32 => {
443-
let array = column.as_any().downcast_ref::<Date32Array>().unwrap();
443+
let array = as_date32_array(column).unwrap();
444444
ScalarValue::Date32(Some(array.value(row_index)))
445445
}
446446
DataType::Utf8 => {

datafusion/common/src/cast.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
//! kernels in arrow-rs such as `as_boolean_array` do.
2222
2323
use crate::DataFusionError;
24-
use arrow::array::{Array, Date32Array};
24+
use arrow::array::{Array, Date32Array, StructArray};
2525

2626
// Downcast ArrayRef to Date32Array
2727
pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array, DataFusionError> {
@@ -32,3 +32,13 @@ pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array, DataFusionErro
3232
))
3333
})
3434
}
35+
36+
// Downcast ArrayRef to StructArray
37+
pub fn as_struct_array(array: &dyn Array) -> Result<&StructArray, DataFusionError> {
38+
array.as_any().downcast_ref::<StructArray>().ok_or_else(|| {
39+
DataFusionError::Internal(format!(
40+
"Expected a StructArray, got: {}",
41+
array.data_type()
42+
))
43+
})
44+
}

datafusion/common/src/scalar.rs

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ use arrow::{
3939
use chrono::{Datelike, Duration, NaiveDate, NaiveDateTime};
4040
use ordered_float::OrderedFloat;
4141

42+
use crate::cast::as_struct_array;
4243
use crate::delta::shift_months;
4344
use crate::error::{DataFusionError, Result};
4445

@@ -2008,15 +2009,7 @@ impl ScalarValue {
20082009
Self::Dictionary(key_type.clone(), Box::new(value))
20092010
}
20102011
DataType::Struct(fields) => {
2011-
let array =
2012-
array
2013-
.as_any()
2014-
.downcast_ref::<StructArray>()
2015-
.ok_or_else(|| {
2016-
DataFusionError::Internal(
2017-
"Failed to downcast ArrayRef to StructArray".to_string(),
2018-
)
2019-
})?;
2012+
let array = as_struct_array(array)?;
20202013
let mut field_values: Vec<ScalarValue> = Vec::new();
20212014
for col_index in 0..array.num_columns() {
20222015
let col_array = array.column(col_index);
@@ -3611,8 +3604,7 @@ mod tests {
36113604
// iter_to_array for struct scalars
36123605
let array =
36133606
ScalarValue::iter_to_array(vec![s0.clone(), s1.clone(), s2.clone()]).unwrap();
3614-
let array = array.as_any().downcast_ref::<StructArray>().unwrap();
3615-
3607+
let array = as_struct_array(&array).unwrap();
36163608
let expected = StructArray::from(vec![
36173609
(
36183610
field_a.clone(),

datafusion/physical-expr/src/expressions/get_indexed_field.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@
1919
2020
use crate::PhysicalExpr;
2121
use arrow::array::Array;
22-
use arrow::array::{ListArray, StructArray};
22+
use arrow::array::ListArray;
2323
use arrow::compute::concat;
2424

2525
use crate::physical_expr::down_cast_any_ref;
2626
use arrow::{
2727
datatypes::{DataType, Schema},
2828
record_batch::RecordBatch,
2929
};
30+
use datafusion_common::cast::as_struct_array;
3031
use datafusion_common::DataFusionError;
3132
use datafusion_common::Result;
3233
use datafusion_common::ScalarValue;
@@ -122,7 +123,7 @@ impl PhysicalExpr for GetIndexedFieldExpr {
122123
}
123124
}
124125
(DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
125-
let as_struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
126+
let as_struct_array = as_struct_array(&array)?;
126127
match as_struct_array.column_by_name(k) {
127128
None => Err(DataFusionError::Execution(format!("get indexed field {} not found in struct", k))),
128129
Some(col) => Ok(ColumnarValue::Array(col.clone()))

0 commit comments

Comments
 (0)