diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 8a70b380669c..40d3a4d13e97 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -118,15 +118,37 @@ pub fn substr(args: &[ArrayRef]) -> Result { } } -// Return the exact byte index for [start, end), set count to -1 to ignore count -fn get_true_start_end(input: &str, start: usize, count: i64) -> (usize, usize) { +// Convert the given `start` and `count` to valid byte indices within `input` string +// Input `start` and `count` are equivalent to PostgreSQL's `substr(s, start, count)` +// `start` is 1-based, if `count` is not provided count to the end of the string +// Input indices are character-based, and return values are byte indices +// The input bounds can be outside string bounds, this function will return +// the intersection between input bounds and valid string bounds +// +// * Example +// 'Hi๐ŸŒ' in-mem (`[]` for one char, `x` for one byte): [x][x][xxxx] +// `get_true_start_end('Hi๐ŸŒ', 1, None) -> (0, 6)` +// `get_true_start_end('Hi๐ŸŒ', 1, 1) -> (0, 1)` +// `get_true_start_end('Hi๐ŸŒ', -10, 2) -> (0, 0)` +fn get_true_start_end(input: &str, start: i64, count: Option) -> (usize, usize) { + let start = start - 1; + let end = match count { + Some(count) => start + count as i64, + None => input.len() as i64, + }; + let count_to_end = count.is_some(); + + let start = start.clamp(0, input.len() as i64) as usize; + let end = end.clamp(0, input.len() as i64) as usize; + let count = end - start; + let (mut st, mut ed) = (input.len(), input.len()); let mut start_counting = false; let mut cnt = 0; for (char_cnt, (byte_cnt, _)) in input.char_indices().enumerate() { if char_cnt == start { st = byte_cnt; - if count != -1 { + if count_to_end { start_counting = true; } else { break; @@ -153,20 +175,15 @@ fn make_and_append_view( start: u32, ) { let substr_len = substr.len(); - if substr_len == 0 { - null_builder.append_null(); - views_buffer.push(0); + let sub_view = if substr_len > 12 { + let view = ByteView::from(*raw); + make_view(substr.as_bytes(), view.buffer_index, view.offset + start) } else { - let sub_view = if substr_len > 12 { - let view = ByteView::from(*raw); - make_view(substr.as_bytes(), view.buffer_index, view.offset + start) - } else { - // inline value does not need block id or offset - make_view(substr.as_bytes(), 0, 0) - }; - views_buffer.push(sub_view); - null_builder.append_non_null(); - } + // inline value does not need block id or offset + make_view(substr.as_bytes(), 0, 0) + }; + views_buffer.push(sub_view); + null_builder.append_non_null(); } // The decoding process refs the trait at: arrow/arrow-data/src/byte_view.rs:44 @@ -180,32 +197,26 @@ fn string_view_substr( let start_array = as_int64_array(&args[0])?; + // In either case of `substr(s, i)` or `substr(s, i, cnt)` + // If any of input argument is `NULL`, the result is `NULL` match args.len() { 1 => { - for (idx, (raw, start)) in string_view_array - .views() + for ((str_opt, raw_view), start_opt) in string_view_array .iter() + .zip(string_view_array.views().iter()) .zip(start_array.iter()) - .enumerate() { - if let Some(start) = start { - let start = (start - 1).max(0) as usize; - - // Safety: - // idx is always smaller or equal to string_view_array.views.len() - unsafe { - let str = string_view_array.value_unchecked(idx); - let (start, end) = get_true_start_end(str, start, -1); - let substr = &str[start..end]; + if let (Some(str), Some(start)) = (str_opt, start_opt) { + let (start, end) = get_true_start_end(str, start, None); + let substr = &str[start..end]; - make_and_append_view( - &mut views_buf, - &mut null_builder, - raw, - substr, - start as u32, - ); - } + make_and_append_view( + &mut views_buf, + &mut null_builder, + raw_view, + substr, + start as u32, + ); } else { null_builder.append_null(); views_buf.push(0); @@ -214,35 +225,31 @@ fn string_view_substr( } 2 => { let count_array = as_int64_array(&args[1])?; - for (idx, ((raw, start), count)) in string_view_array - .views() + for (((str_opt, raw_view), start_opt), count_opt) in string_view_array .iter() + .zip(string_view_array.views().iter()) .zip(start_array.iter()) .zip(count_array.iter()) - .enumerate() { - if let (Some(start), Some(count)) = (start, count) { - let start = (start - 1).max(0) as usize; + if let (Some(str), Some(start), Some(count)) = + (str_opt, start_opt, count_opt) + { if count < 0 { return exec_err!( "negative substring length not allowed: substr(, {start}, {count})" ); } else { - // Safety: - // idx is always smaller or equal to string_view_array.views.len() - unsafe { - let str = string_view_array.value_unchecked(idx); - let (start, end) = get_true_start_end(str, start, count); - let substr = &str[start..end]; - - make_and_append_view( - &mut views_buf, - &mut null_builder, - raw, - substr, - start as u32, - ); - } + let (start, end) = + get_true_start_end(str, start, Some(count as u64)); + let substr = &str[start..end]; + + make_and_append_view( + &mut views_buf, + &mut null_builder, + raw_view, + substr, + start as u32, + ); } } else { null_builder.append_null(); diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index b8519a463637..9bee9b8184ea 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -472,11 +472,6 @@ SELECT substr('alphabet', 30) ---- (empty) -query T -SELECT substr('alphabet', CAST(NULL AS int)) ----- -NULL - query T SELECT substr('alphabet', 3, 2) ---- @@ -487,15 +482,120 @@ SELECT substr('alphabet', 3, 20) ---- phabet +# test range ouside of string length +query TTTTTTTTTTTT +SELECT + substr('hi๐ŸŒ', 1, 3), + substr('hi๐ŸŒ', 1, 4), + substr('hi๐ŸŒ', 1, 100), + substr('hi๐ŸŒ', 0, 1), + substr('hi๐ŸŒ', 0, 2), + substr('hi๐ŸŒ', 0, 4), + substr('hi๐ŸŒ', 0, 5), + substr('hi๐ŸŒ', -10, 100), + substr('hi๐ŸŒ', -10, 12), + substr('hi๐ŸŒ', -10, 5), + substr('hi๐ŸŒ', 10, 0), + substr('hi๐ŸŒ', 10, 10); +---- +hi๐ŸŒ hi๐ŸŒ hi๐ŸŒ (empty) h hi๐ŸŒ hi๐ŸŒ hi๐ŸŒ h (empty) (empty) (empty) + +query TTTTTTTTTTTT +SELECT + substr('', 1, 3), + substr('', 1, 4), + substr('', 1, 100), + substr('', 0, 1), + substr('', 0, 2), + substr('', 0, 4), + substr('', 0, 5), + substr('', -10, 100), + substr('', -10, 12), + substr('', -10, 5), + substr('', 10, 0), + substr('', 10, 10); +---- +(empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) + +# Nulls +query TTTTTTTTTT +SELECT + substr('alphabet', NULL), + substr(NULL, 1), + substr(NULL, NULL), + substr('alphabet', CAST(NULL AS int), -20), + substr('alphabet', 3, CAST(NULL AS int)), + substr(NULL, 3, -4), + substr(NULL, NULL, 4), + substr(NULL, 1, NULL), + substr('', NULL, NULL), + substr(NULL, NULL, NULL); +---- +NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL + query T -SELECT substr('alphabet', CAST(NULL AS int), 20) +SELECT substr('Hello๐ŸŒไธ–็•Œ', 5) ---- -NULL +o๐ŸŒไธ–็•Œ query T -SELECT substr('alphabet', 3, CAST(NULL AS int)) +SELECT substr('Hello๐ŸŒไธ–็•Œ', 5, 3) ---- -NULL +o๐ŸŒไธ– + +statement ok +create table test_substr ( + c1 VARCHAR +) as values ('foo'), ('hello๐ŸŒไธ–็•Œ'), ('๐Ÿ’ฉ'), ('ThisIsAVeryLongASCIIString'), (''), (NULL); + +statement ok +create table test_substr_stringview as +select c1 as c1, arrow_cast(c1, 'Utf8View') as c1_view from test_substr; + +# `substr()` on `StringViewArray`'s implementation operates directly on view's +# logical pointers, so check it's consistent with `StringArray` +query BBBBBBBBBBBBBB +select + substr(c1, 1) = substr(c1_view, 1), + substr(c1, 3) = substr(c1_view, 3), + substr(c1, 100) = substr(c1_view, 100), + substr(c1, -1) = substr(c1_view, -1), + substr(c1, 0, 0) = substr(c1_view, 0, 0), + substr(c1, -1, 2) = substr(c1_view, -1, 2), + substr(c1, -2, 10) = substr(c1_view, -2, 10), + substr(c1, -100, 200) = substr(c1_view, -100, 200), + substr(c1, -10, 10) = substr(c1_view, -10, 10), + substr(c1, -100, 10) = substr(c1_view, -100, 10), + substr(c1, 1, 100) = substr(c1_view, 1, 100), + substr(c1, 5, 3) = substr(c1_view, 5, 3), + substr(c1, 100, 200) = substr(c1_view, 100, 200), + substr(c1, 8, 0) = substr(c1_view, 8, 0) +from test_substr_stringview; +---- +true true true true true true true true true true true true true true +true true true true true true true true true true true true true true +true true true true true true true true true true true true true true +true true true true true true true true true true true true true true +true true true true true true true true true true true true true true +NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL + +# Check for non-ASCII strings +query TT +select substr(c1_view, 1), substr(c1_view, 5,3) from test_substr_stringview; +---- +foo (empty) +hello๐ŸŒไธ–็•Œ o๐ŸŒไธ– +๐Ÿ’ฉ (empty) +ThisIsAVeryLongASCIIString IsA +(empty) (empty) +NULL NULL + +statement ok +drop table test_substr; + +statement ok +drop table test_substr_stringview; + statement error The SUBSTR function can only accept strings, but got Int64. SELECT substr(1, 3) @@ -503,6 +603,12 @@ SELECT substr(1, 3) statement error The SUBSTR function can only accept strings, but got Int64. SELECT substr(1, 3, 4) +statement error Execution error: negative substring length not allowed +select substr(arrow_cast('foo', 'Utf8View'), 1, -1); + +statement error Execution error: negative substring length not allowed +select substr('', 1, -1); + query T SELECT translate('12345', '143', 'ax') ----