|
17 | 17 |
|
18 | 18 | //! Common utilities for implementing string functions
|
19 | 19 |
|
| 20 | +use std::collections::HashMap; |
20 | 21 | use std::fmt::{Display, Formatter};
|
21 | 22 | use std::sync::Arc;
|
22 | 23 |
|
23 | 24 | use arrow::array::{
|
24 | 25 | new_null_array, Array, ArrayAccessor, ArrayDataBuilder, ArrayIter, ArrayRef,
|
25 |
| - GenericStringArray, GenericStringBuilder, OffsetSizeTrait, StringArray, |
| 26 | + BooleanArray, GenericStringArray, GenericStringBuilder, OffsetSizeTrait, StringArray, |
26 | 27 | StringBuilder, StringViewArray,
|
27 | 28 | };
|
28 | 29 | use arrow::buffer::{Buffer, MutableBuffer, NullBuffer};
|
29 | 30 | use arrow::datatypes::DataType;
|
| 31 | +use arrow_buffer::BooleanBufferBuilder; |
30 | 32 | use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
|
31 |
| -use datafusion_common::Result; |
32 | 33 | use datafusion_common::{exec_err, ScalarValue};
|
| 34 | +use datafusion_common::{DataFusionError, Result}; |
33 | 35 | use datafusion_expr::ColumnarValue;
|
| 36 | +use regex::Regex; |
34 | 37 |
|
35 | 38 | pub(crate) enum TrimType {
|
36 | 39 | Left,
|
@@ -478,3 +481,93 @@ where
|
478 | 481 | GenericStringArray::<O>::new_unchecked(offsets, values, nulls)
|
479 | 482 | }))
|
480 | 483 | }
|
| 484 | + |
| 485 | +/// Perform SQL `array ~ regex_array` operation on |
| 486 | +/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`]. |
| 487 | +/// If `regex_array` element has an empty value, the corresponding result value is always true. |
| 488 | +/// |
| 489 | +/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag, |
| 490 | +/// which allow special search modes, such as case-insensitive and multi-line mode. |
| 491 | +/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags) |
| 492 | +/// for more information. |
| 493 | +/// |
| 494 | +/// It is inspired / copied from `regexp_is_match_utf8` [arrow-rs]. |
| 495 | +/// |
| 496 | +/// [arrow-rs]: https://github.com/apache/arrow-rs/blob/8c956a9f9ab26c14072740cce64c2b99cb039b13/arrow-string/src/regexp.rs#L31-L37 |
| 497 | +pub fn regexp_is_match<'a, ArrayType1, ArrayType2, ArrayType3>( |
| 498 | + array: &'a ArrayType1, |
| 499 | + regex_array: &'a ArrayType2, |
| 500 | + flags_array: Option<&'a ArrayType3>, |
| 501 | +) -> datafusion_common::Result<BooleanArray, DataFusionError> |
| 502 | +where |
| 503 | + &'a ArrayType1: StringArrayType<'a>, |
| 504 | + &'a ArrayType2: StringArrayType<'a>, |
| 505 | + &'a ArrayType3: StringArrayType<'a>, |
| 506 | +{ |
| 507 | + if array.len() != regex_array.len() { |
| 508 | + return Err(DataFusionError::Execution( |
| 509 | + "Cannot perform comparison operation on arrays of different length" |
| 510 | + .to_string(), |
| 511 | + )); |
| 512 | + } |
| 513 | + |
| 514 | + let nulls = NullBuffer::union(array.nulls(), regex_array.nulls()); |
| 515 | + |
| 516 | + let mut patterns: HashMap<String, Regex> = HashMap::new(); |
| 517 | + let mut result = BooleanBufferBuilder::new(array.len()); |
| 518 | + |
| 519 | + let complete_pattern = match flags_array { |
| 520 | + Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map( |
| 521 | + |(pattern, flags)| { |
| 522 | + pattern.map(|pattern| match flags { |
| 523 | + Some(flag) => format!("(?{flag}){pattern}"), |
| 524 | + None => pattern.to_string(), |
| 525 | + }) |
| 526 | + }, |
| 527 | + )) as Box<dyn Iterator<Item = Option<String>>>, |
| 528 | + None => Box::new( |
| 529 | + regex_array |
| 530 | + .iter() |
| 531 | + .map(|pattern| pattern.map(|pattern| pattern.to_string())), |
| 532 | + ), |
| 533 | + }; |
| 534 | + |
| 535 | + array |
| 536 | + .iter() |
| 537 | + .zip(complete_pattern) |
| 538 | + .map(|(value, pattern)| { |
| 539 | + match (value, pattern) { |
| 540 | + (Some(_), Some(pattern)) if pattern == *"" => { |
| 541 | + result.append(true); |
| 542 | + } |
| 543 | + (Some(value), Some(pattern)) => { |
| 544 | + let existing_pattern = patterns.get(&pattern); |
| 545 | + let re = match existing_pattern { |
| 546 | + Some(re) => re, |
| 547 | + None => { |
| 548 | + let re = Regex::new(pattern.as_str()).map_err(|e| { |
| 549 | + DataFusionError::Execution(format!( |
| 550 | + "Regular expression did not compile: {e:?}" |
| 551 | + )) |
| 552 | + })?; |
| 553 | + patterns.entry(pattern).or_insert(re) |
| 554 | + } |
| 555 | + }; |
| 556 | + result.append(re.is_match(value)); |
| 557 | + } |
| 558 | + _ => result.append(false), |
| 559 | + } |
| 560 | + Ok(()) |
| 561 | + }) |
| 562 | + .collect::<datafusion_common::Result<Vec<()>, DataFusionError>>()?; |
| 563 | + |
| 564 | + let data = unsafe { |
| 565 | + ArrayDataBuilder::new(DataType::Boolean) |
| 566 | + .len(array.len()) |
| 567 | + .buffers(vec![result.into()]) |
| 568 | + .nulls(nulls) |
| 569 | + .build_unchecked() |
| 570 | + }; |
| 571 | + |
| 572 | + Ok(BooleanArray::from(data)) |
| 573 | +} |
0 commit comments