diff --git a/docs/changelog.md b/docs/changelog.md index edcc69c8..5d44b821 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -45,6 +45,7 @@ ## 1.4.0 - feature: citext type represented as a GraphQL String - feature: Support for Postgres 16 -- feature: Support for user defined function +- feature: Support for user defined functions ## master +- feature: Support for user defined functions with default arguments diff --git a/docs/functions.md b/docs/functions.md index 9ea90963..3b9240f1 100644 --- a/docs/functions.md +++ b/docs/functions.md @@ -231,6 +231,47 @@ Functions returning multiple rows of a table or view are exposed as [collections A set returning function with any of its argument names clashing with argument names of a collection (`first`, `last`, `before`, `after`, `filter`, or `orderBy`) will not be exposed. +## Default Arguments + +Functions with default arguments can have their default arguments omitted. + +=== "Function" + + ```sql + create function "addNums"(a int default 1, b int default 2) + returns int + immutable + language sql + as $$ select a + b; $$; + ``` + +=== "QueryType" + + ```graphql + type Query { + addNums(a: Int, b: Int): Int + } + ``` + + +=== "Query" + + ```graphql + query { + addNums(b: 20) + } + ``` + +=== "Response" + + ```json + { + "data": { + "addNums": 21 + } + } + ``` + ## Limitations The following features are not yet supported. Any function using these features is not exposed in the API: @@ -239,7 +280,6 @@ The following features are not yet supported. Any function using these features * Functions that accept a table's tuple type * Overloaded functions * Functions with a nameless argument -* Functions with a default argument * Functions returning void * Variadic functions * Function that accept or return an array type diff --git a/sql/load_sql_context.sql b/sql/load_sql_context.sql index da5d2c0f..565e4e39 100644 --- a/sql/load_sql_context.sql +++ b/sql/load_sql_context.sql @@ -379,6 +379,7 @@ select 'schema_name', pronamespace::regnamespace::text, 'arg_types', proargtypes::int[], 'arg_names', proargnames::text[], + 'arg_defaults', pg_get_expr(proargdefaults, 0)::text, 'num_args', pronargs, 'num_default_args', pronargdefaults, 'arg_type_names', pp.proargtypes::regtype[]::text[], diff --git a/src/graphql.rs b/src/graphql.rs index 96c39d5d..c2b714d3 100644 --- a/src/graphql.rs +++ b/src/graphql.rs @@ -989,7 +989,7 @@ impl FuncCallResponseType { let inflected_name_to_sql_name: HashMap = self .function .args() - .filter_map(|(_, arg_type_name, arg_name)| { + .filter_map(|(_, arg_type_name, arg_name, _)| { arg_name.map(|arg_name| (arg_type_name, arg_name)) }) .map(|(arg_type_name, arg_name)| { @@ -1316,27 +1316,29 @@ fn function_fields(schema: &Arc<__Schema>, volatilities: &[FunctionVolatility]) fn function_args(schema: &Arc<__Schema>, func: &Arc) -> Vec<__InputValue> { let sql_types = &schema.context.types; func.args() - .filter(|(_, _, arg_name)| !arg_name.is_none()) - .filter_map(|(arg_type, _, arg_name)| match sql_types.get(&arg_type) { - Some(t) => { - if matches!(t.category, TypeCategory::Pseudo) { - None - } else { - Some((t, arg_name.unwrap())) + .filter(|(_, _, arg_name, _)| !arg_name.is_none()) + .filter_map( + |(arg_type, _, arg_name, arg_default)| match sql_types.get(&arg_type) { + Some(t) => { + if matches!(t.category, TypeCategory::Pseudo) { + None + } else { + Some((t, arg_name.unwrap(), arg_default)) + } } - } - None => None, - }) - .filter_map(|(arg_type, arg_name)| { + None => None, + }, + ) + .filter_map(|(arg_type, arg_name, arg_default)| { arg_type .to_graphql_type(None, false, schema) - .map(|t| (t, arg_name)) + .map(|t| (t, arg_name, arg_default)) }) - .map(|(arg_type, arg_name)| __InputValue { + .map(|(arg_type, arg_name, arg_default)| __InputValue { name_: schema.graphql_function_arg_name(func, arg_name), type_: arg_type, description: None, - default_value: None, + default_value: arg_default, sql_type: None, }) .collect() diff --git a/src/sql_types.rs b/src/sql_types.rs index d9aed396..9b973d00 100644 --- a/src/sql_types.rs +++ b/src/sql_types.rs @@ -74,6 +74,7 @@ pub struct Function { pub schema_name: String, pub arg_types: Vec, pub arg_names: Option>, + pub arg_defaults: Option, pub num_args: u32, pub num_default_args: u32, pub arg_type_names: Vec, @@ -87,13 +88,14 @@ pub struct Function { } impl Function { - pub fn args(&self) -> impl Iterator)> { - ArgsIterator { - index: 0, - arg_types: &self.arg_types, - arg_type_names: &self.arg_type_names, - arg_names: &self.arg_names, - } + pub fn args(&self) -> impl Iterator, Option)> { + ArgsIterator::new( + &self.arg_types, + &self.arg_type_names, + &self.arg_names, + &self.arg_defaults, + self.num_default_args, + ) } pub fn function_names_to_count(all_functions: &[Arc]) -> HashMap<&String, u32> { @@ -115,13 +117,12 @@ impl Function { && self.arg_types_are_supported(types) && !self.is_function_overloaded(function_name_to_count) && !self.has_a_nameless_arg() - && !self.has_a_default_arg() && self.permissions.is_executable && !self.is_in_a_system_schema() } fn arg_types_are_supported(&self, types: &HashMap>) -> bool { - self.args().all(|(arg_type, _, _)| { + self.args().all(|(arg_type, _, _, _)| { if let Some(return_type) = types.get(&arg_type) { return_type.category == TypeCategory::Other } else { @@ -149,11 +150,7 @@ impl Function { } fn has_a_nameless_arg(&self) -> bool { - self.args().any(|(_, _, arg_name)| arg_name.is_none()) - } - - fn has_a_default_arg(&self) -> bool { - self.num_default_args > 0 + self.args().any(|(_, _, arg_name, _)| arg_name.is_none()) } fn is_in_a_system_schema(&self) -> bool { @@ -168,6 +165,75 @@ struct ArgsIterator<'a> { arg_types: &'a [u32], arg_type_names: &'a Vec, arg_names: &'a Option>, + arg_defaults: Vec>, +} + +impl<'a> ArgsIterator<'a> { + fn new( + arg_types: &'a [u32], + arg_type_names: &'a Vec, + arg_names: &'a Option>, + arg_defaults: &'a Option, + num_default_args: u32, + ) -> ArgsIterator<'a> { + ArgsIterator { + index: 0, + arg_types, + arg_type_names, + arg_names, + arg_defaults: Self::defaults( + arg_types, + arg_defaults, + num_default_args as usize, + arg_types.len(), + ), + } + } + + fn defaults( + arg_types: &'a [u32], + arg_defaults: &'a Option, + num_default_args: usize, + num_total_args: usize, + ) -> Vec> { + let mut defaults = vec![None; num_total_args]; + let Some(arg_defaults) = arg_defaults else { + return defaults; + }; + + if num_default_args == 0 { + return defaults; + } + + let default_strs: Vec<&str> = arg_defaults.split(',').collect(); + + if default_strs.len() != num_default_args { + return defaults; + } + + debug_assert!(num_default_args <= num_total_args); + let start_idx = num_total_args - num_default_args; + for i in start_idx..num_total_args { + defaults[i] = + Self::sql_to_graphql_default(default_strs[i - start_idx], arg_types[i - start_idx]) + } + + defaults + } + + fn sql_to_graphql_default(default_str: &str, type_oid: u32) -> Option { + let trimmed = default_str.trim(); + match type_oid { + 21 | 23 => trimmed.parse::().ok().map(|i| i.to_string()), + 16 => trimmed.parse::().ok().map(|i| i.to_string()), + 700 | 701 => trimmed.parse::().ok().map(|i| i.to_string()), + 25 => trimmed + .strip_suffix("::text") + .to_owned() + .map(|i| i.trim_matches(',').to_string()), + _ => None, + } + } } lazy_static! { @@ -175,7 +241,7 @@ lazy_static! { } impl<'a> Iterator for ArgsIterator<'a> { - type Item = (u32, &'a str, Option<&'a str>); + type Item = (u32, &'a str, Option<&'a str>, Option); fn next(&mut self) -> Option { if self.index < self.arg_types.len() { @@ -196,8 +262,9 @@ impl<'a> Iterator for ArgsIterator<'a> { if arg_type_name == "character" { arg_type_name = &TEXT_TYPE; } + let arg_default = self.arg_defaults[self.index].clone(); self.index += 1; - Some((arg_type, arg_type_name, arg_name)) + Some((arg_type, arg_type_name, arg_name, arg_default)) } else { None } diff --git a/src/transpile.rs b/src/transpile.rs index 36eeeb54..b7094580 100644 --- a/src/transpile.rs +++ b/src/transpile.rs @@ -548,7 +548,8 @@ impl FunctionCallBuilder { for (arg, arg_value) in &self.args_builder.args { if let Some(arg) = arg { let arg_clause = param_context.clause_for(arg_value, &arg.type_name)?; - arg_clauses.push(arg_clause); + let named_arg_clause = format!("{} => {}", quote_ident(&arg.name), arg_clause); + arg_clauses.push(named_arg_clause); } } diff --git a/test/expected/function_calls.out b/test/expected/function_calls.out index 5cede6d0..5a178290 100644 --- a/test/expected/function_calls.out +++ b/test/expected/function_calls.out @@ -1940,4 +1940,219 @@ begin; (1 row) set search_path to default; + rollback to savepoint a; + create function add_smallints(a smallint default 1, b smallint default 2) + returns smallint language sql immutable + as $$ select a + b; $$; + create function func_with_defaults( + a smallint default 1, + b integer default 2, + c boolean default false, + d real default 3.14, + e double precision default 2.718, + f text default 'hello' + ) + returns smallint language sql immutable + as $$ select 0; $$; + select jsonb_pretty( + graphql.resolve($$ + query IntrospectionQuery { + __schema { + queryType { + name + fields { + name + args { + name + defaultValue + type { + name + } + } + } + } + } + } + $$) + ); + jsonb_pretty +----------------------------------------------------------- + { + + "data": { + + "__schema": { + + "queryType": { + + "name": "Query", + + "fields": [ + + { + + "args": [ + + { + + "name": "a", + + "type": { + + "name": "Int" + + }, + + "defaultValue": "1" + + }, + + { + + "name": "b", + + "type": { + + "name": "Int" + + }, + + "defaultValue": "2" + + } + + ], + + "name": "addSmallints" + + }, + + { + + "args": [ + + { + + "name": "a", + + "type": { + + "name": "Int" + + }, + + "defaultValue": "1" + + }, + + { + + "name": "b", + + "type": { + + "name": "Int" + + }, + + "defaultValue": "2" + + }, + + { + + "name": "c", + + "type": { + + "name": "Boolean" + + }, + + "defaultValue": "false" + + }, + + { + + "name": "d", + + "type": { + + "name": "Float" + + }, + + "defaultValue": "3.14" + + }, + + { + + "name": "e", + + "type": { + + "name": "Float" + + }, + + "defaultValue": "2.718" + + }, + + { + + "name": "f", + + "type": { + + "name": "String" + + }, + + "defaultValue": "'hello'"+ + } + + ], + + "name": "funcWithDefaults" + + }, + + { + + "args": [ + + { + + "name": "nodeId", + + "type": { + + "name": null + + }, + + "defaultValue": null + + } + + ], + + "name": "node" + + } + + ] + + } + + } + + } + + } +(1 row) + + select jsonb_pretty(graphql.resolve($$ + query { + addSmallints(a: 10, b: 20) + } + $$)); + jsonb_pretty +---------------------------- + { + + "data": { + + "addSmallints": 30+ + } + + } +(1 row) + + select jsonb_pretty(graphql.resolve($$ + query { + addSmallints(a: 10) + } + $$)); + jsonb_pretty +---------------------------- + { + + "data": { + + "addSmallints": 12+ + } + + } +(1 row) + + select jsonb_pretty(graphql.resolve($$ + query { + addSmallints(b: 20) + } + $$)); + jsonb_pretty +---------------------------- + { + + "data": { + + "addSmallints": 21+ + } + + } +(1 row) + + select jsonb_pretty(graphql.resolve($$ + query { + addSmallints + } + $$)); + jsonb_pretty +--------------------------- + { + + "data": { + + "addSmallints": 3+ + } + + } +(1 row) + + create function concat_text(a text, b text default 'world') + returns text language sql immutable + as $$ select a || b; $$; + select jsonb_pretty(graphql.resolve($$ + query { + concatText(b: "world!", a: "hello ") + } + $$)); + jsonb_pretty +-------------------------------------- + { + + "data": { + + "concatText": "hello world!"+ + } + + } +(1 row) + + select jsonb_pretty(graphql.resolve($$ + query { + concatText(a: "hello ") + } + $$)); + jsonb_pretty +------------------------------------- + { + + "data": { + + "concatText": "hello world"+ + } + + } +(1 row) + rollback; diff --git a/test/expected/function_calls_unsupported.out b/test/expected/function_calls_unsupported.out index 1e68dd35..bb567df3 100644 --- a/test/expected/function_calls_unsupported.out +++ b/test/expected/function_calls_unsupported.out @@ -184,47 +184,6 @@ begin; } (1 row) - -- functions with a default value - create function func_with_a_default_int(a int default 42) - returns int language sql immutable - as $$ select a; $$; - select jsonb_pretty(graphql.resolve($$ - query { - funcWithADefaultInt - } - $$)); - jsonb_pretty ------------------------------------------------------------------------------- - { + - "data": null, + - "errors": [ + - { + - "message": "Unknown field \"funcWithADefaultInt\" on type Query"+ - } + - ] + - } -(1 row) - - create function func_with_a_default_null_text(a text default null) - returns text language sql immutable - as $$ select a; $$; - select jsonb_pretty(graphql.resolve($$ - query { - funcWithADefaultNullText - } - $$)); - jsonb_pretty ------------------------------------------------------------------------------------ - { + - "data": null, + - "errors": [ + - { + - "message": "Unknown field \"funcWithADefaultNullText\" on type Query"+ - } + - ] + - } -(1 row) - create function func_accepting_array(a int[]) returns int language sql immutable as $$ select 0; $$; diff --git a/test/sql/function_calls.sql b/test/sql/function_calls.sql index ad457323..fee57898 100644 --- a/test/sql/function_calls.sql +++ b/test/sql/function_calls.sql @@ -684,4 +684,82 @@ begin; set search_path to default; + rollback to savepoint a; + + create function add_smallints(a smallint default 1, b smallint default 2) + returns smallint language sql immutable + as $$ select a + b; $$; + + create function func_with_defaults( + a smallint default 1, + b integer default 2, + c boolean default false, + d real default 3.14, + e double precision default 2.718, + f text default 'hello' + ) + returns smallint language sql immutable + as $$ select 0; $$; + + select jsonb_pretty( + graphql.resolve($$ + query IntrospectionQuery { + __schema { + queryType { + name + fields { + name + args { + name + defaultValue + type { + name + } + } + } + } + } + } + $$) + ); + + select jsonb_pretty(graphql.resolve($$ + query { + addSmallints(a: 10, b: 20) + } + $$)); + + select jsonb_pretty(graphql.resolve($$ + query { + addSmallints(a: 10) + } + $$)); + + select jsonb_pretty(graphql.resolve($$ + query { + addSmallints(b: 20) + } + $$)); + + select jsonb_pretty(graphql.resolve($$ + query { + addSmallints + } + $$)); + + create function concat_text(a text, b text default 'world') + returns text language sql immutable + as $$ select a || b; $$; + + select jsonb_pretty(graphql.resolve($$ + query { + concatText(b: "world!", a: "hello ") + } + $$)); + + select jsonb_pretty(graphql.resolve($$ + query { + concatText(a: "hello ") + } + $$)); rollback; diff --git a/test/sql/function_calls_unsupported.sql b/test/sql/function_calls_unsupported.sql index 51a619f7..b78d4291 100644 --- a/test/sql/function_calls_unsupported.sql +++ b/test/sql/function_calls_unsupported.sql @@ -107,27 +107,6 @@ begin; } $$)); - -- functions with a default value - create function func_with_a_default_int(a int default 42) - returns int language sql immutable - as $$ select a; $$; - - select jsonb_pretty(graphql.resolve($$ - query { - funcWithADefaultInt - } - $$)); - - create function func_with_a_default_null_text(a text default null) - returns text language sql immutable - as $$ select a; $$; - - select jsonb_pretty(graphql.resolve($$ - query { - funcWithADefaultNullText - } - $$)); - create function func_accepting_array(a int[]) returns int language sql immutable as $$ select 0; $$;