diff --git a/Cargo.lock b/Cargo.lock index 2ab1c5274c46..8f236e435414 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8216,6 +8216,7 @@ dependencies = [ "common-catalog", "common-telemetry", "common-time", + "sql", ] [[package]] @@ -8487,7 +8488,6 @@ name = "sql" version = "0.2.0" dependencies = [ "api", - "catalog", "common-base", "common-catalog", "common-datasource", diff --git a/src/datanode/src/sql/alter.rs b/src/datanode/src/sql/alter.rs index 88b25c428c7a..fe1c6974efdc 100644 --- a/src/datanode/src/sql/alter.rs +++ b/src/datanode/src/sql/alter.rs @@ -100,7 +100,7 @@ mod tests { use query::parser::{QueryLanguageParser, QueryStatement}; use query::query_engine::SqlStatementExecutor; use session::context::QueryContext; - use sql::dialect::GenericDialect; + use sql::dialect::GreptimeDbDialect; use sql::parser::ParserContext; use sql::statements::statement::Statement; @@ -108,7 +108,7 @@ mod tests { use crate::tests::test_util::MockInstance; fn parse_sql(sql: &str) -> AlterTable { - let mut stmt = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let mut stmt = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, stmt.len()); let stmt = stmt.remove(0); assert_matches!(stmt, Statement::Alter(_)); diff --git a/src/datanode/src/sql/create.rs b/src/datanode/src/sql/create.rs index 104a60bb5c37..682552f8eded 100644 --- a/src/datanode/src/sql/create.rs +++ b/src/datanode/src/sql/create.rs @@ -253,7 +253,7 @@ mod tests { use query::parser::{QueryLanguageParser, QueryStatement}; use query::query_engine::SqlStatementExecutor; use session::context::QueryContext; - use sql::dialect::GenericDialect; + use sql::dialect::GreptimeDbDialect; use sql::parser::ParserContext; use sql::statements::statement::Statement; @@ -262,7 +262,7 @@ mod tests { use crate::tests::test_util::MockInstance; fn sql_to_statement(sql: &str) -> CreateTable { - let mut res = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let mut res = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, res.len()); match res.pop().unwrap() { Statement::CreateTable(c) => c, diff --git a/src/frontend/src/expr_factory.rs b/src/frontend/src/expr_factory.rs index f2541f33e525..6f00204c2933 100644 --- a/src/frontend/src/expr_factory.rs +++ b/src/frontend/src/expr_factory.rs @@ -322,7 +322,7 @@ pub(crate) fn to_alter_expr( #[cfg(test)] mod tests { use session::context::QueryContext; - use sql::dialect::GenericDialect; + use sql::dialect::GreptimeDbDialect; use sql::parser::ParserContext; use sql::statements::statement::Statement; @@ -331,7 +331,7 @@ mod tests { #[test] fn test_create_to_expr() { let sql = "CREATE TABLE monitor (host STRING,ts TIMESTAMP,TIME INDEX (ts),PRIMARY KEY(host)) ENGINE=mito WITH(regions=1, ttl='3days', write_buffer_size='1024KB');"; - let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {}) + let stmt = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}) .unwrap() .pop() .unwrap(); diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 974b9f5fc0be..7d6ce82209c8 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -64,7 +64,7 @@ use servers::query_handler::{ }; use session::context::QueryContextRef; use snafu::prelude::*; -use sql::dialect::GenericDialect; +use sql::dialect::Dialect; use sql::parser::ParserContext; use sql::statements::copy::CopyTable; use sql::statements::statement::Statement; @@ -447,8 +447,8 @@ impl FrontendInstance for Instance { } } -fn parse_stmt(sql: &str) -> Result> { - ParserContext::create_with_dialect(sql, &GenericDialect {}).context(ParseSqlSnafu) +fn parse_stmt(sql: &str, dialect: &(dyn Dialect + Send + Sync)) -> Result> { + ParserContext::create_with_dialect(sql, dialect).context(ParseSqlSnafu) } impl Instance { @@ -473,7 +473,7 @@ impl SqlQueryHandler for Instance { Err(e) => return vec![Err(e)], }; - match parse_stmt(query.as_ref()) + match parse_stmt(query.as_ref(), query_ctx.sql_dialect()) .and_then(|stmts| query_interceptor.post_parsing(stmts, query_ctx.clone())) { Ok(stmts) => { @@ -664,6 +664,7 @@ mod tests { use datatypes::schema::{ColumnDefaultConstraint, ColumnSchema}; use query::query_engine::options::QueryOptions; use session::context::QueryContext; + use sql::dialect::GreptimeDbDialect; use strfmt::Format; use super::*; @@ -748,7 +749,7 @@ mod tests { CREATE DATABASE test_database; SHOW DATABASES; "#; - let stmts = parse_stmt(sql).unwrap(); + let stmts = parse_stmt(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(stmts.len(), 4); for stmt in stmts { let re = check_permission(plugins.clone(), &stmt, &query_ctx); @@ -759,7 +760,7 @@ mod tests { SHOW CREATE TABLE demo; ALTER TABLE demo ADD COLUMN new_col INT; "#; - let stmts = parse_stmt(sql).unwrap(); + let stmts = parse_stmt(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(stmts.len(), 2); for stmt in stmts { let re = check_permission(plugins.clone(), &stmt, &query_ctx); @@ -767,7 +768,7 @@ mod tests { } let sql = "USE randomschema"; - let stmts = parse_stmt(sql).unwrap(); + let stmts = parse_stmt(sql, &GreptimeDbDialect {}).unwrap(); let re = check_permission(plugins.clone(), &stmts[0], &query_ctx); assert!(re.is_ok()); @@ -800,7 +801,7 @@ mod tests { } fn do_test(sql: &str, plugins: Arc, query_ctx: &QueryContextRef, is_ok: bool) { - let stmt = &parse_stmt(sql).unwrap()[0]; + let stmt = &parse_stmt(sql, &GreptimeDbDialect {}).unwrap()[0]; let re = check_permission(plugins, stmt, query_ctx); if is_ok { assert!(re.is_ok()); @@ -828,12 +829,12 @@ mod tests { // test show tables let sql = "SHOW TABLES FROM public"; - let stmt = parse_stmt(sql).unwrap(); + let stmt = parse_stmt(sql, &GreptimeDbDialect {}).unwrap(); let re = check_permission(plugins.clone(), &stmt[0], &query_ctx); assert!(re.is_ok()); let sql = "SHOW TABLES FROM wrongschema"; - let stmt = parse_stmt(sql).unwrap(); + let stmt = parse_stmt(sql, &GreptimeDbDialect {}).unwrap(); let re = check_permission(plugins.clone(), &stmt[0], &query_ctx); assert!(re.is_err()); diff --git a/src/frontend/src/instance/distributed.rs b/src/frontend/src/instance/distributed.rs index 8d5cb9289dcf..e497f9a187ca 100644 --- a/src/frontend/src/instance/distributed.rs +++ b/src/frontend/src/instance/distributed.rs @@ -874,7 +874,7 @@ fn find_partition_columns( #[cfg(test)] mod test { use session::context::QueryContext; - use sql::dialect::GenericDialect; + use sql::dialect::GreptimeDbDialect; use sql::parser::ParserContext; use sql::statements::statement::Statement; @@ -908,7 +908,7 @@ ENGINE=mito", ), ]; for (sql, expected) in cases { - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); match &result[0] { Statement::CreateTable(c) => { let expr = expr_factory::create_to_expr(c, QueryContext::arc()).unwrap(); diff --git a/src/query/src/parser.rs b/src/query/src/parser.rs index 05eefa14da87..14de3e026ead 100644 --- a/src/query/src/parser.rs +++ b/src/query/src/parser.rs @@ -26,7 +26,7 @@ use promql_parser::parser::ast::{Extension as NodeExtension, ExtensionExpr}; use promql_parser::parser::Expr::Extension; use promql_parser::parser::{EvalStmt, Expr, ValueType}; use snafu::ResultExt; -use sql::dialect::GenericDialect; +use sql::dialect::GreptimeDbDialect; use sql::parser::ParserContext; use sql::statements::statement::Statement; @@ -108,7 +108,7 @@ pub struct QueryLanguageParser {} impl QueryLanguageParser { pub fn parse_sql(sql: &str) -> Result { let _timer = timer!(METRIC_PARSE_SQL_ELAPSED); - let mut statement = ParserContext::create_with_dialect(sql, &GenericDialect {}) + let mut statement = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}) .map_err(BoxedError::new) .context(QueryParseSnafu { query: sql.to_string(), diff --git a/src/query/src/sql/show.rs b/src/query/src/sql/show.rs index 54e8c20cd40e..51160e350bb8 100644 --- a/src/query/src/sql/show.rs +++ b/src/query/src/sql/show.rs @@ -20,7 +20,7 @@ use sql::ast::{ ColumnDef, ColumnOption, ColumnOptionDef, Expr, ObjectName, SqlOption, TableConstraint, Value as SqlValue, }; -use sql::dialect::GenericDialect; +use sql::dialect::GreptimeDbDialect; use sql::parser::ParserContext; use sql::statements::create::{CreateTable, TIME_INDEX}; use sql::statements::{self}; @@ -108,7 +108,7 @@ fn create_column_def(column_schema: &ColumnSchema) -> Result { .with_context(|_| ConvertSqlValueSnafu { value: v.clone() })?, ), ColumnDefaultConstraint::Function(expr) => { - ParserContext::parse_function(expr, &GenericDialect {}).context(SqlSnafu)? + ParserContext::parse_function(expr, &GreptimeDbDialect {}).context(SqlSnafu)? } }; diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index 40f6d0801a17..6e40f976b930 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -32,9 +32,9 @@ use opensrv_mysql::{ use parking_lot::RwLock; use rand::RngCore; use session::context::Channel; -use session::Session; +use session::{Session, SessionRef}; use snafu::ensure; -use sql::dialect::GenericDialect; +use sql::dialect::MySqlDialect; use sql::parser::ParserContext; use sql::statements::statement::Statement; use tokio::io::AsyncWrite; @@ -48,7 +48,7 @@ use crate::query_handler::sql::ServerSqlQueryHandlerRef; pub struct MysqlInstanceShim { query_handler: ServerSqlQueryHandlerRef, salt: [u8; 20], - session: Arc, + session: SessionRef, user_provider: Option, // TODO(SSebo): use something like moka to achieve TTL or LRU prepared_stmts: Arc>>, @@ -77,7 +77,7 @@ impl MysqlInstanceShim { MysqlInstanceShim { query_handler, salt: scramble, - session: Arc::new(Session::new(client_addr, Channel::Mysql)), + session: Arc::new(Session::new(Some(client_addr), Channel::Mysql)), user_provider, prepared_stmts: Default::default(), prepared_stmts_counter: AtomicU32::new(1), @@ -140,9 +140,13 @@ impl AsyncMysqlShim for MysqlInstanceShi let username = String::from_utf8_lossy(username); let mut user_info = None; - let addr = self.session.conn_info().client_host.to_string(); + let addr = self + .session + .conn_info() + .client_addr + .map(|addr| addr.to_string()); if let Some(user_provider) = &self.user_provider { - let user_id = Identity::UserId(&username, Some(addr.as_str())); + let user_id = Identity::UserId(&username, addr.as_deref()); let password = match auth_plugin { "mysql_native_password" => Password::MysqlNativePassword(auth_data, salt), @@ -331,7 +335,7 @@ fn format_duration(duration: Duration) -> String { } async fn validate_query(query: &str) -> Result { - let statement = ParserContext::create_with_dialect(query, &GenericDialect {}); + let statement = ParserContext::create_with_dialect(query, &MySqlDialect {}); let mut statement = statement.map_err(|e| { InvalidPrepareStatementSnafu { err_msg: e.to_string(), diff --git a/src/servers/src/postgres.rs b/src/servers/src/postgres.rs index e4f271dffeca..c13b986d621e 100644 --- a/src/servers/src/postgres.rs +++ b/src/servers/src/postgres.rs @@ -31,7 +31,8 @@ use pgwire::api::auth::ServerParameterProvider; use pgwire::api::store::MemPortalStore; use pgwire::api::{ClientInfo, MakeHandler}; pub use server::PostgresServer; -use session::context::{QueryContext, QueryContextRef}; +use session::context::Channel; +use session::Session; use sql::statements::statement::Statement; use self::auth_handler::PgLoginVerifier; @@ -73,7 +74,7 @@ pub struct PostgresServerHandler { force_tls: bool, param_provider: Arc, - query_ctx: QueryContextRef, + session: Session, portal_store: Arc>, query_parser: Arc, } @@ -90,18 +91,18 @@ pub(crate) struct MakePostgresServerHandler { } impl MakeHandler for MakePostgresServerHandler { - type Handler = Arc; + type Handler = PostgresServerHandler; fn make(&self) -> Self::Handler { - Arc::new(PostgresServerHandler { + PostgresServerHandler { query_handler: self.query_handler.clone(), login_verifier: PgLoginVerifier::new(self.user_provider.clone()), force_tls: self.force_tls, param_provider: self.param_provider.clone(), - query_ctx: QueryContext::arc(), + session: Session::new(None, Channel::Postgres), portal_store: Arc::new(MemPortalStore::new()), query_parser: self.query_parser.clone(), - }) + } } } diff --git a/src/servers/src/postgres/auth_handler.rs b/src/servers/src/postgres/auth_handler.rs index 5613304516e9..688d47ac53cb 100644 --- a/src/servers/src/postgres/auth_handler.rs +++ b/src/servers/src/postgres/auth_handler.rs @@ -24,7 +24,8 @@ use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pgwire::messages::response::ErrorResponse; use pgwire::messages::startup::Authentication; use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage}; -use session::context::QueryContextRef; +use session::context::UserInfo; +use session::Session; use super::PostgresServerHandler; use crate::auth::{Identity, Password, UserProviderRef}; @@ -112,15 +113,19 @@ impl PgLoginVerifier { } } -fn set_query_context_from_client_info(client: &C, query_context: QueryContextRef) +fn set_client_info(client: &C, session: &Session) where C: ClientInfo, { + let ctx = session.context(); if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) { - query_context.set_current_catalog(current_catalog); + ctx.set_current_catalog(current_catalog); } if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) { - query_context.set_current_schema(current_schema); + ctx.set_current_schema(current_schema); + } + if let Some(username) = client.metadata().get(super::METADATA_USER) { + session.set_user_info(UserInfo::new(username)); } } @@ -170,7 +175,7 @@ impl StartupHandler for PostgresServerHandler { )) .await?; } else { - set_query_context_from_client_info(client, self.query_ctx.clone()); + set_client_info(client, &self.session); auth::finish_authentication(client, self.param_provider.as_ref()).await; } } @@ -193,7 +198,7 @@ impl StartupHandler for PostgresServerHandler { ) .await; } - set_query_context_from_client_info(client, self.query_ctx.clone()); + set_client_info(client, &self.session); auth::finish_authentication(client, self.param_provider.as_ref()).await; } _ => {} diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 1d4112e449c1..583e9884743f 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -33,7 +33,7 @@ use pgwire::api::stmt::QueryParser; use pgwire::api::store::MemPortalStore; use pgwire::api::{ClientInfo, Type}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; -use sql::dialect::GenericDialect; +use sql::dialect::PostgreSqlDialect; use sql::parser::ParserContext; use sql::statements::statement::Statement; @@ -55,13 +55,13 @@ impl SimpleQueryHandler for PostgresServerHandler { ), ( crate::metrics::METRIC_DB_LABEL, - self.query_ctx.get_db_string() + self.session.context().get_db_string() ) ] ); let outputs = self .query_handler - .do_query(query, self.query_ctx.clone()) + .do_query(query, self.session.context()) .await; let mut results = Vec::with_capacity(outputs.len()); @@ -260,7 +260,7 @@ impl QueryParser for POCQueryParser { fn parse_sql(&self, sql: &str, types: &[Type]) -> PgWireResult { increment_counter!(crate::metrics::METRIC_POSTGRES_PREPARED_COUNT); - let mut stmts = ParserContext::create_with_dialect(sql, &GenericDialect {}) + let mut stmts = ParserContext::create_with_dialect(sql, &PostgreSqlDialect {}) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; if stmts.len() != 1 { Err(PgWireError::UserError(Box::new(ErrorInfo::new( @@ -361,7 +361,7 @@ impl ExtendedQueryHandler for PostgresServerHandler { ), ( crate::metrics::METRIC_DB_LABEL, - self.query_ctx.get_db_string() + self.session.context().get_db_string() ) ] ); @@ -376,7 +376,7 @@ impl ExtendedQueryHandler for PostgresServerHandler { let output = self .query_handler - .do_query(&sql, self.query_ctx.clone()) + .do_query(&sql, self.session.context()) .await .remove(0); @@ -407,7 +407,7 @@ impl ExtendedQueryHandler for PostgresServerHandler { if let Some(schema) = self .query_handler - .do_describe(stmt.clone(), self.query_ctx.clone()) + .do_describe(stmt.clone(), self.session.context()) .await .map_err(|e| PgWireError::ApiError(Box::new(e)))? { diff --git a/src/servers/src/postgres/server.rs b/src/servers/src/postgres/server.rs index f2efc8093dce..106ad24e4164 100644 --- a/src/servers/src/postgres/server.rs +++ b/src/servers/src/postgres/server.rs @@ -73,19 +73,22 @@ impl PostgresServer { accepting_stream.for_each(move |tcp_stream| { let io_runtime = io_runtime.clone(); let tls_acceptor = tls_acceptor.clone(); - let handler = handler.make(); - + let mut handler = handler.make(); async move { match tcp_stream { Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt. Ok(io_stream) => { match io_stream.peer_addr() { - Ok(addr) => debug!("PostgreSQL client coming from {}", addr), + Ok(addr) => { + handler.session.mut_conn_info().client_addr = Some(addr); + debug!("PostgreSQL client coming from {}", addr) + } Err(e) => warn!("Failed to get PostgreSQL client addr, err: {}", e), } io_runtime.spawn(async move { increment_gauge!(crate::metrics::METRIC_POSTGRES_CONNECTIONS, 1.0); + let handler = Arc::new(handler); let r = process_socket( io_stream, tls_acceptor.clone(), diff --git a/src/session/Cargo.toml b/src/session/Cargo.toml index 06224ac8ef76..bdfc4bc7adc6 100644 --- a/src/session/Cargo.toml +++ b/src/session/Cargo.toml @@ -9,3 +9,4 @@ arc-swap = "1.5" common-catalog = { path = "../common/catalog" } common-telemetry = { path = "../common/telemetry" } common-time = { path = "../common/time" } +sql = { path = "../sql" } diff --git a/src/session/src/context.rs b/src/session/src/context.rs index 5adbe84ae9e1..1a9f38d2d27b 100644 --- a/src/session/src/context.rs +++ b/src/session/src/context.rs @@ -21,6 +21,7 @@ use common_catalog::build_db_string; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_telemetry::debug; use common_time::TimeZone; +use sql::dialect::{Dialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect}; pub type QueryContextRef = Arc; pub type ConnInfoRef = Arc; @@ -30,6 +31,7 @@ pub struct QueryContext { current_catalog: ArcSwap, current_schema: ArcSwap, time_zone: ArcSwap>, + sql_dialect: Box, } impl Default for QueryContext { @@ -59,25 +61,42 @@ impl QueryContext { current_catalog: ArcSwap::new(Arc::new(DEFAULT_CATALOG_NAME.to_string())), current_schema: ArcSwap::new(Arc::new(DEFAULT_SCHEMA_NAME.to_string())), time_zone: ArcSwap::new(Arc::new(None)), + sql_dialect: Box::new(GreptimeDbDialect {}), } } pub fn with(catalog: &str, schema: &str) -> Self { + Self::with_sql_dialect(catalog, schema, Box::new(GreptimeDbDialect {})) + } + + pub fn with_sql_dialect( + catalog: &str, + schema: &str, + sql_dialect: Box, + ) -> Self { Self { current_catalog: ArcSwap::new(Arc::new(catalog.to_string())), current_schema: ArcSwap::new(Arc::new(schema.to_string())), time_zone: ArcSwap::new(Arc::new(None)), + sql_dialect, } } + #[inline] pub fn current_schema(&self) -> String { self.current_schema.load().as_ref().clone() } + #[inline] pub fn current_catalog(&self) -> String { self.current_catalog.load().as_ref().clone() } + #[inline] + pub fn sql_dialect(&self) -> &(dyn Dialect + Send + Sync) { + &*self.sql_dialect + } + pub fn set_current_schema(&self, schema: &str) { let last = self.current_schema.swap(Arc::new(schema.to_string())); if schema != last.as_str() { @@ -142,15 +161,30 @@ impl UserInfo { } } +#[derive(Debug)] pub struct ConnInfo { - pub client_host: SocketAddr, + pub client_addr: Option, pub channel: Channel, } +impl std::fmt::Display for ConnInfo { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "{}[{}]", + self.channel, + self.client_addr + .map(|addr| addr.to_string()) + .as_deref() + .unwrap_or("unknown client addr") + ) + } +} + impl ConnInfo { - pub fn new(client_host: SocketAddr, channel: Channel) -> Self { + pub fn new(client_addr: Option, channel: Channel) -> Self { Self { - client_host, + client_addr, channel, } } @@ -158,13 +192,26 @@ impl ConnInfo { #[derive(Debug, PartialEq)] pub enum Channel { - Grpc, - Http, Mysql, Postgres, - Opentsdb, - Influxdb, - Prometheus, +} + +impl Channel { + pub fn dialect(&self) -> Box { + match self { + Channel::Mysql => Box::new(MySqlDialect {}), + Channel::Postgres => Box::new(PostgreSqlDialect {}), + } + } +} + +impl std::fmt::Display for Channel { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Channel::Mysql => write!(f, "mysql"), + Channel::Postgres => write!(f, "postgres"), + } + } } #[cfg(test)] @@ -175,7 +222,7 @@ mod test { #[test] fn test_session() { - let session = Session::new("127.0.0.1:9000".parse().unwrap(), Channel::Mysql); + let session = Session::new(Some("127.0.0.1:9000".parse().unwrap()), Channel::Mysql); // test user_info assert_eq!(session.user_info().username(), "greptime"); session.set_user_info(UserInfo::new("root")); @@ -183,11 +230,11 @@ mod test { // test channel assert_eq!(session.conn_info().channel, Channel::Mysql); - assert_eq!( - session.conn_info().client_host.ip().to_string(), - "127.0.0.1" - ); - assert_eq!(session.conn_info().client_host.port(), 9000); + let client_addr = session.conn_info().client_addr.as_ref().unwrap(); + assert_eq!(client_addr.ip().to_string(), "127.0.0.1"); + assert_eq!(client_addr.port(), 9000); + + assert_eq!("mysql[127.0.0.1:9000]", session.conn_info().to_string()); } #[test] diff --git a/src/session/src/lib.rs b/src/session/src/lib.rs index 875c6769888b..12b663e21aa4 100644 --- a/src/session/src/lib.rs +++ b/src/session/src/lib.rs @@ -18,33 +18,54 @@ use std::net::SocketAddr; use std::sync::Arc; use arc_swap::ArcSwap; +use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; -use crate::context::{Channel, ConnInfo, ConnInfoRef, QueryContext, QueryContextRef, UserInfo}; +use crate::context::{Channel, ConnInfo, QueryContext, QueryContextRef, UserInfo}; +/// Session for persistent connection such as MySQL, PostgreSQL etc. +#[derive(Debug)] pub struct Session { query_ctx: QueryContextRef, user_info: ArcSwap, - conn_info: ConnInfoRef, + conn_info: ConnInfo, } +pub type SessionRef = Arc; + impl Session { - pub fn new(addr: SocketAddr, channel: Channel) -> Self { + pub fn new(addr: Option, channel: Channel) -> Self { Session { - query_ctx: Arc::new(QueryContext::new()), + query_ctx: Arc::new(QueryContext::with_sql_dialect( + DEFAULT_CATALOG_NAME, + DEFAULT_SCHEMA_NAME, + channel.dialect(), + )), user_info: ArcSwap::new(Arc::new(UserInfo::default())), - conn_info: Arc::new(ConnInfo::new(addr, channel)), + conn_info: ConnInfo::new(addr, channel), } } + #[inline] pub fn context(&self) -> QueryContextRef { self.query_ctx.clone() } - pub fn conn_info(&self) -> ConnInfoRef { - self.conn_info.clone() + + #[inline] + pub fn conn_info(&self) -> &ConnInfo { + &self.conn_info + } + + #[inline] + pub fn mut_conn_info(&mut self) -> &mut ConnInfo { + &mut self.conn_info } + + #[inline] pub fn user_info(&self) -> Arc { self.user_info.load().clone() } + + #[inline] pub fn set_user_info(&self, user_info: UserInfo) { self.user_info.store(Arc::new(user_info)); } diff --git a/src/sql/Cargo.toml b/src/sql/Cargo.toml index 9ccadcbc3739..669471d55831 100644 --- a/src/sql/Cargo.toml +++ b/src/sql/Cargo.toml @@ -6,7 +6,6 @@ license.workspace = true [dependencies] api = { path = "../api" } -catalog = { path = "../catalog" } common-base = { path = "../common/base" } common-catalog = { path = "../common/catalog" } common-datasource = { path = "../common/datasource" } diff --git a/src/sql/src/dialect.rs b/src/sql/src/dialect.rs index 078b25d849e5..5060444d0bf8 100644 --- a/src/sql/src/dialect.rs +++ b/src/sql/src/dialect.rs @@ -12,6 +12,32 @@ // See the License for the specific language governing permissions and // limitations under the License. -// todo(hl) wrap sqlparser dialects +pub use sqlparser::dialect::{Dialect, MySqlDialect, PostgreSqlDialect}; -pub use sqlparser::dialect::{Dialect, GenericDialect}; +/// GreptimeDb dialect +#[derive(Debug, Clone)] +pub struct GreptimeDbDialect {} + +impl Dialect for GreptimeDbDialect { + fn is_identifier_start(&self, ch: char) -> bool { + ch.is_alphabetic() || ch == '_' || ch == '#' || ch == '@' + } + + fn is_identifier_part(&self, ch: char) -> bool { + ch.is_alphabetic() + || ch.is_ascii_digit() + || ch == '@' + || ch == '$' + || ch == '#' + || ch == '_' + } + + // Accepts both `identifier` and "identifier". + fn is_delimited_identifier_start(&self, ch: char) -> bool { + ch == '`' || ch == '"' + } + + fn supports_filter_during_aggregation(&self) -> bool { + true + } +} diff --git a/src/sql/src/parser.rs b/src/sql/src/parser.rs index 5c9099707976..a6c350144074 100644 --- a/src/sql/src/parser.rs +++ b/src/sql/src/parser.rs @@ -393,16 +393,16 @@ mod tests { use sqlparser::ast::{ Ident, ObjectName, Query as SpQuery, Statement as SpStatement, WildcardAdditionalOptions, }; - use sqlparser::dialect::GenericDialect; use super::*; + use crate::dialect::GreptimeDbDialect; use crate::statements::create::CreateTable; use crate::statements::sql_data_type_to_concrete_data_type; #[test] pub fn test_show_database_all() { let sql = "SHOW DATABASES"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); let stmts = result.unwrap(); assert_eq!(1, stmts.len()); @@ -417,7 +417,7 @@ mod tests { #[test] pub fn test_show_database_like() { let sql = "SHOW DATABASES LIKE test_database"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); let stmts = result.unwrap(); assert_eq!(1, stmts.len()); @@ -435,7 +435,7 @@ mod tests { #[test] pub fn test_show_database_where() { let sql = "SHOW DATABASES WHERE Database LIKE '%whatever1%' OR Database LIKE '%whatever2%'"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); let stmts = result.unwrap(); assert_eq!(1, stmts.len()); @@ -454,7 +454,7 @@ mod tests { #[test] pub fn test_show_tables_all() { let sql = "SHOW TABLES"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); let stmts = result.unwrap(); assert_eq!(1, stmts.len()); @@ -470,7 +470,7 @@ mod tests { #[test] pub fn test_show_tables_like() { let sql = "SHOW TABLES LIKE test_table"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); let stmts = result.unwrap(); assert_eq!(1, stmts.len()); @@ -486,7 +486,7 @@ mod tests { ); let sql = "SHOW TABLES in test_db LIKE test_table"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); let stmts = result.unwrap(); assert_eq!(1, stmts.len()); @@ -505,7 +505,7 @@ mod tests { #[test] pub fn test_show_tables_where() { let sql = "SHOW TABLES where name like test_table"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); let stmts = result.unwrap(); assert_eq!(1, stmts.len()); @@ -518,7 +518,7 @@ mod tests { ); let sql = "SHOW TABLES in test_db where name LIKE test_table"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); let stmts = result.unwrap(); assert_eq!(1, stmts.len()); @@ -534,7 +534,7 @@ mod tests { #[test] pub fn test_explain() { let sql = "EXPLAIN select * from foo"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); let stmts = result.unwrap(); assert_eq!(1, stmts.len()); @@ -589,7 +589,7 @@ mod tests { #[test] pub fn test_drop_table() { let sql = "DROP TABLE foo"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); let mut stmts = result.unwrap(); assert_eq!( stmts.pop().unwrap(), @@ -597,7 +597,7 @@ mod tests { ); let sql = "DROP TABLE my_schema.foo"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); let mut stmts = result.unwrap(); assert_eq!( stmts.pop().unwrap(), @@ -608,7 +608,7 @@ mod tests { ); let sql = "DROP TABLE my_catalog.my_schema.foo"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); let mut stmts = result.unwrap(); assert_eq!( stmts.pop().unwrap(), @@ -621,7 +621,7 @@ mod tests { } fn test_timestamp_precision(sql: &str, expected_type: ConcreteDataType) { - match ParserContext::create_with_dialect(sql, &GenericDialect {}) + match ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}) .unwrap() .pop() .unwrap() @@ -673,7 +673,7 @@ mod tests { #[test] fn test_parse_function() { let expr = - ParserContext::parse_function("current_timestamp()", &GenericDialect {}).unwrap(); + ParserContext::parse_function("current_timestamp()", &GreptimeDbDialect {}).unwrap(); assert!(matches!(expr, Expr::Function(_))); } } diff --git a/src/sql/src/parsers/alter_parser.rs b/src/sql/src/parsers/alter_parser.rs index cbd752d3874c..d71e9772c833 100644 --- a/src/sql/src/parsers/alter_parser.rs +++ b/src/sql/src/parsers/alter_parser.rs @@ -79,14 +79,14 @@ mod tests { use std::assert_matches::assert_matches; use sqlparser::ast::{ColumnOption, DataType}; - use sqlparser::dialect::GenericDialect; use super::*; + use crate::dialect::GreptimeDbDialect; #[test] fn test_parse_alter_add_column() { let sql = "ALTER TABLE my_metric_1 ADD tagk_i STRING Null;"; - let mut result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let mut result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, result.len()); let statement = result.remove(0); @@ -116,13 +116,13 @@ mod tests { #[test] fn test_parse_alter_drop_column() { let sql = "ALTER TABLE my_metric_1 DROP a"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap_err(); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap_err(); assert!(result .to_string() .contains("expect keyword COLUMN after ALTER TABLE DROP")); let sql = "ALTER TABLE my_metric_1 DROP COLUMN a"; - let mut result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let mut result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, result.len()); let statement = result.remove(0); @@ -147,13 +147,13 @@ mod tests { #[test] fn test_parse_alter_rename_table() { let sql = "ALTER TABLE test_table table_t"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap_err(); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap_err(); assert!(result .to_string() .contains("expect keyword ADD or DROP or RENAME after ALTER TABLE")); let sql = "ALTER TABLE test_table RENAME table_t"; - let mut result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let mut result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, result.len()); let statement = result.remove(0); diff --git a/src/sql/src/parsers/copy_parser.rs b/src/sql/src/parsers/copy_parser.rs index 2b43df6eb062..0e56f90757e3 100644 --- a/src/sql/src/parsers/copy_parser.rs +++ b/src/sql/src/parsers/copy_parser.rs @@ -139,16 +139,15 @@ mod tests { use std::assert_matches::assert_matches; use std::collections::HashMap; - use sqlparser::dialect::GenericDialect; - use super::*; + use crate::dialect::GreptimeDbDialect; #[test] fn test_parse_copy_table() { let sql0 = "COPY catalog0.schema0.tbl TO 'tbl_file.parquet'"; let sql1 = "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' WITH (FORMAT = 'parquet')"; - let result0 = ParserContext::create_with_dialect(sql0, &GenericDialect {}).unwrap(); - let result1 = ParserContext::create_with_dialect(sql1, &GenericDialect {}).unwrap(); + let result0 = ParserContext::create_with_dialect(sql0, &GreptimeDbDialect {}).unwrap(); + let result1 = ParserContext::create_with_dialect(sql1, &GreptimeDbDialect {}).unwrap(); for mut result in vec![result0, result1] { assert_eq!(1, result.len()); @@ -190,7 +189,7 @@ mod tests { "COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (FORMAT = 'parquet')", ] .iter() - .map(|sql| ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap()) + .map(|sql| ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap()) .collect::>(); for mut result in results { @@ -249,7 +248,7 @@ mod tests { for test in tests { let mut result = - ParserContext::create_with_dialect(test.sql, &GenericDialect {}).unwrap(); + ParserContext::create_with_dialect(test.sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, result.len()); let statement = result.remove(0); @@ -290,7 +289,7 @@ mod tests { for test in tests { let mut result = - ParserContext::create_with_dialect(test.sql, &GenericDialect {}).unwrap(); + ParserContext::create_with_dialect(test.sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, result.len()); let statement = result.remove(0); diff --git a/src/sql/src/parsers/create_parser.rs b/src/sql/src/parsers/create_parser.rs index 9c4fc8a90686..3daec82d6352 100644 --- a/src/sql/src/parsers/create_parser.rs +++ b/src/sql/src/parsers/create_parser.rs @@ -784,9 +784,9 @@ mod tests { use common_catalog::consts::IMMUTABLE_FILE_ENGINE; use sqlparser::ast::ColumnOption::NotNull; - use sqlparser::dialect::GenericDialect; use super::*; + use crate::dialect::GreptimeDbDialect; #[test] fn test_parse_create_external_table() { @@ -822,7 +822,8 @@ mod tests { ]; for test in tests { - let stmts = ParserContext::create_with_dialect(test.sql, &GenericDialect {}).unwrap(); + let stmts = + ParserContext::create_with_dialect(test.sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, stmts.len()); match &stmts[0] { Statement::CreateExternalTable(c) => { @@ -852,7 +853,7 @@ mod tests { ("format".to_string(), "csv".to_string()), ]); - let stmts = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let stmts = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, stmts.len()); match &stmts[0] { Statement::CreateExternalTable(c) => { @@ -888,14 +889,14 @@ mod tests { #[test] fn test_parse_create_database() { let sql = "create database"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); assert!(result .unwrap_err() .to_string() .contains("Unexpected token while parsing SQL statement")); let sql = "create database prometheus"; - let stmts = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let stmts = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, stmts.len()); match &stmts[0] { @@ -907,7 +908,7 @@ mod tests { } let sql = "create database if not exists prometheus"; - let stmts = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let stmts = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, stmts.len()); match &stmts[0] { @@ -929,7 +930,7 @@ PARTITION BY RANGE COLUMNS(b, a) ( PARTITION r3 VALUES LESS THAN (MAXVALUE, MAXVALUE), ) ENGINE=mito"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); assert!(result.is_ok()); let sql = r" @@ -940,7 +941,7 @@ PARTITION BY RANGE COLUMNS(b, x) ( PARTITION r3 VALUES LESS THAN (MAXVALUE, MAXVALUE), ) ENGINE=mito"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); assert!(result .unwrap_err() .to_string() @@ -955,7 +956,7 @@ PARTITION BY RANGE COLUMNS(b, a) ( PARTITION r1 VALUES LESS THAN (MAXVALUE, MAXVALUE), ) ENGINE=mito"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); assert!(result .unwrap_err() .to_string() @@ -969,7 +970,7 @@ PARTITION BY RANGE COLUMNS(b, a) ( PARTITION r3 VALUES LESS THAN (MAXVALUE, MAXVALUE), ) ENGINE=mito"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); assert!(result .unwrap_err() .to_string() @@ -1010,7 +1011,7 @@ PARTITION BY RANGE COLUMNS(b, a) ( ENGINE=mito", ]; for sql in cases { - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); assert!(result .unwrap_err() .to_string() @@ -1025,7 +1026,7 @@ PARTITION BY RANGE COLUMNS(b, a) ( PARTITION r3 VALUES LESS THAN (MAXVALUE, 9999), ) ENGINE=mito"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); assert!(result .unwrap_err() .to_string() @@ -1051,7 +1052,7 @@ PARTITION BY RANGE COLUMNS(idc, host_id) ( PARTITION r3 VALUES LESS THAN (MAXVALUE, MAXVALUE), ) ENGINE=mito"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(result.len(), 1); match &result[0] { Statement::CreateTable(c) => { @@ -1117,7 +1118,7 @@ CREATE TABLE monitor ( PRIMARY KEY (host), ) ENGINE=mito"; - let result1 = ParserContext::create_with_dialect(sql1, &GenericDialect {}).unwrap(); + let result1 = ParserContext::create_with_dialect(sql1, &GreptimeDbDialect {}).unwrap(); if let Statement::CreateTable(c) = &result1[0] { assert_eq!(c.constraints.len(), 2); @@ -1152,7 +1153,7 @@ CREATE TABLE monitor ( PRIMARY KEY (host), ) ENGINE=mito"; - let result2 = ParserContext::create_with_dialect(sql2, &GenericDialect {}).unwrap(); + let result2 = ParserContext::create_with_dialect(sql2, &GreptimeDbDialect {}).unwrap(); assert_eq!(result1, result2); @@ -1169,7 +1170,7 @@ CREATE TABLE monitor ( ) ENGINE=mito"; - let result3 = ParserContext::create_with_dialect(sql3, &GenericDialect {}).unwrap(); + let result3 = ParserContext::create_with_dialect(sql3, &GreptimeDbDialect {}).unwrap(); assert_ne!(result1, result3); @@ -1184,7 +1185,7 @@ CREATE TABLE monitor ( PRIMARY KEY (host), ) ENGINE=mito"; - let result1 = ParserContext::create_with_dialect(sql1, &GenericDialect {}).unwrap(); + let result1 = ParserContext::create_with_dialect(sql1, &GreptimeDbDialect {}).unwrap(); if let Statement::CreateTable(c) = &result1[0] { assert_eq!(c.constraints.len(), 2); @@ -1220,7 +1221,7 @@ CREATE TABLE monitor ( PRIMARY KEY (host), ) ENGINE=mito"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(result.len(), 1); if let Statement::CreateTable(c) = &result[0] { @@ -1243,7 +1244,7 @@ CREATE TABLE monitor ( ) ENGINE=mito"; - let result1 = ParserContext::create_with_dialect(sql1, &GenericDialect {}).unwrap(); + let result1 = ParserContext::create_with_dialect(sql1, &GreptimeDbDialect {}).unwrap(); assert_eq!(result, result1); let sql2 = r" @@ -1258,7 +1259,7 @@ CREATE TABLE monitor ( ) ENGINE=mito"; - let result2 = ParserContext::create_with_dialect(sql2, &GenericDialect {}).unwrap(); + let result2 = ParserContext::create_with_dialect(sql2, &GreptimeDbDialect {}).unwrap(); assert_eq!(result, result2); let sql3 = r" @@ -1273,7 +1274,7 @@ CREATE TABLE monitor ( ) ENGINE=mito"; - let result3 = ParserContext::create_with_dialect(sql3, &GenericDialect {}); + let result3 = ParserContext::create_with_dialect(sql3, &GreptimeDbDialect {}); assert!(result3.is_err()); let sql4 = r" @@ -1288,7 +1289,7 @@ CREATE TABLE monitor ( ) ENGINE=mito"; - let result4 = ParserContext::create_with_dialect(sql4, &GenericDialect {}); + let result4 = ParserContext::create_with_dialect(sql4, &GreptimeDbDialect {}); assert!(result4.is_err()); let sql = r" @@ -1303,7 +1304,7 @@ CREATE TABLE monitor ( ) ENGINE=mito"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); if let Statement::CreateTable(c) = &result[0] { let tc = c.constraints[0].clone(); @@ -1339,7 +1340,7 @@ PARTITION RANGE COLUMNS(b, a) ( PARTITION r3 VALUES LESS THAN (MAXVALUE, MAXVALUE), ) ENGINE=mito"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); assert!(result .unwrap_err() .to_string() @@ -1353,7 +1354,7 @@ PARTITION BY RANGE COLUMNS(b, a) ( PARTITION r3 VALUES LESS THAN (MAXVALUE, MAXVALUE), ) ENGINE=mito"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); assert!(result .unwrap_err() .to_string() @@ -1367,11 +1368,11 @@ PARTITION BY RANGE COLUMNS(b, a) ( PARTITION r3 VALUES LESS THAN (MAXVALUE, MAXVALU), ) ENGINE=mito"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); assert!(result .unwrap_err() .to_string() - .contains("Please provide an extra partition that is bounded by 'MAXVALUE'.")); + .contains("Expected a concrete value, found: MAXVALU")); } fn assert_column_def(column: &ColumnDef, name: &str, data_type: &str) { @@ -1390,7 +1391,7 @@ ENGINE=mito"; PRIMARY KEY(ts, host)) engine=mito with(regions=1); "; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, result.len()); match &result[0] { Statement::CreateTable(c) => { @@ -1438,7 +1439,7 @@ ENGINE=mito"; PRIMARY KEY(ts, host)) engine=mito with(regions=1); "; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); assert!(result.is_err()); assert_matches!(result, Err(crate::error::Error::InvalidTimeIndex { .. })); } @@ -1455,7 +1456,7 @@ ENGINE=mito"; PRIMARY KEY(ts, host)) engine=mito with(regions=1); "; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); assert!(result.is_err()); assert_matches!(result, Err(crate::error::Error::InvalidColumnOption { .. })); @@ -1469,7 +1470,7 @@ ENGINE=mito"; PRIMARY KEY(ts, host)) engine=mito with(regions=1); "; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); assert!(result.is_err()); assert_matches!(result, Err(crate::error::Error::InvalidTimeIndex { .. })); } @@ -1477,7 +1478,7 @@ ENGINE=mito"; #[test] fn test_invalid_column_name() { let sql = "create table foo(user string, i bigint time index)"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); assert!(result .unwrap_err() .to_string() @@ -1487,7 +1488,7 @@ ENGINE=mito"; let sql = r#" create table foo("user" string, i bigint time index) "#; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); assert!(result.is_ok()); } } diff --git a/src/sql/src/parsers/delete_parser.rs b/src/sql/src/parsers/delete_parser.rs index 1d52f1839fee..e48f3df61ab0 100644 --- a/src/sql/src/parsers/delete_parser.rs +++ b/src/sql/src/parsers/delete_parser.rs @@ -46,14 +46,13 @@ impl<'a> ParserContext<'a> { mod tests { use std::assert_matches::assert_matches; - use sqlparser::dialect::GenericDialect; - use super::*; + use crate::dialect::GreptimeDbDialect; #[test] pub fn test_parse_insert() { let sql = r"delete from my_table where k1 = xxx and k2 = xxx and timestamp = xxx;"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, result.len()); assert_matches!(result[0], Statement::Delete { .. }) } @@ -61,7 +60,7 @@ mod tests { #[test] pub fn test_parse_invalid_insert() { let sql = r"delete my_table where "; // intentionally a bad sql - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); assert!(result.is_err(), "result is: {result:?}"); } } diff --git a/src/sql/src/parsers/insert_parser.rs b/src/sql/src/parsers/insert_parser.rs index 3c40389b0cfc..6f035a72187c 100644 --- a/src/sql/src/parsers/insert_parser.rs +++ b/src/sql/src/parsers/insert_parser.rs @@ -46,9 +46,8 @@ impl<'a> ParserContext<'a> { mod tests { use std::assert_matches::assert_matches; - use sqlparser::dialect::GenericDialect; - use super::*; + use crate::dialect::GreptimeDbDialect; #[test] pub fn test_parse_insert() { @@ -56,7 +55,7 @@ mod tests { 'test1',1,'true', 'test2',2,'false') "; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, result.len()); assert_matches!(result[0], Statement::Insert { .. }) } @@ -64,7 +63,7 @@ mod tests { #[test] pub fn test_parse_invalid_insert() { let sql = r"INSERT INTO table_1 VALUES ("; // intentionally a bad sql - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); assert!(result.is_err(), "result is: {result:?}"); } } diff --git a/src/sql/src/parsers/query_parser.rs b/src/sql/src/parsers/query_parser.rs index 65b61c5638b8..82224c6cc5ba 100644 --- a/src/sql/src/parsers/query_parser.rs +++ b/src/sql/src/parsers/query_parser.rs @@ -33,8 +33,7 @@ impl<'a> ParserContext<'a> { #[cfg(test)] mod tests { - use sqlparser::dialect::GenericDialect; - + use crate::dialect::GreptimeDbDialect; use crate::parser::ParserContext; #[test] @@ -44,13 +43,13 @@ mod tests { WHERE a > b AND b < 100 \ ORDER BY a DESC, b"; - let _ = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let _ = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); } #[test] pub fn test_parse_invalid_query() { let sql = "SELECT * FROM table_1 WHERE"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}); assert!(result.is_err()); assert!(result .unwrap_err() diff --git a/src/sql/src/parsers/tql_parser.rs b/src/sql/src/parsers/tql_parser.rs index 9ef6d97ea4ba..5a1e9b9afeae 100644 --- a/src/sql/src/parsers/tql_parser.rs +++ b/src/sql/src/parsers/tql_parser.rs @@ -166,14 +166,13 @@ impl<'a> ParserContext<'a> { #[cfg(test)] mod tests { - use sqlparser::dialect::GenericDialect; - use super::*; + use crate::dialect::GreptimeDbDialect; #[test] fn test_parse_tql_eval() { let sql = "TQL EVAL (1676887657, 1676887659, '1m') http_requests_total{environment=~'staging|testing|development',method!='GET'} @ 1609746000 offset 5m"; - let mut result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let mut result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, result.len()); let statement = result.remove(0); @@ -189,7 +188,7 @@ mod tests { let sql = "TQL EVAL (1676887657.1, 1676887659.5, 30.3) http_requests_total{environment=~'staging|testing|development',method!='GET'} @ 1609746000 offset 5m"; - let mut result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let mut result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, result.len()); let statement = result.remove(0); @@ -205,7 +204,7 @@ mod tests { let sql = "TQL EVALUATE (1676887657.1, 1676887659.5, 30.3) http_requests_total{environment=~'staging|testing|development',method!='GET'} @ 1609746000 offset 5m"; - let mut result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let mut result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, result.len()); let statement2 = result.remove(0); @@ -213,7 +212,7 @@ mod tests { let sql = "tql eval ('2015-07-01T20:10:30.781Z', '2015-07-01T20:11:00.781Z', '30s') http_requests_total{environment=~'staging|testing|development',method!='GET'} @ 1609746000 offset 5m"; - let mut result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let mut result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, result.len()); let statement = result.remove(0); @@ -232,7 +231,7 @@ mod tests { fn test_parse_tql_explain() { let sql = "TQL EXPLAIN http_requests_total{environment=~'staging|testing|development',method!='GET'} @ 1609746000 offset 5m"; - let mut result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let mut result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, result.len()); let statement = result.remove(0); @@ -248,7 +247,7 @@ mod tests { let sql = "TQL EXPLAIN (20,100,10) http_requests_total{environment=~'staging|testing|development',method!='GET'} @ 1609746000 offset 5m"; - let mut result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let mut result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, result.len()); let statement = result.remove(0); @@ -266,7 +265,7 @@ mod tests { #[test] fn test_parse_tql_analyze() { let sql = "TQL ANALYZE (1676887657.1, 1676887659.5, 30.3) http_requests_total{environment=~'staging|testing|development',method!='GET'} @ 1609746000 offset 5m"; - let mut result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let mut result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, result.len()); let statement = result.remove(0); match statement { @@ -284,12 +283,12 @@ mod tests { fn test_parse_tql_error() { // Invalid duration let sql = "TQL EVAL (1676887657, 1676887659, 1m) http_requests_total{environment=~'staging|testing|development',method!='GET'} @ 1609746000 offset 5m"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap_err(); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap_err(); assert!(result.to_string().contains("Expected ), found: m")); // missing end let sql = "TQL EVAL (1676887657, '1m') http_requests_total{environment=~'staging|testing|development',method!='GET'} @ 1609746000 offset 5m"; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap_err(); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap_err(); assert!(result.to_string().contains("Expected ,, found: )")); } } diff --git a/src/sql/src/statements/create.rs b/src/sql/src/statements/create.rs index 6e4c5b02549d..4fa987a9ff85 100644 --- a/src/sql/src/statements/create.rs +++ b/src/sql/src/statements/create.rs @@ -206,8 +206,7 @@ pub struct CreateExternalTable { #[cfg(test)] mod tests { - use sqlparser::dialect::GenericDialect; - + use crate::dialect::GreptimeDbDialect; use crate::parser::ParserContext; use crate::statements::statement::Statement; @@ -229,7 +228,7 @@ mod tests { engine=mito with(regions=1, ttl='7d'); "; - let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, result.len()); match &result[0] { @@ -259,7 +258,7 @@ WITH( ); let new_result = - ParserContext::create_with_dialect(&new_sql, &GenericDialect {}).unwrap(); + ParserContext::create_with_dialect(&new_sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(result, new_result); } _ => unreachable!(), diff --git a/src/sql/src/statements/describe.rs b/src/sql/src/statements/describe.rs index f435f628a9fc..5bc5b4fb0632 100644 --- a/src/sql/src/statements/describe.rs +++ b/src/sql/src/statements/describe.rs @@ -35,8 +35,7 @@ impl DescribeTable { mod tests { use std::assert_matches::assert_matches; - use sqlparser::dialect::GenericDialect; - + use crate::dialect::GreptimeDbDialect; use crate::parser::ParserContext; use crate::statements::statement::Statement; @@ -44,7 +43,7 @@ mod tests { pub fn test_describe_table() { let sql = "DESCRIBE TABLE test"; let stmts: Vec = - ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, stmts.len()); assert_matches!(&stmts[0], Statement::DescribeTable { .. }); match &stmts[0] { @@ -61,7 +60,7 @@ mod tests { pub fn test_describe_schema_table() { let sql = "DESCRIBE TABLE test_schema.test"; let stmts: Vec = - ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, stmts.len()); assert_matches!(&stmts[0], Statement::DescribeTable { .. }); match &stmts[0] { @@ -78,7 +77,7 @@ mod tests { pub fn test_describe_catalog_schema_table() { let sql = "DESCRIBE TABLE test_catalog.test_schema.test"; let stmts: Vec = - ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, stmts.len()); assert_matches!(&stmts[0], Statement::DescribeTable { .. }); match &stmts[0] { @@ -94,6 +93,6 @@ mod tests { #[test] pub fn test_describe_missing_table_name() { let sql = "DESCRIBE TABLE"; - ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap_err(); + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap_err(); } } diff --git a/src/sql/src/statements/insert.rs b/src/sql/src/statements/insert.rs index dd42db11201f..f8d69d469732 100644 --- a/src/sql/src/statements/insert.rs +++ b/src/sql/src/statements/insert.rs @@ -136,9 +136,8 @@ impl TryFrom for Insert { #[cfg(test)] mod tests { - use sqlparser::dialect::GenericDialect; - use super::*; + use crate::dialect::GreptimeDbDialect; use crate::parser::ParserContext; use crate::statements::statement::Statement; @@ -146,7 +145,7 @@ mod tests { fn test_insert_value_with_unary_op() { // insert "-1" let sql = "INSERT INTO my_table VALUES(-1)"; - let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {}) + let stmt = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}) .unwrap() .remove(0); match stmt { @@ -159,7 +158,7 @@ mod tests { // insert "+1" let sql = "INSERT INTO my_table VALUES(+1)"; - let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {}) + let stmt = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}) .unwrap() .remove(0); match stmt { @@ -175,7 +174,7 @@ mod tests { fn test_insert_value_with_default() { // insert "default" let sql = "INSERT INTO my_table VALUES(default)"; - let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {}) + let stmt = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}) .unwrap() .remove(0); match stmt { @@ -191,7 +190,7 @@ mod tests { fn test_insert_value_with_default_uppercase() { // insert "DEFAULT" let sql = "INSERT INTO my_table VALUES(DEFAULT)"; - let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {}) + let stmt = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}) .unwrap() .remove(0); match stmt { @@ -207,7 +206,7 @@ mod tests { fn test_insert_value_with_quoted_string() { // insert "'default'" let sql = "INSERT INTO my_table VALUES('default')"; - let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {}) + let stmt = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}) .unwrap() .remove(0); match stmt { @@ -225,7 +224,7 @@ mod tests { #[test] fn test_insert_select() { let sql = "INSERT INTO my_table select * from other_table"; - let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {}) + let stmt = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}) .unwrap() .remove(0); match stmt { diff --git a/src/sql/src/statements/query.rs b/src/sql/src/statements/query.rs index 6b0b9bbde0e5..ef0e17665ad5 100644 --- a/src/sql/src/statements/query.rs +++ b/src/sql/src/statements/query.rs @@ -74,14 +74,13 @@ impl fmt::Display for Query { #[cfg(test)] mod test { - use sqlparser::dialect::GenericDialect; - use super::Query; + use crate::dialect::GreptimeDbDialect; use crate::parser::ParserContext; use crate::statements::statement::Statement; fn create_query(sql: &str) -> Option> { - match ParserContext::create_with_dialect(sql, &GenericDialect {}) + match ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}) .unwrap() .remove(0) { diff --git a/src/sql/src/statements/show.rs b/src/sql/src/statements/show.rs index e77eaf3ea4a8..ef36770ac9ce 100644 --- a/src/sql/src/statements/show.rs +++ b/src/sql/src/statements/show.rs @@ -65,9 +65,9 @@ mod tests { use std::assert_matches::assert_matches; use sqlparser::ast::UnaryOperator; - use sqlparser::dialect::GenericDialect; use super::*; + use crate::dialect::GreptimeDbDialect; use crate::parser::ParserContext; use crate::statements::statement::Statement; @@ -102,7 +102,7 @@ mod tests { #[test] pub fn test_show_database() { let sql = "SHOW DATABASES"; - let stmts = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + let stmts = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, stmts.len()); assert_matches!(&stmts[0], Statement::ShowDatabases { .. }); match &stmts[0] { @@ -119,7 +119,7 @@ mod tests { pub fn test_show_create_table() { let sql = "SHOW CREATE TABLE test"; let stmts: Vec = - ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap(); assert_eq!(1, stmts.len()); assert_matches!(&stmts[0], Statement::ShowCreateTable { .. }); match &stmts[0] { @@ -135,6 +135,6 @@ mod tests { #[test] pub fn test_show_create_missing_table_name() { let sql = "SHOW CREATE TABLE"; - ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap_err(); + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap_err(); } }