Skip to content

Commit

Permalink
Merge pull request #427 from supabase/feat/default-args
Browse files Browse the repository at this point in the history
add support for calling UDFs with default arguments
  • Loading branch information
imor authored Oct 5, 2023
2 parents 7d708e1 + f26b2b1 commit aeb1801
Show file tree
Hide file tree
Showing 10 changed files with 439 additions and 96 deletions.
3 changes: 2 additions & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
## 1.4.0
- feature: citext type represented as a GraphQL String
- feature: Support for Postgres 16
- feature: Support for user defined function
- feature: Support for user defined functions

## master
- feature: Support for user defined functions with default arguments
42 changes: 41 additions & 1 deletion docs/functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,47 @@ Functions returning multiple rows of a table or view are exposed as [collections

A set returning function with any of its argument names clashing with argument names of a collection (`first`, `last`, `before`, `after`, `filter`, or `orderBy`) will not be exposed.

## Default Arguments

Functions with default arguments can have their default arguments omitted.

=== "Function"

```sql
create function "addNums"(a int default 1, b int default 2)
returns int
immutable
language sql
as $$ select a + b; $$;
```

=== "QueryType"

```graphql
type Query {
addNums(a: Int, b: Int): Int
}
```


=== "Query"

```graphql
query {
addNums(b: 20)
}
```

=== "Response"

```json
{
"data": {
"addNums": 21
}
}
```

## Limitations

The following features are not yet supported. Any function using these features is not exposed in the API:
Expand All @@ -239,7 +280,6 @@ The following features are not yet supported. Any function using these features
* Functions that accept a table's tuple type
* Overloaded functions
* Functions with a nameless argument
* Functions with a default argument
* Functions returning void
* Variadic functions
* Function that accept or return an array type
1 change: 1 addition & 0 deletions sql/load_sql_context.sql
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ select
'schema_name', pronamespace::regnamespace::text,
'arg_types', proargtypes::int[],
'arg_names', proargnames::text[],
'arg_defaults', pg_get_expr(proargdefaults, 0)::text,
'num_args', pronargs,
'num_default_args', pronargdefaults,
'arg_type_names', pp.proargtypes::regtype[]::text[],
Expand Down
32 changes: 17 additions & 15 deletions src/graphql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ impl FuncCallResponseType {
let inflected_name_to_sql_name: HashMap<String, (String, String)> = self
.function
.args()
.filter_map(|(_, arg_type_name, arg_name)| {
.filter_map(|(_, arg_type_name, arg_name, _)| {
arg_name.map(|arg_name| (arg_type_name, arg_name))
})
.map(|(arg_type_name, arg_name)| {
Expand Down Expand Up @@ -1316,27 +1316,29 @@ fn function_fields(schema: &Arc<__Schema>, volatilities: &[FunctionVolatility])
fn function_args(schema: &Arc<__Schema>, func: &Arc<Function>) -> Vec<__InputValue> {
let sql_types = &schema.context.types;
func.args()
.filter(|(_, _, arg_name)| !arg_name.is_none())
.filter_map(|(arg_type, _, arg_name)| match sql_types.get(&arg_type) {
Some(t) => {
if matches!(t.category, TypeCategory::Pseudo) {
None
} else {
Some((t, arg_name.unwrap()))
.filter(|(_, _, arg_name, _)| !arg_name.is_none())
.filter_map(
|(arg_type, _, arg_name, arg_default)| match sql_types.get(&arg_type) {
Some(t) => {
if matches!(t.category, TypeCategory::Pseudo) {
None
} else {
Some((t, arg_name.unwrap(), arg_default))
}
}
}
None => None,
})
.filter_map(|(arg_type, arg_name)| {
None => None,
},
)
.filter_map(|(arg_type, arg_name, arg_default)| {
arg_type
.to_graphql_type(None, false, schema)
.map(|t| (t, arg_name))
.map(|t| (t, arg_name, arg_default))
})
.map(|(arg_type, arg_name)| __InputValue {
.map(|(arg_type, arg_name, arg_default)| __InputValue {
name_: schema.graphql_function_arg_name(func, arg_name),
type_: arg_type,
description: None,
default_value: None,
default_value: arg_default,
sql_type: None,
})
.collect()
Expand Down
99 changes: 83 additions & 16 deletions src/sql_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ pub struct Function {
pub schema_name: String,
pub arg_types: Vec<u32>,
pub arg_names: Option<Vec<String>>,
pub arg_defaults: Option<String>,
pub num_args: u32,
pub num_default_args: u32,
pub arg_type_names: Vec<String>,
Expand All @@ -87,13 +88,14 @@ pub struct Function {
}

impl Function {
pub fn args(&self) -> impl Iterator<Item = (u32, &str, Option<&str>)> {
ArgsIterator {
index: 0,
arg_types: &self.arg_types,
arg_type_names: &self.arg_type_names,
arg_names: &self.arg_names,
}
pub fn args(&self) -> impl Iterator<Item = (u32, &str, Option<&str>, Option<String>)> {
ArgsIterator::new(
&self.arg_types,
&self.arg_type_names,
&self.arg_names,
&self.arg_defaults,
self.num_default_args,
)
}

pub fn function_names_to_count(all_functions: &[Arc<Function>]) -> HashMap<&String, u32> {
Expand All @@ -115,13 +117,12 @@ impl Function {
&& self.arg_types_are_supported(types)
&& !self.is_function_overloaded(function_name_to_count)
&& !self.has_a_nameless_arg()
&& !self.has_a_default_arg()
&& self.permissions.is_executable
&& !self.is_in_a_system_schema()
}

fn arg_types_are_supported(&self, types: &HashMap<u32, Arc<Type>>) -> bool {
self.args().all(|(arg_type, _, _)| {
self.args().all(|(arg_type, _, _, _)| {
if let Some(return_type) = types.get(&arg_type) {
return_type.category == TypeCategory::Other
} else {
Expand Down Expand Up @@ -149,11 +150,7 @@ impl Function {
}

fn has_a_nameless_arg(&self) -> bool {
self.args().any(|(_, _, arg_name)| arg_name.is_none())
}

fn has_a_default_arg(&self) -> bool {
self.num_default_args > 0
self.args().any(|(_, _, arg_name, _)| arg_name.is_none())
}

fn is_in_a_system_schema(&self) -> bool {
Expand All @@ -168,14 +165,83 @@ struct ArgsIterator<'a> {
arg_types: &'a [u32],
arg_type_names: &'a Vec<String>,
arg_names: &'a Option<Vec<String>>,
arg_defaults: Vec<Option<String>>,
}

impl<'a> ArgsIterator<'a> {
fn new(
arg_types: &'a [u32],
arg_type_names: &'a Vec<String>,
arg_names: &'a Option<Vec<String>>,
arg_defaults: &'a Option<String>,
num_default_args: u32,
) -> ArgsIterator<'a> {
ArgsIterator {
index: 0,
arg_types,
arg_type_names,
arg_names,
arg_defaults: Self::defaults(
arg_types,
arg_defaults,
num_default_args as usize,
arg_types.len(),
),
}
}

fn defaults(
arg_types: &'a [u32],
arg_defaults: &'a Option<String>,
num_default_args: usize,
num_total_args: usize,
) -> Vec<Option<String>> {
let mut defaults = vec![None; num_total_args];
let Some(arg_defaults) = arg_defaults else {
return defaults;
};

if num_default_args == 0 {
return defaults;
}

let default_strs: Vec<&str> = arg_defaults.split(',').collect();

if default_strs.len() != num_default_args {
return defaults;
}

debug_assert!(num_default_args <= num_total_args);
let start_idx = num_total_args - num_default_args;
for i in start_idx..num_total_args {
defaults[i] =
Self::sql_to_graphql_default(default_strs[i - start_idx], arg_types[i - start_idx])
}

defaults
}

fn sql_to_graphql_default(default_str: &str, type_oid: u32) -> Option<String> {
let trimmed = default_str.trim();
match type_oid {
21 | 23 => trimmed.parse::<i32>().ok().map(|i| i.to_string()),
16 => trimmed.parse::<bool>().ok().map(|i| i.to_string()),
700 | 701 => trimmed.parse::<f64>().ok().map(|i| i.to_string()),
25 => trimmed
.strip_suffix("::text")
.to_owned()
.map(|i| i.trim_matches(',').to_string()),
_ => None,
}
}
}

lazy_static! {
static ref TEXT_TYPE: String = "text".to_string();
}

impl<'a> Iterator for ArgsIterator<'a> {
type Item = (u32, &'a str, Option<&'a str>);
type Item = (u32, &'a str, Option<&'a str>, Option<String>);

fn next(&mut self) -> Option<Self::Item> {
if self.index < self.arg_types.len() {
Expand All @@ -196,8 +262,9 @@ impl<'a> Iterator for ArgsIterator<'a> {
if arg_type_name == "character" {
arg_type_name = &TEXT_TYPE;
}
let arg_default = self.arg_defaults[self.index].clone();
self.index += 1;
Some((arg_type, arg_type_name, arg_name))
Some((arg_type, arg_type_name, arg_name, arg_default))
} else {
None
}
Expand Down
3 changes: 2 additions & 1 deletion src/transpile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,8 @@ impl FunctionCallBuilder {
for (arg, arg_value) in &self.args_builder.args {
if let Some(arg) = arg {
let arg_clause = param_context.clause_for(arg_value, &arg.type_name)?;
arg_clauses.push(arg_clause);
let named_arg_clause = format!("{} => {}", quote_ident(&arg.name), arg_clause);
arg_clauses.push(named_arg_clause);
}
}

Expand Down
Loading

0 comments on commit aeb1801

Please sign in to comment.