|
16 | 16 | // under the License.
|
17 | 17 |
|
18 | 18 | //! Regx expressions
|
19 |
| -use arrow::array::new_null_array; |
20 |
| -use arrow::array::ArrayAccessor; |
21 | 19 | use arrow::array::ArrayDataBuilder;
|
22 | 20 | use arrow::array::BufferBuilder;
|
23 | 21 | use arrow::array::GenericStringArray;
|
24 | 22 | use arrow::array::StringViewBuilder;
|
| 23 | +use arrow::array::{new_null_array, ArrayIter, AsArray}; |
25 | 24 | use arrow::array::{Array, ArrayRef, OffsetSizeTrait};
|
| 25 | +use arrow::array::{ArrayAccessor, StringViewArray}; |
26 | 26 | use arrow::datatypes::DataType;
|
27 | 27 | use datafusion_common::cast::as_string_view_array;
|
28 | 28 | use datafusion_common::exec_err;
|
@@ -59,6 +59,7 @@ impl RegexpReplaceFunc {
|
59 | 59 | Exact(vec![Utf8, Utf8, Utf8]),
|
60 | 60 | Exact(vec![Utf8View, Utf8, Utf8]),
|
61 | 61 | Exact(vec![Utf8, Utf8, Utf8, Utf8]),
|
| 62 | + Exact(vec![Utf8View, Utf8, Utf8, Utf8]), |
62 | 63 | ],
|
63 | 64 | Volatility::Immutable,
|
64 | 65 | ),
|
@@ -187,104 +188,147 @@ fn regex_replace_posix_groups(replacement: &str) -> String {
|
187 | 188 | /// # Ok(())
|
188 | 189 | /// # }
|
189 | 190 | /// ```
|
190 |
| -pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> { |
| 191 | +pub fn regexp_replace<'a, T: OffsetSizeTrait, V, B>( |
| 192 | + string_array: V, |
| 193 | + pattern_array: B, |
| 194 | + replacement_array: B, |
| 195 | + flags: Option<&ArrayRef>, |
| 196 | +) -> Result<ArrayRef> |
| 197 | +where |
| 198 | + V: ArrayAccessor<Item = &'a str>, |
| 199 | + B: ArrayAccessor<Item = &'a str>, |
| 200 | +{ |
191 | 201 | // Default implementation for regexp_replace, assumes all args are arrays
|
192 | 202 | // and args is a sequence of 3 or 4 elements.
|
193 | 203 |
|
194 | 204 | // creating Regex is expensive so create hashmap for memoization
|
195 | 205 | let mut patterns: HashMap<String, Regex> = HashMap::new();
|
196 | 206 |
|
197 |
| - match args.len() { |
198 |
| - 3 => { |
199 |
| - let string_array = as_generic_string_array::<T>(&args[0])?; |
200 |
| - let pattern_array = as_generic_string_array::<T>(&args[1])?; |
201 |
| - let replacement_array = as_generic_string_array::<T>(&args[2])?; |
202 |
| - |
203 |
| - let result = string_array |
204 |
| - .iter() |
205 |
| - .zip(pattern_array.iter()) |
206 |
| - .zip(replacement_array.iter()) |
207 |
| - .map(|((string, pattern), replacement)| match (string, pattern, replacement) { |
208 |
| - (Some(string), Some(pattern), Some(replacement)) => { |
209 |
| - let replacement = regex_replace_posix_groups(replacement); |
210 |
| - |
211 |
| - // if patterns hashmap already has regexp then use else create and return |
212 |
| - let re = match patterns.get(pattern) { |
213 |
| - Some(re) => Ok(re), |
214 |
| - None => { |
215 |
| - match Regex::new(pattern) { |
216 |
| - Ok(re) => { |
217 |
| - patterns.insert(pattern.to_string(), re); |
218 |
| - Ok(patterns.get(pattern).unwrap()) |
| 207 | + let datatype = string_array.data_type().to_owned(); |
| 208 | + |
| 209 | + let string_array_iter = ArrayIter::new(string_array); |
| 210 | + let pattern_array_iter = ArrayIter::new(pattern_array); |
| 211 | + let replacement_array_iter = ArrayIter::new(replacement_array); |
| 212 | + |
| 213 | + match flags { |
| 214 | + None => { |
| 215 | + let result_iter = string_array_iter |
| 216 | + .zip(pattern_array_iter) |
| 217 | + .zip(replacement_array_iter) |
| 218 | + .map(|((string, pattern), replacement)| { |
| 219 | + match (string, pattern, replacement) { |
| 220 | + (Some(string), Some(pattern), Some(replacement)) => { |
| 221 | + let replacement = regex_replace_posix_groups(replacement); |
| 222 | + // if patterns hashmap already has regexp then use else create and return |
| 223 | + let re = match patterns.get(pattern) { |
| 224 | + Some(re) => Ok(re), |
| 225 | + None => match Regex::new(pattern) { |
| 226 | + Ok(re) => { |
| 227 | + patterns.insert(pattern.to_string(), re); |
| 228 | + Ok(patterns.get(pattern).unwrap()) |
| 229 | + } |
| 230 | + Err(err) => { |
| 231 | + Err(DataFusionError::External(Box::new(err))) |
| 232 | + } |
219 | 233 | },
|
220 |
| - Err(err) => Err(DataFusionError::External(Box::new(err))), |
221 |
| - } |
222 |
| - } |
223 |
| - }; |
| 234 | + }; |
224 | 235 |
|
225 |
| - Some(re.map(|re| re.replace(string, replacement.as_str()))).transpose() |
| 236 | + Some(re.map(|re| re.replace(string, replacement.as_str()))) |
| 237 | + .transpose() |
| 238 | + } |
| 239 | + _ => Ok(None), |
| 240 | + } |
| 241 | + }); |
| 242 | + |
| 243 | + match datatype { |
| 244 | + DataType::Utf8 | DataType::LargeUtf8 => { |
| 245 | + let result = |
| 246 | + result_iter.collect::<Result<GenericStringArray<T>>>()?; |
| 247 | + Ok(Arc::new(result) as ArrayRef) |
226 | 248 | }
|
227 |
| - _ => Ok(None) |
228 |
| - }) |
229 |
| - .collect::<Result<GenericStringArray<T>>>()?; |
230 |
| - |
231 |
| - Ok(Arc::new(result) as ArrayRef) |
| 249 | + DataType::Utf8View => { |
| 250 | + let result = result_iter.collect::<Result<StringViewArray>>()?; |
| 251 | + Ok(Arc::new(result) as ArrayRef) |
| 252 | + } |
| 253 | + other => { |
| 254 | + exec_err!( |
| 255 | + "Unsupported data type {other:?} for function regex_replace" |
| 256 | + ) |
| 257 | + } |
| 258 | + } |
232 | 259 | }
|
233 |
| - 4 => { |
234 |
| - let string_array = as_generic_string_array::<T>(&args[0])?; |
235 |
| - let pattern_array = as_generic_string_array::<T>(&args[1])?; |
236 |
| - let replacement_array = as_generic_string_array::<T>(&args[2])?; |
237 |
| - let flags_array = as_generic_string_array::<T>(&args[3])?; |
238 |
| - |
239 |
| - let result = string_array |
240 |
| - .iter() |
241 |
| - .zip(pattern_array.iter()) |
242 |
| - .zip(replacement_array.iter()) |
243 |
| - .zip(flags_array.iter()) |
244 |
| - .map(|(((string, pattern), replacement), flags)| match (string, pattern, replacement, flags) { |
245 |
| - (Some(string), Some(pattern), Some(replacement), Some(flags)) => { |
246 |
| - let replacement = regex_replace_posix_groups(replacement); |
247 |
| - |
248 |
| - // format flags into rust pattern |
249 |
| - let (pattern, replace_all) = if flags == "g" { |
250 |
| - (pattern.to_string(), true) |
251 |
| - } else if flags.contains('g') { |
252 |
| - (format!("(?{}){}", flags.to_string().replace('g', ""), pattern), true) |
253 |
| - } else { |
254 |
| - (format!("(?{flags}){pattern}"), false) |
255 |
| - }; |
256 |
| - |
257 |
| - // if patterns hashmap already has regexp then use else create and return |
258 |
| - let re = match patterns.get(&pattern) { |
259 |
| - Some(re) => Ok(re), |
260 |
| - None => { |
261 |
| - match Regex::new(pattern.as_str()) { |
262 |
| - Ok(re) => { |
263 |
| - patterns.insert(pattern.clone(), re); |
264 |
| - Ok(patterns.get(&pattern).unwrap()) |
| 260 | + Some(flags) => { |
| 261 | + let flags_array = as_generic_string_array::<T>(flags)?; |
| 262 | + |
| 263 | + let result_iter = string_array_iter |
| 264 | + .zip(pattern_array_iter) |
| 265 | + .zip(replacement_array_iter) |
| 266 | + .zip(flags_array.iter()) |
| 267 | + .map(|(((string, pattern), replacement), flags)| { |
| 268 | + match (string, pattern, replacement, flags) { |
| 269 | + (Some(string), Some(pattern), Some(replacement), Some(flags)) => { |
| 270 | + let replacement = regex_replace_posix_groups(replacement); |
| 271 | + |
| 272 | + // format flags into rust pattern |
| 273 | + let (pattern, replace_all) = if flags == "g" { |
| 274 | + (pattern.to_string(), true) |
| 275 | + } else if flags.contains('g') { |
| 276 | + ( |
| 277 | + format!( |
| 278 | + "(?{}){}", |
| 279 | + flags.to_string().replace('g', ""), |
| 280 | + pattern |
| 281 | + ), |
| 282 | + true, |
| 283 | + ) |
| 284 | + } else { |
| 285 | + (format!("(?{flags}){pattern}"), false) |
| 286 | + }; |
| 287 | + |
| 288 | + // if patterns hashmap already has regexp then use else create and return |
| 289 | + let re = match patterns.get(&pattern) { |
| 290 | + Some(re) => Ok(re), |
| 291 | + None => match Regex::new(pattern.as_str()) { |
| 292 | + Ok(re) => { |
| 293 | + patterns.insert(pattern.clone(), re); |
| 294 | + Ok(patterns.get(&pattern).unwrap()) |
| 295 | + } |
| 296 | + Err(err) => { |
| 297 | + Err(DataFusionError::External(Box::new(err))) |
| 298 | + } |
265 | 299 | },
|
266 |
| - Err(err) => Err(DataFusionError::External(Box::new(err))), |
267 |
| - } |
| 300 | + }; |
| 301 | + |
| 302 | + Some(re.map(|re| { |
| 303 | + if replace_all { |
| 304 | + re.replace_all(string, replacement.as_str()) |
| 305 | + } else { |
| 306 | + re.replace(string, replacement.as_str()) |
| 307 | + } |
| 308 | + })) |
| 309 | + .transpose() |
268 | 310 | }
|
269 |
| - }; |
270 |
| - |
271 |
| - Some(re.map(|re| { |
272 |
| - if replace_all { |
273 |
| - re.replace_all(string, replacement.as_str()) |
274 |
| - } else { |
275 |
| - re.replace(string, replacement.as_str()) |
276 |
| - } |
277 |
| - })).transpose() |
| 311 | + _ => Ok(None), |
| 312 | + } |
| 313 | + }); |
| 314 | + |
| 315 | + match datatype { |
| 316 | + DataType::Utf8 | DataType::LargeUtf8 => { |
| 317 | + let result = |
| 318 | + result_iter.collect::<Result<GenericStringArray<T>>>()?; |
| 319 | + Ok(Arc::new(result) as ArrayRef) |
278 | 320 | }
|
279 |
| - _ => Ok(None) |
280 |
| - }) |
281 |
| - .collect::<Result<GenericStringArray<T>>>()?; |
282 |
| - |
283 |
| - Ok(Arc::new(result) as ArrayRef) |
| 321 | + DataType::Utf8View => { |
| 322 | + let result = result_iter.collect::<Result<StringViewArray>>()?; |
| 323 | + Ok(Arc::new(result) as ArrayRef) |
| 324 | + } |
| 325 | + other => { |
| 326 | + exec_err!( |
| 327 | + "Unsupported data type {other:?} for function regex_replace" |
| 328 | + ) |
| 329 | + } |
| 330 | + } |
284 | 331 | }
|
285 |
| - other => exec_err!( |
286 |
| - "regexp_replace was called with {other} arguments. It requires at least 3 and at most 4." |
287 |
| - ), |
288 | 332 | }
|
289 | 333 | }
|
290 | 334 |
|
@@ -495,7 +539,47 @@ pub fn specialize_regexp_replace<T: OffsetSizeTrait>(
|
495 | 539 | .iter()
|
496 | 540 | .map(|arg| arg.clone().into_array(inferred_length))
|
497 | 541 | .collect::<Result<Vec<_>>>()?;
|
498 |
| - regexp_replace::<T>(&args) |
| 542 | + |
| 543 | + match args[0].data_type() { |
| 544 | + DataType::Utf8View => { |
| 545 | + let string_array = args[0].as_string_view(); |
| 546 | + let pattern_array = args[1].as_string::<i32>(); |
| 547 | + let replacement_array = args[2].as_string::<i32>(); |
| 548 | + regexp_replace::<i32, _, _>( |
| 549 | + string_array, |
| 550 | + pattern_array, |
| 551 | + replacement_array, |
| 552 | + args.get(3), |
| 553 | + ) |
| 554 | + } |
| 555 | + DataType::Utf8 => { |
| 556 | + let string_array = args[0].as_string::<i32>(); |
| 557 | + let pattern_array = args[1].as_string::<i32>(); |
| 558 | + let replacement_array = args[2].as_string::<i32>(); |
| 559 | + regexp_replace::<i32, _, _>( |
| 560 | + string_array, |
| 561 | + pattern_array, |
| 562 | + replacement_array, |
| 563 | + args.get(3), |
| 564 | + ) |
| 565 | + } |
| 566 | + DataType::LargeUtf8 => { |
| 567 | + let string_array = args[0].as_string::<i64>(); |
| 568 | + let pattern_array = args[1].as_string::<i64>(); |
| 569 | + let replacement_array = args[2].as_string::<i64>(); |
| 570 | + regexp_replace::<i64, _, _>( |
| 571 | + string_array, |
| 572 | + pattern_array, |
| 573 | + replacement_array, |
| 574 | + args.get(3), |
| 575 | + ) |
| 576 | + } |
| 577 | + other => { |
| 578 | + exec_err!( |
| 579 | + "Unsupported data type {other:?} for function regex_replace" |
| 580 | + ) |
| 581 | + } |
| 582 | + } |
499 | 583 | }
|
500 | 584 | }
|
501 | 585 | }
|
|
0 commit comments