Skip to content

Commit 2eb1e7b

Browse files
authored
Add CREATE FUNCTION support for SQL Server (#1808)
1 parent 945f8e0 commit 2eb1e7b

10 files changed

+313
-50
lines changed

src/ast/ddl.rs

+9-1
Original file line numberDiff line numberDiff line change
@@ -2157,6 +2157,10 @@ impl fmt::Display for ClusteredBy {
21572157
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
21582158
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
21592159
pub struct CreateFunction {
2160+
/// True if this is a `CREATE OR ALTER FUNCTION` statement
2161+
///
2162+
/// [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql?view=sql-server-ver16#or-alter)
2163+
pub or_alter: bool,
21602164
pub or_replace: bool,
21612165
pub temporary: bool,
21622166
pub if_not_exists: bool,
@@ -2219,9 +2223,10 @@ impl fmt::Display for CreateFunction {
22192223
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
22202224
write!(
22212225
f,
2222-
"CREATE {or_replace}{temp}FUNCTION {if_not_exists}{name}",
2226+
"CREATE {or_alter}{or_replace}{temp}FUNCTION {if_not_exists}{name}",
22232227
name = self.name,
22242228
temp = if self.temporary { "TEMPORARY " } else { "" },
2229+
or_alter = if self.or_alter { "OR ALTER " } else { "" },
22252230
or_replace = if self.or_replace { "OR REPLACE " } else { "" },
22262231
if_not_exists = if self.if_not_exists {
22272232
"IF NOT EXISTS "
@@ -2272,6 +2277,9 @@ impl fmt::Display for CreateFunction {
22722277
if let Some(CreateFunctionBody::AsAfterOptions(function_body)) = &self.function_body {
22732278
write!(f, " AS {function_body}")?;
22742279
}
2280+
if let Some(CreateFunctionBody::AsBeginEnd(bes)) = &self.function_body {
2281+
write!(f, " AS {bes}")?;
2282+
}
22752283
Ok(())
22762284
}
22772285
}

src/ast/mod.rs

+89-11
Original file line numberDiff line numberDiff line change
@@ -2293,18 +2293,14 @@ pub enum ConditionalStatements {
22932293
/// SELECT 1; SELECT 2; SELECT 3; ...
22942294
Sequence { statements: Vec<Statement> },
22952295
/// BEGIN SELECT 1; SELECT 2; SELECT 3; ... END
2296-
BeginEnd {
2297-
begin_token: AttachedToken,
2298-
statements: Vec<Statement>,
2299-
end_token: AttachedToken,
2300-
},
2296+
BeginEnd(BeginEndStatements),
23012297
}
23022298

23032299
impl ConditionalStatements {
23042300
pub fn statements(&self) -> &Vec<Statement> {
23052301
match self {
23062302
ConditionalStatements::Sequence { statements } => statements,
2307-
ConditionalStatements::BeginEnd { statements, .. } => statements,
2303+
ConditionalStatements::BeginEnd(bes) => &bes.statements,
23082304
}
23092305
}
23102306
}
@@ -2318,15 +2314,44 @@ impl fmt::Display for ConditionalStatements {
23182314
}
23192315
Ok(())
23202316
}
2321-
ConditionalStatements::BeginEnd { statements, .. } => {
2322-
write!(f, "BEGIN ")?;
2323-
format_statement_list(f, statements)?;
2324-
write!(f, " END")
2325-
}
2317+
ConditionalStatements::BeginEnd(bes) => write!(f, "{}", bes),
23262318
}
23272319
}
23282320
}
23292321

2322+
/// Represents a list of statements enclosed within `BEGIN` and `END` keywords.
2323+
/// Example:
2324+
/// ```sql
2325+
/// BEGIN
2326+
/// SELECT 1;
2327+
/// SELECT 2;
2328+
/// END
2329+
/// ```
2330+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
2331+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
2332+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
2333+
pub struct BeginEndStatements {
2334+
pub begin_token: AttachedToken,
2335+
pub statements: Vec<Statement>,
2336+
pub end_token: AttachedToken,
2337+
}
2338+
2339+
impl fmt::Display for BeginEndStatements {
2340+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2341+
let BeginEndStatements {
2342+
begin_token: AttachedToken(begin_token),
2343+
statements,
2344+
end_token: AttachedToken(end_token),
2345+
} = self;
2346+
2347+
write!(f, "{begin_token} ")?;
2348+
if !statements.is_empty() {
2349+
format_statement_list(f, statements)?;
2350+
}
2351+
write!(f, " {end_token}")
2352+
}
2353+
}
2354+
23302355
/// A `RAISE` statement.
23312356
///
23322357
/// Examples:
@@ -3615,6 +3640,7 @@ pub enum Statement {
36153640
/// 1. [Hive](https://cwiki.apache.org/confluence/display/hive/languagemanual+ddl#LanguageManualDDL-Create/Drop/ReloadFunction)
36163641
/// 2. [PostgreSQL](https://www.postgresql.org/docs/15/sql-createfunction.html)
36173642
/// 3. [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement)
3643+
/// 4. [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql)
36183644
CreateFunction(CreateFunction),
36193645
/// CREATE TRIGGER
36203646
///
@@ -4061,6 +4087,12 @@ pub enum Statement {
40614087
///
40624088
/// See: <https://learn.microsoft.com/en-us/sql/t-sql/statements/print-transact-sql>
40634089
Print(PrintStatement),
4090+
/// ```sql
4091+
/// RETURN [ expression ]
4092+
/// ```
4093+
///
4094+
/// See [ReturnStatement]
4095+
Return(ReturnStatement),
40644096
}
40654097

40664098
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
@@ -5753,6 +5785,7 @@ impl fmt::Display for Statement {
57535785
Ok(())
57545786
}
57555787
Statement::Print(s) => write!(f, "{s}"),
5788+
Statement::Return(r) => write!(f, "{r}"),
57565789
Statement::List(command) => write!(f, "LIST {command}"),
57575790
Statement::Remove(command) => write!(f, "REMOVE {command}"),
57585791
}
@@ -8355,6 +8388,7 @@ impl fmt::Display for FunctionDeterminismSpecifier {
83558388
///
83568389
/// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#syntax_11
83578390
/// [PostgreSQL]: https://www.postgresql.org/docs/15/sql-createfunction.html
8391+
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
83588392
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
83598393
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
83608394
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
@@ -8383,6 +8417,22 @@ pub enum CreateFunctionBody {
83838417
///
83848418
/// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#syntax_11
83858419
AsAfterOptions(Expr),
8420+
/// Function body with statements before the `RETURN` keyword.
8421+
///
8422+
/// Example:
8423+
/// ```sql
8424+
/// CREATE FUNCTION my_scalar_udf(a INT, b INT)
8425+
/// RETURNS INT
8426+
/// AS
8427+
/// BEGIN
8428+
/// DECLARE c INT;
8429+
/// SET c = a + b;
8430+
/// RETURN c;
8431+
/// END
8432+
/// ```
8433+
///
8434+
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
8435+
AsBeginEnd(BeginEndStatements),
83868436
/// Function body expression using the 'RETURN' keyword.
83878437
///
83888438
/// Example:
@@ -9231,6 +9281,34 @@ impl fmt::Display for PrintStatement {
92319281
}
92329282
}
92339283

9284+
/// Represents a `Return` statement.
9285+
///
9286+
/// [MsSql triggers](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-trigger-transact-sql)
9287+
/// [MsSql functions](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql)
9288+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
9289+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9290+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
9291+
pub struct ReturnStatement {
9292+
pub value: Option<ReturnStatementValue>,
9293+
}
9294+
9295+
impl fmt::Display for ReturnStatement {
9296+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
9297+
match &self.value {
9298+
Some(ReturnStatementValue::Expr(expr)) => write!(f, "RETURN {}", expr),
9299+
None => write!(f, "RETURN"),
9300+
}
9301+
}
9302+
}
9303+
9304+
/// Variants of a `RETURN` statement
9305+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
9306+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9307+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
9308+
pub enum ReturnStatementValue {
9309+
Expr(Expr),
9310+
}
9311+
92349312
#[cfg(test)]
92359313
mod tests {
92369314
use super::*;

src/ast/spans.rs

+19-7
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ use crate::tokenizer::Span;
2323
use super::{
2424
dcl::SecondaryRoles, value::ValueWithSpan, AccessExpr, AlterColumnOperation,
2525
AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, AttachedToken,
26-
CaseStatement, CloseCursor, ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef,
27-
ConditionalStatementBlock, ConditionalStatements, ConflictTarget, ConnectBy,
26+
BeginEndStatements, CaseStatement, CloseCursor, ClusteredIndex, ColumnDef, ColumnOption,
27+
ColumnOptionDef, ConditionalStatementBlock, ConditionalStatements, ConflictTarget, ConnectBy,
2828
ConstraintCharacteristics, CopySource, CreateIndex, CreateTable, CreateTableOptions, Cte,
2929
Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr, ExprWithAlias, Fetch, FromTable,
3030
Function, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList,
@@ -520,6 +520,7 @@ impl Spanned for Statement {
520520
Statement::RenameTable { .. } => Span::empty(),
521521
Statement::RaisError { .. } => Span::empty(),
522522
Statement::Print { .. } => Span::empty(),
523+
Statement::Return { .. } => Span::empty(),
523524
Statement::List(..) | Statement::Remove(..) => Span::empty(),
524525
}
525526
}
@@ -778,11 +779,7 @@ impl Spanned for ConditionalStatements {
778779
ConditionalStatements::Sequence { statements } => {
779780
union_spans(statements.iter().map(|s| s.span()))
780781
}
781-
ConditionalStatements::BeginEnd {
782-
begin_token: AttachedToken(start),
783-
statements: _,
784-
end_token: AttachedToken(end),
785-
} => union_spans([start.span, end.span].into_iter()),
782+
ConditionalStatements::BeginEnd(bes) => bes.span(),
786783
}
787784
}
788785
}
@@ -2282,6 +2279,21 @@ impl Spanned for TableObject {
22822279
}
22832280
}
22842281

