Skip to content

Commit

Permalink
Fix regex cache on pattern, less alloc, hash less often (#13414)
Browse files Browse the repository at this point in the history
* cache on pattern, less alloc, hash less often

* inline get_pattern

* reduce to one hash

* remove unnecessary lifetimes
  • Loading branch information
Dimchikkk authored Nov 15, 2024
1 parent 75a27a8 commit 8c35270
Showing 1 changed file with 40 additions and 10 deletions.
50 changes: 40 additions & 10 deletions datafusion/functions/src/regex/regexpcount.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use datafusion_expr::{
};
use itertools::izip;
use regex::Regex;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::{Arc, OnceLock};

Expand Down Expand Up @@ -548,16 +549,22 @@ where
}
}

fn compile_and_cache_regex<'a>(
regex: &'a str,
flags: Option<&'a str>,
regex_cache: &'a mut HashMap<String, Regex>,
) -> Result<&'a Regex, ArrowError> {
if !regex_cache.contains_key(regex) {
let compiled = compile_regex(regex, flags)?;
regex_cache.insert(regex.to_string(), compiled);
}
Ok(regex_cache.get(regex).unwrap())
fn compile_and_cache_regex<'strings, 'cache>(
regex: &'strings str,
flags: Option<&'strings str>,
regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>,
) -> Result<&'cache Regex, ArrowError>
where
'strings: 'cache,
{
let result = match regex_cache.entry((regex, flags)) {
Entry::Occupied(occupied_entry) => occupied_entry.into_mut(),
Entry::Vacant(vacant_entry) => {
let compiled = compile_regex(regex, flags)?;
vacant_entry.insert(compiled)
}
};
Ok(result)
}

fn compile_regex(regex: &str, flags: Option<&str>) -> Result<Regex, ArrowError> {
Expand Down Expand Up @@ -634,6 +641,8 @@ mod tests {
test_case_sensitive_regexp_count_array_complex::<GenericStringArray<i32>>();
test_case_sensitive_regexp_count_array_complex::<GenericStringArray<i64>>();
test_case_sensitive_regexp_count_array_complex::<StringViewArray>();

test_case_regexp_count_cache_check::<GenericStringArray<i32>>();
}

fn test_case_sensitive_regexp_count_scalar() {
Expand Down Expand Up @@ -977,4 +986,25 @@ mod tests {
.unwrap();
assert_eq!(re.as_ref(), &expected);
}

fn test_case_regexp_count_cache_check<A>()
where
A: From<Vec<&'static str>> + Array + 'static,
{
let values = A::from(vec!["aaa", "Aaa", "aaa"]);
let regex = A::from(vec!["aaa", "aaa", "aaa"]);
let start = Int64Array::from(vec![1, 1, 1]);
let flags = A::from(vec!["", "i", ""]);

let expected = Int64Array::from(vec![1, 1, 1]);

let re = regexp_count_func(&[
Arc::new(values),
Arc::new(regex),
Arc::new(start),
Arc::new(flags),
])
.unwrap();
assert_eq!(re.as_ref(), &expected);
}
}

0 comments on commit 8c35270

Please sign in to comment.