Skip to content

Commit

Permalink
Fix substr() (#12383)
Browse files Browse the repository at this point in the history
  • Loading branch information
2010YOUY01 committed Sep 10, 2024
1 parent dbc7890 commit 376a0b8
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 65 deletions.
119 changes: 63 additions & 56 deletions datafusion/functions/src/unicode/substr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,37 @@ pub fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}

// Return the exact byte index for [start, end), set count to -1 to ignore count
fn get_true_start_end(input: &str, start: usize, count: i64) -> (usize, usize) {
// Convert the given `start` and `count` to valid byte indices within `input` string
// Input `start` and `count` are equivalent to PostgreSQL's `substr(s, start, count)`
// `start` is 1-based, if `count` is not provided count to the end of the string
// Input indices are character-based, and return values are byte indices
// The input bounds can be outside string bounds, this function will return
// the intersection between input bounds and valid string bounds
//
// * Example
// 'Hi🌏' in-mem (`[]` for one char, `x` for one byte): [x][x][xxxx]
// `get_true_start_end('Hi🌏', 1, None) -> (0, 6)`
// `get_true_start_end('Hi🌏', 1, 1) -> (0, 1)`
// `get_true_start_end('Hi🌏', -10, 2) -> (0, 0)`
fn get_true_start_end(input: &str, start: i64, count: Option<u64>) -> (usize, usize) {
let start = start - 1;
let end = match count {
Some(count) => start + count as i64,
None => input.len() as i64,
};
let count_to_end = count.is_some();

let start = start.clamp(0, input.len() as i64) as usize;
let end = end.clamp(0, input.len() as i64) as usize;
let count = end - start;

let (mut st, mut ed) = (input.len(), input.len());
let mut start_counting = false;
let mut cnt = 0;
for (char_cnt, (byte_cnt, _)) in input.char_indices().enumerate() {
if char_cnt == start {
st = byte_cnt;
if count != -1 {
if count_to_end {
start_counting = true;
} else {
break;
Expand All @@ -153,20 +175,15 @@ fn make_and_append_view(
start: u32,
) {
let substr_len = substr.len();
if substr_len == 0 {
null_builder.append_null();
views_buffer.push(0);
let sub_view = if substr_len > 12 {
let view = ByteView::from(*raw);
make_view(substr.as_bytes(), view.buffer_index, view.offset + start)
} else {
let sub_view = if substr_len > 12 {
let view = ByteView::from(*raw);
make_view(substr.as_bytes(), view.buffer_index, view.offset + start)
} else {
// inline value does not need block id or offset
make_view(substr.as_bytes(), 0, 0)
};
views_buffer.push(sub_view);
null_builder.append_non_null();
}
// inline value does not need block id or offset
make_view(substr.as_bytes(), 0, 0)
};
views_buffer.push(sub_view);
null_builder.append_non_null();
}

// The decoding process refs the trait at: arrow/arrow-data/src/byte_view.rs:44
Expand All @@ -180,32 +197,26 @@ fn string_view_substr(

let start_array = as_int64_array(&args[0])?;

// In either case of `substr(s, i)` or `substr(s, i, cnt)`
// If any of input argument is `NULL`, the result is `NULL`
match args.len() {
1 => {
for (idx, (raw, start)) in string_view_array
.views()
for ((str_opt, raw_view), start_opt) in string_view_array
.iter()
.zip(string_view_array.views().iter())
.zip(start_array.iter())
.enumerate()
{
if let Some(start) = start {
let start = (start - 1).max(0) as usize;

// Safety:
// idx is always smaller or equal to string_view_array.views.len()
unsafe {
let str = string_view_array.value_unchecked(idx);
let (start, end) = get_true_start_end(str, start, -1);
let substr = &str[start..end];
if let (Some(str), Some(start)) = (str_opt, start_opt) {
let (start, end) = get_true_start_end(str, start, None);
let substr = &str[start..end];

make_and_append_view(
&mut views_buf,
&mut null_builder,
raw,
substr,
start as u32,
);
}
make_and_append_view(
&mut views_buf,
&mut null_builder,
raw_view,
substr,
start as u32,
);
} else {
null_builder.append_null();
views_buf.push(0);
Expand All @@ -214,35 +225,31 @@ fn string_view_substr(
}
2 => {
let count_array = as_int64_array(&args[1])?;
for (idx, ((raw, start), count)) in string_view_array
.views()
for (((str_opt, raw_view), start_opt), count_opt) in string_view_array
.iter()
.zip(string_view_array.views().iter())
.zip(start_array.iter())
.zip(count_array.iter())
.enumerate()
{
if let (Some(start), Some(count)) = (start, count) {
let start = (start - 1).max(0) as usize;
if let (Some(str), Some(start), Some(count)) =
(str_opt, start_opt, count_opt)
{
if count < 0 {
return exec_err!(
"negative substring length not allowed: substr(<str>, {start}, {count})"
);
} else {
// Safety:
// idx is always smaller or equal to string_view_array.views.len()
unsafe {
let str = string_view_array.value_unchecked(idx);
let (start, end) = get_true_start_end(str, start, count);
let substr = &str[start..end];

make_and_append_view(
&mut views_buf,
&mut null_builder,
raw,
substr,
start as u32,
);
}
let (start, end) =
get_true_start_end(str, start, Some(count as u64));
let substr = &str[start..end];

make_and_append_view(
&mut views_buf,
&mut null_builder,
raw_view,
substr,
start as u32,
);
}
} else {
null_builder.append_null();
Expand Down
124 changes: 115 additions & 9 deletions datafusion/sqllogictest/test_files/functions.slt
Original file line number Diff line number Diff line change
Expand Up @@ -472,11 +472,6 @@ SELECT substr('alphabet', 30)
----
(empty)

query T
SELECT substr('alphabet', CAST(NULL AS int))
----
NULL

query T
SELECT substr('alphabet', 3, 2)
----
Expand All @@ -487,22 +482,133 @@ SELECT substr('alphabet', 3, 20)
----
phabet

# test range ouside of string length
query TTTTTTTTTTTT
SELECT
substr('hi🌏', 1, 3),
substr('hi🌏', 1, 4),
substr('hi🌏', 1, 100),
substr('hi🌏', 0, 1),
substr('hi🌏', 0, 2),
substr('hi🌏', 0, 4),
substr('hi🌏', 0, 5),
substr('hi🌏', -10, 100),
substr('hi🌏', -10, 12),
substr('hi🌏', -10, 5),
substr('hi🌏', 10, 0),
substr('hi🌏', 10, 10);
----
hi🌏 hi🌏 hi🌏 (empty) h hi🌏 hi🌏 hi🌏 h (empty) (empty) (empty)

query TTTTTTTTTTTT
SELECT
substr('', 1, 3),
substr('', 1, 4),
substr('', 1, 100),
substr('', 0, 1),
substr('', 0, 2),
substr('', 0, 4),
substr('', 0, 5),
substr('', -10, 100),
substr('', -10, 12),
substr('', -10, 5),
substr('', 10, 0),
substr('', 10, 10);
----
(empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty)

# Nulls
query TTTTTTTTTT
SELECT
substr('alphabet', NULL),
substr(NULL, 1),
substr(NULL, NULL),
substr('alphabet', CAST(NULL AS int), -20),
substr('alphabet', 3, CAST(NULL AS int)),
substr(NULL, 3, -4),
substr(NULL, NULL, 4),
substr(NULL, 1, NULL),
substr('', NULL, NULL),
substr(NULL, NULL, NULL);
----
NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL

query T
SELECT substr('alphabet', CAST(NULL AS int), 20)
SELECT substr('Hello🌏世界', 5)
----
NULL
o🌏世界

query T
SELECT substr('alphabet', 3, CAST(NULL AS int))
SELECT substr('Hello🌏世界', 5, 3)
----
NULL
o🌏世

statement ok
create table test_substr (
c1 VARCHAR
) as values ('foo'), ('hello🌏世界'), ('💩'), ('ThisIsAVeryLongASCIIString'), (''), (NULL);

statement ok
create table test_substr_stringview as
select c1 as c1, arrow_cast(c1, 'Utf8View') as c1_view from test_substr;

# `substr()` on `StringViewArray`'s implementation operates directly on view's
# logical pointers, so check it's consistent with `StringArray`
query BBBBBBBBBBBBBB
select
substr(c1, 1) = substr(c1_view, 1),
substr(c1, 3) = substr(c1_view, 3),
substr(c1, 100) = substr(c1_view, 100),
substr(c1, -1) = substr(c1_view, -1),
substr(c1, 0, 0) = substr(c1_view, 0, 0),
substr(c1, -1, 2) = substr(c1_view, -1, 2),
substr(c1, -2, 10) = substr(c1_view, -2, 10),
substr(c1, -100, 200) = substr(c1_view, -100, 200),
substr(c1, -10, 10) = substr(c1_view, -10, 10),
substr(c1, -100, 10) = substr(c1_view, -100, 10),
substr(c1, 1, 100) = substr(c1_view, 1, 100),
substr(c1, 5, 3) = substr(c1_view, 5, 3),
substr(c1, 100, 200) = substr(c1_view, 100, 200),
substr(c1, 8, 0) = substr(c1_view, 8, 0)
from test_substr_stringview;
----
true true true true true true true true true true true true true true
true true true true true true true true true true true true true true
true true true true true true true true true true true true true true
true true true true true true true true true true true true true true
true true true true true true true true true true true true true true
NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL

# Check for non-ASCII strings
query TT
select substr(c1_view, 1), substr(c1_view, 5,3) from test_substr_stringview;
----
foo (empty)
hello🌏世界 o🌏世
💩 (empty)
ThisIsAVeryLongASCIIString IsA
(empty) (empty)
NULL NULL

statement ok
drop table test_substr;

statement ok
drop table test_substr_stringview;


statement error The SUBSTR function can only accept strings, but got Int64.
SELECT substr(1, 3)

statement error The SUBSTR function can only accept strings, but got Int64.
SELECT substr(1, 3, 4)

statement error Execution error: negative substring length not allowed
select substr(arrow_cast('foo', 'Utf8View'), 1, -1);

statement error Execution error: negative substring length not allowed
select substr('', 1, -1);

query T
SELECT translate('12345', '143', 'ax')
----
Expand Down

0 comments on commit 376a0b8

Please sign in to comment.