Skip to content

Commit 627abd7

Browse files
authored
Minor: rename GetFieldAccessCharacteristic and add docs (#7220)
* Minor: rename `GetFieldAccessCharacteristic` and add docs * Update datafusion/expr/src/field_util.rs
1 parent 3d917a0 commit 627abd7

File tree

3 files changed

+96
-128
lines changed

3 files changed

+96
-128
lines changed

datafusion/expr/src/expr_schema.rs

Lines changed: 26 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ use crate::expr::{
2121
GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, ScalarUDF, Sort,
2222
TryCast, WindowFunction,
2323
};
24-
use crate::field_util::{get_indexed_field, GetFieldAccessCharacteristic};
24+
use crate::field_util::GetFieldAccessSchema;
2525
use crate::type_coercion::binary::get_result_type;
2626
use crate::{LogicalPlan, Projection, Subquery};
2727
use arrow::compute::can_cast_types;
28-
use arrow::datatypes::DataType;
28+
use arrow::datatypes::{DataType, Field};
2929
use datafusion_common::{
3030
plan_err, Column, DFField, DFSchema, DataFusionError, ExprSchema, Result,
3131
};
@@ -157,26 +157,7 @@ impl ExprSchemable for Expr {
157157
Ok(DataType::Null)
158158
}
159159
Expr::GetIndexedField(GetIndexedField { expr, field }) => {
160-
let expr_dt = expr.get_type(schema)?;
161-
let field_ch = match field {
162-
GetFieldAccess::NamedStructField { name } => {
163-
GetFieldAccessCharacteristic::NamedStructField {
164-
name: name.clone(),
165-
}
166-
}
167-
GetFieldAccess::ListIndex { key } => {
168-
GetFieldAccessCharacteristic::ListIndex {
169-
key_dt: key.get_type(schema)?,
170-
}
171-
}
172-
GetFieldAccess::ListRange { start, stop } => {
173-
GetFieldAccessCharacteristic::ListRange {
174-
start_dt: start.get_type(schema)?,
175-
stop_dt: stop.get_type(schema)?,
176-
}
177-
}
178-
};
179-
get_indexed_field(&expr_dt, &field_ch).map(|x| x.data_type().clone())
160+
field_for_index(expr, field, schema).map(|x| x.data_type().clone())
180161
}
181162
}
182163
}
@@ -285,26 +266,7 @@ impl ExprSchemable for Expr {
285266
.to_owned(),
286267
)),
287268
Expr::GetIndexedField(GetIndexedField { expr, field }) => {
288-
let expr_dt = expr.get_type(input_schema)?;
289-
let field_ch = match field {
290-
GetFieldAccess::NamedStructField { name } => {
291-
GetFieldAccessCharacteristic::NamedStructField {
292-
name: name.clone(),
293-
}
294-
}
295-
GetFieldAccess::ListIndex { key } => {
296-
GetFieldAccessCharacteristic::ListIndex {
297-
key_dt: key.get_type(input_schema)?,
298-
}
299-
}
300-
GetFieldAccess::ListRange { start, stop } => {
301-
GetFieldAccessCharacteristic::ListRange {
302-
start_dt: start.get_type(input_schema)?,
303-
stop_dt: stop.get_type(input_schema)?,
304-
}
305-
}
306-
};
307-
get_indexed_field(&expr_dt, &field_ch).map(|x| x.is_nullable())
269+
field_for_index(expr, field, input_schema).map(|x| x.is_nullable())
308270
}
309271
Expr::GroupingSet(_) => {
310272
// grouping sets do not really have the concept of nullable and do not appear
@@ -373,6 +335,28 @@ impl ExprSchemable for Expr {
373335
}
374336
}
375337