2282+
impl Spanned for BeginEndStatements {
2283+
fn span(&self) -> Span {
2284+
let BeginEndStatements {
2285+
begin_token,
2286+
statements,
2287+
end_token,
2288+
} = self;
2289+
union_spans(
2290+
core::iter::once(begin_token.0.span)
2291+
.chain(statements.iter().map(|i| i.span()))
2292+
.chain(core::iter::once(end_token.0.span)),
2293+
)
2294+
}
2295+
}
2296+
22852297
#[cfg(test)]
22862298
pub mod tests {
22872299
use crate::dialect::{Dialect, GenericDialect, SnowflakeDialect};

src/dialect/mssql.rs

+11-5
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
// under the License.
1717

1818
use crate::ast::helpers::attached_token::AttachedToken;
19-
use crate::ast::{ConditionalStatementBlock, ConditionalStatements, IfStatement, Statement};
19+
use crate::ast::{
20+
BeginEndStatements, ConditionalStatementBlock, ConditionalStatements, IfStatement, Statement,
21+
};
2022
use crate::dialect::Dialect;
2123
use crate::keywords::{self, Keyword};
2224
use crate::parser::{Parser, ParserError};
@@ -149,11 +151,11 @@ impl MsSqlDialect {
149151
start_token: AttachedToken(if_token),
150152
condition: Some(condition),
151153
then_token: None,
152-
conditional_statements: ConditionalStatements::BeginEnd {
154+
conditional_statements: ConditionalStatements::BeginEnd(BeginEndStatements {
153155
begin_token: AttachedToken(begin_token),
154156
statements,
155157
end_token: AttachedToken(end_token),
156-
},
158+
}),
157159
}
158160
} else {
159161
let stmt = parser.parse_statement()?;
@@ -167,8 +169,10 @@ impl MsSqlDialect {
167169
}
168170
};
169171

172+
let mut prior_statement_ended_with_semi_colon = false;
170173
while let Token::SemiColon = parser.peek_token_ref().token {
171174
parser.advance_token();
175+
prior_statement_ended_with_semi_colon = true;
172176
}
173177

174178
let mut else_block = None;
@@ -182,11 +186,11 @@ impl MsSqlDialect {
182186
start_token: AttachedToken(else_token),
183187
condition: None,
184188
then_token: None,
185-
conditional_statements: ConditionalStatements::BeginEnd {
189+
conditional_statements: ConditionalStatements::BeginEnd(BeginEndStatements {
186190
begin_token: AttachedToken(begin_token),
187191
statements,
188192
end_token: AttachedToken(end_token),
189-
},
193+
}),
190194
});
191195
} else {
192196
let stmt = parser.parse_statement()?;
@@ -199,6 +203,8 @@ impl MsSqlDialect {
199203
},
200204
});
201205
}
206+
} else if prior_statement_ended_with_semi_colon {
207+
parser.prev_token();
202208
}
203209

204210
Ok(Statement::If(IfStatement {

0 commit comments

Comments
 (0)