Skip to content

Commit 376a0b8

Browse files
authored
Fix substr() (#12383)
1 parent dbc7890 commit 376a0b8

File tree

2 files changed

+178
-65
lines changed

2 files changed

+178
-65
lines changed

datafusion/functions/src/unicode/substr.rs

+63-56
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,37 @@ pub fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
118118
}
119119
}
120120

121-
// Return the exact byte index for [start, end), set count to -1 to ignore count
122-
fn get_true_start_end(input: &str, start: usize, count: i64) -> (usize, usize) {
121+
// Convert the given `start` and `count` to valid byte indices within `input` string
122+
// Input `start` and `count` are equivalent to PostgreSQL's `substr(s, start, count)`
123+
// `start` is 1-based, if `count` is not provided count to the end of the string
124+
// Input indices are character-based, and return values are byte indices
125+
// The input bounds can be outside string bounds, this function will return
126+
// the intersection between input bounds and valid string bounds
127+
//
128+
// * Example
129+
// 'Hi🌏' in-mem (`[]` for one char, `x` for one byte): [x][x][xxxx]
130+
// `get_true_start_end('Hi🌏', 1, None) -> (0, 6)`
131+
// `get_true_start_end('Hi🌏', 1, 1) -> (0, 1)`
132+
// `get_true_start_end('Hi🌏', -10, 2) -> (0, 0)`
133+
fn get_true_start_end(input: &str, start: i64, count: Option<u64>) -> (usize, usize) {
134+
let start = start - 1;
135+
let end = match count {
136+
Some(count) => start + count as i64,
137+
None => input.len() as i64,
138+
};
139+
let count_to_end = count.is_some();
140+
141+
let start = start.clamp(0, input.len() as i64) as usize;
142+
let end = end.clamp(0, input.len() as i64) as usize;
143+
let count = end - start;
144+
123145
let (mut st, mut ed) = (input.len(), input.len());
124146
let mut start_counting = false;
125147
let mut cnt = 0;
126148
for (char_cnt, (byte_cnt, _)) in input.char_indices().enumerate() {
127149
if char_cnt == start {
128150
st = byte_cnt;
129-
if count != -1 {
151+
if count_to_end {
130152
start_counting = true;
131153
} else {
132154
break;
@@ -153,20 +175,15 @@ fn make_and_append_view(
153175
start: u32,
154176
) {
155177
let substr_len = substr.len();
156-
if substr_len == 0 {
157-
null_builder.append_null();
158-
views_buffer.push(0);
178+
let sub_view = if substr_len > 12 {
179+
let view = ByteView::from(*raw);
180+
make_view(substr.as_bytes(), view.buffer_index, view.offset + start)
159181
} else {
160-
let sub_view = if substr_len > 12 {
161-
let view = ByteView::from(*raw);
162-
make_view(substr.as_bytes(), view.buffer_index, view.offset + start)
163-
} else {
164-
// inline value does not need block id or offset
165-
make_view(substr.as_bytes(), 0, 0)
166-
};
167-
views_buffer.push(sub_view);
168-
null_builder.append_non_null();
169-
}
182+
// inline value does not need block id or offset
183+
make_view(substr.as_bytes(), 0, 0)
184+
};
185+
views_buffer.push(sub_view);
186+
null_builder.append_non_null();
170187
}
171188

172189
// The decoding process refs the trait at: arrow/arrow-data/src/byte_view.rs:44
@@ -180,32 +197,26 @@ fn string_view_substr(
180197

181198
let start_array = as_int64_array(&args[0])?;
182199

200+
// In either case of `substr(s, i)` or `substr(s, i, cnt)`
201+
// If any of input argument is `NULL`, the result is `NULL`
183202
match args.len() {
184203
1 => {
185-
for (idx, (raw, start)) in string_view_array
186-
.views()
204+
for ((str_opt, raw_view), start_opt) in string_view_array
187205
.iter()
206+
.zip(string_view_array.views().iter())
188207
.zip(start_array.iter())
189-
.enumerate()
190208
{
191-
if let Some(start) = start {
192-
let start = (start - 1).max(0) as usize;
193-
194-
// Safety:
195-
// idx is always smaller or equal to string_view_array.views.len()
196-
unsafe {
197-
let str = string_view_array.value_unchecked(idx);
198-
let (start, end) = get_true_start_end(str, start, -1);
199-
let substr = &str[start..end];
209+
if let (Some(str), Some(start)) = (str_opt, start_opt) {
210+
let (start, end) = get_true_start_end(str, start, None);
211+
let substr = &str[start..end];
200212

201-
make_and_append_view(
202-
&mut views_buf,
203-
&mut null_builder,
204-
raw,
205-
substr,
206-
start as u32,
207-
);
208-
}
213+
make_and_append_view(
214+
&mut views_buf,
215+
&mut null_builder,
216+
raw_view,
217+
substr,
218+
start as u32,
219+
);
209220
} else {
210221
null_builder.append_null();
211222
views_buf.push(0);
@@ -214,35 +225,31 @@ fn string_view_substr(
214225
}
215226
2 => {
216227
let count_array = as_int64_array(&args[1])?;
217-
for (idx, ((raw, start), count)) in string_view_array
218-
.views()
228+
for (((str_opt, raw_view), start_opt), count_opt) in string_view_array
219229
.iter()
230+
.zip(string_view_array.views().iter())
220231
.zip(start_array.iter())
221232
.zip(count_array.iter())
222-
.enumerate()
223233
{
224-
if let (Some(start), Some(count)) = (start, count) {
225-
let start = (start - 1).max(0) as usize;
234+
if let (Some(str), Some(start), Some(count)) =
235+
(str_opt, start_opt, count_opt)
236+
{
226237
if count < 0 {
227238
return exec_err!(
228239
"negative substring length not allowed: substr(<str>, {start}, {count})"
229240
);
230241
} else {
231-
// Safety:
232-
// idx is always smaller or equal to string_view_array.views.len()
233-
unsafe {
234-
let str = string_view_array.value_unchecked(idx);
235-
let (start, end) = get_true_start_end(str, start, count);
236-
let substr = &str[start..end];
237-
238-
make_and_append_view(
239-
&mut views_buf,
240-
&mut null_builder,
241-
raw,
242-
substr,
243-
start as u32,
244-
);
245-
}
242+
let (start, end) =
243+
get_true_start_end(str, start, Some(count as u64));
244+
let substr = &str[start..end];
245+
246+
make_and_append_view(
247+
&mut views_buf,
248+
&mut null_builder,
249+
raw_view,
250+
substr,
251+
start as u32,
252+
);
246253
}
247254
} else {
248255
null_builder.append_null();

datafusion/sqllogictest/test_files/functions.slt

+115-9
Original file line numberDiff line numberDiff line change
@@ -472,11 +472,6 @@ SELECT substr('alphabet', 30)
472472
----
473473
(empty)
474474

475-
query T
476-
SELECT substr('alphabet', CAST(NULL AS int))
477-
----
478-
NULL
479-
480475
query T
481476
SELECT substr('alphabet', 3, 2)
482477
----
@@ -487,22 +482,133 @@ SELECT substr('alphabet', 3, 20)
487482
----
488483
phabet
489484

485+
# test range ouside of string length
486+
query TTTTTTTTTTTT
487+
SELECT
488+
substr('hi🌏', 1, 3),
489+
substr('hi🌏', 1, 4),
490+
substr('hi🌏', 1, 100),
491+
substr('hi🌏', 0, 1),
492+
substr('hi🌏', 0, 2),
493+
substr('hi🌏', 0, 4),
494+
substr('hi🌏', 0, 5),
495+
substr('hi🌏', -10, 100),
496+
substr('hi🌏', -10, 12),
497+
substr('hi🌏', -10, 5),
498+
substr('hi🌏', 10, 0),
499+
substr('hi🌏', 10, 10);
500+
----
501+
hi🌏 hi🌏 hi🌏 (empty) h hi🌏 hi🌏 hi🌏 h (empty) (empty) (empty)
502+
503+
query TTTTTTTTTTTT
504+
SELECT
505+
substr('', 1, 3),
506+
substr('', 1, 4),
507+
substr('', 1, 100),
508+
substr('', 0, 1),
509+
substr('', 0, 2),
510+
substr('', 0, 4),
511+
substr('', 0, 5),
512+
substr('', -10, 100),
513+
substr('', -10, 12),
514+
substr('', -10, 5),
515+
substr('', 10, 0),
516+
substr('', 10, 10);
517+
----
518+
(empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty)
519+
520+
# Nulls
521+
query TTTTTTTTTT
522+
SELECT
523+
substr('alphabet', NULL),
524+
substr(NULL, 1),
525+
substr(NULL, NULL),
526+
substr('alphabet', CAST(NULL AS int), -20),
527+
substr('alphabet', 3, CAST(NULL AS int)),
528+
substr(NULL, 3, -4),
529+
substr(NULL, NULL, 4),
530+
substr(NULL, 1, NULL),
531+
substr('', NULL, NULL),
532+
substr(NULL, NULL, NULL);
533+
----
534+
NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL
535+
490536
query T
491-
SELECT substr('alphabet', CAST(NULL AS int), 20)
537+
SELECT substr('Hello🌏世界', 5)
492538
----
493-
NULL
539+
o🌏世界
494540

495541
query T
496-
SELECT substr('alphabet', 3, CAST(NULL AS int))
542+
SELECT substr('Hello🌏世界', 5, 3)
497543
----
498-
NULL
544+
o🌏世
545+
546+
statement ok
547+
create table test_substr (
548+
c1 VARCHAR
549+
) as values ('foo'), ('hello🌏世界'), ('💩'), ('ThisIsAVeryLongASCIIString'), (''), (NULL);
550+
551+
statement ok
552+
create table test_substr_stringview as
553+
select c1 as c1, arrow_cast(c1, 'Utf8View') as c1_view from test_substr;
554+
555+
# `substr()` on `StringViewArray`'s implementation operates directly on view's
556+
# logical pointers, so check it's consistent with `StringArray`
557+
query BBBBBBBBBBBBBB
558+
select
559+
substr(c1, 1) = substr(c1_view, 1),
560+
substr(c1, 3) = substr(c1_view, 3),
561+
substr(c1, 100) = substr(c1_view, 100),
562+
substr(c1, -1) = substr(c1_view, -1),
563+
substr(c1, 0, 0) = substr(c1_view, 0, 0),
564+
substr(c1, -1, 2) = substr(c1_view, -1, 2),
565+
substr(c1, -2, 10) = substr(c1_view, -2, 10),
566+
substr(c1, -100, 200) = substr(c1_view, -100, 200),
567+
substr(c1, -10, 10) = substr(c1_view, -10, 10),
568+
substr(c1, -100, 10) = substr(c1_view, -100, 10),
569+
substr(c1, 1, 100) = substr(c1_view, 1, 100),
570+
substr(c1, 5, 3) = substr(c1_view, 5, 3),
571+
substr(c1, 100, 200) = substr(c1_view, 100, 200),
572+
substr(c1, 8, 0) = substr(c1_view, 8, 0)
573+
from test_substr_stringview;
574+
----
575+
true true true true true true true true true true true true true true
576+
true true true true true true true true true true true true true true
577+
true true true true true true true true true true true true true true
578+
true true true true true true true true true true true true true true
579+
true true true true true true true true true true true true true true
580+
NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL
581+
582+
# Check for non-ASCII strings
583+
query TT
584+
select substr(c1_view, 1), substr(c1_view, 5,3) from test_substr_stringview;
585+
----
586+
foo (empty)
587+
hello🌏世界 o🌏世
588+
💩 (empty)
589+
ThisIsAVeryLongASCIIString IsA
590+
(empty) (empty)
591+
NULL NULL
592+
593+
statement ok
594+
drop table test_substr;
595+
596+
statement ok
597+
drop table test_substr_stringview;
598+
499599

500600
statement error The SUBSTR function can only accept strings, but got Int64.
501601
SELECT substr(1, 3)
502602

503603
statement error The SUBSTR function can only accept strings, but got Int64.
504604
SELECT substr(1, 3, 4)
505605

606+
statement error Execution error: negative substring length not allowed
607+
select substr(arrow_cast('foo', 'Utf8View'), 1, -1);
608+
609+
statement error Execution error: negative substring length not allowed
610+
select substr('', 1, -1);
611+
506612
query T
507613
SELECT translate('12345', '143', 'ax')
508614
----

0 commit comments

Comments
 (0)