Skip to content

Commit 9f772f0

Browse files
authored
Add support for Recursive CTEs (apache#278)
i.e. `WITH RECURSIVE ... AS ( ... ) SELECT` - see https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#with-clause Fixes apache#277
1 parent 54be391 commit 9f772f0

File tree

5 files changed

+73
-19
lines changed

5 files changed

+73
-19
lines changed

src/ast/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ pub use self::ddl::{
3030
pub use self::operator::{BinaryOperator, UnaryOperator};
3131
pub use self::query::{
3232
Cte, Fetch, Join, JoinConstraint, JoinOperator, Offset, OffsetRows, OrderByExpr, Query, Select,
33-
SelectItem, SetExpr, SetOperator, TableAlias, TableFactor, TableWithJoins, Top, Values,
33+
SelectItem, SetExpr, SetOperator, TableAlias, TableFactor, TableWithJoins, Top, Values, With,
3434
};
3535
pub use self::value::{DateTimeField, Value};
3636

src/ast/query.rs

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ use serde::{Deserialize, Serialize};
2020
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
2121
pub struct Query {
2222
/// WITH (common table expressions, or CTEs)
23-
pub ctes: Vec<Cte>,
24-
/// SELECT or UNION / EXCEPT / INTECEPT
23+
pub with: Option<With>,
24+
/// SELECT or UNION / EXCEPT / INTERSECT
2525
pub body: SetExpr,
2626
/// ORDER BY
2727
pub order_by: Vec<OrderByExpr>,
@@ -35,8 +35,8 @@ pub struct Query {
3535

3636
impl fmt::Display for Query {
3737
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
38-
if !self.ctes.is_empty() {
39-
write!(f, "WITH {} ", display_comma_separated(&self.ctes))?;
38+
if let Some(ref with) = self.with {
39+
write!(f, "{} ", with)?;
4040
}
4141
write!(f, "{}", self.body)?;
4242
if !self.order_by.is_empty() {
@@ -157,6 +157,24 @@ impl fmt::Display for Select {
157157
}
158158
}
159159

160+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
161+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
162+
pub struct With {
163+
pub recursive: bool,
164+
pub cte_tables: Vec<Cte>,
165+
}
166+
167+
impl fmt::Display for With {
168+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
169+
write!(
170+
f,
171+
"WITH {}{}",
172+
if self.recursive { "RECURSIVE " } else { "" },
173+
display_comma_separated(&self.cte_tables)
174+
)
175+
}
176+
}
177+
160178
/// A single CTE (used after `WITH`): `alias [(col1, col2, ...)] AS ( query )`
161179
/// The names in the column list before `AS`, when specified, replace the names
162180
/// of the columns returned by the query. The parser does not validate that the

src/parser.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1795,11 +1795,13 @@ impl<'a> Parser<'a> {
17951795
/// by `ORDER BY`. Unlike some other parse_... methods, this one doesn't
17961796
/// expect the initial keyword to be already consumed
17971797
pub fn parse_query(&mut self) -> Result<Query, ParserError> {
1798-
let ctes = if self.parse_keyword(Keyword::WITH) {
1799-
// TODO: optional RECURSIVE
1800-
self.parse_comma_separated(Parser::parse_cte)?
1798+
let with = if self.parse_keyword(Keyword::WITH) {
1799+
Some(With {
1800+
recursive: self.parse_keyword(Keyword::RECURSIVE),
1801+
cte_tables: self.parse_comma_separated(Parser::parse_cte)?,
1802+
})
18011803
} else {
1802-
vec![]
1804+
None
18031805
};
18041806

18051807
let body = self.parse_query_body(0)?;
@@ -1829,7 +1831,7 @@ impl<'a> Parser<'a> {
18291831
};
18301832

18311833
Ok(Query {
1832-
ctes,
1834+
with,
18331835
body,
18341836
limit,
18351837
order_by,

src/tokenizer.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,10 +382,7 @@ impl<'a> Tokenizer<'a> {
382382
// numbers
383383
'0'..='9' => {
384384
// TODO: https://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#unsigned-numeric-literal
385-
let s = peeking_take_while(chars, |ch| match ch {
386-
'0'..='9' | '.' => true,
387-
_ => false,
388-
});
385+
let s = peeking_take_while(chars, |ch| matches!(ch, '0'..='9' | '.'));
389386
Ok(Some(Token::Number(s)))
390387
}
391388
// punctuation

tests/sqlparser_common.rs

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2389,7 +2389,7 @@ fn parse_ctes() {
23892389

23902390
fn assert_ctes_in_select(expected: &[&str], sel: &Query) {
23912391
for (i, exp) in expected.iter().enumerate() {
2392-
let Cte { alias, query } = &sel.ctes[i];
2392+
let Cte { alias, query } = &sel.with.as_ref().unwrap().cte_tables[i];
23932393
assert_eq!(*exp, query.to_string());
23942394
assert_eq!(
23952395
if i == 0 {
@@ -2432,7 +2432,7 @@ fn parse_ctes() {
24322432
// CTE in a CTE...
24332433
let sql = &format!("WITH outer_cte AS ({}) SELECT * FROM outer_cte", with);
24342434
let select = verified_query(sql);
2435-
assert_ctes_in_select(&cte_sqls, &only(&select.ctes).query);
2435+
assert_ctes_in_select(&cte_sqls, &only(&select.with.unwrap().cte_tables).query);
24362436
}
24372437

24382438
#[test]
@@ -2441,10 +2441,47 @@ fn parse_cte_renamed_columns() {
24412441
let query = all_dialects().verified_query(sql);
24422442
assert_eq!(
24432443
vec![Ident::new("col1"), Ident::new("col2")],
2444-
query.ctes.first().unwrap().alias.columns
2444+
query
2445+
.with
2446+
.unwrap()
2447+
.cte_tables
2448+
.first()
2449+
.unwrap()
2450+
.alias
2451+
.columns
24452452
);
24462453
}
24472454

2455+
#[test]
2456+
fn parse_recursive_cte() {
2457+
let cte_query = "SELECT 1 UNION ALL SELECT val + 1 FROM nums WHERE val < 10".to_owned();
2458+
let sql = &format!(
2459+
"WITH RECURSIVE nums (val) AS ({}) SELECT * FROM nums",
2460+
cte_query
2461+
);
2462+
2463+
let cte_query = verified_query(&cte_query);
2464+
let query = verified_query(sql);
2465+
2466+
let with = query.with.as_ref().unwrap();
2467+
assert!(with.recursive);
2468+
assert_eq!(with.cte_tables.len(), 1);
2469+
let expected = Cte {
2470+
alias: TableAlias {
2471+
name: Ident {
2472+
value: "nums".to_string(),
2473+
quote_style: None,
2474+
},
2475+
columns: vec![Ident {
2476+
value: "val".to_string(),
2477+
quote_style: None,
2478+
}],
2479+
},
2480+
query: cte_query,
2481+
};
2482+
assert_eq!(with.cte_tables.first().unwrap(), &expected);
2483+
}
2484+
24482485
#[test]
24492486
fn parse_derived_tables() {
24502487
let sql = "SELECT a.x, b.y FROM (SELECT x FROM foo) AS a CROSS JOIN (SELECT y FROM bar) AS b";
@@ -3266,8 +3303,8 @@ fn parse_drop_index() {
32663303
fn all_keywords_sorted() {
32673304
// assert!(ALL_KEYWORDS.is_sorted())
32683305
let mut copy = Vec::from(ALL_KEYWORDS);
3269-
copy.sort();
3270-
assert!(copy == ALL_KEYWORDS)
3306+
copy.sort_unstable();
3307+
assert_eq!(copy, ALL_KEYWORDS)
32713308
}
32723309

32733310
fn parse_sql_statements(sql: &str) -> Result<Vec<Statement>, ParserError> {

0 commit comments

Comments
 (0)