From 1f381f42a555c7016a95f33a6af5a2b83ad6bb1a Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Thu, 15 Dec 2022 16:30:55 +0000 Subject: [PATCH 1/7] Add derive based AST visitor --- .gitignore | 1 + Cargo.toml | 4 +- derive/Cargo.toml | 23 ++++ derive/README.md | 79 +++++++++++++ derive/src/lib.rs | 169 +++++++++++++++++++++++++++ src/ast/data_type.rs | 8 ++ src/ast/ddl.rs | 12 ++ src/ast/helpers/stmt_create_table.rs | 4 + src/ast/mod.rs | 102 ++++++++++++++-- src/ast/operator.rs | 6 + src/ast/query.rs | 31 +++++ src/ast/value.rs | 7 ++ src/ast/visitor.rs | 138 ++++++++++++++++++++++ src/keywords.rs | 4 + src/lib.rs | 3 + src/tokenizer.rs | 6 + 16 files changed, 586 insertions(+), 11 deletions(-) create mode 100644 derive/Cargo.toml create mode 100644 derive/README.md create mode 100644 derive/src/lib.rs create mode 100644 src/ast/visitor.rs diff --git a/.gitignore b/.gitignore index baccda415..d41369207 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ # will have compiled files and executables /target/ /sqlparser_bench/target/ +/derive/target/ # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries # More information here http://doc.crates.io/guide.html#cargotoml-vs-cargolock diff --git a/Cargo.toml b/Cargo.toml index 2355f4646..a3376a673 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ version = "0.28.0" authors = ["Andy Grove "] homepage = "https://github.com/sqlparser-rs/sqlparser-rs" documentation = "https://docs.rs/sqlparser/" -keywords = [ "ansi", "sql", "lexer", "parser" ] +keywords = ["ansi", "sql", "lexer", "parser"] repository = "https://github.com/sqlparser-rs/sqlparser-rs" license = "Apache-2.0" include = [ @@ -23,6 +23,7 @@ default = ["std"] std = [] # Enable JSON output in the `cli` example: json_example = ["serde_json", "serde"] +visitor = ["sqlparser_derive"] [dependencies] bigdecimal = { version = "0.3", features = ["serde"], optional = true } @@ -32,6 +33,7 @@ serde = { version = "1.0", features = ["derive"], optional = true } # of dev-dependencies because of # https://github.com/rust-lang/cargo/issues/1596 serde_json = { version = "1.0", optional = true } +sqlparser_derive = { version = "0.1", path = "derive", optional = true } [dev-dependencies] simple_logger = "4.0" diff --git a/derive/Cargo.toml b/derive/Cargo.toml new file mode 100644 index 000000000..221437a9e --- /dev/null +++ b/derive/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "sqlparser_derive" +description = "proc macro for sqlparser" +version = "0.1.0" +authors = ["Andy Grove "] +homepage = "https://github.com/sqlparser-rs/sqlparser-rs" +documentation = "https://docs.rs/sqlparser/" +keywords = ["ansi", "sql", "lexer", "parser"] +repository = "https://github.com/sqlparser-rs/sqlparser-rs" +license = "Apache-2.0" +include = [ + "src/**/*.rs", + "Cargo.toml", +] +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +syn = "1.0" +proc-macro2 = "1.0" +quote = "1.0" diff --git a/derive/README.md b/derive/README.md new file mode 100644 index 000000000..3cf34faeb --- /dev/null +++ b/derive/README.md @@ -0,0 +1,79 @@ +# SQL Parser Derive Macro + +## Visit + +This crate contains a procedural macro that can automatically derive implementations of the `Visit` trait + +```rust +#[derive(Visit)] +struct Foo { + boolean: bool, + bar: Bar, +} + +#[derive(Visit)] +enum Bar { + A(), + B(String, bool), + C { named: i32 }, +} +``` + +Will generate code akin to + +```rust +impl Visit for Foo { + fn visit(&self, visitor: &mut V) -> ControlFlow { + self.boolean.visit(visitor)?; + self.bar.visit(visitor)?; + ControlFlow::Continue(()) + } +} + +impl Visit for Bar { + fn visit(&self, visitor: &mut V) -> ControlFlow { + match self { + Self::A() => {} + Self::B(_1, _2) => { + _1.visit(visitor)?; + _2.visit(visitor)?; + } + Self::C { named } => { + named.visit(visitor)?; + } + } + ControlFlow::Continue(()) + } +} +``` + +Additionally certain types may wish to call a corresponding method on visitor before recursing + +```rust +#[derive(Visit)] +#[visit(with = "visit_expr")] +enum Expr { + A(), + B(String, #[visit(with = "visit_table")] ObjectName, bool), +} +``` + +Will generate + +```rust +impl Visit for Bar { + fn visit(&self, visitor: &mut V) -> ControlFlow { + visitor.visit_expr(self)?; + match self { + Self::A() => {} + Self::B(_1, _2, _3) => { + _1.visit(visitor)?; + visitor.visit_table(_3)?; + _2.visit(visitor)?; + _3.visit(visitor)?; + } + } + ControlFlow::Continue(()) + } +} +``` diff --git a/derive/src/lib.rs b/derive/src/lib.rs new file mode 100644 index 000000000..59ea359b4 --- /dev/null +++ b/derive/src/lib.rs @@ -0,0 +1,169 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote, quote_spanned, ToTokens}; +use syn::spanned::Spanned; +use syn::{ + parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics, + Ident, Index, Lit, Meta, MetaNameValue, NestedMeta, +}; + +#[proc_macro_derive(Visit, attributes(visit))] +pub fn derive_visit(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + // Parse the input tokens into a syntax tree. + let input = parse_macro_input!(input as DeriveInput); + let name = input.ident; + + let attributes = Attributes::parse(&input.attrs); + // Add a bound `T: HeapSize` to every type parameter T. + let generics = add_trait_bounds(input.generics); + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let with = attributes.with.map(|m| quote!(visitor.#m(self)?;)); + let children = visit_children(&input.data); + + let expanded = quote! { + // The generated impl. + impl #impl_generics sqlparser::ast::Visit for #name #ty_generics #where_clause { + fn visit(&self, visitor: &mut V) -> ::std::ops::ControlFlow { + #with + #children + ::std::ops::ControlFlow::Continue(()) + } + } + }; + + proc_macro::TokenStream::from(expanded) +} + +/// Parses attributes that can be provided to this macro +/// +/// `#[visit(leaf, with = "visit_expr")]` +#[derive(Default)] +struct Attributes { + /// Content for the `with` attribute + with: Option, +} + +impl Attributes { + fn parse(attrs: &[Attribute]) -> Self { + let mut out = Self::default(); + for attr in attrs.iter().filter(|a| a.path.is_ident("visit")) { + let meta = attr.parse_meta().expect("visit attribute"); + match meta { + Meta::List(l) => { + for nested in &l.nested { + match nested { + NestedMeta::Meta(Meta::NameValue(v)) => out.parse_name_value(v), + _ => panic!("Expected #[visit(key = \"value\")]"), + } + } + } + _ => panic!("Expected #[visit(...)]"), + } + } + out + } + + /// Updates self with a name value attribute + fn parse_name_value(&mut self, v: &MetaNameValue) { + if v.path.is_ident("with") { + match &v.lit { + Lit::Str(s) => self.with = Some(format_ident!("{}", s.value(), span = s.span())), + _ => panic!("Expected a string value, got {}", v.lit.to_token_stream()), + } + return; + } + panic!("Unrecognised kv attribute {}", v.path.to_token_stream()) + } +} + +// Add a bound `T: Visit` to every type parameter T. +fn add_trait_bounds(mut generics: Generics) -> Generics { + for param in &mut generics.params { + if let GenericParam::Type(ref mut type_param) = *param { + type_param.bounds.push(parse_quote!(sqlparser::ast::Visit)); + } + } + generics +} + +// Generate the body of the visit implementation for the given type +fn visit_children(data: &Data) -> TokenStream { + match data { + Data::Struct(data) => match &data.fields { + Fields::Named(fields) => { + let recurse = fields.named.iter().map(|f| { + let name = &f.ident; + let attributes = Attributes::parse(&f.attrs); + let with = attributes.with.map(|m| quote!(visitor.#m(&self.#name)?;)); + quote_spanned!(f.span() => #with sqlparser::ast::Visit::visit(&self.#name, visitor)?;) + }); + quote! { + #(#recurse)* + } + } + Fields::Unnamed(fields) => { + let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| { + let index = Index::from(i); + let attributes = Attributes::parse(&f.attrs); + let with = attributes.with.map(|m| quote!(visitor.#m(&self.#index)?;)); + quote_spanned!(f.span() => #with sqlparser::ast::Visit::visit(&self.#index, visitor)?;) + }); + quote! { + #(#recurse)* + } + } + Fields::Unit => { + quote!() + } + }, + Data::Enum(data) => { + let statements = data.variants.iter().map(|v| { + let name = &v.ident; + match &v.fields { + Fields::Named(fields) => { + let names = fields.named.iter().map(|f| &f.ident); + let visit = fields.named.iter().map(|f| { + let name = &f.ident; + let attributes = Attributes::parse(&f.attrs); + let with = attributes.with.map(|m| quote!(visitor.#m(&#name)?;)); + quote_spanned!(f.span() => #with sqlparser::ast::Visit::visit(#name, visitor)?) + }); + + quote!( + Self::#name { #(#names),* } => { + #(#visit);* + } + ) + } + Fields::Unnamed(fields) => { + let names = fields.unnamed.iter().enumerate().map(|(i, f)| format_ident!("_{}", i, span = f.span())); + let visit = fields.unnamed.iter().enumerate().map(|(i, f)| { + let name = format_ident!("_{}", i); + let attributes = Attributes::parse(&f.attrs); + let with = attributes.with.map(|m| quote!(visitor.#m(&#name)?;)); + quote_spanned!(f.span() => #with sqlparser::ast::Visit::visit(#name, visitor)?) + }); + + quote! { + Self::#name ( #(#names),*) => { + #(#visit);* + } + } + } + Fields::Unit => { + quote! { + Self::#name => {} + } + } + } + }); + + quote! { + match self { + #(#statements),* + } + } + } + Data::Union(_) => unimplemented!(), + } +} diff --git a/src/ast/data_type.rs b/src/ast/data_type.rs index 1353eca90..af8320d8f 100644 --- a/src/ast/data_type.rs +++ b/src/ast/data_type.rs @@ -17,6 +17,9 @@ use core::fmt; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +#[cfg(feature = "visitor")] +use sqlparser_derive::Visit; + use crate::ast::ObjectName; use super::value::escape_single_quote_string; @@ -24,6 +27,7 @@ use super::value::escape_single_quote_string; /// SQL data types #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum DataType { /// Fixed-length character type e.g. CHARACTER(10) Character(Option), @@ -337,6 +341,7 @@ fn format_datetime_precision_and_tz( /// guarantee compatibility with the input query we must maintain its exact information. #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum TimezoneInfo { /// No information about time zone. E.g., TIMESTAMP None, @@ -384,6 +389,7 @@ impl fmt::Display for TimezoneInfo { /// [standard]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#exact-numeric-type #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum ExactNumberInfo { /// No additional information e.g. `DECIMAL` None, @@ -414,6 +420,7 @@ impl fmt::Display for ExactNumberInfo { /// [1]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#character-length #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub struct CharacterLength { /// Default (if VARYING) or maximum (if not VARYING) length pub length: u64, @@ -436,6 +443,7 @@ impl fmt::Display for CharacterLength { /// [1]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#char-length-units #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum CharLengthUnits { /// CHARACTERS unit Characters, diff --git a/src/ast/ddl.rs b/src/ast/ddl.rs index 00bf83aba..e5b3d6ded 100644 --- a/src/ast/ddl.rs +++ b/src/ast/ddl.rs @@ -20,6 +20,9 @@ use core::fmt; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +#[cfg(feature = "visitor")] +use sqlparser_derive::Visit; + use crate::ast::value::escape_single_quote_string; use crate::ast::{display_comma_separated, display_separated, DataType, Expr, Ident, ObjectName}; use crate::tokenizer::Token; @@ -27,6 +30,7 @@ use crate::tokenizer::Token; /// An `ALTER TABLE` (`Statement::AlterTable`) operation #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum AlterTableOperation { /// `ADD ` AddConstraint(TableConstraint), @@ -203,6 +207,7 @@ impl fmt::Display for AlterTableOperation { /// An `ALTER COLUMN` (`Statement::AlterTable`) operation #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum AlterColumnOperation { /// `SET NOT NULL` SetNotNull, @@ -246,6 +251,7 @@ impl fmt::Display for AlterColumnOperation { /// `ALTER TABLE ADD ` statement. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum TableConstraint { /// `[ CONSTRAINT ] { PRIMARY KEY | UNIQUE } ()` Unique { @@ -409,6 +415,7 @@ impl fmt::Display for TableConstraint { /// [1]: https://dev.mysql.com/doc/refman/8.0/en/create-table.html #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum KeyOrIndexDisplay { /// Nothing to display None, @@ -444,6 +451,7 @@ impl fmt::Display for KeyOrIndexDisplay { /// [3]: https://www.postgresql.org/docs/14/sql-createindex.html #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum IndexType { BTree, Hash, @@ -462,6 +470,7 @@ impl fmt::Display for IndexType { /// SQL column definition #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub struct ColumnDef { pub name: Ident, pub data_type: DataType, @@ -497,6 +506,7 @@ impl fmt::Display for ColumnDef { /// "column options," and we allow any column option to be named. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub struct ColumnOptionDef { pub name: Option, pub option: ColumnOption, @@ -512,6 +522,7 @@ impl fmt::Display for ColumnOptionDef { /// TABLE` statement. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum ColumnOption { /// `NULL` Null, @@ -599,6 +610,7 @@ fn display_constraint_name(name: &'_ Option) -> impl fmt::Display + '_ { /// Used in foreign key constraints in `ON UPDATE` and `ON DELETE` options. #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum ReferentialAction { Restrict, Cascade, diff --git a/src/ast/helpers/stmt_create_table.rs b/src/ast/helpers/stmt_create_table.rs index 97c567b83..403d91131 100644 --- a/src/ast/helpers/stmt_create_table.rs +++ b/src/ast/helpers/stmt_create_table.rs @@ -4,6 +4,9 @@ use alloc::{boxed::Box, format, string::String, vec, vec::Vec}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +#[cfg(feature = "visitor")] +use sqlparser_derive::Visit; + use crate::ast::{ ColumnDef, FileFormat, HiveDistributionStyle, HiveFormat, ObjectName, OnCommit, Query, SqlOption, Statement, TableConstraint, @@ -40,6 +43,7 @@ use crate::parser::ParserError; /// [1]: crate::ast::Statement::CreateTable #[derive(Debug, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub struct CreateTableBuilder { pub or_replace: bool, pub temporary: bool, diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 7f3d15438..af4a69647 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -22,6 +22,9 @@ use core::fmt; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +#[cfg(feature = "visitor")] +use sqlparser_derive::Visit; + pub use self::data_type::{ CharLengthUnits, CharacterLength, DataType, ExactNumberInfo, TimezoneInfo, }; @@ -38,6 +41,9 @@ pub use self::query::{ }; pub use self::value::{escape_quoted_string, DateTimeField, TrimWhereField, Value}; +#[cfg(feature = "visitor")] +pub use visitor::*; + mod data_type; mod ddl; pub mod helpers; @@ -45,6 +51,9 @@ mod operator; mod query; mod value; +#[cfg(feature = "visitor")] +mod visitor; + struct DisplaySeparated<'a, T> where T: fmt::Display, @@ -85,6 +94,7 @@ where /// An identifier, decomposed into its value or character data and the quote style. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub struct Ident { /// The value of the identifier without quotes. pub value: String, @@ -145,6 +155,7 @@ impl fmt::Display for Ident { /// A name of a table, view, custom type, etc., possibly multi-part, i.e. db.schema.obj #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub struct ObjectName(pub Vec); impl fmt::Display for ObjectName { @@ -153,10 +164,11 @@ impl fmt::Display for ObjectName { } } -#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] /// Represents an Array Expression, either /// `ARRAY[..]`, or `[..]` +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub struct Array { /// The list of expressions between brackets pub elem: Vec, @@ -179,6 +191,7 @@ impl fmt::Display for Array { /// JsonOperator #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum JsonOperator { /// -> keeps the value as json Arrow, @@ -221,6 +234,7 @@ impl fmt::Display for JsonOperator { /// inappropriate type, like `WHERE 1` or `SELECT 1=1`, as necessary. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit), visit(with = "visit_expr"))] pub enum Expr { /// Identifier e.g. table name or column name Identifier(Ident), @@ -861,6 +875,7 @@ impl fmt::Display for Expr { /// A window specification (i.e. `OVER (PARTITION BY .. ORDER BY .. etc.)`) #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub struct WindowSpec { pub partition_by: Vec, pub order_by: Vec, @@ -906,6 +921,7 @@ impl fmt::Display for WindowSpec { /// reject invalid bounds like `ROWS UNBOUNDED FOLLOWING` before execution. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub struct WindowFrame { pub units: WindowFrameUnits, pub start_bound: WindowFrameBound, @@ -931,6 +947,7 @@ impl Default for WindowFrame { #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum WindowFrameUnits { Rows, Range, @@ -950,6 +967,7 @@ impl fmt::Display for WindowFrameUnits { /// Specifies [WindowFrame]'s `start_bound` and `end_bound` #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum WindowFrameBound { /// `CURRENT ROW` CurrentRow, @@ -973,6 +991,7 @@ impl fmt::Display for WindowFrameBound { #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum AddDropSync { ADD, DROP, @@ -991,6 +1010,7 @@ impl fmt::Display for AddDropSync { #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum ShowCreateObject { Event, Function, @@ -1015,6 +1035,7 @@ impl fmt::Display for ShowCreateObject { #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum CommentObject { Column, Table, @@ -1031,6 +1052,7 @@ impl fmt::Display for CommentObject { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum Password { Password(Expr), NullPassword, @@ -1040,9 +1062,11 @@ pub enum Password { #[allow(clippy::large_enum_variant)] #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit), visit(with = "visit_statement"))] pub enum Statement { /// Analyze (Hive) Analyze { + #[visit(with = "visit_table")] table_name: ObjectName, partitions: Option>, for_columns: bool, @@ -1053,11 +1077,13 @@ pub enum Statement { }, /// Truncate (Hive) Truncate { + #[visit(with = "visit_table")] table_name: ObjectName, partitions: Option>, }, /// Msck (Hive) Msck { + #[visit(with = "visit_table")] table_name: ObjectName, repair: bool, partition_action: Option, @@ -1071,6 +1097,7 @@ pub enum Statement { /// INTO - optional keyword into: bool, /// TABLE + #[visit(with = "visit_table")] table_name: ObjectName, /// COLUMNS columns: Vec, @@ -1098,6 +1125,7 @@ pub enum Statement { }, Copy { /// TABLE + #[visit(with = "visit_table")] table_name: ObjectName, /// COLUMNS columns: Vec, @@ -1160,6 +1188,7 @@ pub enum Statement { global: Option, if_not_exists: bool, /// Table name + #[visit(with = "visit_table")] name: ObjectName, /// Optional schema columns: Vec, @@ -1184,6 +1213,7 @@ pub enum Statement { }, /// SQLite's `CREATE VIRTUAL TABLE .. USING ()` CreateVirtualTable { + #[visit(with = "visit_table")] name: ObjectName, if_not_exists: bool, module_name: Ident, @@ -1193,6 +1223,7 @@ pub enum Statement { CreateIndex { /// index name name: ObjectName, + #[visit(with = "visit_table")] table_name: ObjectName, using: Option, columns: Vec, @@ -1226,6 +1257,7 @@ pub enum Statement { /// ALTER TABLE AlterTable { /// Table name + #[visit(with = "visit_table")] name: ObjectName, operation: AlterTableOperation, }, @@ -1350,6 +1382,7 @@ pub enum Statement { ShowColumns { extended: bool, full: bool, + #[visit(with = "visit_table")] table_name: ObjectName, filter: Option, }, @@ -1465,9 +1498,10 @@ pub enum Statement { /// EXPLAIN TABLE /// Note: this is a MySQL-specific statement. See ExplainTable { - // If true, query used the MySQL `DESCRIBE` alias for explain + /// If true, query used the MySQL `DESCRIBE` alias for explain describe_alias: bool, - // Table name + /// Table name + #[visit(with = "visit_table")] table_name: ObjectName, }, /// EXPLAIN / DESCRIBE for select_statement @@ -1501,19 +1535,21 @@ pub enum Statement { /// CACHE [ FLAG ] TABLE [ OPTIONS('K1' = 'V1', 'K2' = V2) ] [ AS ] [ ] /// Based on Spark SQL,see Cache { - // Table flag + /// Table flag table_flag: Option, - // Table name + /// Table name + #[visit(with = "visit_table")] table_name: ObjectName, has_as: bool, - // Table confs + /// Table confs options: Vec, - // Cache table as a Query + /// Cache table as a Query query: Option, }, /// UNCACHE TABLE [ IF EXISTS ] UNCache { - // Table name + /// Table name + #[visit(with = "visit_table")] table_name: ObjectName, if_exists: bool, }, @@ -2621,6 +2657,7 @@ impl fmt::Display for Statement { /// [ START [ WITH ] start ] [ CACHE cache ] [ [ NO ] CYCLE ] #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum SequenceOptions { IncrementBy(Expr, bool), MinValue(MinMaxValue), @@ -2685,6 +2722,7 @@ impl fmt::Display for SequenceOptions { /// [ MINVALUE minvalue | NO MINVALUE ] [ MAXVALUE maxvalue | NO MAXVALUE ] #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum MinMaxValue { // clause is not specified Empty, @@ -2696,6 +2734,7 @@ pub enum MinMaxValue { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] #[non_exhaustive] pub enum OnInsert { /// ON DUPLICATE KEY UPDATE (MySQL when the key already exists, then execute an update instead) @@ -2706,12 +2745,14 @@ pub enum OnInsert { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub struct OnConflict { pub conflict_target: Vec, pub action: OnConflictAction, } #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum OnConflictAction { DoNothing, DoUpdate(DoUpdate), @@ -2719,6 +2760,7 @@ pub enum OnConflictAction { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub struct DoUpdate { /// Column assignments pub assignments: Vec, @@ -2772,6 +2814,7 @@ impl fmt::Display for OnConflictAction { /// Privileges granted in a GRANT statement or revoked in a REVOKE statement. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum Privileges { /// All privileges applicable to the object type All { @@ -2808,6 +2851,7 @@ impl fmt::Display for Privileges { /// Specific direction for FETCH statement #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum FetchDirection { Count { limit: Value }, Next, @@ -2871,6 +2915,7 @@ impl fmt::Display for FetchDirection { /// A privilege on a database object (table, sequence, etc.). #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum Action { Connect, Create, @@ -2920,6 +2965,7 @@ impl fmt::Display for Action { /// Objects on which privileges are granted in a GRANT statement. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum GrantObjects { /// Grant privileges on `ALL SEQUENCES IN SCHEMA [, ...]` AllSequencesInSchema { schemas: Vec }, @@ -2966,6 +3012,7 @@ impl fmt::Display for GrantObjects { /// SQL assignment `foo = expr` as used in SQLUpdate #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub struct Assignment { pub id: Vec, pub value: Expr, @@ -2979,6 +3026,7 @@ impl fmt::Display for Assignment { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum FunctionArgExpr { Expr(Expr), /// Qualified wildcard, e.g. `alias.*` or `schema.table.*`. @@ -2999,6 +3047,7 @@ impl fmt::Display for FunctionArgExpr { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum FunctionArg { Named { name: Ident, arg: FunctionArgExpr }, Unnamed(FunctionArgExpr), @@ -3015,6 +3064,7 @@ impl fmt::Display for FunctionArg { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum CloseCursor { All, Specific { name: Ident }, @@ -3032,6 +3082,7 @@ impl fmt::Display for CloseCursor { /// A function call #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub struct Function { pub name: ObjectName, pub args: Vec, @@ -3045,6 +3096,7 @@ pub struct Function { #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum AnalyzeFormat { TEXT, GRAPHVIZ, @@ -3086,6 +3138,7 @@ impl fmt::Display for Function { /// External table's available file format #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum FileFormat { TEXTFILE, SEQUENCEFILE, @@ -3115,6 +3168,7 @@ impl fmt::Display for FileFormat { /// [ WITHIN GROUP (ORDER BY [, ...] ) ]` #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub struct ListAgg { pub distinct: bool, pub expr: Box, @@ -3152,6 +3206,7 @@ impl fmt::Display for ListAgg { /// The `ON OVERFLOW` clause of a LISTAGG invocation #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum ListAggOnOverflow { /// `ON OVERFLOW ERROR` Error, @@ -3189,6 +3244,7 @@ impl fmt::Display for ListAggOnOverflow { /// ORDER BY position is defined differently for BigQuery, Postgres and Snowflake. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub struct ArrayAgg { pub distinct: bool, pub expr: Box, @@ -3225,6 +3281,7 @@ impl fmt::Display for ArrayAgg { #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum ObjectType { Table, View, @@ -3249,6 +3306,7 @@ impl fmt::Display for ObjectType { #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum KillType { Connection, Query, @@ -3269,6 +3327,7 @@ impl fmt::Display for KillType { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum HiveDistributionStyle { PARTITIONED { columns: Vec, @@ -3288,14 +3347,15 @@ pub enum HiveDistributionStyle { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum HiveRowFormat { SERDE { class: String }, DELIMITED, } -#[allow(clippy::large_enum_variant)] #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] #[allow(clippy::large_enum_variant)] pub enum HiveIOFormat { IOF { @@ -3309,6 +3369,7 @@ pub enum HiveIOFormat { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash, Default)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub struct HiveFormat { pub row_format: Option, pub storage: Option, @@ -3317,6 +3378,7 @@ pub struct HiveFormat { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub struct SqlOption { pub name: Ident, pub value: Value, @@ -3330,6 +3392,7 @@ impl fmt::Display for SqlOption { #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum TransactionMode { AccessMode(TransactionAccessMode), IsolationLevel(TransactionIsolationLevel), @@ -3347,6 +3410,7 @@ impl fmt::Display for TransactionMode { #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum TransactionAccessMode { ReadOnly, ReadWrite, @@ -3364,6 +3428,7 @@ impl fmt::Display for TransactionAccessMode { #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum TransactionIsolationLevel { ReadUncommitted, ReadCommitted, @@ -3385,6 +3450,7 @@ impl fmt::Display for TransactionIsolationLevel { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum ShowStatementFilter { Like(String), ILike(String), @@ -3407,6 +3473,7 @@ impl fmt::Display for ShowStatementFilter { /// https://sqlite.org/lang_conflict.html #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum SqliteOnConflict { Rollback, Abort, @@ -3430,6 +3497,7 @@ impl fmt::Display for SqliteOnConflict { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum CopyTarget { Stdin, Stdout, @@ -3461,6 +3529,7 @@ impl fmt::Display for CopyTarget { #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum OnCommit { DeleteRows, PreserveRows, @@ -3472,6 +3541,7 @@ pub enum OnCommit { /// #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum CopyOption { /// FORMAT format_name Format(Ident), @@ -3525,6 +3595,7 @@ impl fmt::Display for CopyOption { /// #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum CopyLegacyOption { /// BINARY Binary, @@ -3553,6 +3624,7 @@ impl fmt::Display for CopyLegacyOption { /// #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum CopyLegacyCsvOption { /// HEADER Header, @@ -3584,6 +3656,7 @@ impl fmt::Display for CopyLegacyCsvOption { /// #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum MergeClause { MatchedUpdate { predicate: Option, @@ -3645,6 +3718,7 @@ impl fmt::Display for MergeClause { #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum DiscardObject { ALL, PLANS, @@ -3666,6 +3740,7 @@ impl fmt::Display for DiscardObject { /// Optional context modifier for statements that can be or `LOCAL`, or `SESSION`. #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum ContextModifier { /// No context defined. Each dialect defines the default in this scenario. None, @@ -3694,6 +3769,7 @@ impl fmt::Display for ContextModifier { /// Function argument in CREATE FUNCTION. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub struct CreateFunctionArg { pub mode: Option, pub name: Option, @@ -3742,6 +3818,7 @@ impl fmt::Display for CreateFunctionArg { /// The mode of an argument in CREATE FUNCTION. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum ArgMode { In, Out, @@ -3761,6 +3838,7 @@ impl fmt::Display for ArgMode { /// These attributes inform the query optimizer about the behavior of the function. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum FunctionBehavior { Immutable, Stable, @@ -3780,6 +3858,7 @@ impl fmt::Display for FunctionBehavior { /// Postgres: https://www.postgresql.org/docs/15/sql-createfunction.html #[derive(Debug, Default, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub struct CreateFunctionBody { /// LANGUAGE lang_name pub language: Option, @@ -3818,6 +3897,7 @@ impl fmt::Display for CreateFunctionBody { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum CreateFunctionUsing { Jar(String), File(String), @@ -3840,6 +3920,7 @@ impl fmt::Display for CreateFunctionUsing { /// [1]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#schema-definition #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum SchemaName { /// Only schema name specified: ``. Simple(ObjectName), @@ -3870,6 +3951,7 @@ impl fmt::Display for SchemaName { /// [1]: https://dev.mysql.com/doc/refman/8.0/en/fulltext-search.html#function_match #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum SearchModifier { /// `IN NATURAL LANGUAGE MODE`. InNaturalLanguageMode, diff --git a/src/ast/operator.rs b/src/ast/operator.rs index f22839474..b8f371be3 100644 --- a/src/ast/operator.rs +++ b/src/ast/operator.rs @@ -14,14 +14,19 @@ use core::fmt; #[cfg(not(feature = "std"))] use alloc::{string::String, vec::Vec}; + #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +#[cfg(feature = "visitor")] +use sqlparser_derive::Visit; + use super::display_separated; /// Unary operators #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum UnaryOperator { Plus, Minus, @@ -59,6 +64,7 @@ impl fmt::Display for UnaryOperator { /// Binary operators #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum BinaryOperator { Plus, Minus, diff --git a/src/ast/query.rs b/src/ast/query.rs index f813f44dd..0fae84201 100644 --- a/src/ast/query.rs +++ b/src/ast/query.rs @@ -16,12 +16,16 @@ use alloc::{boxed::Box, vec::Vec}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +#[cfg(feature = "visitor")] +use sqlparser_derive::Visit; + use crate::ast::*; /// The most complete variant of a `SELECT` query expression, optionally /// including `WITH`, `UNION` / other set operations, and `ORDER BY`. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub struct Query { /// WITH (common table expressions, or CTEs) pub with: Option, @@ -69,6 +73,7 @@ impl fmt::Display for Query { #[allow(clippy::large_enum_variant)] #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit))] pub enum SetExpr { /// Restricted SELECT .. FROM .. HAVING (no ORDER BY or set operations) Select(Box