diff --git a/src/ast/mod.rs b/src/ast/mod.rs index a1d8ff6fc..5d55db528 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -2982,6 +2982,36 @@ impl From for Statement { } } +/// A representation of a `WHEN` arm with all the identifiers catched and the statements to execute +/// for the arm. +/// +/// Snowflake: +/// BigQuery: +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct ExceptionWhen { + pub idents: Vec, + pub statements: Vec, +} + +impl Display for ExceptionWhen { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "WHEN {idents} THEN", + idents = display_separated(&self.idents, " OR ") + )?; + + if !self.statements.is_empty() { + write!(f, " ")?; + format_statement_list(f, &self.statements)?; + } + + Ok(()) + } +} + /// A top-level statement (SELECT, INSERT, CREATE, etc.) #[allow(clippy::large_enum_variant)] #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] @@ -3670,17 +3700,20 @@ pub enum Statement { /// END; /// ``` statements: Vec, - /// Statements of an exception clause. + /// Exception handling with exception clauses. /// Example: /// ```sql - /// BEGIN - /// SELECT 1; - /// EXCEPTION WHEN ERROR THEN - /// SELECT 2; - /// SELECT 3; - /// END; + /// EXCEPTION + /// WHEN EXCEPTION_1 THEN + /// SELECT 2; + /// WHEN EXCEPTION_2 OR EXCEPTION_3 THEN + /// SELECT 3; + /// WHEN OTHER THEN + /// SELECT 4; + /// ``` /// - exception_statements: Option>, + /// + exception: Option>, /// TRUE if the statement has an `END` keyword. has_end_keyword: bool, }, @@ -5525,7 +5558,7 @@ impl fmt::Display for Statement { transaction, modifier, statements, - exception_statements, + exception, has_end_keyword, } => { if *syntax_begin { @@ -5547,11 +5580,10 @@ impl fmt::Display for Statement { write!(f, " ")?; format_statement_list(f, statements)?; } - if let Some(exception_statements) = exception_statements { - write!(f, " EXCEPTION WHEN ERROR THEN")?; - if !exception_statements.is_empty() { - write!(f, " ")?; - format_statement_list(f, exception_statements)?; + if let Some(exception_when) = exception { + write!(f, " EXCEPTION")?; + for when in exception_when { + write!(f, " {when}")?; } } if *has_end_keyword { diff --git a/src/dialect/bigquery.rs b/src/dialect/bigquery.rs index 68ca1390a..c2cd507ce 100644 --- a/src/dialect/bigquery.rs +++ b/src/dialect/bigquery.rs @@ -46,7 +46,11 @@ pub struct BigQueryDialect; impl Dialect for BigQueryDialect { fn parse_statement(&self, parser: &mut Parser) -> Option> { - self.maybe_parse_statement(parser) + if parser.parse_keyword(Keyword::BEGIN) { + return Some(parser.parse_begin_exception_end()); + } + + None } /// See @@ -141,48 +145,3 @@ impl Dialect for BigQueryDialect { true } } - -impl BigQueryDialect { - fn maybe_parse_statement(&self, parser: &mut Parser) -> Option> { - if parser.peek_keyword(Keyword::BEGIN) { - return Some(self.parse_begin(parser)); - } - None - } - - /// Parse a `BEGIN` statement. - /// - fn parse_begin(&self, parser: &mut Parser) -> Result { - parser.expect_keyword(Keyword::BEGIN)?; - - let statements = parser.parse_statement_list(&[Keyword::EXCEPTION, Keyword::END])?; - - let has_exception_when_clause = parser.parse_keywords(&[ - Keyword::EXCEPTION, - Keyword::WHEN, - Keyword::ERROR, - Keyword::THEN, - ]); - let exception_statements = if has_exception_when_clause { - if !parser.peek_keyword(Keyword::END) { - Some(parser.parse_statement_list(&[Keyword::END])?) - } else { - Some(Default::default()) - } - } else { - None - }; - - parser.expect_keyword(Keyword::END)?; - - Ok(Statement::StartTransaction { - begin: true, - statements, - exception_statements, - has_end_keyword: true, - transaction: None, - modifier: None, - modes: Default::default(), - }) - } -} diff --git a/src/dialect/snowflake.rs b/src/dialect/snowflake.rs index 990e2ea29..cbd7bcf2e 100644 --- a/src/dialect/snowflake.rs +++ b/src/dialect/snowflake.rs @@ -131,6 +131,10 @@ impl Dialect for SnowflakeDialect { } fn parse_statement(&self, parser: &mut Parser) -> Option> { + if parser.parse_keyword(Keyword::BEGIN) { + return Some(parser.parse_begin_exception_end()); + } + if parser.parse_keywords(&[Keyword::ALTER, Keyword::SESSION]) { // ALTER SESSION let set = match parser.parse_one_of_keywords(&[Keyword::SET, Keyword::UNSET]) { diff --git a/src/keywords.rs b/src/keywords.rs index f5c5e567e..cb6c7a6e2 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -646,6 +646,7 @@ define_keywords!( ORDER, ORDINALITY, ORGANIZATION, + OTHER, OUT, OUTER, OUTPUT, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 2c208e2e5..43b4765ae 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -15096,7 +15096,7 @@ impl<'a> Parser<'a> { transaction: Some(BeginTransactionKind::Transaction), modifier: None, statements: vec![], - exception_statements: None, + exception: None, has_end_keyword: false, }) } @@ -15128,11 +15128,56 @@ impl<'a> Parser<'a> { transaction, modifier, statements: vec![], - exception_statements: None, + exception: None, has_end_keyword: false, }) } + pub fn parse_begin_exception_end(&mut self) -> Result { + let statements = self.parse_statement_list(&[Keyword::EXCEPTION, Keyword::END])?; + + let exception = if self.parse_keyword(Keyword::EXCEPTION) { + let mut when = Vec::new(); + + // We can have multiple `WHEN` arms so we consume all cases until `END` + while !self.peek_keyword(Keyword::END) { + self.expect_keyword(Keyword::WHEN)?; + + // Each `WHEN` case can have one or more conditions, e.g. + // WHEN EXCEPTION_1 [OR EXCEPTION_2] THEN + // So we parse identifiers until the `THEN` keyword. + let mut idents = Vec::new(); + + while !self.parse_keyword(Keyword::THEN) { + let ident = self.parse_identifier()?; + idents.push(ident); + + self.maybe_parse(|p| p.expect_keyword(Keyword::OR))?; + } + + let statements = self.parse_statement_list(&[Keyword::WHEN, Keyword::END])?; + + when.push(ExceptionWhen { idents, statements }); + } + + Some(when) + } else { + None + }; + + self.expect_keyword(Keyword::END)?; + + Ok(Statement::StartTransaction { + begin: true, + statements, + exception, + has_end_keyword: true, + transaction: None, + modifier: None, + modes: Default::default(), + }) + } + pub fn parse_end(&mut self) -> Result { let modifier = if !self.dialect.supports_end_transaction_modifier() { None diff --git a/tests/sqlparser_bigquery.rs b/tests/sqlparser_bigquery.rs index b64f190f6..a4bdfc07f 100644 --- a/tests/sqlparser_bigquery.rs +++ b/tests/sqlparser_bigquery.rs @@ -261,10 +261,10 @@ fn parse_at_at_identifier() { #[test] fn parse_begin() { - let sql = r#"BEGIN SELECT 1; EXCEPTION WHEN ERROR THEN SELECT 2; END"#; + let sql = r#"BEGIN SELECT 1; EXCEPTION WHEN ERROR THEN SELECT 2; RAISE USING MESSAGE = FORMAT('ERR: %s', 'Bad'); END"#; let Statement::StartTransaction { statements, - exception_statements, + exception, has_end_keyword, .. } = bigquery().verified_stmt(sql) @@ -272,7 +272,10 @@ fn parse_begin() { unreachable!(); }; assert_eq!(1, statements.len()); - assert_eq!(1, exception_statements.unwrap().len()); + assert!(exception.is_some()); + + let exception = exception.unwrap(); + assert_eq!(1, exception.len()); assert!(has_end_keyword); bigquery().verified_stmt( diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index abcadb458..e4363ff6e 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -8592,8 +8592,11 @@ fn lateral_function() { #[test] fn parse_start_transaction() { let dialects = all_dialects_except(|d| - // BigQuery does not support this syntax - d.is::()); + // BigQuery and Snowflake does not support this syntax + // + // BigQuery: + // Snowflake: + d.is::() || d.is::()); match dialects .verified_stmt("START TRANSACTION READ ONLY, READ WRITE, ISOLATION LEVEL SERIALIZABLE") { diff --git a/tests/sqlparser_snowflake.rs b/tests/sqlparser_snowflake.rs index 7a164c248..453337c42 100644 --- a/tests/sqlparser_snowflake.rs +++ b/tests/sqlparser_snowflake.rs @@ -4066,3 +4066,67 @@ fn parse_connect_by_root_operator() { "sql parser error: Expected an expression, found: FROM" ); } + +#[test] +fn test_begin_exception_end() { + for sql in [ + "BEGIN SELECT 1; EXCEPTION WHEN OTHER THEN SELECT 2; RAISE; END", + "BEGIN SELECT 1; EXCEPTION WHEN OTHER THEN SELECT 2; RAISE EX_1; END", + "BEGIN SELECT 1; EXCEPTION WHEN FOO THEN SELECT 2; WHEN OTHER THEN SELECT 3; RAISE; END", + "BEGIN BEGIN SELECT 1; EXCEPTION WHEN OTHER THEN SELECT 2; RAISE; END; END", + ] { + snowflake().verified_stmt(sql); + } + + let sql = r#" +DECLARE + EXCEPTION_1 EXCEPTION (-20001, 'I caught the expected exception.'); + EXCEPTION_2 EXCEPTION (-20002, 'Not the expected exception!'); + EXCEPTION_3 EXCEPTION (-20003, 'The worst exception...'); +BEGIN + BEGIN + SELECT 1; + EXCEPTION + WHEN EXCEPTION_1 THEN + SELECT 1; + WHEN EXCEPTION_2 OR EXCEPTION_3 THEN + SELECT 2; + SELECT 3; + WHEN OTHER THEN + SELECT 4; + RAISE; + END; +END +"#; + + // Outer `BEGIN` of the two nested `BEGIN` statements. + let Statement::StartTransaction { mut statements, .. } = snowflake() + .parse_sql_statements(sql) + .unwrap() + .pop() + .unwrap() + else { + unreachable!(); + }; + + // Inner `BEGIN` of the two nested `BEGIN` statements. + let Statement::StartTransaction { + statements, + exception, + has_end_keyword, + .. + } = statements.pop().unwrap() + else { + unreachable!(); + }; + + assert_eq!(1, statements.len()); + assert!(has_end_keyword); + + let exception = exception.unwrap(); + assert_eq!(3, exception.len()); + assert_eq!(1, exception[0].idents.len()); + assert_eq!(1, exception[0].statements.len()); + assert_eq!(2, exception[1].idents.len()); + assert_eq!(2, exception[1].statements.len()); +}