338+
/// return the schema [`Field`] for the type referenced by `get_indexed_field`
339+
fn field_for_index<S: ExprSchema>(
340+
expr: &Expr,
341+
field: &GetFieldAccess,
342+
schema: &S,
343+
) -> Result<Field> {
344+
let expr_dt = expr.get_type(schema)?;
345+
match field {
346+
GetFieldAccess::NamedStructField { name } => {
347+
GetFieldAccessSchema::NamedStructField { name: name.clone() }
348+
}
349+
GetFieldAccess::ListIndex { key } => GetFieldAccessSchema::ListIndex {
350+
key_dt: key.get_type(schema)?,
351+
},
352+
GetFieldAccess::ListRange { start, stop } => GetFieldAccessSchema::ListRange {
353+
start_dt: start.get_type(schema)?,
354+
stop_dt: stop.get_type(schema)?,
355+
},
356+
}
357+
.get_accessed_field(&expr_dt)
358+
}
359+
376360
/// cast subquery in InSubquery/ScalarSubquery to a given type.
377361
pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subquery> {
378362
if subquery.subquery.schema().field(0).data_type() == cast_to_type {

datafusion/expr/src/field_util.rs

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -20,64 +20,64 @@
2020
use arrow::datatypes::{DataType, Field};
2121
use datafusion_common::{plan_err, DataFusionError, Result, ScalarValue};
2222

23-
pub enum GetFieldAccessCharacteristic {
24-
/// returns the field `struct[field]`. For example `struct["name"]`
23+
/// Types of the field access expression of a nested type, such as `Field` or `List`
24+
pub enum GetFieldAccessSchema {
25+
/// Named field, For example `struct["name"]`
2526
NamedStructField { name: ScalarValue },
26-
/// single list index
27-
// list[i]
27+
/// Single list index, for example: `list[i]`
2828
ListIndex { key_dt: DataType },
29-
/// list range `list[i:j]`
29+
/// List range, for example `list[i:j]`
3030
ListRange {
3131
start_dt: DataType,
3232
stop_dt: DataType,
3333
},
3434
}
3535

36-
/// Returns the field access indexed by `key` and/or `extra_key` from a [`DataType::List`] or [`DataType::Struct`]
37-
/// # Error
38-
/// Errors if
39-
/// * the `data_type` is not a Struct or a List,
40-
/// * the `data_type` of extra key does not match with `data_type` of key
41-
/// * there is no field key is not of the required index type
42-
pub fn get_indexed_field(
43-
data_type: &DataType,
44-
field_characteristic: &GetFieldAccessCharacteristic,
45-
) -> Result<Field> {
46-
match field_characteristic {
47-
GetFieldAccessCharacteristic::NamedStructField{ name } => {
48-
match (data_type, name) {
49-
(DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => {
50-
if s.is_empty() {
51-
plan_err!(
52-
"Struct based indexed access requires a non empty string"
53-
)
54-
} else {
55-
let field = fields.iter().find(|f| f.name() == s);
56-
field.ok_or(DataFusionError::Plan(format!("Field {s} not found in struct"))).map(|f| f.as_ref().clone())
36+
impl GetFieldAccessSchema {
37+
/// Returns the schema [`Field`] from a [`DataType::List`] or
38+
/// [`DataType::Struct`] indexed by this structure
39+
///
40+
/// # Error
41+
/// Errors if
42+
/// * the `data_type` is not a Struct or a List,
43+
/// * the `data_type` of the name/index/start-stop do not match a supported index type
44+
pub fn get_accessed_field(&self, data_type: &DataType) -> Result<Field> {
45+
match self {
46+
Self::NamedStructField{ name } => {
47+
match (data_type, name) {
48+
(DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => {
49+
if s.is_empty() {
50+
plan_err!(
51+
"Struct based indexed access requires a non empty string"
52+
)
53+
} else {
54+
let field = fields.iter().find(|f| f.name() == s);
55+
field.ok_or(DataFusionError::Plan(format!("Field {s} not found in struct"))).map(|f| f.as_ref().clone())
56+
}
5757
}
58+
(DataType::Struct(_), _) => plan_err!(
59+
"Only utf8 strings are valid as an indexed field in a struct"
60+
),
61+
(other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"),
5862
}
59-
(DataType::Struct(_), _) => plan_err!(
60-
"Only utf8 strings are valid as an indexed field in a struct"
61-
),
62-
(other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"),
6363
}
64-
}
65-
GetFieldAccessCharacteristic::ListIndex{ key_dt } => {
66-
match (data_type, key_dt) {
67-
(DataType::List(lt), DataType::Int64) => Ok(Field::new("list", lt.data_type().clone(), true)),
68-
(DataType::List(_), _) => plan_err!(
69-
"Only ints are valid as an indexed field in a list"
70-
),
71-
(other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"),
64+
Self::ListIndex{ key_dt } => {
65+
match (data_type, key_dt) {
66+
(DataType::List(lt), DataType::Int64) => Ok(Field::new("list", lt.data_type().clone(), true)),
67+
(DataType::List(_), _) => plan_err!(
68+
"Only ints are valid as an indexed field in a list"
69+
),
70+
(other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"),
71+
}
7272
}
73-
}
74-
GetFieldAccessCharacteristic::ListRange{ start_dt, stop_dt } => {
75-
match (data_type, start_dt, stop_dt) {
76-
(DataType::List(_), DataType::Int64, DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)),
77-
(DataType::List(_), _, _) => plan_err!(
78-
"Only ints are valid as an indexed field in a list"
79-
),
80-
(other, _, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"),
73+
Self::ListRange{ start_dt, stop_dt } => {
74+
match (data_type, start_dt, stop_dt) {
75+
(DataType::List(_), DataType::Int64, DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)),
76+
(DataType::List(_), _, _) => plan_err!(
77+
"Only ints are valid as an indexed field in a list"
78+
),
79+
(other, _, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"),
80+
}
8181
}
8282
}
8383
}

datafusion/physical-expr/src/expressions/get_indexed_field.rs

Lines changed: 24 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,7 @@ use arrow::{
2727
record_batch::RecordBatch,
2828
};
2929
use datafusion_common::{cast::as_struct_array, DataFusionError, Result, ScalarValue};
30-
use datafusion_expr::{
31-
field_util::{
32-
get_indexed_field as get_data_type_field, GetFieldAccessCharacteristic,
33-
},
34-
ColumnarValue,
35-
};
30+
use datafusion_expr::{field_util::GetFieldAccessSchema, ColumnarValue};
3631
use std::fmt::Debug;
3732
use std::hash::{Hash, Hasher};
3833
use std::{any::Any, sync::Arc};
@@ -120,6 +115,23 @@ impl GetIndexedFieldExpr {
120115
pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
121116
&self.arg
122117
}
118+
119+
fn schema_access(&self, input_schema: &Schema) -> Result<GetFieldAccessSchema> {
120+
Ok(match &self.field {
121+
GetFieldAccessExpr::NamedStructField { name } => {
122+
GetFieldAccessSchema::NamedStructField { name: name.clone() }
123+
}
124+
GetFieldAccessExpr::ListIndex { key } => GetFieldAccessSchema::ListIndex {
125+
key_dt: key.data_type(input_schema)?,
126+
},
127+
GetFieldAccessExpr::ListRange { start, stop } => {
128+
GetFieldAccessSchema::ListRange {
129+
start_dt: start.data_type(input_schema)?,
130+
stop_dt: stop.data_type(input_schema)?,
131+
}
132+
}
133+
})
134+
}
123135
}
124136

125137
impl std::fmt::Display for GetIndexedFieldExpr {
@@ -135,44 +147,16 @@ impl PhysicalExpr for GetIndexedFieldExpr {
135147

136148
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
137149
let arg_dt = self.arg.data_type(input_schema)?;
138-
let field_ch = match &self.field {
139-
GetFieldAccessExpr::NamedStructField { name } => {
140-
GetFieldAccessCharacteristic::NamedStructField { name: name.clone() }
141-
}
142-
GetFieldAccessExpr::ListIndex { key } => {
143-
GetFieldAccessCharacteristic::ListIndex {
144-
key_dt: key.data_type(input_schema)?,
145-
}
146-
}
147-
GetFieldAccessExpr::ListRange { start, stop } => {
148-
GetFieldAccessCharacteristic::ListRange {
149-
start_dt: start.data_type(input_schema)?,
150-
stop_dt: stop.data_type(input_schema)?,
151-
}
152-
}
153-
};
154-
get_data_type_field(&arg_dt, &field_ch).map(|f| f.data_type().clone())
150+
self.schema_access(input_schema)?
151+
.get_accessed_field(&arg_dt)
152+
.map(|f| f.data_type().clone())
155153
}
156154

157155
fn nullable(&self, input_schema: &Schema) -> Result<bool> {
158156
let arg_dt = self.arg.data_type(input_schema)?;
159-
let field_ch = match &self.field {
160-
GetFieldAccessExpr::NamedStructField { name } => {
161-
GetFieldAccessCharacteristic::NamedStructField { name: name.clone() }
162-
}
163-
GetFieldAccessExpr::ListIndex { key } => {
164-
GetFieldAccessCharacteristic::ListIndex {
165-
key_dt: key.data_type(input_schema)?,
166-
}
167-
}
168-
GetFieldAccessExpr::ListRange { start, stop } => {
169-
GetFieldAccessCharacteristic::ListRange {
170-
start_dt: start.data_type(input_schema)?,
171-
stop_dt: stop.data_type(input_schema)?,
172-
}
173-
}
174-
};
175-
get_data_type_field(&arg_dt, &field_ch).map(|f| f.is_nullable())
157+
self.schema_access(input_schema)?
158+
.get_accessed_field(&arg_dt)
159+
.map(|f| f.is_nullable())
176160
}
177161

178162
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {

0 commit comments

Comments
 (0)