diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 947a3f43870e..509dab5474a0 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -181,9 +181,28 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { let expr = create_physical_name(expr, false)?; Ok(format!("{expr} IS NOT UNKNOWN")) } - Expr::GetIndexedField(GetIndexedField { key, expr }) => { + Expr::GetIndexedField(GetIndexedField { + key, + extra_key, + expr, + }) => { let expr = create_physical_name(expr, false)?; - Ok(format!("{expr}[{key}]")) + + if let (Some(list_key), Some(extra_key)) = (&key.list_key, extra_key) { + let key = create_physical_name(list_key, false)?; + let extra_key = create_physical_name(extra_key, false)?; + Ok(format!("{expr}[{key}:{extra_key}]")) + } else { + let key = if let Some(list_key) = &key.list_key { + create_physical_name(list_key, false)? + } else if let Some(ScalarValue::Utf8(Some(struct_key))) = &key.struct_key + { + struct_key.to_string() + } else { + String::from("") + }; + Ok(format!("{expr}[{key}]")) + } } Expr::ScalarFunction(func) => { create_function_physical_name(&func.fun.to_string(), false, &func.args) diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt b/datafusion/core/tests/sqllogictests/test_files/array.slt index 27d288cf60f0..4088414296f9 100644 --- a/datafusion/core/tests/sqllogictests/test_files/array.slt +++ b/datafusion/core/tests/sqllogictests/test_files/array.slt @@ -16,7 +16,7 @@ # under the License. ############# -## Array expressions Tests +## Array Expressions Tests ############# @@ -55,6 +55,18 @@ AS VALUES (make_array(make_array(15, 16),make_array(NULL, 18)), make_array(16.6, 17.7, 18.8), NULL) ; +statement ok +CREATE TABLE slices +AS VALUES + (make_array(NULL, 2, 3, 4, 5, 6, 7, 8, 9, 10), 1, 1), + (make_array(11, 12, 13, 14, 15, 16, 17, 18, NULL, 20), 2, -4), + (make_array(21, 22, 23, NULL, 25, 26, 27, 28, 29, 30), 0, 0), + (make_array(31, 32, 33, 34, 35, NULL, 37, 38, 39, 40), -4, -7), + (NULL, 4, 5), + (make_array(41, 42, 43, 44, 45, 46, 47, 48, 49, 50), NULL, 6), + (make_array(51, 52, NULL, 54, 55, 56, 57, 58, 59, 60), 5, NULL) +; + statement ok CREATE TABLE nested_arrays AS VALUES @@ -213,6 +225,18 @@ NULL 44 5 @ [51, 52, , 54, 55, 56, 57, 58, 59, 60] 55 NULL ^ [61, 62, 63, 64, 65, 66, 67, 68, 69, 70] 66 7 NULL +# slices table +query ?II +select column1, column2, column3 from slices; +---- +[, 2, 3, 4, 5, 6, 7, 8, 9, 10] 1 1 +[11, 12, 13, 14, 15, 16, 17, 18, , 20] 2 -4 +[21, 22, 23, , 25, 26, 27, 28, 29, 30] 0 0 +[31, 32, 33, 34, 35, , 37, 38, 39, 40] -4 -7 +NULL 4 5 +[41, 42, 43, 44, 45, 46, 47, 48, 49, 50] NULL 6 +[51, 52, , 54, 55, 56, 57, 58, 59, 60] 5 NULL + query ??I? select column1, column2, column3, column4 from arrays_values_v2; ---- @@ -250,6 +274,178 @@ select column1, column2, column3, column4 from nested_arrays_with_repeating_elem [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [19, 20, 21] [28, 29, 30] 5 [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [28, 29, 30] [37, 38, 39] 10 + +### Array index + + +## array[i] + +# single index with scalars #1 (positive index) +query IRT +select make_array(1, 2, 3)[1], make_array(1.0, 2.0, 3.0)[2], make_array('h', 'e', 'l', 'l', 'o')[3]; +---- +1 2 l + +# single index with scalars #2 (zero index) +query I +select make_array(1, 2, 3)[0]; +---- +NULL + +# single index with scalars #3 (negative index) +query IRT +select make_array(1, 2, 3)[-1], make_array(1.0, 2.0, 3.0)[-2], make_array('h', 'e', 'l', 'l', 'o')[-3]; +---- +3 2 l + +# single index with scalars #4 (complex index) +query IRT +select make_array(1, 2, 3)[1 + 2 - 1], make_array(1.0, 2.0, 3.0)[2 * 1 * 0 - 2], make_array('h', 'e', 'l', 'l', 'o')[2 - 3]; +---- +2 2 o + +# single index with columns #1 (positive index) +query ?RT +select column1[2], column2[3], column3[1] from arrays; +---- +[3, ] 3.3 L +[5, 6] 6.6 i +[7, 8] 9.9 d +[9, 10] 12.2 s +NULL 15.5 a +[13, 14] NULL , +[, 18] 18.8 NULL + +# single index with columns #2 (zero index) +query ?RT +select column1[0], column2[0], column3[0] from arrays; +---- +NULL NULL NULL +NULL NULL NULL +NULL NULL NULL +NULL NULL NULL +NULL NULL NULL +NULL NULL NULL +NULL NULL NULL + +# single index with columns #3 (negative index) +query ?RT +select column1[-2], column2[-3], column3[-1] from arrays; +---- +[, 2] 1.1 m +[3, 4] NULL m +[5, 6] 7.7 r +[7, ] 10.1 t +NULL 13.3 t +[11, 12] NULL , +[15, 16] 16.6 NULL + +# single index with columns #4 (complex index) +query ?RT +select column1[9 - 7], column2[2 * 0], column3[1 - 3] from arrays; +---- +[3, ] NULL e +[5, 6] NULL u +[7, 8] NULL o +[9, 10] NULL i +NULL NULL e +[13, 14] NULL NULL +[, 18] NULL NULL + +# TODO: support index as column +# single index with columns #5 (index as column) +# query ? +# select make_array(1, 2, 3, 4, 5)[column2] from arrays_with_repeating_elements; +# ---- + +# TODO: support argument and index as columns +# single index with columns #6 (argument and index as columns) +# query I +# select column1[column2] from arrays_with_repeating_elements; +# ---- + +## array[i:j] + +# multiple index with columns #1 (positive index) +query ??? +select make_array(1, 2, 3)[1:2], make_array(1.0, 2.0, 3.0)[2:3], make_array('h', 'e', 'l', 'l', 'o')[2:4]; +---- +[1, 2] [2.0, 3.0] [e, l, l] + +# multiple index with columns #2 (zero index) +query ??? +select make_array(1, 2, 3)[0:0], make_array(1.0, 2.0, 3.0)[0:2], make_array('h', 'e', 'l', 'l', 'o')[0:6]; +---- +[] [1.0, 2.0] [h, e, l, l, o] + +# TODO: support multiple negative index +# multiple index with columns #3 (negative index) +# query II +# select make_array(1, 2, 3)[-3:-1], make_array(1.0, 2.0, 3.0)[-3:-1], make_array('h', 'e', 'l', 'l', 'o')[-2:0]; +# ---- + +# TODO: support complex index +# multiple index with columns #4 (complex index) +# query III +# select make_array(1, 2, 3)[2 + 1 - 1:10], make_array(1.0, 2.0, 3.0)[2 | 2:10], make_array('h', 'e', 'l', 'l', 'o')[6 ^ 6:10]; +# ---- + +# multiple index with columns #1 (positive index) +query ??? +select column1[2:4], column2[1:4], column3[3:4] from arrays; +---- +[[3, ]] [1.1, 2.2, 3.3] [r, e] +[[5, 6]] [, 5.5, 6.6] [, u] +[[7, 8]] [7.7, 8.8, 9.9] [l, o] +[[9, 10]] [10.1, , 12.2] [t] +[] [13.3, 14.4, 15.5] [e, t] +[[13, 14]] [] [] +[[, 18]] [16.6, 17.7, 18.8] [] + +# multiple index with columns #2 (zero index) +query ??? +select column1[0:5], column2[0:3], column3[0:9] from arrays; +---- +[[, 2], [3, ]] [1.1, 2.2, 3.3] [L, o, r, e, m] +[[3, 4], [5, 6]] [, 5.5, 6.6] [i, p, , u, m] +[[5, 6], [7, 8]] [7.7, 8.8, 9.9] [d, , l, o, r] +[[7, ], [9, 10]] [10.1, , 12.2] [s, i, t] +[] [13.3, 14.4, 15.5] [a, m, e, t] +[[11, 12], [13, 14]] [] [,] +[[15, 16], [, 18]] [16.6, 17.7, 18.8] [] + +# TODO: support negative index +# multiple index with columns #3 (negative index) +# query ?RT +# select column1[-2:-4], column2[-3:-5], column3[-1:-4] from arrays; +# ---- +# [, 2] 1.1 m + +# TODO: support complex index +# multiple index with columns #4 (complex index) +# query ?RT +# select column1[9 - 7:2 + 2], column2[1 * 0:2 * 3], column3[1 + 1 - 0:5 % 3] from arrays; +# ---- + +# TODO: support first index as column +# multiple index with columns #5 (first index as column) +# query ? +# select make_array(1, 2, 3, 4, 5)[column2:4] from arrays_with_repeating_elements +# ---- + +# TODO: support last index as column +# multiple index with columns #6 (last index as column) +# query ?RT +# select make_array(1, 2, 3, 4, 5)[2:column3] from arrays_with_repeating_elements; +# ---- + +# TODO: support argument and indices as column +# multiple index with columns #7 (argument and indices as column) +# query ?RT +# select column1[column2:column3] from arrays_with_repeating_elements; +# ---- + + ### Array function tests @@ -363,23 +559,6 @@ select make_array(a, b, c, d) from values; [7.0, 13.0, 14.0, ] [8.0, 15.0, 16.0, 8.8] -# make_array null handling -query ?B?BB -select - make_array(a), make_array(a)[1] IS NULL, - make_array(e, f), make_array(e, f)[1] IS NULL, make_array(e, f)[2] IS NULL -from values; ----- -[1] false [Lorem, A] false false -[2] false [ipsum, ] false false -[3] false [dolor, BB] false false -[4] false [sit, ] false true -[] true [amet, CCC] false false -[5] false [,, DD] false false -[6] false [consectetur, E] false false -[7] false [adipiscing, F] false false -[8] false [, ] true false - # make_array with column of list query ?? select column1, column5 from arrays_values_without_nulls; @@ -400,6 +579,257 @@ from arrays_values_without_nulls; [[21, 22, 23, 24, 25, 26, 27, 28, 29, 30]] [[21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [6, 7]] [[21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [50, 51, 52]] [[31, 32, 33, 34, 35, 26, 37, 38, 39, 40]] [[31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [8, 9]] [[31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [50, 51, 52]] +## array_element (aliases: array_extract, list_extract, list_element) + +# array_element scalar function #1 (with positive index) +query IT +select array_element(make_array(1, 2, 3, 4, 5), 2), array_element(make_array('h', 'e', 'l', 'l', 'o'), 3); +---- +2 l + +# array_element scalar function #2 (with positive index; out of bounds) +query IT +select array_element(make_array(1, 2, 3, 4, 5), 7), array_element(make_array('h', 'e', 'l', 'l', 'o'), 11); +---- +NULL NULL + +# array_element scalar function #3 (with zero) +query IT +select array_element(make_array(1, 2, 3, 4, 5), 0), array_element(make_array('h', 'e', 'l', 'l', 'o'), 0); +---- +NULL NULL + +# array_element scalar function #4 (with NULL) +query error +select array_element(make_array(1, 2, 3, 4, 5), NULL), array_element(make_array('h', 'e', 'l', 'l', 'o'), NULL); +---- +NULL NULL + +# array_element scalar function #5 (with negative index) +query IT +select array_element(make_array(1, 2, 3, 4, 5), -2), array_element(make_array('h', 'e', 'l', 'l', 'o'), -3); +---- +4 l + +# array_element scalar function #6 (with negative index; out of bounds) +query IT +select array_element(make_array(1, 2, 3, 4, 5), -11), array_element(make_array('h', 'e', 'l', 'l', 'o'), -7); +---- +NULL NULL + +# array_element scalar function #7 (nested array) +query ? +select array_element(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 1); +---- +[1, 2, 3, 4, 5] + +# array_extract scalar function #8 (function alias `array_slice`) +query IT +select array_extract(make_array(1, 2, 3, 4, 5), 2), array_extract(make_array('h', 'e', 'l', 'l', 'o'), 3); +---- +2 l + +# list_element scalar function #9 (function alias `array_slice`) +query IT +select list_element(make_array(1, 2, 3, 4, 5), 2), list_element(make_array('h', 'e', 'l', 'l', 'o'), 3); +---- +2 l + +# list_extract scalar function #10 (function alias `array_slice`) +query IT +select list_extract(make_array(1, 2, 3, 4, 5), 2), list_extract(make_array('h', 'e', 'l', 'l', 'o'), 3); +---- +2 l + +# array_element with columns +query I +select array_element(column1, column2) from slices; +---- +NULL +12 +NULL +37 +NULL +NULL +55 + +# array_element with columns and scalars +query II +select array_element(make_array(1, 2, 3, 4, 5), column2), array_element(column1, 3) from slices; +---- +1 3 +2 13 +NULL 23 +2 33 +4 NULL +NULL 43 +5 NULL + +## array_slice (aliases: list_slice) + +# array_slice scalar function #1 (with positive indexes) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 2, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 1, 2); +---- +[2, 3, 4] [h, e] + +# array_slice scalar function #2 (with positive indexes; full array) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 0, 6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 5); +---- +[1, 2, 3, 4, 5] [h, e, l, l, o] + +# array_slice scalar function #3 (with positive indexes; first index = second index) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 4, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, 3); +---- +[4] [l] + +# array_slice scalar function #4 (with positive indexes; first index > second_index) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 2, 1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 4, 1); +---- +[] [] + +# array_slice scalar function #5 (with positive indexes; out of bounds) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 2, 6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, 7); +---- +[2, 3, 4, 5] [l, l, o] + +# array_slice scalar function #6 (with positive indexes; nested array) +query ? +select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 1, 1); +---- +[[1, 2, 3, 4, 5]] + +# array_slice scalar function #7 (with zero and positive number) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 0, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 3); +---- +[1, 2, 3, 4] [h, e, l] + +# array_slice scalar function #8 (with NULL and positive number) +query error +select array_slice(make_array(1, 2, 3, 4, 5), NULL, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, 3); +---- +[1, 2, 3, 4] [h, e, l] + +# array_slice scalar function #9 (with positive number and NULL) +query error +select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); +---- +[2, 3, 4, 5] [l, l, o] + +# array_slice scalar function #10 (with zero-zero) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 0, 0), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 0); +---- +[] [] + +# array_slice scalar function #11 (with NULL-NULL) +query error +select array_slice(make_array(1, 2, 3, 4, 5), NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL); +---- +[] [] + +# array_slice scalar function #12 (with zero and negative number) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 0, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, -3); +---- +[1] [h, e] + +# array_slice scalar function #13 (with negative number and NULL) +query error +select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); +---- +[2, 3, 4, 5] [l, l, o] + +# array_slice scalar function #14 (with NULL and negative number) +query error +select array_slice(make_array(1, 2, 3, 4, 5), NULL, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, -3); +---- +[1] [h, e] + +# array_slice scalar function #15 (with negative indexes) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), -4, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -1); +---- +[2, 3, 4] [l, l] + +# array_slice scalar function #16 (with negative indexes; almost full array (only with negative indices cannot return full array)) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), -5, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -5, -1); +---- +[1, 2, 3, 4] [h, e, l, l] + +# array_slice scalar function #17 (with negative indexes; first index = second index) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), -4, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -3); +---- +[] [] + +# array_slice scalar function #18 (with negative indexes; first index > second_index) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), -4, -6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -6); +---- +[] [] + +# array_slice scalar function #19 (with negative indexes; out of bounds) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), -7, -2), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -7, -3); +---- +[] [] + +# array_slice scalar function #20 (with negative indexes; nested array) +query ? +select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), -2, -1); +---- +[[1, 2, 3, 4, 5]] + +# array_slice scalar function #21 (with first positive index and last negative index) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 2, -3), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 2, -2); +---- +[2] [e, l] + +# array_slice scalar function #22 (with first negative index and last positive index) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), -2, 5), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, 4); +---- +[4, 5] [l, l] + +# list_slice scalar function #23 (function alias `array_slice`) +query ?? +select list_slice(make_array(1, 2, 3, 4, 5), 2, 4), list_slice(make_array('h', 'e', 'l', 'l', 'o'), 1, 2); +---- +[2, 3, 4] [h, e] + +# array_slice with columns +query ? +select array_slice(column1, column2, column3) from slices; +---- +[] +[12, 13, 14, 15, 16] +[] +[] +[] +[41, 42, 43, 44, 45, 46] +[55, 56, 57, 58, 59, 60] + +# TODO: support NULLS in output instead of `[]` +# array_slice with columns and scalars +query ??? +select array_slice(make_array(1, 2, 3, 4, 5), column2, column3), array_slice(column1, 3, column3), array_slice(column1, column2, 5) from slices; +---- +[1] [] [, 2, 3, 4, 5] +[] [13, 14, 15, 16] [12, 13, 14, 15] +[] [] [21, 22, 23, , 25] +[] [33] [] +[4, 5] [] [] +[1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45] +[5] [, 54, 55, 56, 57, 58, 59, 60] [55] + ## array_append (aliases: `list_append`, `array_push_back`, `list_push_back`) # array_append scalar function #1 @@ -1430,31 +1860,7 @@ select array_remove_all(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] -## trim_array - -# trim_array scalar function #1 -query ??? -select trim_array(make_array(1, 2, 3, 4, 5), 2), trim_array(['h', 'e', 'l', 'l', 'o'], 3), trim_array([1.0, 2.0, 3.0], 2); ----- -[1, 2, 3] [h, e] [1.0] - -# trim_array scalar function #2 -query ?? -select trim_array([[1, 2], [3, 4], [5, 6]], 2), trim_array(array_fill(4, [3, 4, 2]), 2); ----- -[[1, 2]] [[[4, 4], [4, 4], [4, 4], [4, 4]]] - -# trim_array scalar function #3 -query ? -select array_concat(trim_array(make_array(1, 2, 3), 3), make_array(4, 5), make_array()); ----- -[4, 5] - -# trim_array scalar function #4 -query ?? -select trim_array(make_array(), 0), trim_array(make_array(), 1) ----- -[] [] +## trim_array (deprecated) ## array_length (aliases: `list_length`) @@ -1864,6 +2270,9 @@ drop table nested_arrays; statement ok drop table arrays; +statement ok +drop table slices; + statement ok drop table arrays_values; diff --git a/datafusion/core/tests/sqllogictests/test_files/struct.slt b/datafusion/core/tests/sqllogictests/test_files/struct.slt new file mode 100644 index 000000000000..2629b6b038a3 --- /dev/null +++ b/datafusion/core/tests/sqllogictests/test_files/struct.slt @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +############# +## Struct Expressions Tests +############# + +statement ok +CREATE TABLE values( + a INT, + b FLOAT, + c VARCHAR +) AS VALUES + (1, 1.1, 'a'), + (2, 2.2, 'b'), + (3, 3.3, 'c') +; + +# struct[i] +query IRT +select struct(1, 3.14, 'h')['c0'], struct(3, 2.55, 'b')['c1'], struct(2, 6.43, 'a')['c2']; +---- +1 2.55 a + +# struct[i] with columns +query R +select struct(a, b, c)['c1'] from values; +---- +1.1 +2.2 +3.3 + +# struct scalar function #1 +query ? +select struct(1, 3.14, 'e'); +---- +{c0: 1, c1: 3.14, c2: e} + +# struct scalar function with columns #1 +query ? +select struct(a, b, c) from values; +---- +{c0: 1, c1: 1.1, c2: a} +{c0: 2, c1: 2.2, c2: b} +{c0: 3, c1: 3.3, c2: c} + +statement ok +drop table values; \ No newline at end of file diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 3bcf0ec4c54c..2886617c0d45 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -129,6 +129,8 @@ pub enum BuiltinScalarFunction { ArrayHasAny, /// array_dims ArrayDims, + /// array_element + ArrayElement, /// array_fill ArrayFill, /// array_length @@ -153,14 +155,18 @@ pub enum BuiltinScalarFunction { ArrayReplaceN, /// array_replace_all ArrayReplaceAll, + /// array_slice + ArraySlice, /// array_to_string ArrayToString, /// cardinality Cardinality, /// construct an array from columns MakeArray, - /// trim_array - TrimArray, + + // struct functions + /// struct + Struct, // string functions /// ascii @@ -259,8 +265,6 @@ pub enum BuiltinScalarFunction { Uuid, /// regexp_match RegexpMatch, - /// struct - Struct, /// arrow_typeof ArrowTypeof, } @@ -349,6 +353,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayHasAny => Volatility::Immutable, BuiltinScalarFunction::ArrayHas => Volatility::Immutable, BuiltinScalarFunction::ArrayDims => Volatility::Immutable, + BuiltinScalarFunction::ArrayElement => Volatility::Immutable, BuiltinScalarFunction::ArrayFill => Volatility::Immutable, BuiltinScalarFunction::ArrayLength => Volatility::Immutable, BuiltinScalarFunction::ArrayNdims => Volatility::Immutable, @@ -361,10 +366,10 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayReplace => Volatility::Immutable, BuiltinScalarFunction::ArrayReplaceN => Volatility::Immutable, BuiltinScalarFunction::ArrayReplaceAll => Volatility::Immutable, + BuiltinScalarFunction::ArraySlice => Volatility::Immutable, BuiltinScalarFunction::ArrayToString => Volatility::Immutable, BuiltinScalarFunction::Cardinality => Volatility::Immutable, BuiltinScalarFunction::MakeArray => Volatility::Immutable, - BuiltinScalarFunction::TrimArray => Volatility::Immutable, BuiltinScalarFunction::Ascii => Volatility::Immutable, BuiltinScalarFunction::BitLength => Volatility::Immutable, BuiltinScalarFunction::Btrim => Volatility::Immutable, @@ -525,6 +530,12 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayDims => { Ok(List(Arc::new(Field::new("item", UInt64, true)))) } + BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] { + List(field) => Ok(field.data_type().clone()), + _ => Err(DataFusionError::Internal(format!( + "The {self} function can only accept list as the first argument" + ))), + }, BuiltinScalarFunction::ArrayFill => Ok(List(Arc::new(Field::new( "item", input_expr_types[1].clone(), @@ -543,6 +554,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayReplace => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayReplaceN => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayToString => Ok(Utf8), BuiltinScalarFunction::Cardinality => Ok(UInt64), BuiltinScalarFunction::MakeArray => match input_expr_types.len() { @@ -559,16 +571,6 @@ impl BuiltinScalarFunction { Ok(List(Arc::new(Field::new("item", expr_type, true)))) } }, - BuiltinScalarFunction::TrimArray => match &input_expr_types[0] { - List(field) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {self} function can only accept list as the first argument" - ))), - }, BuiltinScalarFunction::Ascii => Ok(Int32), BuiltinScalarFunction::BitLength => { utf8_to_int_type(&input_expr_types[0], "bit_length") @@ -819,6 +821,7 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::ArrayHasAny | BuiltinScalarFunction::ArrayHas => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()), + BuiltinScalarFunction::ArrayElement => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayFill => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayLength => { Signature::variadic_any(self.volatility()) @@ -837,6 +840,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayReplaceAll => { Signature::any(3, self.volatility()) } + BuiltinScalarFunction::ArraySlice => Signature::any(3, self.volatility()), BuiltinScalarFunction::ArrayToString => { Signature::variadic_any(self.volatility()) } @@ -844,7 +848,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::MakeArray => { Signature::variadic_any(self.volatility()) } - BuiltinScalarFunction::TrimArray => Signature::any(2, self.volatility()), BuiltinScalarFunction::Struct => Signature::variadic( struct_expressions::SUPPORTED_STRUCT_TYPES.to_vec(), self.volatility(), @@ -1283,7 +1286,6 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { BuiltinScalarFunction::Decode => &["decode"], // other functions - BuiltinScalarFunction::Struct => &["struct"], BuiltinScalarFunction::ArrowTypeof => &["arrow_typeof"], // array functions @@ -1297,6 +1299,12 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { &["array_concat", "array_cat", "list_concat", "list_cat"] } BuiltinScalarFunction::ArrayDims => &["array_dims", "list_dims"], + BuiltinScalarFunction::ArrayElement => &[ + "array_element", + "array_extract", + "list_element", + "list_extract", + ], BuiltinScalarFunction::ArrayHasAll => &["array_has_all", "list_has_all"], BuiltinScalarFunction::ArrayHasAny => &["array_has_any", "list_has_any"], BuiltinScalarFunction::ArrayHas => { @@ -1326,6 +1334,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { BuiltinScalarFunction::ArrayReplaceAll => { &["array_replace_all", "list_replace_all"] } + BuiltinScalarFunction::ArraySlice => &["array_slice", "list_slice"], BuiltinScalarFunction::ArrayToString => &[ "array_to_string", "list_to_string", @@ -1334,7 +1343,9 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { ], BuiltinScalarFunction::Cardinality => &["cardinality"], BuiltinScalarFunction::MakeArray => &["make_array", "make_list"], - BuiltinScalarFunction::TrimArray => &["trim_array"], + + // struct functions + BuiltinScalarFunction::Struct => &["struct"], } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index ce916e173104..4adeb0653fc1 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -116,6 +116,7 @@ pub enum Expr { /// arithmetic negation of an expression, the operand must be of a signed numeric data type Negative(Box), /// Returns the field of a [`arrow::array::ListArray`] or [`arrow::array::StructArray`] by key + /// GetIndexedField(GetIndexedField), /// Whether an expression is between a given range. Between(Between), @@ -358,19 +359,57 @@ impl ScalarUDF { } } -/// Returns the field of a [`arrow::array::ListArray`] or [`arrow::array::StructArray`] by key +/// Key of `GetIndexedFieldKey`. +/// This structure is needed to separate the responsibilities of the key for `DataType::List` and `DataType::Struct`. +/// If we use index with `DataType::List`, then we use the `list_key` argument with `struct_key` equal to `None`. +/// If we use index with `DataType::Struct`, then we use the `struct_key` argument with `list_key` equal to `None`. +/// `list_key` can be any expression, unlike `struct_key` which can only be `ScalarValue::Utf8`. +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub struct GetIndexedFieldKey { + /// The key expression for `DataType::List` + pub list_key: Option, + /// The key expression for `DataType::Struct` + pub struct_key: Option, +} + +impl GetIndexedFieldKey { + /// Create a new GetIndexedFieldKey expression + pub fn new(list_key: Option, struct_key: Option) -> Self { + // value must be either `list_key` or `struct_key` + assert_ne!(list_key.is_some(), struct_key.is_some()); + assert_ne!(list_key.is_none(), struct_key.is_none()); + + Self { + list_key, + struct_key, + } + } +} + +/// Returns the field of a [`arrow::array::ListArray`] or [`arrow::array::StructArray`] by `key`. +/// If `extra_key` is not `None`, returns the slice of a [`arrow::array::ListArray`] in the range from `key` to `extra_key`. #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct GetIndexedField { - /// the expression to take the field from + /// The expression to take the field from pub expr: Box, /// The name of the field to take - pub key: ScalarValue, + pub key: Box, + /// The right border of the field to take + pub extra_key: Option>, } impl GetIndexedField { /// Create a new GetIndexedField expression - pub fn new(expr: Box, key: ScalarValue) -> Self { - Self { expr, key } + pub fn new( + expr: Box, + key: Box, + extra_key: Option>, + ) -> Self { + Self { + expr, + key, + extra_key, + } } } @@ -1139,8 +1178,21 @@ impl fmt::Display for Expr { } Expr::Wildcard => write!(f, "*"), Expr::QualifiedWildcard { qualifier } => write!(f, "{qualifier}.*"), - Expr::GetIndexedField(GetIndexedField { key, expr }) => { - write!(f, "({expr})[{key}]") + Expr::GetIndexedField(GetIndexedField { + key, + extra_key, + expr, + }) => { + let key = if let Some(list_key) = &key.list_key { + format!("{list_key}") + } else { + format!("{0}", key.struct_key.clone().unwrap()) + }; + if let Some(extra_key) = extra_key { + write!(f, "({expr})[{key}:{extra_key}]") + } else { + write!(f, "({expr})[{key}]") + } } Expr::GroupingSet(grouping_sets) => match grouping_sets { GroupingSet::Rollup(exprs) => { @@ -1330,9 +1382,25 @@ fn create_name(e: &Expr) -> Result { Expr::ScalarSubquery(subquery) => { Ok(subquery.subquery.schema().field(0).name().clone()) } - Expr::GetIndexedField(GetIndexedField { key, expr }) => { + Expr::GetIndexedField(GetIndexedField { + key, + extra_key, + expr, + }) => { let expr = create_name(expr)?; - Ok(format!("{expr}[{key}]")) + let key = if let Some(list_key) = &key.list_key { + create_name(list_key)? + } else if let Some(ScalarValue::Utf8(Some(struct_key))) = &key.struct_key { + struct_key.to_string() + } else { + String::new() + }; + if let Some(extra_key) = extra_key { + let extra_key = create_name(extra_key)?; + Ok(format!("{expr}[{key}:{extra_key}]")) + } else { + Ok(format!("{expr}[{key}]")) + } } Expr::ScalarFunction(func) => { create_function_name(&func.fun.to_string(), false, &func.args) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 4a59e92999d5..6d0b6c1d6535 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -570,6 +570,12 @@ scalar_expr!( array, "returns an array of the array's dimensions." ); +scalar_expr!( + ArrayElement, + array_element, + array element, + "extracts the element with the index n from the array." +); scalar_expr!( ArrayFill, array_fill, @@ -642,6 +648,12 @@ scalar_expr!( array from to, "replaces all occurrences of the specified element with another specified element." ); +scalar_expr!( + ArraySlice, + array_slice, + array offset length, + "returns a slice of the array." +); scalar_expr!( ArrayToString, array_to_string, @@ -659,12 +671,6 @@ nary_scalar_expr!( array, "returns an Arrow array using the specified input expressions." ); -scalar_expr!( - TrimArray, - trim_array, - array n, - "removes the last n elements from the array." -); // string functions scalar_expr!(Ascii, ascii, chr, "ASCII code value of the character"); @@ -1071,7 +1077,6 @@ mod test { test_scalar_expr!(ArrayToString, array_to_string, array, delimiter); test_unary_scalar_expr!(Cardinality, cardinality); test_nary_scalar_expr!(MakeArray, array, input); - test_scalar_expr!(TrimArray, trim_array, array, n); test_unary_scalar_expr!(ArrowTypeof, arrow_typeof); } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index e988ed8c3f3e..6ca32c4377fc 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -155,10 +155,24 @@ impl ExprSchemable for Expr { // grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) } - Expr::GetIndexedField(GetIndexedField { key, expr }) => { - let data_type = expr.get_type(schema)?; - - get_indexed_field(&data_type, key).map(|x| x.data_type().clone()) + Expr::GetIndexedField(GetIndexedField { + key, + extra_key, + expr, + }) => { + let expr_dt = expr.get_type(schema)?; + let key = if let Some(list_key) = &key.list_key { + (Some(list_key.get_type(schema)?), None) + } else { + (None, key.struct_key.clone()) + }; + let extra_key_dt = if let Some(extra_key) = extra_key { + Some(extra_key.get_type(schema)?) + } else { + None + }; + get_indexed_field(&expr_dt, &key, &extra_key_dt) + .map(|x| x.data_type().clone()) } } } @@ -266,9 +280,23 @@ impl ExprSchemable for Expr { "QualifiedWildcard expressions are not valid in a logical query plan" .to_owned(), )), - Expr::GetIndexedField(GetIndexedField { key, expr }) => { - let data_type = expr.get_type(input_schema)?; - get_indexed_field(&data_type, key).map(|x| x.is_nullable()) + Expr::GetIndexedField(GetIndexedField { + key, + extra_key, + expr, + }) => { + let expr_dt = expr.get_type(input_schema)?; + let key = if let Some(list_key) = &key.list_key { + (Some(list_key.get_type(input_schema)?), None) + } else { + (None, key.struct_key.clone()) + }; + let extra_key_dt = if let Some(extra_key) = extra_key { + Some(extra_key.get_type(input_schema)?) + } else { + None + }; + get_indexed_field(&expr_dt, &key, &extra_key_dt).map(|x| x.is_nullable()) } Expr::GroupingSet(_) => { // grouping sets do not really have the concept of nullable and do not appear diff --git a/datafusion/expr/src/field_util.rs b/datafusion/expr/src/field_util.rs index 2af242229717..b490a57cb582 100644 --- a/datafusion/expr/src/field_util.rs +++ b/datafusion/expr/src/field_util.rs @@ -20,17 +20,28 @@ use arrow::datatypes::{DataType, Field}; use datafusion_common::{plan_err, DataFusionError, Result, ScalarValue}; -/// Returns the field access indexed by `key` from a [`DataType::List`] or [`DataType::Struct`] +/// Returns the field access indexed by `key` and/or `extra_key` from a [`DataType::List`] or [`DataType::Struct`] /// # Error /// Errors if -/// * the `data_type` is not a Struct or, +/// * the `data_type` is not a Struct or a List, +/// * the `data_type` of extra key does not match with `data_type` of key /// * there is no field key is not of the required index type -pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result { +pub fn get_indexed_field( + data_type: &DataType, + key: &(Option, Option), + extra_key: &Option, +) -> Result { match (data_type, key) { - (DataType::List(lt), ScalarValue::Int64(Some(i))) => { - Ok(Field::new(i.to_string(), lt.data_type().clone(), true)) + (DataType::List(lt), (Some(DataType::Int64), None)) => { + match extra_key { + Some(DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)), + None => Ok(Field::new("list", lt.data_type().clone(), true)), + _ => Err(DataFusionError::Plan( + "Only ints are valid as an indexed field in a list".to_string(), + )), + } } - (DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => { + (DataType::Struct(fields), (None, Some(ScalarValue::Utf8(Some(s))))) => { if s.is_empty() { plan_err!( "Struct based indexed access requires a non empty string" diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 9eba7124b7c5..8a4eae74fe9b 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -60,7 +60,8 @@ pub use aggregate_function::AggregateFunction; pub use built_in_function::BuiltinScalarFunction; pub use columnar_value::ColumnarValue; pub use expr::{ - Between, BinaryExpr, Case, Cast, Expr, GetIndexedField, GroupingSet, Like, TryCast, + Between, BinaryExpr, Case, Cast, Expr, GetIndexedField, GetIndexedFieldKey, + GroupingSet, Like, TryCast, }; pub use expr_fn::*; pub use expr_schema::ExprSchemable; diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index c56228d40b0d..d057260fa582 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -342,12 +342,15 @@ impl TreeNode for Expr { Expr::QualifiedWildcard { qualifier } => { Expr::QualifiedWildcard { qualifier } } - Expr::GetIndexedField(GetIndexedField { key, expr }) => { - Expr::GetIndexedField(GetIndexedField::new( - transform_boxed(expr, &mut transform)?, - key, - )) - } + Expr::GetIndexedField(GetIndexedField { + key, + extra_key, + expr, + }) => Expr::GetIndexedField(GetIndexedField::new( + transform_boxed(expr, &mut transform)?, + key, + extra_key, + )), Expr::Placeholder(Placeholder { id, data_type }) => { Expr::Placeholder(Placeholder { id, data_type }) } diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 4e32a00471b1..74eca0b8cd74 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -361,6 +361,177 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result { } } +fn return_empty(return_null: bool, data_type: DataType) -> Arc { + if return_null { + new_null_array(&data_type, 1) + } else { + new_empty_array(&data_type) + } +} + +macro_rules! list_slice { + ($ARRAY:expr, $I:expr, $J:expr, $RETURN_ELEMENT:expr, $ARRAY_TYPE:ident) => {{ + let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); + if $I == 0 && $J == 0 || $ARRAY.is_empty() { + return return_empty($RETURN_ELEMENT, $ARRAY.data_type().clone()); + } + + let i = if $I < 0 { + if $I.abs() as usize > array.len() { + return return_empty(true, $ARRAY.data_type().clone()); + } + + (array.len() as i64 + $I + 1) as usize + } else { + if $I == 0 { + 1 + } else { + $I as usize + } + }; + let j = if $J < 0 { + if $J.abs() as usize > array.len() { + return return_empty(true, $ARRAY.data_type().clone()); + } + + if $RETURN_ELEMENT { + (array.len() as i64 + $J + 1) as usize + } else { + (array.len() as i64 + $J) as usize + } + } else { + if $J == 0 { + 1 + } else { + if $J as usize > array.len() { + array.len() + } else { + $J as usize + } + } + }; + + if i > j || i as usize > $ARRAY.len() { + return_empty($RETURN_ELEMENT, $ARRAY.data_type().clone()) + } else { + Arc::new(array.slice((i - 1), (j + 1 - i))) + } + }}; +} + +macro_rules! slice { + ($ARRAY:expr, $KEY:expr, $EXTRA_KEY:expr, $RETURN_ELEMENT:expr, $ARRAY_TYPE:ident) => {{ + let sliced_array: Vec> = $ARRAY + .iter() + .zip($KEY.iter()) + .zip($EXTRA_KEY.iter()) + .map(|((arr, i), j)| match (arr, i, j) { + (Some(arr), Some(i), Some(j)) => { + list_slice!(arr, i, j, $RETURN_ELEMENT, $ARRAY_TYPE) + } + (Some(arr), None, Some(j)) => { + list_slice!(arr, 1i64, j, $RETURN_ELEMENT, $ARRAY_TYPE) + } + (Some(arr), Some(i), None) => { + list_slice!(arr, i, arr.len() as i64, $RETURN_ELEMENT, $ARRAY_TYPE) + } + (Some(arr), None, None) if !$RETURN_ELEMENT => arr, + _ => return_empty($RETURN_ELEMENT, $ARRAY.value_type().clone()), + }) + .collect(); + + // concat requires input of at least one array + if sliced_array.is_empty() { + Ok(return_empty($RETURN_ELEMENT, $ARRAY.value_type())) + } else { + let vec = sliced_array + .iter() + .map(|a| a.as_ref()) + .collect::>(); + let mut i: i32 = 0; + let mut offsets = vec![i]; + offsets.extend( + vec.iter() + .map(|a| { + i += a.len() as i32; + i + }) + .collect::>(), + ); + let values = compute::concat(vec.as_slice()).unwrap(); + + if $RETURN_ELEMENT { + Ok(values) + } else { + let field = + Arc::new(Field::new("item", $ARRAY.value_type().clone(), true)); + Ok(Arc::new(ListArray::try_new( + field, + OffsetBuffer::new(offsets.into()), + values, + None, + )?)) + } + } + }}; +} + +fn define_array_slice( + list_array: &ListArray, + key: &Int64Array, + extra_key: &Int64Array, + return_element: bool, +) -> Result { + match list_array.value_type() { + DataType::List(_) => { + slice!(list_array, key, extra_key, return_element, ListArray) + } + DataType::Utf8 => slice!(list_array, key, extra_key, return_element, StringArray), + DataType::LargeUtf8 => { + slice!(list_array, key, extra_key, return_element, LargeStringArray) + } + DataType::Boolean => { + slice!(list_array, key, extra_key, return_element, BooleanArray) + } + DataType::Float32 => { + slice!(list_array, key, extra_key, return_element, Float32Array) + } + DataType::Float64 => { + slice!(list_array, key, extra_key, return_element, Float64Array) + } + DataType::Int8 => slice!(list_array, key, extra_key, return_element, Int8Array), + DataType::Int16 => slice!(list_array, key, extra_key, return_element, Int16Array), + DataType::Int32 => slice!(list_array, key, extra_key, return_element, Int32Array), + DataType::Int64 => slice!(list_array, key, extra_key, return_element, Int64Array), + DataType::UInt8 => slice!(list_array, key, extra_key, return_element, UInt8Array), + DataType::UInt16 => { + slice!(list_array, key, extra_key, return_element, UInt16Array) + } + DataType::UInt32 => { + slice!(list_array, key, extra_key, return_element, UInt32Array) + } + DataType::UInt64 => { + slice!(list_array, key, extra_key, return_element, UInt64Array) + } + data_type => Err(DataFusionError::NotImplemented(format!( + "array is not implemented for types '{data_type:?}'" + ))), + } +} + +pub fn array_element(args: &[ArrayRef]) -> Result { + let list_array = as_list_array(&args[0])?; + let key = as_int64_array(&args[1])?; + define_array_slice(list_array, key, key, true) +} + +pub fn array_slice(args: &[ArrayRef]) -> Result { + let list_array = as_list_array(&args[0])?; + let key = as_int64_array(&args[1])?; + let extra_key = as_int64_array(&args[2])?; + define_array_slice(list_array, key, extra_key, false) +} + macro_rules! append { ($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{ let mut offsets: Vec = vec![0]; @@ -1481,25 +1652,6 @@ pub fn array_to_string(args: &[ArrayRef]) -> Result { Ok(Arc::new(StringArray::from(res))) } -/// Trim_array SQL function -pub fn trim_array(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let n = as_int64_array(&args[1])?.value(0) as usize; - - let values = list_array.value(0); - if values.len() <= n { - return Ok(array(&[ColumnarValue::Scalar(ScalarValue::Null)])?.into_array(1)); - } - - let res = values.slice(0, values.len() - n); - let mut scalars = vec![]; - for i in 0..res.len() { - scalars.push(ColumnarValue::Scalar(ScalarValue::try_from_array(&res, i)?)); - } - - Ok(array(scalars.as_slice())?.into_array(1)) -} - /// Cardinality SQL function pub fn cardinality(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?.clone(); @@ -1950,6 +2102,364 @@ mod tests { ) } + #[test] + fn test_array_element() { + // array_element([1, 2, 3, 4], 1) = 1 + let list_array = return_array().into_array(1); + let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(1, 1))]) + .expect("failed to initialize function array_element"); + let result = + as_int64_array(&arr).expect("failed to initialize function array_element"); + + assert_eq!(result, &Int64Array::from_value(1, 1)); + + // array_element([1, 2, 3, 4], 3) = 3 + let list_array = return_array().into_array(1); + let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(3, 1))]) + .expect("failed to initialize function array_element"); + let result = + as_int64_array(&arr).expect("failed to initialize function array_element"); + + assert_eq!(result, &Int64Array::from_value(3, 1)); + + // array_element([1, 2, 3, 4], 0) = NULL + let list_array = return_array().into_array(1); + let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(0, 1))]) + .expect("failed to initialize function array_element"); + let result = + as_int64_array(&arr).expect("failed to initialize function array_element"); + + assert_eq!(result, &Int64Array::from(vec![None])); + + // array_element([1, 2, 3, 4], NULL) = NULL + let list_array = return_array().into_array(1); + let arr = array_element(&[list_array, Arc::new(Int64Array::from(vec![None]))]) + .expect("failed to initialize function array_element"); + let result = + as_int64_array(&arr).expect("failed to initialize function array_element"); + + assert_eq!(result, &Int64Array::from(vec![None])); + + // array_element([1, 2, 3, 4], -1) = 4 + let list_array = return_array().into_array(1); + let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(-1, 1))]) + .expect("failed to initialize function array_element"); + let result = + as_int64_array(&arr).expect("failed to initialize function array_element"); + + assert_eq!(result, &Int64Array::from_value(4, 1)); + + // array_element([1, 2, 3, 4], -3) = 2 + let list_array = return_array().into_array(1); + let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(-3, 1))]) + .expect("failed to initialize function array_element"); + let result = + as_int64_array(&arr).expect("failed to initialize function array_element"); + + assert_eq!(result, &Int64Array::from_value(2, 1)); + + // array_element([1, 2, 3, 4], 10) = NULL + let list_array = return_array().into_array(1); + let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(10, 1))]) + .expect("failed to initialize function array_element"); + let result = + as_int64_array(&arr).expect("failed to initialize function array_element"); + + assert_eq!(result, &Int64Array::from(vec![None])); + } + + #[test] + fn test_nested_array_element() { + // array_element([[1, 2, 3, 4], [5, 6, 7, 8]], 2) = [5, 6, 7, 8] + let list_array = return_nested_array().into_array(1); + let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(2, 1))]) + .expect("failed to initialize function array_element"); + let result = + as_list_array(&arr).expect("failed to initialize function array_element"); + + assert_eq!( + &[5, 6, 7, 8], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + } + + #[test] + fn test_array_slice() { + // array_slice([1, 2, 3, 4], 1, 3) = [1, 2, 3] + let list_array = return_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(1, 1)), + Arc::new(Int64Array::from_value(3, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert_eq!( + &[1, 2, 3], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + + // array_slice([1, 2, 3, 4], 2, 2) = [2] + let list_array = return_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(2, 1)), + Arc::new(Int64Array::from_value(2, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert_eq!( + &[2], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + + // array_slice([1, 2, 3, 4], 0, 0) = [] + let list_array = return_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(0, 1)), + Arc::new(Int64Array::from_value(0, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert!(result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .is_empty()); + + // array_slice([1, 2, 3, 4], 0, 6) = [1, 2, 3, 4] + let list_array = return_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(0, 1)), + Arc::new(Int64Array::from_value(6, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert_eq!( + &[1, 2, 3, 4], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + + // array_slice([1, 2, 3, 4], -2, -2) = [] + let list_array = return_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(-2, 1)), + Arc::new(Int64Array::from_value(-2, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert!(result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .is_empty()); + + // array_slice([1, 2, 3, 4], -3, -1) = [2, 3] + let list_array = return_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(-3, 1)), + Arc::new(Int64Array::from_value(-1, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert_eq!( + &[2, 3], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + + // array_slice([1, 2, 3, 4], -3, 2) = [2] + let list_array = return_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(-3, 1)), + Arc::new(Int64Array::from_value(2, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert_eq!( + &[2], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + + // array_slice([1, 2, 3, 4], 2, 11) = [2, 3, 4] + let list_array = return_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(2, 1)), + Arc::new(Int64Array::from_value(11, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert_eq!( + &[2, 3, 4], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + + // array_slice([1, 2, 3, 4], 3, 1) = [] + let list_array = return_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(3, 1)), + Arc::new(Int64Array::from_value(1, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert!(result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .is_empty()); + + // array_slice([1, 2, 3, 4], -7, -2) = NULL + let list_array = return_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(-7, 1)), + Arc::new(Int64Array::from_value(-2, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert!(result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .is_null(0)); + } + + #[test] + fn test_nested_array_slice() { + // array_slice([[1, 2, 3, 4], [5, 6, 7, 8]], 1, 1) = [[1, 2, 3, 4]] + let list_array = return_nested_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(1, 1)), + Arc::new(Int64Array::from_value(1, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert_eq!( + &[1, 2, 3, 4], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + + // array_slice([[1, 2, 3, 4], [5, 6, 7, 8]], -1, -1) = [] + let list_array = return_nested_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(-1, 1)), + Arc::new(Int64Array::from_value(-1, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert!(result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .is_empty()); + + // array_slice([[1, 2, 3, 4], [5, 6, 7, 8]], -1, 2) = [[5, 6, 7, 8]] + let list_array = return_nested_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(-1, 1)), + Arc::new(Int64Array::from_value(2, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert_eq!( + &[5, 6, 7, 8], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + } + #[test] fn test_array_append() { // array_append([1, 2, 3], 4) = [1, 2, 3, 4] @@ -2548,69 +3058,6 @@ mod tests { assert_eq!("1-*-3-*-*-6-7-*", result.value(0)); } - #[test] - fn test_trim_array() { - // trim_array([1, 2, 3, 4], 1) = [1, 2, 3] - let list_array = return_array().into_array(1); - let arr = trim_array(&[list_array, Arc::new(Int64Array::from(vec![Some(1)]))]) - .expect("failed to initialize function trim_array"); - let result = - as_list_array(&arr).expect("failed to initialize function trim_array"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // trim_array([1, 2, 3, 4], 3) = [1] - let list_array = return_array().into_array(1); - let arr = trim_array(&[list_array, Arc::new(Int64Array::from(vec![Some(3)]))]) - .expect("failed to initialize function trim_array"); - let result = - as_list_array(&arr).expect("failed to initialize function trim_array"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_trim_array() { - // trim_array([[1, 2, 3, 4], [5, 6, 7, 8]], 1) = [[1, 2, 3, 4]] - let list_array = return_nested_array().into_array(1); - let arr = trim_array(&[list_array, Arc::new(Int64Array::from(vec![Some(1)]))]) - .expect("failed to initialize function trim_array"); - let binding = as_list_array(&arr) - .expect("failed to initialize function trim_array") - .value(0); - let result = - as_list_array(&binding).expect("failed to initialize function trim_array"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - #[test] fn test_cardinality() { // cardinality([1, 2, 3, 4]) = 4 diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs index 090cfe5a6e64..4386913d368e 100644 --- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs +++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs @@ -19,43 +19,120 @@ use crate::PhysicalExpr; use arrow::array::Array; -use arrow::compute::concat; +use crate::array_expressions::{array_element, array_slice}; use crate::physical_expr::down_cast_any_ref; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::cast::{as_list_array, as_struct_array}; -use datafusion_common::DataFusionError; -use datafusion_common::Result; -use datafusion_common::ScalarValue; +use datafusion_common::{cast::as_struct_array, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ field_util::get_indexed_field as get_data_type_field, ColumnarValue, }; -use std::convert::TryInto; use std::fmt::Debug; use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; -/// expression to get a field of a struct array. +/// Key of `GetIndexedFieldExpr`. +/// This structure is needed to separate the responsibilities of the key for `DataType::List` and `DataType::Struct`. +/// If we use index with `DataType::List`, then we use the `list_key` argument with `struct_key` equal to `None`. +/// If we use index with `DataType::Struct`, then we use the `struct_key` argument with `list_key` equal to `None`. +/// `list_key` can be any expression, unlike `struct_key` which can only be `ScalarValue::Utf8`. +#[derive(Clone, Hash, Debug)] +pub struct GetIndexedFieldExprKey { + /// The key expression for `DataType::List` + list_key: Option>, + /// The key expression for `DataType::Struct` + struct_key: Option, +} + +impl GetIndexedFieldExprKey { + /// Create new get field expression key + pub fn new( + list_key: Option>, + struct_key: Option, + ) -> Self { + Self { + list_key, + struct_key, + } + } + + /// Get the key expression for `DataType::List` + pub fn list_key(&self) -> &Option> { + &self.list_key + } + + /// Get the key expression for `DataType::Struct` + pub fn struct_key(&self) -> &Option { + &self.struct_key + } +} + +impl std::fmt::Display for GetIndexedFieldExprKey { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + if let Some(list_key) = &self.list_key { + write!(f, "{}", list_key) + } else { + write!(f, "{}", self.struct_key.clone().unwrap()) + } + } +} + +impl PartialEq for GetIndexedFieldExprKey { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + if let Some(list_key) = &self.list_key { + list_key.eq(&x.list_key.clone().unwrap()) + } else { + self.struct_key + .clone() + .unwrap() + .eq(&x.struct_key.clone().unwrap()) + } + }) + .unwrap_or(false) + } +} + +/// Expression to get a field of a struct array. #[derive(Debug, Hash)] pub struct GetIndexedFieldExpr { + /// The expression to find arg: Arc, - key: ScalarValue, + /// The key statement + key: GetIndexedFieldExprKey, + /// The extra key (it can be used only with `DataType::List`) + extra_key: Option>, } impl GetIndexedFieldExpr { /// Create new get field expression - pub fn new(arg: Arc, key: ScalarValue) -> Self { - Self { arg, key } + pub fn new( + arg: Arc, + key: GetIndexedFieldExprKey, + extra_key: Option>, + ) -> Self { + Self { + arg, + key, + extra_key, + } } /// Get the input key - pub fn key(&self) -> &ScalarValue { + pub fn key(&self) -> &GetIndexedFieldExprKey { &self.key } + /// Get the input extra key + pub fn extra_key(&self) -> &Option> { + &self.extra_key + } + /// Get the input expression pub fn arg(&self) -> &Arc { &self.arg @@ -64,7 +141,11 @@ impl GetIndexedFieldExpr { impl std::fmt::Display for GetIndexedFieldExpr { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "({}).[{}]", self.arg, self.key) + if let Some(extra_key) = &self.extra_key { + write!(f, "({}).[{}:{}]", self.arg, self.key, extra_key) + } else { + write!(f, "({}).[{}]", self.arg, self.key) + } } } @@ -74,70 +155,88 @@ impl PhysicalExpr for GetIndexedFieldExpr { } fn data_type(&self, input_schema: &Schema) -> Result { - let data_type = self.arg.data_type(input_schema)?; - get_data_type_field(&data_type, &self.key).map(|f| f.data_type().clone()) + let arg_dt = self.arg.data_type(input_schema)?; + let key = if let Some(list_key) = &self.key.list_key { + (Some(list_key.data_type(input_schema)?), None) + } else { + (None, self.key.struct_key.clone()) + }; + let extra_key_dt = if let Some(extra_key) = &self.extra_key { + Some(extra_key.data_type(input_schema)?) + } else { + None + }; + get_data_type_field(&arg_dt, &key, &extra_key_dt).map(|f| f.data_type().clone()) } fn nullable(&self, input_schema: &Schema) -> Result { - let data_type = self.arg.data_type(input_schema)?; - get_data_type_field(&data_type, &self.key).map(|f| f.is_nullable()) + let arg_dt = self.arg.data_type(input_schema)?; + let key = if let Some(list_key) = &self.key.list_key { + (Some(list_key.data_type(input_schema)?), None) + } else { + (None, self.key.struct_key.clone()) + }; + let extra_key_dt = if let Some(extra_key) = &self.extra_key { + Some(extra_key.data_type(input_schema)?) + } else { + None + }; + get_data_type_field(&arg_dt, &key, &extra_key_dt).map(|f| f.is_nullable()) } fn evaluate(&self, batch: &RecordBatch) -> Result { - let array = self.arg.evaluate(batch)?.into_array(1); - match (array.data_type(), &self.key) { - (DataType::List(_) | DataType::Struct(_), _) if self.key.is_null() => { - let scalar_null: ScalarValue = array.data_type().try_into()?; - Ok(ColumnarValue::Scalar(scalar_null)) + let array = self.arg.evaluate(batch)?.into_array(batch.num_rows()); + if let Some(extra_key) = &self.extra_key { + let list_key = self + .key + .list_key + .clone() + .unwrap() + .evaluate(batch)? + .into_array(batch.num_rows()); + let extra_key = extra_key.evaluate(batch)?.into_array(batch.num_rows()); + match (array.data_type(), list_key.data_type(), extra_key.data_type()) { + (DataType::List(_), DataType::Int64, DataType::Int64) => Ok(ColumnarValue::Array(array_slice(&[ + array, list_key, extra_key + ])?)), + (DataType::List(_), key, extra_key) => Err(DataFusionError::Execution( + format!("get indexed field is only possible on lists with int64 indexes. \ + Tried with {key:?} and {extra_key:?} indices"))), + (dt, key, extra_key) => Err(DataFusionError::Execution( + format!("get indexed field is only possible on lists with int64 indexes or struct \ + with utf8 indexes. Tried {dt:?} with {key:?} and {extra_key:?} indices"))), } - (DataType::List(lst), ScalarValue::Int64(Some(i))) => { - let as_list_array = as_list_array(&array)?; - - if *i < 1 || as_list_array.is_empty() { - let scalar_null: ScalarValue = lst.data_type().try_into()?; - return Ok(ColumnarValue::Scalar(scalar_null)) + } else if let Some(list_key) = &self.key.list_key { + let list_key = list_key.evaluate(batch)?.into_array(batch.num_rows()); + match (array.data_type(), list_key.data_type()) { + (DataType::List(_), DataType::Int64) => Ok(ColumnarValue::Array(array_element(&[ + array, list_key + ])?)), + (DataType::List(_), key) => Err(DataFusionError::Execution( + format!("get indexed field is only possible on lists with int64 indexes. \ + Tried with {key:?} index"))), + (dt, key) => Err(DataFusionError::Execution( + format!("get indexed field is only possible on lists with int64 indexes or struct \ + with utf8 indexes. Tried {dt:?} with {key:?} index"))), } - - let sliced_array: Vec> = as_list_array - .iter() - .filter_map(|o| match o { - Some(list) => if *i as usize > list.len() { - None - } else { - Some(list.slice((*i -1) as usize, 1)) - }, - None => None - }) - .collect(); - - // concat requires input of at least one array - if sliced_array.is_empty() { - let scalar_null: ScalarValue = lst.data_type().try_into()?; - Ok(ColumnarValue::Scalar(scalar_null)) - } else { - let vec = sliced_array.iter().map(|a| a.as_ref()).collect::>(); - let iter = concat(vec.as_slice()).unwrap(); - - Ok(ColumnarValue::Array(iter)) - } - } - (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { - let as_struct_array = as_struct_array(&array)?; - match as_struct_array.column_by_name(k) { - None => Err(DataFusionError::Execution( - format!("get indexed field {k} not found in struct"))), - Some(col) => Ok(ColumnarValue::Array(col.clone())) + } else { + let struct_key = self.key.struct_key.clone().unwrap(); + match (array.data_type(), struct_key) { + (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { + let as_struct_array = as_struct_array(&array)?; + match as_struct_array.column_by_name(&k) { + None => Err(DataFusionError::Execution( + format!("get indexed field {k} not found in struct"))), + Some(col) => Ok(ColumnarValue::Array(col.clone())) + } + } + (DataType::Struct(_), key) => Err(DataFusionError::Execution( + format!("get indexed field is only possible on struct with utf8 indexes. \ + Tried with {key:?} index"))), + (dt, key) => Err(DataFusionError::Execution( + format!("get indexed field is only possible on lists with int64 indexes or struct \ + with utf8 indexes. Tried {dt:?} with {key:?} index"))), } - } - (DataType::List(_), key) => Err(DataFusionError::Execution( - format!("get indexed field is only possible on lists with int64 indexes. \ - Tried with {key:?} index"))), - (DataType::Struct(_), key) => Err(DataFusionError::Execution( - format!("get indexed field is only possible on struct with utf8 indexes. \ - Tried with {key:?} index"))), - (dt, key) => Err(DataFusionError::Execution( - format!("get indexed field is only possible on lists with int64 indexes or struct \ - with utf8 indexes. Tried {dt:?} with {key:?} index"))), } } @@ -152,6 +251,7 @@ impl PhysicalExpr for GetIndexedFieldExpr { Ok(Arc::new(GetIndexedFieldExpr::new( children[0].clone(), self.key.clone(), + self.extra_key.clone(), ))) } @@ -165,7 +265,15 @@ impl PartialEq for GetIndexedFieldExpr { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) .downcast_ref::() - .map(|x| self.arg.eq(&x.arg) && self.key == x.key) + .map(|x| { + if let Some(extra_key) = &self.extra_key { + self.arg.eq(&x.arg) + && self.key.eq(&x.key) + && extra_key.eq(&x.extra_key) + } else { + self.arg.eq(&x.arg) && self.key.eq(&x.key) + } + }) .unwrap_or(false) } } @@ -173,301 +281,217 @@ impl PartialEq for GetIndexedFieldExpr { #[cfg(test)] mod tests { use super::*; - use crate::expressions::{col, lit}; - use arrow::array::{ArrayRef, Float64Array, GenericListArray, PrimitiveBuilder}; + use crate::expressions::col; + use arrow::array::new_empty_array; + use arrow::array::{ArrayRef, GenericListArray}; use arrow::array::{ - Int64Array, Int64Builder, ListBuilder, StringBuilder, StructArray, StructBuilder, + BooleanArray, Int64Array, ListBuilder, StringBuilder, StructArray, }; - use arrow::datatypes::{Float64Type, Int64Type}; + use arrow::datatypes::Fields; use arrow::{array::StringArray, datatypes::Field}; - use datafusion_common::cast::{as_int64_array, as_string_array}; + use datafusion_common::cast::{as_boolean_array, as_list_array, as_string_array}; use datafusion_common::Result; - fn build_utf8_lists(list_of_lists: Vec>>) -> GenericListArray { + fn build_list_arguments( + list_of_lists: Vec>>, + list_of_keys: Vec>, + list_of_extra_keys: Vec>, + ) -> (GenericListArray, Int64Array, Int64Array) { let builder = StringBuilder::with_capacity(list_of_lists.len(), 1024); - let mut lb = ListBuilder::new(builder); + let mut list_builder = ListBuilder::new(builder); for values in list_of_lists { - let builder = lb.values(); + let builder = list_builder.values(); for value in values { match value { None => builder.append_null(), Some(v) => builder.append_value(v), } } - lb.append(true); + list_builder.append(true); } - lb.finish() + let key_array = Int64Array::from(list_of_keys); + let extra_key_array = Int64Array::from(list_of_extra_keys); + (list_builder.finish(), key_array, extra_key_array) } - fn get_indexed_field_test( - list_of_lists: Vec>>, - index: i64, - expected: Vec>, - ) -> Result<()> { - let schema = list_schema("l"); - let list_col = build_utf8_lists(list_of_lists); - let expr = col("l", &schema).unwrap(); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(list_col)])?; - let key = ScalarValue::Int64(Some(index)); - let expr = Arc::new(GetIndexedFieldExpr::new(expr, key)); - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); - let result = as_string_array(&result).expect("failed to downcast to StringArray"); - let expected = &StringArray::from(expected); - assert_eq!(expected, result); + #[test] + fn get_indexed_field_struct() -> Result<()> { + let schema = struct_schema(); + let boolean = BooleanArray::from(vec![false, false, true, true]); + let int = Int64Array::from(vec![42, 28, 19, 31]); + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("a", DataType::Boolean, true)), + Arc::new(boolean.clone()) as ArrayRef, + ), + ( + Arc::new(Field::new("b", DataType::Int64, true)), + Arc::new(int) as ArrayRef, + ), + ]); + let expr = col("str", &schema).unwrap(); + // only one row should be processed + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)])?; + let expr = Arc::new(GetIndexedFieldExpr::new( + expr, + GetIndexedFieldExprKey::new( + None, + Some(ScalarValue::Utf8(Some(String::from("a")))), + ), + None, + )); + let result = expr.evaluate(&batch)?.into_array(1); + let result = + as_boolean_array(&result).expect("failed to downcast to BooleanArray"); + assert_eq!(boolean, result.clone()); Ok(()) } - fn list_schema(col: &str) -> Schema { - Schema::new(vec![Field::new_list( - col, - Field::new("item", DataType::Utf8, true), + fn struct_schema() -> Schema { + Schema::new(vec![Field::new_struct( + "str", + Fields::from(vec![ + Field::new("a", DataType::Boolean, true), + Field::new("b", DataType::Int64, true), + ]), true, )]) } + fn list_schema(cols: &[&str]) -> Schema { + if cols.len() == 2 { + Schema::new(vec![ + Field::new_list(cols[0], Field::new("item", DataType::Utf8, true), true), + Field::new(cols[1], DataType::Int64, true), + ]) + } else { + Schema::new(vec![ + Field::new_list(cols[0], Field::new("item", DataType::Utf8, true), true), + Field::new(cols[1], DataType::Int64, true), + Field::new(cols[2], DataType::Int64, true), + ]) + } + } + #[test] - fn get_indexed_field_list() -> Result<()> { + fn get_indexed_field_list_without_extra_key() -> Result<()> { let list_of_lists = vec![ vec![Some("a"), Some("b"), None], vec![None, Some("c"), Some("d")], vec![Some("e"), None, Some("f")], ]; - let expected_list = vec![ - vec![Some("a"), None, Some("e")], - vec![Some("b"), Some("c"), None], - vec![None, Some("d"), Some("f")], - ]; - - for (i, expected) in expected_list.into_iter().enumerate() { - get_indexed_field_test(list_of_lists.clone(), (i + 1) as i64, expected)?; - } - Ok(()) - } + let list_of_keys = vec![Some(1), Some(2), None]; + let list_of_extra_keys = vec![None]; + let expected_list = vec![Some("a"), Some("c"), None]; - #[test] - fn get_indexed_field_empty_list() -> Result<()> { - let schema = list_schema("l"); - let builder = StringBuilder::new(); - let mut lb = ListBuilder::new(builder); + let schema = list_schema(&["l", "k"]); + let (list_col, key_col, _) = + build_list_arguments(list_of_lists, list_of_keys, list_of_extra_keys); let expr = col("l", &schema).unwrap(); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?; - let key = ScalarValue::Int64(Some(1)); - let expr = Arc::new(GetIndexedFieldExpr::new(expr, key)); - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); - assert!(result.is_empty()); - Ok(()) - } - - fn get_indexed_field_test_failure( - schema: Schema, - expr: Arc, - key: ScalarValue, - expected: &str, - ) -> Result<()> { - let builder = StringBuilder::with_capacity(3, 1024); - let mut lb = ListBuilder::new(builder); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?; - let expr = Arc::new(GetIndexedFieldExpr::new(expr, key)); - let r = expr.evaluate(&batch).map(|_| ()); - assert!(r.is_err()); - assert_eq!(format!("{}", r.unwrap_err()), expected); + let key = col("k", &schema).unwrap(); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(list_col), Arc::new(key_col)], + )?; + let expr = Arc::new(GetIndexedFieldExpr::new( + expr, + GetIndexedFieldExprKey::new(Some(key), None), + None, + )); + let result = expr.evaluate(&batch)?.into_array(1); + let result = as_string_array(&result).expect("failed to downcast to ListArray"); + let expected = StringArray::from(expected_list); + assert_eq!(expected, result.clone()); Ok(()) } #[test] - fn get_indexed_field_invalid_scalar() -> Result<()> { - let schema = list_schema("l"); - let expr = lit("a"); - get_indexed_field_test_failure( - schema, expr, ScalarValue::Int64(Some(0)), - "Execution error: get indexed field is only possible on lists with int64 indexes or \ - struct with utf8 indexes. Tried Utf8 with Int64(0) index") - } - - #[test] - fn get_indexed_field_invalid_list_index() -> Result<()> { - let schema = list_schema("l"); - let expr = col("l", &schema).unwrap(); - get_indexed_field_test_failure( - schema, expr, ScalarValue::Int8(Some(0)), - "Execution error: get indexed field is only possible on lists with int64 indexes. \ - Tried with Int8(0) index") - } - - fn build_struct( - fields: Vec, - list_of_tuples: Vec<(Option, Vec>)>, - ) -> StructArray { - let foo_builder = Int64Array::builder(list_of_tuples.len()); - let str_builder = StringBuilder::with_capacity(list_of_tuples.len(), 1024); - let bar_builder = ListBuilder::new(str_builder); - let mut builder = StructBuilder::new( - fields, - vec![Box::new(foo_builder), Box::new(bar_builder)], - ); - for (int_value, list_value) in list_of_tuples { - let fb = builder.field_builder::(0).unwrap(); - match int_value { - None => fb.append_null(), - Some(v) => fb.append_value(v), - }; - builder.append(true); - let lb = builder - .field_builder::>(1) - .unwrap(); - for str_value in list_value { - match str_value { - None => lb.values().append_null(), - Some(v) => lb.values().append_value(v), - }; - } - lb.append(true); - } - builder.finish() - } - - fn get_indexed_field_mixed_test( - list_of_tuples: Vec<(Option, Vec>)>, - expected_strings: Vec>>, - expected_ints: Vec>, - ) -> Result<()> { - let struct_col = "s"; - let fields = vec![ - Field::new("foo", DataType::Int64, true), - Field::new_list("bar", Field::new("item", DataType::Utf8, true), true), + fn get_indexed_field_list_with_extra_key() -> Result<()> { + let list_of_lists = vec![ + vec![Some("a"), Some("b"), None], + vec![None, Some("c"), Some("d")], + vec![Some("e"), None, Some("f")], + ]; + let list_of_keys = vec![Some(1), Some(2), None]; + let list_of_extra_keys = vec![Some(2), None, Some(3)]; + let expected_list = vec![ + vec![Some("a"), Some("b")], + vec![Some("c"), Some("d")], + vec![Some("e"), None, Some("f")], ]; - let schema = Schema::new(vec![Field::new( - struct_col, - DataType::Struct(fields.clone().into()), - true, - )]); - let struct_col = build_struct(fields, list_of_tuples.clone()); - - let struct_col_expr = col("s", &schema).unwrap(); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_col)])?; - let int_field_key = ScalarValue::Utf8(Some("foo".to_string())); - let get_field_expr = Arc::new(GetIndexedFieldExpr::new( - struct_col_expr.clone(), - int_field_key, + let schema = list_schema(&["l", "k", "ek"]); + let (list_col, key_col, extra_key_col) = + build_list_arguments(list_of_lists, list_of_keys, list_of_extra_keys); + let expr = col("l", &schema).unwrap(); + let key = col("k", &schema).unwrap(); + let extra_key = col("ek", &schema).unwrap(); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(list_col), + Arc::new(key_col), + Arc::new(extra_key_col), + ], + )?; + let expr = Arc::new(GetIndexedFieldExpr::new( + expr, + GetIndexedFieldExprKey::new(Some(key), None), + Some(extra_key), )); - let result = get_field_expr - .evaluate(&batch)? - .into_array(batch.num_rows()); - let result = as_int64_array(&result)?; - let expected = &Int64Array::from(expected_ints); - assert_eq!(expected, result); - - let list_field_key = ScalarValue::Utf8(Some("bar".to_string())); - let get_list_expr = - Arc::new(GetIndexedFieldExpr::new(struct_col_expr, list_field_key)); - let result = get_list_expr.evaluate(&batch)?.into_array(batch.num_rows()); - let result = as_list_array(&result)?; - let expected = - &build_utf8_lists(list_of_tuples.into_iter().map(|t| t.1).collect()); - assert_eq!(expected, result); - - for (i, expected) in expected_strings.into_iter().enumerate() { - let get_nested_str_expr = Arc::new(GetIndexedFieldExpr::new( - get_list_expr.clone(), - ScalarValue::Int64(Some((i + 1) as i64)), - )); - let result = get_nested_str_expr - .evaluate(&batch)? - .into_array(batch.num_rows()); - let result = as_string_array(&result)?; - let expected = &StringArray::from(expected); - assert_eq!(expected, result); - } + let result = expr.evaluate(&batch)?.into_array(1); + let result = as_list_array(&result).expect("failed to downcast to ListArray"); + let (expected, _, _) = + build_list_arguments(expected_list, vec![None], vec![None]); + assert_eq!(expected, result.clone()); Ok(()) } #[test] - fn get_indexed_field_struct() -> Result<()> { - let list_of_structs = vec![ - (Some(10), vec![Some("a"), Some("b"), None]), - (Some(15), vec![None, Some("c"), Some("d")]), - (None, vec![Some("e"), None, Some("f")]), - ]; - - let expected_list = vec![ - vec![Some("a"), None, Some("e")], - vec![Some("b"), Some("c"), None], - vec![None, Some("d"), Some("f")], - ]; - - let expected_ints = vec![Some(10), Some(15), None]; - - get_indexed_field_mixed_test( - list_of_structs.clone(), - expected_list, - expected_ints, + fn get_indexed_field_empty_list() -> Result<()> { + let schema = list_schema(&["l", "k"]); + let builder = StringBuilder::new(); + let mut list_builder = ListBuilder::new(builder); + let key_array = new_empty_array(&DataType::Int64); + let expr = col("l", &schema).unwrap(); + let key = col("k", &schema).unwrap(); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(list_builder.finish()), key_array], )?; + let expr = Arc::new(GetIndexedFieldExpr::new( + expr, + GetIndexedFieldExprKey::new(Some(key), None), + None, + )); + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + assert!(result.is_null(0)); Ok(()) } #[test] - fn get_indexed_field_list_out_of_bounds() { - let fields = vec![ - Field::new("id", DataType::Int64, true), - Field::new_list("a", Field::new("item", DataType::Float64, true), true), - ]; - - let schema = Schema::new(fields); - let mut int_builder = PrimitiveBuilder::::new(); - int_builder.append_value(1); - - let mut lb = ListBuilder::new(PrimitiveBuilder::::new()); - lb.values().append_value(1.0); - lb.values().append_null(); - lb.values().append_value(3.0); - lb.append(true); + fn get_indexed_field_invalid_list_index() -> Result<()> { + let schema = list_schema(&["l", "e"]); + let expr = col("l", &schema).unwrap(); + let key_expr = col("e", &schema).unwrap(); + let builder = StringBuilder::with_capacity(3, 1024); + let mut list_builder = ListBuilder::new(builder); + list_builder.values().append_value("hello"); + list_builder.append(true); + let key_array = Int64Array::from(vec![Some(3)]); let batch = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(int_builder.finish()), Arc::new(lb.finish())], - ) - .unwrap(); - - let col_a = col("a", &schema).unwrap(); - // out of bounds index - verify_index_evaluation(&batch, col_a.clone(), 0, float64_array(None)); - - verify_index_evaluation(&batch, col_a.clone(), 1, float64_array(Some(1.0))); - verify_index_evaluation(&batch, col_a.clone(), 2, float64_array(None)); - verify_index_evaluation(&batch, col_a.clone(), 3, float64_array(Some(3.0))); - - // out of bounds index - verify_index_evaluation(&batch, col_a.clone(), 100, float64_array(None)); - } - - fn verify_index_evaluation( - batch: &RecordBatch, - arg: Arc, - index: i64, - expected_result: ArrayRef, - ) { + Arc::new(schema), + vec![Arc::new(list_builder.finish()), Arc::new(key_array)], + )?; let expr = Arc::new(GetIndexedFieldExpr::new( - arg, - ScalarValue::Int64(Some(index)), + expr, + GetIndexedFieldExprKey::new(Some(key_expr), None), + None, )); - let result = expr.evaluate(batch).unwrap().into_array(batch.num_rows()); - assert!( - result == expected_result.clone(), - "result: {result:?} != expected result: {expected_result:?}" - ); - assert_eq!(result.data_type(), &DataType::Float64); - } - - fn float64_array(value: Option) -> ArrayRef { - match value { - Some(v) => Arc::new(Float64Array::from_value(v, 1)), - None => { - let mut b = PrimitiveBuilder::::new(); - b.append_null(); - Arc::new(b.finish()) - } - } + let result = expr.evaluate(&batch)?.into_array(1); + assert!(result.is_null(0)); + Ok(()) } } diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index c56c63db7b4d..7f18f748aa91 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -82,7 +82,7 @@ pub use case::{case, CaseExpr}; pub use cast::{cast, cast_column, cast_with_options, CastExpr}; pub use column::{col, Column, UnKnownColumn}; pub use datetime::{date_time_interval_expr, DateTimeIntervalExpr}; -pub use get_indexed_field::GetIndexedFieldExpr; +pub use get_indexed_field::{GetIndexedFieldExpr, GetIndexedFieldExprKey}; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 080cbcd62f02..e877faa3c12e 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -431,6 +431,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayDims => { Arc::new(|args| make_scalar_function(array_expressions::array_dims)(args)) } + BuiltinScalarFunction::ArrayElement => { + Arc::new(|args| make_scalar_function(array_expressions::array_element)(args)) + } BuiltinScalarFunction::ArrayFill => Arc::new(array_expressions::array_fill), BuiltinScalarFunction::ArrayLength => { Arc::new(|args| make_scalar_function(array_expressions::array_length)(args)) @@ -465,6 +468,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayReplaceAll => Arc::new(|args| { make_scalar_function(array_expressions::array_replace_all)(args) }), + BuiltinScalarFunction::ArraySlice => { + Arc::new(|args| make_scalar_function(array_expressions::array_slice)(args)) + } BuiltinScalarFunction::ArrayToString => Arc::new(|args| { make_scalar_function(array_expressions::array_to_string)(args) }), @@ -474,12 +480,11 @@ pub fn create_physical_fun( BuiltinScalarFunction::MakeArray => { Arc::new(|args| make_scalar_function(array_expressions::make_array)(args)) } - BuiltinScalarFunction::TrimArray => { - Arc::new(|args| make_scalar_function(array_expressions::trim_array)(args)) - } - // string functions + // struct functions BuiltinScalarFunction::Struct => Arc::new(struct_expressions::struct_expr), + + // string functions BuiltinScalarFunction::Ascii => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::ascii::)(args) diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index f4e4e8b1343b..5de76ac741cb 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -19,7 +19,8 @@ use crate::var_provider::is_system_variables; use crate::{ execution_props::ExecutionProps, expressions::{ - self, binary, date_time_interval_expr, like, Column, GetIndexedFieldExpr, Literal, + self, binary, date_time_interval_expr, like, Column, GetIndexedFieldExpr, + GetIndexedFieldExprKey, Literal, }, functions, udf, var_provider::VarType, @@ -338,7 +339,34 @@ pub fn create_physical_expr( input_schema, execution_props, )?), - Expr::GetIndexedField(GetIndexedField { key, expr }) => { + Expr::GetIndexedField(GetIndexedField { + key, + extra_key, + expr, + }) => { + let extra_key_expr = if let Some(extra_key) = extra_key { + Some(create_physical_expr( + extra_key, + input_dfschema, + input_schema, + execution_props, + )?) + } else { + None + }; + let key_expr = if let Some(list_key) = &key.list_key { + GetIndexedFieldExprKey::new( + Some(create_physical_expr( + list_key, + input_dfschema, + input_schema, + execution_props, + )?), + None, + ) + } else { + GetIndexedFieldExprKey::new(None, Some(key.struct_key.clone().unwrap())) + }; Ok(Arc::new(GetIndexedFieldExpr::new( create_physical_expr( expr, @@ -346,7 +374,8 @@ pub fn create_physical_expr( input_schema, execution_props, )?, - key.clone(), + key_expr, + extra_key_expr, ))) } diff --git a/datafusion/physical-expr/src/struct_expressions.rs b/datafusion/physical-expr/src/struct_expressions.rs index d4cfa309ff44..e2dd98cdf1ba 100644 --- a/datafusion/physical-expr/src/struct_expressions.rs +++ b/datafusion/physical-expr/src/struct_expressions.rs @@ -78,3 +78,55 @@ pub fn struct_expr(values: &[ColumnarValue]) -> Result { .collect(); Ok(ColumnarValue::Array(array_struct(arrays.as_slice())?)) } + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_common::cast::as_struct_array; + use datafusion_common::ScalarValue; + + #[test] + fn test_struct() { + // struct(1, 2, 3) = {"c0": 1, "c1": 2, "c2": 3} + let args = [ + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ]; + let struc = struct_expr(&args) + .expect("failed to initialize function struct") + .into_array(1); + let result = + as_struct_array(&struc).expect("failed to initialize function struct"); + assert_eq!( + &Int64Array::from(vec![1]), + result + .column_by_name("c0") + .unwrap() + .clone() + .as_any() + .downcast_ref::() + .unwrap() + ); + assert_eq!( + &Int64Array::from(vec![2]), + result + .column_by_name("c1") + .unwrap() + .clone() + .as_any() + .downcast_ref::() + .unwrap() + ); + assert_eq!( + &Int64Array::from(vec![3]), + result + .column_by_name("c2") + .unwrap() + .clone() + .as_any() + .downcast_ref::() + .unwrap() + ); + } +} diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 19dace31328d..3ef5850b546b 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -404,7 +404,8 @@ message RollupNode { message GetIndexedField { LogicalExprNode expr = 1; - ScalarValue key = 2; + LogicalExprNode key = 2; + LogicalExprNode extra_key = 3; } message IsNull { @@ -566,7 +567,8 @@ enum ScalarFunction { ArrayReplace = 96; ArrayToString = 97; Cardinality = 98; - TrimArray = 99; + ArrayElement = 99; + ArraySlice = 100; Encode = 101; Decode = 102; Cot = 103; @@ -1489,5 +1491,5 @@ message ColumnStats { message PhysicalGetIndexedFieldExprNode { PhysicalExprNode arg = 1; - ScalarValue key = 2; + PhysicalExprNode key = 2; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index dbdca6f28251..df70a70f6c4f 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -7276,6 +7276,9 @@ impl serde::Serialize for GetIndexedField { if self.key.is_some() { len += 1; } + if self.extra_key.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.GetIndexedField", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; @@ -7283,6 +7286,9 @@ impl serde::Serialize for GetIndexedField { if let Some(v) = self.key.as_ref() { struct_ser.serialize_field("key", v)?; } + if let Some(v) = self.extra_key.as_ref() { + struct_ser.serialize_field("extraKey", v)?; + } struct_ser.end() } } @@ -7295,12 +7301,15 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { const FIELDS: &[&str] = &[ "expr", "key", + "extra_key", + "extraKey", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, Key, + ExtraKey, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7324,6 +7333,7 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { match value { "expr" => Ok(GeneratedField::Expr), "key" => Ok(GeneratedField::Key), + "extraKey" | "extra_key" => Ok(GeneratedField::ExtraKey), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7345,6 +7355,7 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { { let mut expr__ = None; let mut key__ = None; + let mut extra_key__ = None; while let Some(k) = map.next_key()? { match k { GeneratedField::Expr => { @@ -7359,11 +7370,18 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { } key__ = map.next_value()?; } + GeneratedField::ExtraKey => { + if extra_key__.is_some() { + return Err(serde::de::Error::duplicate_field("extraKey")); + } + extra_key__ = map.next_value()?; + } } } Ok(GetIndexedField { expr: expr__, key: key__, + extra_key: extra_key__, }) } } @@ -18273,7 +18291,8 @@ impl serde::Serialize for ScalarFunction { Self::ArrayReplace => "ArrayReplace", Self::ArrayToString => "ArrayToString", Self::Cardinality => "Cardinality", - Self::TrimArray => "TrimArray", + Self::ArrayElement => "ArrayElement", + Self::ArraySlice => "ArraySlice", Self::Encode => "Encode", Self::Decode => "Decode", Self::Cot => "Cot", @@ -18395,7 +18414,8 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayReplace", "ArrayToString", "Cardinality", - "TrimArray", + "ArrayElement", + "ArraySlice", "Encode", "Decode", "Cot", @@ -18548,7 +18568,8 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayReplace" => Ok(ScalarFunction::ArrayReplace), "ArrayToString" => Ok(ScalarFunction::ArrayToString), "Cardinality" => Ok(ScalarFunction::Cardinality), - "TrimArray" => Ok(ScalarFunction::TrimArray), + "ArrayElement" => Ok(ScalarFunction::ArrayElement), + "ArraySlice" => Ok(ScalarFunction::ArraySlice), "Encode" => Ok(ScalarFunction::Encode), "Decode" => Ok(ScalarFunction::Decode), "Cot" => Ok(ScalarFunction::Cot), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 605bc2033e1c..a5f4d27e5fad 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -614,8 +614,10 @@ pub struct RollupNode { pub struct GetIndexedField { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, optional, tag = "2")] - pub key: ::core::option::Option, + #[prost(message, optional, boxed, tag = "2")] + pub key: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "3")] + pub extra_key: ::core::option::Option<::prost::alloc::boxed::Box>, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -2124,8 +2126,8 @@ pub struct ColumnStats { pub struct PhysicalGetIndexedFieldExprNode { #[prost(message, optional, boxed, tag = "1")] pub arg: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, optional, tag = "2")] - pub key: ::core::option::Option, + #[prost(message, optional, boxed, tag = "2")] + pub key: ::core::option::Option<::prost::alloc::boxed::Box>, } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] @@ -2299,7 +2301,8 @@ pub enum ScalarFunction { ArrayReplace = 96, ArrayToString = 97, Cardinality = 98, - TrimArray = 99, + ArrayElement = 99, + ArraySlice = 100, Encode = 101, Decode = 102, Cot = 103, @@ -2418,7 +2421,8 @@ impl ScalarFunction { ScalarFunction::ArrayReplace => "ArrayReplace", ScalarFunction::ArrayToString => "ArrayToString", ScalarFunction::Cardinality => "Cardinality", - ScalarFunction::TrimArray => "TrimArray", + ScalarFunction::ArrayElement => "ArrayElement", + ScalarFunction::ArraySlice => "ArraySlice", ScalarFunction::Encode => "Encode", ScalarFunction::Decode => "Decode", ScalarFunction::Cot => "Cot", @@ -2534,7 +2538,8 @@ impl ScalarFunction { "ArrayReplace" => Some(Self::ArrayReplace), "ArrayToString" => Some(Self::ArrayToString), "Cardinality" => Some(Self::Cardinality), - "TrimArray" => Some(Self::TrimArray), + "ArrayElement" => Some(Self::ArrayElement), + "ArraySlice" => Some(Self::ArraySlice), "Encode" => Some(Self::Encode), "Decode" => Some(Self::Decode), "Cot" => Some(Self::Cot), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 86d2a683cfe8..add4e78548d8 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -36,13 +36,14 @@ use datafusion_common::{ }; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::{ - abs, acos, acosh, array, array_append, array_concat, array_dims, array_fill, - array_has, array_has_all, array_has_any, array_length, array_ndims, array_position, - array_positions, array_prepend, array_remove, array_remove_all, array_remove_n, - array_replace, array_replace_all, array_replace_n, array_to_string, ascii, asin, - asinh, atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, - character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, - current_date, current_time, date_bin, date_part, date_trunc, degrees, digest, exp, + abs, acos, acosh, array, array_append, array_concat, array_dims, array_element, + array_fill, array_has, array_has_all, array_has_any, array_length, array_ndims, + array_position, array_positions, array_prepend, array_remove, array_remove_all, + array_remove_n, array_replace, array_replace_all, array_replace_n, array_slice, + array_to_string, ascii, asin, asinh, atan, atan2, atanh, bit_length, btrim, + cardinality, cbrt, ceil, character_length, chr, coalesce, concat_expr, + concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, date_part, + date_trunc, degrees, digest, exp, expr::{self, InList, Sort, WindowFunction}, factorial, floor, from_unixtime, gcd, lcm, left, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, @@ -50,11 +51,10 @@ use datafusion_expr::{ random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt, starts_with, strpos, substr, substring, tan, tanh, to_hex, to_timestamp_micros, - to_timestamp_millis, to_timestamp_seconds, translate, trim, trim_array, trunc, upper, - uuid, + to_timestamp_millis, to_timestamp_seconds, translate, trim, trunc, upper, uuid, window_frame::regularize, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, - Case, Cast, Expr, GetIndexedField, GroupingSet, + Case, Cast, Expr, GetIndexedField, GetIndexedFieldKey, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -456,6 +456,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrayHasAny => Self::ArrayHasAny, ScalarFunction::ArrayHas => Self::ArrayHas, ScalarFunction::ArrayDims => Self::ArrayDims, + ScalarFunction::ArrayElement => Self::ArrayElement, ScalarFunction::ArrayFill => Self::ArrayFill, ScalarFunction::ArrayLength => Self::ArrayLength, ScalarFunction::ArrayNdims => Self::ArrayNdims, @@ -468,10 +469,10 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrayReplace => Self::ArrayReplace, ScalarFunction::ArrayReplaceN => Self::ArrayReplaceN, ScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll, + ScalarFunction::ArraySlice => Self::ArraySlice, ScalarFunction::ArrayToString => Self::ArrayToString, ScalarFunction::Cardinality => Self::Cardinality, ScalarFunction::Array => Self::MakeArray, - ScalarFunction::TrimArray => Self::TrimArray, ScalarFunction::NullIf => Self::NullIf, ScalarFunction::DatePart => Self::DatePart, ScalarFunction::DateTrunc => Self::DateTrunc, @@ -932,17 +933,13 @@ pub fn parse_expr( .expect("Binary expression could not be reduced to a single expression.")) } ExprType::GetIndexedField(field) => { - let key = field - .key - .as_ref() - .ok_or_else(|| Error::required("value"))? - .try_into()?; - let expr = parse_required_expr(field.expr.as_deref(), registry, "expr")?; + let key = parse_required_expr(field.key.as_deref(), registry, "key")?; Ok(Expr::GetIndexedField(GetIndexedField::new( Box::new(expr), - key, + Box::new(GetIndexedFieldKey::new(Some(key), None)), + None, ))) } ExprType::Column(column) => Ok(Expr::Column(column.into())), @@ -1290,6 +1287,11 @@ pub fn parse_expr( parse_expr(&args[1], registry)?, parse_expr(&args[2], registry)?, )), + ScalarFunction::ArraySlice => Ok(array_slice( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, + )), ScalarFunction::ArrayToString => Ok(array_to_string( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1297,10 +1299,6 @@ pub fn parse_expr( ScalarFunction::Cardinality => { Ok(cardinality(parse_expr(&args[0], registry)?)) } - ScalarFunction::TrimArray => Ok(trim_array( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - )), ScalarFunction::ArrayLength => Ok(array_length( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1308,6 +1306,10 @@ pub fn parse_expr( ScalarFunction::ArrayDims => { Ok(array_dims(parse_expr(&args[0], registry)?)) } + ScalarFunction::ArrayElement => Ok(array_element( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::ArrayNdims => { Ok(array_ndims(parse_expr(&args[0], registry)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index e90ba317b145..78588ba536b0 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -955,11 +955,28 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { // see discussion in https://github.com/apache/arrow-datafusion/issues/2565 return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported".to_string())); } - Expr::GetIndexedField(GetIndexedField { key, expr }) => Self { + Expr::GetIndexedField(GetIndexedField { + key, + expr, + extra_key, + }) => Self { expr_type: Some(ExprType::GetIndexedField(Box::new( - protobuf::GetIndexedField { - key: Some(key.try_into()?), - expr: Some(Box::new(expr.as_ref().try_into()?)), + if let Some(extra_key) = extra_key { + protobuf::GetIndexedField { + extra_key: Some(Box::new(extra_key.as_ref().try_into()?)), + key: Some(Box::new( + (&key.as_ref().list_key.clone().unwrap()).try_into()?, + )), + expr: Some(Box::new(expr.as_ref().try_into()?)), + } + } else { + protobuf::GetIndexedField { + extra_key: None, + key: Some(Box::new( + (&key.as_ref().list_key.clone().unwrap()).try_into()?, + )), + expr: Some(Box::new(expr.as_ref().try_into()?)), + } }, ))), }, @@ -1408,6 +1425,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArrayHasAny => Self::ArrayHasAny, BuiltinScalarFunction::ArrayHas => Self::ArrayHas, BuiltinScalarFunction::ArrayDims => Self::ArrayDims, + BuiltinScalarFunction::ArrayElement => Self::ArrayElement, BuiltinScalarFunction::ArrayFill => Self::ArrayFill, BuiltinScalarFunction::ArrayLength => Self::ArrayLength, BuiltinScalarFunction::ArrayNdims => Self::ArrayNdims, @@ -1420,10 +1438,10 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArrayReplace => Self::ArrayReplace, BuiltinScalarFunction::ArrayReplaceN => Self::ArrayReplaceN, BuiltinScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll, + BuiltinScalarFunction::ArraySlice => Self::ArraySlice, BuiltinScalarFunction::ArrayToString => Self::ArrayToString, BuiltinScalarFunction::Cardinality => Self::Cardinality, BuiltinScalarFunction::MakeArray => Self::Array, - BuiltinScalarFunction::TrimArray => Self::TrimArray, BuiltinScalarFunction::NullIf => Self::NullIf, BuiltinScalarFunction::DatePart => Self::DatePart, BuiltinScalarFunction::DateTrunc => Self::DateTrunc, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 7a52e5f0d09f..c4c897b13ef9 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -30,7 +30,7 @@ use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::window_function::WindowFunction; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - date_time_interval_expr, GetIndexedFieldExpr, + date_time_interval_expr, GetIndexedFieldExpr, GetIndexedFieldExprKey, }; use datafusion::physical_plan::expressions::{in_list, LikeExpr}; use datafusion::physical_plan::{ @@ -317,7 +317,16 @@ pub fn parse_physical_expr( "arg", input_schema, )?, - convert_required!(get_indexed_field_expr.key)?, + GetIndexedFieldExprKey::new( + Some(parse_required_physical_expr( + get_indexed_field_expr.key.as_deref(), + registry, + "key", + input_schema, + )?), + None, + ), + None, )) } }; diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index e97a773d3472..34d0d747412c 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -1401,6 +1401,7 @@ mod roundtrip_tests { use datafusion::physical_plan::aggregates::PhysicalGroupBy; use datafusion::physical_plan::expressions::{ date_time_interval_expr, like, BinaryExpr, GetIndexedFieldExpr, + GetIndexedFieldExprKey, }; use datafusion::physical_plan::functions::make_scalar_function; use datafusion::physical_plan::projection::ProjectionExec; @@ -1923,14 +1924,19 @@ mod roundtrip_tests { let fields = vec![ Field::new("id", DataType::Int64, true), Field::new_list("a", Field::new("item", DataType::Float64, true), true), + Field::new("b", DataType::Int64, true), ]; let schema = Schema::new(fields); let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); let col_a = col("a", &schema)?; - let key = ScalarValue::Int64(Some(1)); - let get_indexed_field_expr = Arc::new(GetIndexedFieldExpr::new(col_a, key)); + let col_b = col("b", &schema)?; + let get_indexed_field_expr = Arc::new(GetIndexedFieldExpr::new( + col_a, + GetIndexedFieldExprKey::new(Some(col_b), None), + None, + )); let plan = Arc::new(ProjectionExec::try_new( vec![(get_indexed_field_expr, "result".to_string())], diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index aaf3569d1634..f80f44a0e36b 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -45,7 +45,6 @@ use datafusion::physical_plan::{AggregateExpr, PhysicalExpr}; use crate::protobuf; use crate::protobuf::{ physical_aggregate_expr_node, PhysicalSortExprNode, PhysicalSortExprNodeCollection, - ScalarValue, }; use datafusion::logical_expr::BuiltinScalarFunction; use datafusion::physical_expr::expressions::{DateTimeIntervalExpr, GetIndexedFieldExpr}; @@ -394,7 +393,14 @@ impl TryFrom> for protobuf::PhysicalExprNode { protobuf::physical_expr_node::ExprType::GetIndexedFieldExpr( Box::new(protobuf::PhysicalGetIndexedFieldExprNode { arg: Some(Box::new(expr.arg().to_owned().try_into()?)), - key: Some(ScalarValue::try_from(expr.key())?), + key: Some(Box::new( + expr.key() + .list_key() + .clone() + .unwrap() + .to_owned() + .try_into()?, + )), }), ), ), diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index c18587d8340c..084ee38ae367 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -19,7 +19,7 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{ Column, DFField, DFSchema, DataFusionError, Result, ScalarValue, TableReference, }; -use datafusion_expr::{Case, Expr, GetIndexedField}; +use datafusion_expr::{Case, Expr, GetIndexedField, GetIndexedFieldKey}; use sqlparser::ast::{Expr as SQLExpr, Ident}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -138,7 +138,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let nested_name = nested_names[0].to_string(); Ok(Expr::GetIndexedField(GetIndexedField::new( Box::new(Expr::Column(field.qualified_column())), - ScalarValue::Utf8(Some(nested_name)), + Box::new(GetIndexedFieldKey::new( + None, + Some(ScalarValue::Utf8(Some(nested_name))), + )), + None, ))) } // found matching field with no spare identifier(s) diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 126ae4d11bb6..83aeb67dd1a9 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -36,9 +36,9 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr::{InList, Placeholder}; use datafusion_expr::{ col, expr, lit, AggregateFunction, Between, BinaryExpr, BuiltinScalarFunction, Cast, - Expr, ExprSchemable, GetIndexedField, Like, Operator, TryCast, + Expr, ExprSchemable, GetIndexedField, GetIndexedFieldKey, Like, Operator, TryCast, }; -use sqlparser::ast::{ArrayAgg, Expr as SQLExpr, TrimWhereField, Value}; +use sqlparser::ast::{ArrayAgg, Expr as SQLExpr, JsonOperator, TrimWhereField, Value}; use sqlparser::parser::ParserError::ParserError; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -183,7 +183,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::MapAccess { column, keys } => { if let SQLExpr::Identifier(id) = *column { - plan_indexed(col(self.normalizer.normalize(id)), keys) + self.plan_indexed(col(self.normalizer.normalize(id)), keys, schema, planner_context) } else { Err(DataFusionError::NotImplemented(format!( "map access requires an identifier, found column {column} instead" @@ -193,7 +193,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::ArrayIndex { obj, indexes } => { let expr = self.sql_expr_to_logical_expr(*obj, schema, planner_context)?; - plan_indexed(expr, indexes) + self.plan_indexed(expr, indexes, schema, planner_context) } SQLExpr::CompoundIdentifier(ids) => self.sql_compound_identifier_to_expr(ids, schema, planner_context), @@ -505,6 +505,71 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ), } } + + fn plan_indices( + &self, + expr: SQLExpr, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result<(Box, Option>)> { + let (key, extra_key) = match expr.clone() { + SQLExpr::JsonAccess { + left, + operator: JsonOperator::Colon, + right, + } => { + let left = + self.sql_expr_to_logical_expr(*left, schema, planner_context)?; + let right = + self.sql_expr_to_logical_expr(*right, schema, planner_context)?; + + ( + GetIndexedFieldKey::new(Some(left), None), + Some(Box::new(right)), + ) + } + SQLExpr::Value( + Value::SingleQuotedString(s) | Value::DoubleQuotedString(s), + ) => ( + GetIndexedFieldKey::new(None, Some(ScalarValue::Utf8(Some(s)))), + None, + ), + _ => ( + GetIndexedFieldKey::new( + Some(self.sql_expr_to_logical_expr(expr, schema, planner_context)?), + None, + ), + None, + ), + }; + + Ok((Box::new(key), extra_key)) + } + + fn plan_indexed( + &self, + expr: Expr, + mut keys: Vec, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let indices = keys.pop().ok_or_else(|| { + ParserError("Internal error: Missing index key expression".to_string()) + })?; + + let expr = if !keys.is_empty() { + self.plan_indexed(expr, keys, schema, planner_context)? + } else { + expr + }; + + let (key, extra_key) = self.plan_indices(indices, schema, planner_context)?; + Ok(Expr::GetIndexedField(GetIndexedField::new( + Box::new(expr), + key, + extra_key, + ))) + } } // modifies expr if it is a placeholder with datatype of right @@ -540,42 +605,6 @@ fn infer_placeholder_types(expr: Expr, schema: &DFSchema) -> Result { }) } -fn plan_key(key: SQLExpr) -> Result { - let scalar = match key { - SQLExpr::Value(Value::Number(s, _)) => ScalarValue::Int64(Some( - s.parse() - .map_err(|_| ParserError(format!("Cannot parse {s} as i64.")))?, - )), - SQLExpr::Value(Value::SingleQuotedString(s) | Value::DoubleQuotedString(s)) => { - ScalarValue::Utf8(Some(s)) - } - _ => { - return Err(DataFusionError::SQL(ParserError(format!( - "Unsupported index key expression: {key:?}" - )))); - } - }; - - Ok(scalar) -} - -fn plan_indexed(expr: Expr, mut keys: Vec) -> Result { - let key = keys.pop().ok_or_else(|| { - ParserError("Internal error: Missing index key expression".to_string()) - })?; - - let expr = if !keys.is_empty() { - plan_indexed(expr, keys)? - } else { - expr - }; - - Ok(Expr::GetIndexedField(GetIndexedField::new( - Box::new(expr), - plan_key(key)?, - ))) -} - #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 57815989ab5b..36503d70dfc9 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -377,12 +377,15 @@ where ))), Expr::Wildcard => Ok(Expr::Wildcard), Expr::QualifiedWildcard { .. } => Ok(expr.clone()), - Expr::GetIndexedField(GetIndexedField { key, expr }) => { - Ok(Expr::GetIndexedField(GetIndexedField::new( - Box::new(clone_with_replacement(expr.as_ref(), replacement_fn)?), - key.clone(), - ))) - } + Expr::GetIndexedField(GetIndexedField { + key, + extra_key, + expr, + }) => Ok(Expr::GetIndexedField(GetIndexedField::new( + Box::new(clone_with_replacement(expr.as_ref(), replacement_fn)?), + key.clone(), + extra_key.clone(), + ))), Expr::GroupingSet(set) => match set { GroupingSet::Rollup(exprs) => Ok(Expr::GroupingSet(GroupingSet::Rollup( exprs diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 139e968eccfb..0278f779388d 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -187,6 +187,7 @@ Unlike to some databases the math functions in Datafusion works the same way as | array_has_all(array, sub-array) | Returns true if all elements of sub-array exist in array `array_has_all([1,2,3], [1,3]) -> true` | | array_has_any(array, sub-array) | Returns true if any elements exist in both arrays `array_has_any([1,2,3], [1,4]) -> true` | | array_dims(array) | Returns an array of the array's dimensions. `array_dims([[1, 2, 3], [4, 5, 6]]) -> [2, 3]` | +| array_element(array, index) | Extracts the element with the index n from the array `array_element([1, 2, 3, 4], 3) -> 3` | | array_fill(element, array) | Returns an array filled with copies of the given value. | | array_length(array, dimension) | Returns the length of the array dimension. `array_length([1, 2, 3, 4, 5]) -> 5` | | array_ndims(array) | Returns the number of dimensions of the array. `array_ndims([[1, 2, 3], [4, 5, 6]]) -> 2` | @@ -199,10 +200,11 @@ Unlike to some databases the math functions in Datafusion works the same way as | array_replace(array, from, to) | Replaces the first occurrence of the specified element with another specified element. `array_replace([1, 2, 2, 3, 2, 1, 4], 2, 5) -> [1, 5, 2, 3, 2, 1, 4]` | | array_replace_n(array, from, to, max) | Replaces the first `max` occurrences of the specified element with another specified element. `array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2) -> [1, 5, 5, 3, 2, 1, 4]` | | array_replace_all(array, from, to) | Replaces all occurrences of the specified element with another specified element. `array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5) -> [1, 5, 5, 3, 5, 1, 4]` | +| array_slice(array, index) | Returns a slice of the array. `array_slice([1, 2, 3, 4, 5, 6, 7, 8], 3, 6) -> [3, 4, 5, 6]` | | array_to_string(array, delimeter) | Converts each element to its text representation. `array_to_string([1, 2, 3, 4], ',') -> 1,2,3,4` | | cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` | | make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` | -| trim_array(array, n) | Removes the last n elements from the array. | +| trim_array(array, n) | Deprecated | ## Regular Expressions diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index d8337832d464..fa1bac5d9426 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -413,7 +413,6 @@ radians(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. - ======= ### `random` @@ -1447,6 +1446,8 @@ from_unixtime(expression) - [array_concat](#array_concat) - [array_contains](#array_contains) - [array_dims](#array_dims) +- [array_element](#array_element) +- [array_extract](#array_extract) - [array_fill](#array_fill) - [array_indexof](#array_indexof) - [array_join](#array_join) @@ -1463,12 +1464,15 @@ from_unixtime(expression) - [array_replace](#array_replace) - [array_replace_n](#array_replace_n) - [array_replace_all](#array_replace_all) +- [array_slice](#array_slice) - [array_to_string](#array_to_string) - [cardinality](#cardinality) - [list_append](#list_append) - [list_cat](#list_cat) - [list_concat](#list_concat) - [list_dims](#list_dims) +- [list_element](#list_element) +- [list_extract](#list_extract) - [list_indexof](#list_indexof) - [list_join](#list_join) - [list_length](#list_length) @@ -1484,6 +1488,7 @@ from_unixtime(expression) - [list_replace](#list_replace) - [list_replace_n](#list_replace_n) - [list_replace_all](#list_replace_all) +- [list_slice](#list_slice) - [list_to_string](#list_to_string) - [make_array](#make_array) - [make_list](#make_list) @@ -1628,6 +1633,41 @@ array_dims(array) - list_dims +### `array_element` + +Extracts the element with the index n from the array. + +``` +array_element(array, index) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **index**: Index to extract the element from the array. + +#### Example + +``` +❯ select array_element([1, 2, 3, 4], 3); ++-----------------------------------------+ +| array_element(List([1,2,3,4]),Int64(3)) | ++-----------------------------------------+ +| 3 | ++-----------------------------------------+ +``` + +#### Aliases + +- array_extract +- list_element +- list_extract + +### `array_extract` + +_Alias of [array_element](#array_element)._ + ### `array_fill` Returns an array filled with copies of the given value. @@ -1833,6 +1873,10 @@ array_remove(array, element) +----------------------------------------------+ ``` +#### Aliases + +- list_remove + ### `array_remove_n` Removes the first `max` elements from the array equal to the given value. @@ -1859,6 +1903,10 @@ array_remove_n(array, element, max) +---------------------------------------------------------+ ``` +#### Aliases + +- list_remove_n + ### `array_remove_all` Removes all elements from the array equal to the given value. @@ -1884,6 +1932,10 @@ array_remove_all(array, element) +--------------------------------------------------+ ``` +#### Aliases + +- list_remove_all + ### `array_replace` Replaces the first occurrence of the specified element with another specified element. @@ -1910,6 +1962,10 @@ array_replace(array, from, to) +--------------------------------------------------------+ ``` +#### Aliases + +- list_replace + ### `array_replace_n` Replaces the first `max` occurrences of the specified element with another specified element. @@ -1937,6 +1993,10 @@ array_replace_n(array, from, to, max) +-------------------------------------------------------------------+ ``` +#### Aliases + +- list_replace_n + ### `array_replace_all` Replaces all occurrences of the specified element with another specified element. @@ -1963,6 +2023,33 @@ array_replace_all(array, from, to) +------------------------------------------------------------+ ``` +#### Aliases + +- list_replace_all + +### `array_slice` + +Returns a slice of the array. + +``` +array_slice(array, begin, end) +``` + +#### Example + +``` +❯ select array_slice([1, 2, 3, 4, 5, 6, 7, 8], 3, 6); ++--------------------------------------------------------+ +| array_slice(List([1,2,3,4,5,6,7,8]),Int64(3),Int64(6)) | ++--------------------------------------------------------+ +| [3, 4, 5, 6] | ++--------------------------------------------------------+ +``` + +#### Aliases + +- list_slice + ### `array_to_string` Converts each element to its text representation. @@ -2034,6 +2121,14 @@ _Alias of [array_concat](#array_concat)._ _Alias of [array_dims](#array_dims)._ +### `list_element` + +_Alias of [array_element](#array_element)._ + +### `list_extract` + +_Alias of [array_element](#array_element)._ + ### `list_indexof` _Alias of [array_position](#array_position)._ @@ -2094,6 +2189,10 @@ _Alias of [array_replace_n](#array_replace_n)._ _Alias of [array_replace_all](#array_replace_all)._ +### `list_slice` + +_Alias of [array_slice](#array_slice)._ + ### `list_to_string` _Alias of [list_to_string](#list_to_string)._ @@ -2135,6 +2234,8 @@ _Alias of [make_array](#make_array)._ Removes the last n elements from the array. +DEPRECATED: use `array_slice` instead! + ``` trim_array(array, n) ```