Skip to content

Commit 04e8e53

Browse files
authored
Fix: Internal error in regexp_replace() for some StringView input (#12203)
* Fix: Internal error in regexp_replace() for some StringView input * fix regex bench * fmt * fix bench regx * clippy * fmt * adds tests for flags + includes type signature for utf8view with flag * fix: adding collect for string view type
1 parent 533fbc4 commit 04e8e53

File tree

3 files changed

+268
-94
lines changed

3 files changed

+268
-94
lines changed

datafusion/functions/benches/regx.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
extern crate criterion;
1919

2020
use arrow::array::builder::StringBuilder;
21-
use arrow::array::{ArrayRef, StringArray};
21+
use arrow::array::{ArrayRef, AsArray, StringArray};
2222
use criterion::{black_box, criterion_group, criterion_main, Criterion};
2323
use datafusion_functions::regex::regexplike::regexp_like;
2424
use datafusion_functions::regex::regexpmatch::regexp_match;
@@ -122,12 +122,12 @@ fn criterion_benchmark(c: &mut Criterion) {
122122

123123
b.iter(|| {
124124
black_box(
125-
regexp_replace::<i32>(&[
126-
Arc::clone(&data),
127-
Arc::clone(&regex),
128-
Arc::clone(&replacement),
129-
Arc::clone(&flags),
130-
])
125+
regexp_replace::<i32, _, _>(
126+
data.as_string::<i32>(),
127+
regex.as_string::<i32>(),
128+
replacement.as_string::<i32>(),
129+
Some(&flags),
130+
)
131131
.expect("regexp_replace should work on valid values"),
132132
)
133133
})

datafusion/functions/src/regex/regexpreplace.rs

+171-87
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
// under the License.
1717

1818
//! Regx expressions
19-
use arrow::array::new_null_array;
20-
use arrow::array::ArrayAccessor;
2119
use arrow::array::ArrayDataBuilder;
2220
use arrow::array::BufferBuilder;
2321
use arrow::array::GenericStringArray;
2422
use arrow::array::StringViewBuilder;
23+
use arrow::array::{new_null_array, ArrayIter, AsArray};
2524
use arrow::array::{Array, ArrayRef, OffsetSizeTrait};
25+
use arrow::array::{ArrayAccessor, StringViewArray};
2626
use arrow::datatypes::DataType;
2727
use datafusion_common::cast::as_string_view_array;
2828
use datafusion_common::exec_err;
@@ -59,6 +59,7 @@ impl RegexpReplaceFunc {
5959
Exact(vec![Utf8, Utf8, Utf8]),
6060
Exact(vec![Utf8View, Utf8, Utf8]),
6161
Exact(vec![Utf8, Utf8, Utf8, Utf8]),
62+
Exact(vec![Utf8View, Utf8, Utf8, Utf8]),
6263
],
6364
Volatility::Immutable,
6465
),
@@ -187,104 +188,147 @@ fn regex_replace_posix_groups(replacement: &str) -> String {
187188
/// # Ok(())
188189
/// # }
189190
/// ```
190-
pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
191+
pub fn regexp_replace<'a, T: OffsetSizeTrait, V, B>(
192+
string_array: V,
193+
pattern_array: B,
194+
replacement_array: B,
195+
flags: Option<&ArrayRef>,
196+
) -> Result<ArrayRef>
197+
where
198+
V: ArrayAccessor<Item = &'a str>,
199+
B: ArrayAccessor<Item = &'a str>,
200+
{
191201
// Default implementation for regexp_replace, assumes all args are arrays
192202
// and args is a sequence of 3 or 4 elements.
193203

194204
// creating Regex is expensive so create hashmap for memoization
195205
let mut patterns: HashMap<String, Regex> = HashMap::new();
196206

197-
match args.len() {
198-
3 => {
199-
let string_array = as_generic_string_array::<T>(&args[0])?;
200-
let pattern_array = as_generic_string_array::<T>(&args[1])?;
201-
let replacement_array = as_generic_string_array::<T>(&args[2])?;
202-
203-
let result = string_array
204-
.iter()
205-
.zip(pattern_array.iter())
206-
.zip(replacement_array.iter())
207-
.map(|((string, pattern), replacement)| match (string, pattern, replacement) {
208-
(Some(string), Some(pattern), Some(replacement)) => {
209-
let replacement = regex_replace_posix_groups(replacement);
210-
211-
// if patterns hashmap already has regexp then use else create and return
212-
let re = match patterns.get(pattern) {
213-
Some(re) => Ok(re),
214-
None => {
215-
match Regex::new(pattern) {
216-
Ok(re) => {
217-
patterns.insert(pattern.to_string(), re);
218-
Ok(patterns.get(pattern).unwrap())
207+
let datatype = string_array.data_type().to_owned();
208+
209+
let string_array_iter = ArrayIter::new(string_array);
210+
let pattern_array_iter = ArrayIter::new(pattern_array);
211+
let replacement_array_iter = ArrayIter::new(replacement_array);
212+
213+
match flags {
214+
None => {
215+
let result_iter = string_array_iter
216+
.zip(pattern_array_iter)
217+
.zip(replacement_array_iter)
218+
.map(|((string, pattern), replacement)| {
219+
match (string, pattern, replacement) {
220+
(Some(string), Some(pattern), Some(replacement)) => {
221+
let replacement = regex_replace_posix_groups(replacement);
222+
// if patterns hashmap already has regexp then use else create and return
223+
let re = match patterns.get(pattern) {
224+
Some(re) => Ok(re),
225+
None => match Regex::new(pattern) {
226+
Ok(re) => {
227+
patterns.insert(pattern.to_string(), re);
228+
Ok(patterns.get(pattern).unwrap())
229+
}
230+
Err(err) => {
231+
Err(DataFusionError::External(Box::new(err)))
232+
}
219233
},
220-
Err(err) => Err(DataFusionError::External(Box::new(err))),
221-
}
222-
}
223-
};
234+
};
224235

225-
Some(re.map(|re| re.replace(string, replacement.as_str()))).transpose()
236+
Some(re.map(|re| re.replace(string, replacement.as_str())))
237+
.transpose()
238+
}
239+
_ => Ok(None),
240+
}
241+
});
242+
243+
match datatype {
244+
DataType::Utf8 | DataType::LargeUtf8 => {
245+
let result =
246+
result_iter.collect::<Result<GenericStringArray<T>>>()?;
247+
Ok(Arc::new(result) as ArrayRef)
226248
}
227-
_ => Ok(None)
228-
})
229-
.collect::<Result<GenericStringArray<T>>>()?;
230-
231-
Ok(Arc::new(result) as ArrayRef)
249+
DataType::Utf8View => {
250+
let result = result_iter.collect::<Result<StringViewArray>>()?;
251+
Ok(Arc::new(result) as ArrayRef)
252+
}
253+
other => {
254+
exec_err!(
255+
"Unsupported data type {other:?} for function regex_replace"
256+
)
257+
}
258+
}
232259
}
233-
4 => {
234-
let string_array = as_generic_string_array::<T>(&args[0])?;
235-
let pattern_array = as_generic_string_array::<T>(&args[1])?;
236-
let replacement_array = as_generic_string_array::<T>(&args[2])?;
237-
let flags_array = as_generic_string_array::<T>(&args[3])?;
238-
239-
let result = string_array
240-
.iter()
241-
.zip(pattern_array.iter())
242-
.zip(replacement_array.iter())
243-
.zip(flags_array.iter())
244-
.map(|(((string, pattern), replacement), flags)| match (string, pattern, replacement, flags) {
245-
(Some(string), Some(pattern), Some(replacement), Some(flags)) => {
246-
let replacement = regex_replace_posix_groups(replacement);
247-
248-
// format flags into rust pattern
249-
let (pattern, replace_all) = if flags == "g" {
250-
(pattern.to_string(), true)
251-
} else if flags.contains('g') {
252-
(format!("(?{}){}", flags.to_string().replace('g', ""), pattern), true)
253-
} else {
254-
(format!("(?{flags}){pattern}"), false)
255-
};
256-
257-
// if patterns hashmap already has regexp then use else create and return
258-
let re = match patterns.get(&pattern) {
259-
Some(re) => Ok(re),
260-
None => {
261-
match Regex::new(pattern.as_str()) {
262-
Ok(re) => {
263-
patterns.insert(pattern.clone(), re);
264-
Ok(patterns.get(&pattern).unwrap())
260+
Some(flags) => {
261+
let flags_array = as_generic_string_array::<T>(flags)?;
262+
263+
let result_iter = string_array_iter
264+
.zip(pattern_array_iter)
265+
.zip(replacement_array_iter)
266+
.zip(flags_array.iter())
267+
.map(|(((string, pattern), replacement), flags)| {
268+
match (string, pattern, replacement, flags) {
269+
(Some(string), Some(pattern), Some(replacement), Some(flags)) => {
270+
let replacement = regex_replace_posix_groups(replacement);
271+
272+
// format flags into rust pattern
273+
let (pattern, replace_all) = if flags == "g" {
274+
(pattern.to_string(), true)
275+
} else if flags.contains('g') {
276+
(
277+
format!(
278+
"(?{}){}",
279+
flags.to_string().replace('g', ""),
280+
pattern
281+
),
282+
true,
283+
)
284+
} else {
285+
(format!("(?{flags}){pattern}"), false)
286+
};
287+
288+
// if patterns hashmap already has regexp then use else create and return
289+
let re = match patterns.get(&pattern) {
290+
Some(re) => Ok(re),
291+
None => match Regex::new(pattern.as_str()) {
292+
Ok(re) => {
293+
patterns.insert(pattern.clone(), re);
294+
Ok(patterns.get(&pattern).unwrap())
295+
}
296+
Err(err) => {
297+
Err(DataFusionError::External(Box::new(err)))
298+
}
265299
},
266-
Err(err) => Err(DataFusionError::External(Box::new(err))),
267-
}
300+
};
301+
302+
Some(re.map(|re| {
303+
if replace_all {
304+
re.replace_all(string, replacement.as_str())
305+
} else {
306+
re.replace(string, replacement.as_str())
307+
}
308+
}))
309+
.transpose()
268310
}
269-
};
270-
271-
Some(re.map(|re| {
272-
if replace_all {
273-
re.replace_all(string, replacement.as_str())
274-
} else {
275-
re.replace(string, replacement.as_str())
276-
}
277-
})).transpose()
311+
_ => Ok(None),
312+
}
313+
});
314+
315+
match datatype {
316+
DataType::Utf8 | DataType::LargeUtf8 => {
317+
let result =
318+
result_iter.collect::<Result<GenericStringArray<T>>>()?;
319+
Ok(Arc::new(result) as ArrayRef)
278320
}
279-
_ => Ok(None)
280-
})
281-
.collect::<Result<GenericStringArray<T>>>()?;
282-
283-
Ok(Arc::new(result) as ArrayRef)
321+
DataType::Utf8View => {
322+
let result = result_iter.collect::<Result<StringViewArray>>()?;
323+
Ok(Arc::new(result) as ArrayRef)
324+
}
325+
other => {
326+
exec_err!(
327+
"Unsupported data type {other:?} for function regex_replace"
328+
)
329+
}
330+
}
284331
}
285-
other => exec_err!(
286-
"regexp_replace was called with {other} arguments. It requires at least 3 and at most 4."
287-
),
288332
}
289333
}
290334

@@ -495,7 +539,47 @@ pub fn specialize_regexp_replace<T: OffsetSizeTrait>(
495539
.iter()
496540
.map(|arg| arg.clone().into_array(inferred_length))
497541
.collect::<Result<Vec<_>>>()?;
498-
regexp_replace::<T>(&args)
542+
543+
match args[0].data_type() {
544+
DataType::Utf8View => {
545+
let string_array = args[0].as_string_view();
546+
let pattern_array = args[1].as_string::<i32>();
547+
let replacement_array = args[2].as_string::<i32>();
548+
regexp_replace::<i32, _, _>(
549+
string_array,
550+
pattern_array,
551+
replacement_array,
552+
args.get(3),
553+
)
554+
}
555+
DataType::Utf8 => {
556+
let string_array = args[0].as_string::<i32>();
557+
let pattern_array = args[1].as_string::<i32>();
558+
let replacement_array = args[2].as_string::<i32>();
559+
regexp_replace::<i32, _, _>(
560+
string_array,
561+
pattern_array,
562+
replacement_array,
563+
args.get(3),
564+
)
565+
}
566+
DataType::LargeUtf8 => {
567+
let string_array = args[0].as_string::<i64>();
568+
let pattern_array = args[1].as_string::<i64>();
569+
let replacement_array = args[2].as_string::<i64>();
570+
regexp_replace::<i64, _, _>(
571+
string_array,
572+
pattern_array,
573+
replacement_array,
574+
args.get(3),
575+
)
576+
}
577+
other => {
578+
exec_err!(
579+
"Unsupported data type {other:?} for function regex_replace"
580+
)
581+
}
582+
}
499583
}
500584
}
501585
}

0 commit comments

Comments
 (0)