diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 1286c6b5b1bc..8da154430fc5 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -30,6 +30,7 @@ use datafusion_expr::{ }; use itertools::izip; use regex::Regex; +use std::collections::hash_map::Entry; use std::collections::HashMap; use std::sync::{Arc, OnceLock}; @@ -548,16 +549,22 @@ where } } -fn compile_and_cache_regex<'a>( - regex: &'a str, - flags: Option<&'a str>, - regex_cache: &'a mut HashMap, -) -> Result<&'a Regex, ArrowError> { - if !regex_cache.contains_key(regex) { - let compiled = compile_regex(regex, flags)?; - regex_cache.insert(regex.to_string(), compiled); - } - Ok(regex_cache.get(regex).unwrap()) +fn compile_and_cache_regex<'strings, 'cache>( + regex: &'strings str, + flags: Option<&'strings str>, + regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>, +) -> Result<&'cache Regex, ArrowError> +where + 'strings: 'cache, +{ + let result = match regex_cache.entry((regex, flags)) { + Entry::Occupied(occupied_entry) => occupied_entry.into_mut(), + Entry::Vacant(vacant_entry) => { + let compiled = compile_regex(regex, flags)?; + vacant_entry.insert(compiled) + } + }; + Ok(result) } fn compile_regex(regex: &str, flags: Option<&str>) -> Result { @@ -634,6 +641,8 @@ mod tests { test_case_sensitive_regexp_count_array_complex::>(); test_case_sensitive_regexp_count_array_complex::>(); test_case_sensitive_regexp_count_array_complex::(); + + test_case_regexp_count_cache_check::>(); } fn test_case_sensitive_regexp_count_scalar() { @@ -977,4 +986,25 @@ mod tests { .unwrap(); assert_eq!(re.as_ref(), &expected); } + + fn test_case_regexp_count_cache_check() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["aaa", "Aaa", "aaa"]); + let regex = A::from(vec!["aaa", "aaa", "aaa"]); + let start = Int64Array::from(vec![1, 1, 1]); + let flags = A::from(vec!["", "i", ""]); + + let expected = Int64Array::from(vec![1, 1, 1]); + + let re = regexp_count_func(&[ + Arc::new(values), + Arc::new(regex), + Arc::new(start), + Arc::new(flags), + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } }