Skip to content

Commit 3c20b44

Browse files
committed
feat(mapper): rewrite std SQL functions for EQL cols
e.g. `jsonb_query_path(eql_col, selector)` on an EQL column should be transformed to `eql_v1.jsonb_query_path(..)`
1 parent 27aac0d commit 3c20b44

File tree

8 files changed

+207
-43
lines changed

8 files changed

+207
-43
lines changed

packages/eql-mapper/src/inference/infer_type_impls/function.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use eql_mapper_macros::trace_infer;
22
use sqltk::parser::ast::{Function, FunctionArguments};
33

44
use crate::{
5-
get_type_signature_for_special_cased_sql_function, inference::infer_type::InferType,
5+
get_sql_function_def, inference::infer_type::InferType,
66
CompoundIdent, FunctionSig, TypeError, TypeInferencer,
77
};
88

@@ -23,9 +23,9 @@ impl<'ast> InferType<'ast, Function> for TypeInferencer<'ast> {
2323
let Function { name, args, .. } = function;
2424
let fn_name = CompoundIdent::from(&name.0);
2525

26-
match get_type_signature_for_special_cased_sql_function(&fn_name, args) {
27-
Some(sig) => {
28-
sig.instantiate(&*self).apply_constraints(self, function)?;
26+
match get_sql_function_def(&fn_name, args) {
27+
Some(sql_fn) => {
28+
sql_fn.sig.instantiate(&*self).apply_constraints(self, function)?;
2929
}
3030
None => {
3131
FunctionSig::instantiate_native(function).apply_constraints(self, function)?;

packages/eql-mapper/src/inference/sql_fn_macros.rs

+26-1
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,35 @@ macro_rules! sql_fn_args {
2121

2222
#[macro_export]
2323
macro_rules! sql_fn {
24+
($name:ident $args:tt -> $return_kind:ident, rewrite) => {
25+
$crate::SqlFunction::new(
26+
stringify!($name),
27+
FunctionSig::new($crate::sql_fn_args!($args), $crate::to_kind!($return_kind)),
28+
crate::RewriteRule::AsEqlFunction,
29+
)
30+
};
31+
2432
($name:ident $args:tt -> $return_kind:ident) => {
2533
$crate::SqlFunction::new(
2634
stringify!($name),
2735
FunctionSig::new($crate::sql_fn_args!($args), $crate::to_kind!($return_kind)),
36+
crate::RewriteRule::Ignore,
2837
)
2938
};
30-
}
39+
40+
($schema:ident . $name:ident $args:tt -> $return_kind:ident, rewrite) => {
41+
$crate::SqlFunction::new(
42+
stringify!($schema . $name),
43+
FunctionSig::new($crate::sql_fn_args!($args), $crate::to_kind!($return_kind)),
44+
crate::RewriteRule::AsEqlFunction,
45+
)
46+
};
47+
48+
($schema:ident . $name:ident $args:tt -> $return_kind:ident) => {
49+
$crate::SqlFunction::new(
50+
stringify!($schema . $name),
51+
FunctionSig::new($crate::sql_fn_args!($args), $crate::to_kind!($return_kind)),
52+
crate::RewriteRule::Ignore,
53+
)
54+
};
55+
}

packages/eql-mapper/src/inference/sql_functions.rs

+61-34
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,22 @@ use super::TypeError;
1717
///
1818
/// See [`SQL_FUNCTION_SIGNATURES`].
1919
#[derive(Debug)]
20-
pub(crate) struct SqlFunction(CompoundIdent, FunctionSig);
20+
pub(crate) struct SqlFunction {
21+
pub(crate) name: CompoundIdent,
22+
pub(crate) sig: FunctionSig,
23+
pub(crate) rewrite_rule: RewriteRule,
24+
}
25+
26+
#[derive(Debug)]
27+
pub(crate) enum RewriteRule {
28+
Ignore,
29+
AsEqlFunction,
30+
}
2131

2232
/// A representation of the type of an argument or return type in a SQL function.
2333
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
2434
pub(crate) enum Kind {
25-
/// A type that mjust be a native type
35+
/// A type that must be a native type
2636
Native,
2737

2838
/// A type that can be a native or EQL type. The `str` is the generic variable name.
@@ -178,8 +188,12 @@ fn get_function_arg_expr(fn_arg: &FunctionArg) -> &FunctionArgExpr {
178188
}
179189

180190
impl SqlFunction {
181-
fn new(ident: &str, sig: FunctionSig) -> Self {
182-
Self(CompoundIdent::from(ident), sig)
191+
fn new(ident: &str, sig: FunctionSig, rewrite_rule: RewriteRule) -> Self {
192+
Self {
193+
name: CompoundIdent::from(ident),
194+
sig,
195+
rewrite_rule,
196+
}
183197
}
184198
}
185199

@@ -202,38 +216,51 @@ impl From<&Vec<Ident>> for CompoundIdent {
202216
}
203217

204218
/// SQL functions that are handled with special case type checking rules.
205-
static SQL_FUNCTION_SIGNATURES: LazyLock<HashMap<CompoundIdent, Vec<FunctionSig>>> =
206-
LazyLock::new(|| {
207-
// Notation: a single uppercase letter denotes an unknown type. Matching letters in a signature will be assigned
208-
// *the same type variable* and thus must resolve to the same type. (🙏 Haskell)
209-
//
210-
// Eventually we should type check EQL types against their configured indexes instead of leaving that to the EQL
211-
// extension in the database. I can imagine supporting type bounds in signatures here, such as: `T: Eq`
212-
let sql_fns = vec![
213-
sql_fn!(count(T) -> NATIVE),
214-
sql_fn!(min(T) -> T),
215-
sql_fn!(max(T) -> T),
216-
sql_fn!(jsonb_path_query(T, T) -> T),
217-
sql_fn!(jsonb_path_query_first(T, T) -> T),
218-
sql_fn!(jsonb_path_exists(T, T) -> T),
219-
];
220-
221-
let mut sql_fns_by_name: HashMap<CompoundIdent, Vec<FunctionSig>> = HashMap::new();
222-
223-
for (key, chunk) in &sql_fns.into_iter().chunk_by(|sql_fn| sql_fn.0.clone()) {
224-
sql_fns_by_name.insert(
225-
key.clone(),
226-
chunk.into_iter().map(|sql_fn| sql_fn.1).collect(),
227-
);
228-
}
219+
static SQL_FUNCTIONS: LazyLock<HashMap<CompoundIdent, Vec<SqlFunction>>> = LazyLock::new(|| {
220+
// Notation: a single uppercase letter denotes an unknown type. Matching letters in a signature will be assigned
221+
// *the same type variable* and thus must resolve to the same type. (🙏 Haskell)
222+
//
223+
// Eventually we should type check EQL types against their configured indexes instead of leaving that to the EQL
224+
// extension in the database. I can imagine supporting type bounds in signatures here, such as: `T: Eq`
225+
let sql_fns = vec![
226+
// TODO: when search_path support is added to the resolver we should change these
227+
// to their fully-qualified names.
228+
sql_fn!(count(T) -> NATIVE),
229+
sql_fn!(min(T) -> T, rewrite),
230+
sql_fn!(max(T) -> T, rewrite),
231+
sql_fn!(jsonb_path_query(T, T) -> T, rewrite),
232+
sql_fn!(jsonb_path_query_first(T, T) -> T, rewrite),
233+
sql_fn!(jsonb_path_exists(T, T) -> T, rewrite),
234+
sql_fn!(jsonb_array_length(T) -> T, rewrite),
235+
sql_fn!(jsonb_array_elements(T) -> T, rewrite),
236+
sql_fn!(jsonb_array_elements_text(T) -> T, rewrite),
237+
// These are typings for when customer SQL already contains references to EQL functions.
238+
// They must be type checked but not rewritten.
239+
sql_fn!(eql_v1.min(T) -> T),
240+
sql_fn!(eql_v1.max(T) -> T),
241+
sql_fn!(eql_v1.jsonb_path_query(T, T) -> T),
242+
sql_fn!(eql_v1.jsonb_path_query_first(T, T) -> T),
243+
sql_fn!(eql_v1.jsonb_path_exists(T, T) -> T),
244+
sql_fn!(eql_v1.jsonb_array_length(T) -> T),
245+
sql_fn!(eql_v1.jsonb_array_elements(T) -> T),
246+
sql_fn!(eql_v1.jsonb_array_elements_text(T) -> T),
247+
];
248+
249+
let mut sql_fns_by_name: HashMap<CompoundIdent, Vec<SqlFunction>> = HashMap::new();
250+
251+
for (key, chunk) in &sql_fns.into_iter().chunk_by(|sql_fn| sql_fn.name.clone()) {
252+
sql_fns_by_name.insert(key.clone(), chunk.into_iter().collect());
253+
}
229254

230-
sql_fns_by_name
231-
});
255+
sql_fns_by_name
256+
});
232257

233-
pub(crate) fn get_type_signature_for_special_cased_sql_function(
258+
pub(crate) fn get_sql_function_def(
234259
fn_name: &CompoundIdent,
235260
args: &FunctionArguments,
236-
) -> Option<&'static FunctionSig> {
237-
let sigs = SQL_FUNCTION_SIGNATURES.get(fn_name)?;
238-
sigs.iter().find(|sig| sig.is_applicable_to_args(args))
261+
) -> Option<&'static SqlFunction> {
262+
let sql_fns = SQL_FUNCTIONS.get(fn_name)?;
263+
sql_fns
264+
.iter()
265+
.find(|sql_fn| sql_fn.sig.is_applicable_to_args(args))
239266
}

packages/eql-mapper/src/lib.rs

+36
Original file line numberDiff line numberDiff line change
@@ -1391,6 +1391,42 @@ mod test {
13911391
}
13921392
}
13931393

1394+
#[test]
1395+
fn rewrite_standard_sql_fns_on_eql_types() {
1396+
// init_tracing();
1397+
let schema = resolver(schema! {
1398+
tables: {
1399+
employees: {
1400+
id (PK),
1401+
eql_col (EQL),
1402+
native_col,
1403+
}
1404+
}
1405+
});
1406+
1407+
let statement = parse("
1408+
SELECT jsonb_path_query(eql_col, '$.secret'), jsonb_path_query(native_col, '$.not-secret') FROM employees
1409+
");
1410+
1411+
match type_check(schema, &statement) {
1412+
Ok(typed) => {
1413+
match typed.transform(test_helpers::dummy_encrypted_json_selector(
1414+
&statement,
1415+
ast::Value::SingleQuotedString("$.secret".into()),
1416+
)) {
1417+
Ok(statement) => {
1418+
assert_eq!(
1419+
statement.to_string(),
1420+
"SELECT eql_v1.jsonb_path_query(eql_col, ROW('<encrypted-selector($.secret)>'::JSONB)), jsonb_path_query(native_col, '$.not-secret') FROM employees"
1421+
);
1422+
}
1423+
Err(err) => panic!("transformation failed: {err}"),
1424+
}
1425+
}
1426+
Err(err) => panic!("type check failed: {err}"),
1427+
}
1428+
}
1429+
13941430
#[test]
13951431
fn jsonb_operator_arrow() {
13961432
test_jsonb_operator("->");

packages/eql-mapper/src/test_helpers.rs

+1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ pub(crate) fn find_nodekey_for_value_node(
9393

9494
visitor.found
9595
}
96+
9697
#[macro_export]
9798
macro_rules! col {
9899
((NATIVE)) => {

packages/eql-mapper/src/transformation_rules/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ mod fail_on_placeholder_change;
1515
mod group_by_eql_col;
1616
mod preserve_effective_aliases;
1717
mod replace_plaintext_eql_literals;
18+
mod rewrite_standard_sql_fns_on_eql_types;
1819
mod use_equivalent_eql_fns_on_eql_types;
1920
mod wrap_eql_cols_in_order_by_with_ore_fn;
2021
mod wrap_grouped_eql_col_in_aggregate_fn;
@@ -24,6 +25,7 @@ use std::marker::PhantomData;
2425
pub(crate) use fail_on_placeholder_change::*;
2526
pub(crate) use group_by_eql_col::*;
2627
pub(crate) use preserve_effective_aliases::*;
28+
pub(crate) use rewrite_standard_sql_fns_on_eql_types::*;
2729
pub(crate) use replace_plaintext_eql_literals::*;
2830
pub(crate) use use_equivalent_eql_fns_on_eql_types::*;
2931
pub(crate) use wrap_eql_cols_in_order_by_with_ore_fn::*;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
use std::mem;
2+
use std::{collections::HashMap, sync::Arc};
3+
4+
use sqltk::parser::ast::{Expr, Function, Ident, ObjectName};
5+
use sqltk::{AsNodeKey, NodeKey, NodePath, Visitable};
6+
7+
use crate::{
8+
get_sql_function_def, CompoundIdent, EqlMapperError, RewriteRule, SqlFunction, Type, Value,
9+
};
10+
11+
use super::TransformationRule;
12+
13+
#[derive(Debug)]
14+
pub struct RewriteStandardSqlFnsOnEqlTypes<'ast> {
15+
node_types: Arc<HashMap<NodeKey<'ast>, Type>>,
16+
}
17+
18+
impl<'ast> RewriteStandardSqlFnsOnEqlTypes<'ast> {
19+
pub fn new(node_types: Arc<HashMap<NodeKey<'ast>, Type>>) -> Self {
20+
Self { node_types }
21+
}
22+
}
23+
24+
impl<'ast> TransformationRule<'ast> for RewriteStandardSqlFnsOnEqlTypes<'ast> {
25+
fn apply<N: Visitable>(
26+
&mut self,
27+
node_path: &NodePath<'ast>,
28+
target_node: &mut N,
29+
) -> Result<bool, EqlMapperError> {
30+
if self.would_edit(node_path, target_node) {
31+
if let Some((_expr, function)) = node_path.last_2_as::<Expr, Function>() {
32+
if matches!(
33+
self.node_types.get(&function.as_node_key()),
34+
Some(Type::Value(Value::Eql(_)))
35+
) {
36+
let function_name = CompoundIdent::from(&function.name.0);
37+
38+
if let Some(SqlFunction {
39+
rewrite_rule: RewriteRule::AsEqlFunction,
40+
..
41+
}) = get_sql_function_def(&function_name, &function.args)
42+
{
43+
let function = target_node.downcast_mut::<Function>().unwrap();
44+
let mut existing_name = mem::take(&mut function.name.0);
45+
existing_name.insert(0, Ident::new("eql_v1"));
46+
function.name = ObjectName(existing_name);
47+
}
48+
}
49+
}
50+
}
51+
52+
Ok(false)
53+
}
54+
55+
fn would_edit<N: Visitable>(&mut self, node_path: &NodePath<'ast>, _target_node: &N) -> bool {
56+
if let Some((_expr, function)) = node_path.last_2_as::<Expr, Function>() {
57+
if matches!(
58+
self.node_types.get(&function.as_node_key()),
59+
Some(Type::Value(Value::Eql(_)))
60+
) {
61+
let function_name = CompoundIdent::from(&function.name.0);
62+
63+
if let Some(SqlFunction {
64+
rewrite_rule: RewriteRule::AsEqlFunction,
65+
..
66+
}) = get_sql_function_def(&function_name, &function.args)
67+
{
68+
return true;
69+
}
70+
}
71+
}
72+
73+
false
74+
}
75+
}

packages/eql-mapper/src/type_checked_statement.rs

+2-4
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,7 @@ use sqltk::parser::ast::{self, Statement};
44
use sqltk::{AsNodeKey, NodeKey, Transformable};
55

66
use crate::{
7-
DryRunnable, EqlMapperError, EqlValue, FailOnPlaceholderChange, GroupByEqlCol, Param,
8-
PreserveEffectiveAliases, Projection, ReplacePlaintextEqlLiterals, TransformationRule, Type,
9-
UseEquivalentSqlFuncForEqlTypes, Value, WrapEqlColsInOrderByWithOreFn,
10-
WrapGroupedEqlColInAggregateFn,
7+
DryRunnable, EqlMapperError, EqlValue, FailOnPlaceholderChange, GroupByEqlCol, Param, PreserveEffectiveAliases, Projection, ReplacePlaintextEqlLiterals, RewriteStandardSqlFnsOnEqlTypes, TransformationRule, Type, UseEquivalentSqlFuncForEqlTypes, Value, WrapEqlColsInOrderByWithOreFn, WrapGroupedEqlColInAggregateFn
118
};
129

1310
/// A `TypeCheckedStatement` is returned from a successful call to [`crate::type_check`].
@@ -140,6 +137,7 @@ impl<'ast> TypeCheckedStatement<'ast> {
140137
encrypted_literals: HashMap<NodeKey<'ast>, sqltk::parser::ast::Value>,
141138
) -> DryRunnable<impl TransformationRule<'_>> {
142139
DryRunnable::new((
140+
RewriteStandardSqlFnsOnEqlTypes::new(Arc::clone(&self.node_types)),
143141
WrapGroupedEqlColInAggregateFn::new(Arc::clone(&self.node_types)),
144142
GroupByEqlCol::new(Arc::clone(&self.node_types)),
145143
WrapEqlColsInOrderByWithOreFn::new(Arc::clone(&self.node_types)),

0 commit comments

Comments
 (0)