Skip to content

Commit

Permalink
feat: sql dialect for different protocols (#1631)
Browse files Browse the repository at this point in the history
* feat: add SqlDialect to query context

* feat: use session in postgrel handlers

* chore: refactor sql dialect

* feat: use different dialects for different sql protocols

* feat: adds GreptimeDbDialect

* refactor: replace GenericDialect with GreptimeDbDialect

* feat: save user info to session

* fix: compile error

* fix: test
  • Loading branch information
killme2008 authored May 30, 2023
1 parent 563ce59 commit ab5dfd3
Show file tree
Hide file tree
Showing 31 changed files with 285 additions and 185 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions src/datanode/src/sql/alter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,15 @@ mod tests {
use query::parser::{QueryLanguageParser, QueryStatement};
use query::query_engine::SqlStatementExecutor;
use session::context::QueryContext;
use sql::dialect::GenericDialect;
use sql::dialect::GreptimeDbDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;

use super::*;
use crate::tests::test_util::MockInstance;

fn parse_sql(sql: &str) -> AlterTable {
let mut stmt = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let mut stmt = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, stmt.len());
let stmt = stmt.remove(0);
assert_matches!(stmt, Statement::Alter(_));
Expand Down
4 changes: 2 additions & 2 deletions src/datanode/src/sql/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ mod tests {
use query::parser::{QueryLanguageParser, QueryStatement};
use query::query_engine::SqlStatementExecutor;
use session::context::QueryContext;
use sql::dialect::GenericDialect;
use sql::dialect::GreptimeDbDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;

Expand All @@ -262,7 +262,7 @@ mod tests {
use crate::tests::test_util::MockInstance;

fn sql_to_statement(sql: &str) -> CreateTable {
let mut res = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let mut res = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, res.len());
match res.pop().unwrap() {
Statement::CreateTable(c) => c,
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/src/expr_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ pub(crate) fn to_alter_expr(
#[cfg(test)]
mod tests {
use session::context::QueryContext;
use sql::dialect::GenericDialect;
use sql::dialect::GreptimeDbDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;

Expand All @@ -331,7 +331,7 @@ mod tests {
#[test]
fn test_create_to_expr() {
let sql = "CREATE TABLE monitor (host STRING,ts TIMESTAMP,TIME INDEX (ts),PRIMARY KEY(host)) ENGINE=mito WITH(regions=1, ttl='3days', write_buffer_size='1024KB');";
let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {})
let stmt = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {})
.unwrap()
.pop()
.unwrap();
Expand Down
21 changes: 11 additions & 10 deletions src/frontend/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ use servers::query_handler::{
};
use session::context::QueryContextRef;
use snafu::prelude::*;
use sql::dialect::GenericDialect;
use sql::dialect::Dialect;
use sql::parser::ParserContext;
use sql::statements::copy::CopyTable;
use sql::statements::statement::Statement;
Expand Down Expand Up @@ -447,8 +447,8 @@ impl FrontendInstance for Instance {
}
}

fn parse_stmt(sql: &str) -> Result<Vec<Statement>> {
ParserContext::create_with_dialect(sql, &GenericDialect {}).context(ParseSqlSnafu)
fn parse_stmt(sql: &str, dialect: &(dyn Dialect + Send + Sync)) -> Result<Vec<Statement>> {
ParserContext::create_with_dialect(sql, dialect).context(ParseSqlSnafu)
}

impl Instance {
Expand All @@ -473,7 +473,7 @@ impl SqlQueryHandler for Instance {
Err(e) => return vec![Err(e)],
};

match parse_stmt(query.as_ref())
match parse_stmt(query.as_ref(), query_ctx.sql_dialect())
.and_then(|stmts| query_interceptor.post_parsing(stmts, query_ctx.clone()))
{
Ok(stmts) => {
Expand Down Expand Up @@ -664,6 +664,7 @@ mod tests {
use datatypes::schema::{ColumnDefaultConstraint, ColumnSchema};
use query::query_engine::options::QueryOptions;
use session::context::QueryContext;
use sql::dialect::GreptimeDbDialect;
use strfmt::Format;

use super::*;
Expand Down Expand Up @@ -748,7 +749,7 @@ mod tests {
CREATE DATABASE test_database;
SHOW DATABASES;
"#;
let stmts = parse_stmt(sql).unwrap();
let stmts = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(stmts.len(), 4);
for stmt in stmts {
let re = check_permission(plugins.clone(), &stmt, &query_ctx);
Expand All @@ -759,15 +760,15 @@ mod tests {
SHOW CREATE TABLE demo;
ALTER TABLE demo ADD COLUMN new_col INT;
"#;
let stmts = parse_stmt(sql).unwrap();
let stmts = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(stmts.len(), 2);
for stmt in stmts {
let re = check_permission(plugins.clone(), &stmt, &query_ctx);
assert!(re.is_ok());
}

let sql = "USE randomschema";
let stmts = parse_stmt(sql).unwrap();
let stmts = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
let re = check_permission(plugins.clone(), &stmts[0], &query_ctx);
assert!(re.is_ok());

Expand Down Expand Up @@ -800,7 +801,7 @@ mod tests {
}

fn do_test(sql: &str, plugins: Arc<Plugins>, query_ctx: &QueryContextRef, is_ok: bool) {
let stmt = &parse_stmt(sql).unwrap()[0];
let stmt = &parse_stmt(sql, &GreptimeDbDialect {}).unwrap()[0];
let re = check_permission(plugins, stmt, query_ctx);
if is_ok {
assert!(re.is_ok());
Expand Down Expand Up @@ -828,12 +829,12 @@ mod tests {

// test show tables
let sql = "SHOW TABLES FROM public";
let stmt = parse_stmt(sql).unwrap();
let stmt = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
let re = check_permission(plugins.clone(), &stmt[0], &query_ctx);
assert!(re.is_ok());

let sql = "SHOW TABLES FROM wrongschema";
let stmt = parse_stmt(sql).unwrap();
let stmt = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
let re = check_permission(plugins.clone(), &stmt[0], &query_ctx);
assert!(re.is_err());

Expand Down
4 changes: 2 additions & 2 deletions src/frontend/src/instance/distributed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ fn find_partition_columns(
#[cfg(test)]
mod test {
use session::context::QueryContext;
use sql::dialect::GenericDialect;
use sql::dialect::GreptimeDbDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;

Expand Down Expand Up @@ -908,7 +908,7 @@ ENGINE=mito",
),
];
for (sql, expected) in cases {
let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
match &result[0] {
Statement::CreateTable(c) => {
let expr = expr_factory::create_to_expr(c, QueryContext::arc()).unwrap();
Expand Down
4 changes: 2 additions & 2 deletions src/query/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use promql_parser::parser::ast::{Extension as NodeExtension, ExtensionExpr};
use promql_parser::parser::Expr::Extension;
use promql_parser::parser::{EvalStmt, Expr, ValueType};
use snafu::ResultExt;
use sql::dialect::GenericDialect;
use sql::dialect::GreptimeDbDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;

Expand Down Expand Up @@ -108,7 +108,7 @@ pub struct QueryLanguageParser {}
impl QueryLanguageParser {
pub fn parse_sql(sql: &str) -> Result<QueryStatement> {
let _timer = timer!(METRIC_PARSE_SQL_ELAPSED);
let mut statement = ParserContext::create_with_dialect(sql, &GenericDialect {})
let mut statement = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {})
.map_err(BoxedError::new)
.context(QueryParseSnafu {
query: sql.to_string(),
Expand Down
4 changes: 2 additions & 2 deletions src/query/src/sql/show.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use sql::ast::{
ColumnDef, ColumnOption, ColumnOptionDef, Expr, ObjectName, SqlOption, TableConstraint,
Value as SqlValue,
};
use sql::dialect::GenericDialect;
use sql::dialect::GreptimeDbDialect;
use sql::parser::ParserContext;
use sql::statements::create::{CreateTable, TIME_INDEX};
use sql::statements::{self};
Expand Down Expand Up @@ -108,7 +108,7 @@ fn create_column_def(column_schema: &ColumnSchema) -> Result<ColumnDef> {
.with_context(|_| ConvertSqlValueSnafu { value: v.clone() })?,
),
ColumnDefaultConstraint::Function(expr) => {
ParserContext::parse_function(expr, &GenericDialect {}).context(SqlSnafu)?
ParserContext::parse_function(expr, &GreptimeDbDialect {}).context(SqlSnafu)?
}
};

Expand Down
18 changes: 11 additions & 7 deletions src/servers/src/mysql/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ use opensrv_mysql::{
use parking_lot::RwLock;
use rand::RngCore;
use session::context::Channel;
use session::Session;
use session::{Session, SessionRef};
use snafu::ensure;
use sql::dialect::GenericDialect;
use sql::dialect::MySqlDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;
use tokio::io::AsyncWrite;
Expand All @@ -48,7 +48,7 @@ use crate::query_handler::sql::ServerSqlQueryHandlerRef;
pub struct MysqlInstanceShim {
query_handler: ServerSqlQueryHandlerRef,
salt: [u8; 20],
session: Arc<Session>,
session: SessionRef,
user_provider: Option<UserProviderRef>,
// TODO(SSebo): use something like moka to achieve TTL or LRU
prepared_stmts: Arc<RwLock<HashMap<u32, String>>>,
Expand Down Expand Up @@ -77,7 +77,7 @@ impl MysqlInstanceShim {
MysqlInstanceShim {
query_handler,
salt: scramble,
session: Arc::new(Session::new(client_addr, Channel::Mysql)),
session: Arc::new(Session::new(Some(client_addr), Channel::Mysql)),
user_provider,
prepared_stmts: Default::default(),
prepared_stmts_counter: AtomicU32::new(1),
Expand Down Expand Up @@ -140,9 +140,13 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
let username = String::from_utf8_lossy(username);

let mut user_info = None;
let addr = self.session.conn_info().client_host.to_string();
let addr = self
.session
.conn_info()
.client_addr
.map(|addr| addr.to_string());
if let Some(user_provider) = &self.user_provider {
let user_id = Identity::UserId(&username, Some(addr.as_str()));
let user_id = Identity::UserId(&username, addr.as_deref());

let password = match auth_plugin {
"mysql_native_password" => Password::MysqlNativePassword(auth_data, salt),
Expand Down Expand Up @@ -331,7 +335,7 @@ fn format_duration(duration: Duration) -> String {
}

async fn validate_query(query: &str) -> Result<Statement> {
let statement = ParserContext::create_with_dialect(query, &GenericDialect {});
let statement = ParserContext::create_with_dialect(query, &MySqlDialect {});
let mut statement = statement.map_err(|e| {
InvalidPrepareStatementSnafu {
err_msg: e.to_string(),
Expand Down
13 changes: 7 additions & 6 deletions src/servers/src/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ use pgwire::api::auth::ServerParameterProvider;
use pgwire::api::store::MemPortalStore;
use pgwire::api::{ClientInfo, MakeHandler};
pub use server::PostgresServer;
use session::context::{QueryContext, QueryContextRef};
use session::context::Channel;
use session::Session;
use sql::statements::statement::Statement;

use self::auth_handler::PgLoginVerifier;
Expand Down Expand Up @@ -73,7 +74,7 @@ pub struct PostgresServerHandler {
force_tls: bool,
param_provider: Arc<GreptimeDBStartupParameters>,

query_ctx: QueryContextRef,
session: Session,
portal_store: Arc<MemPortalStore<(Statement, String)>>,
query_parser: Arc<POCQueryParser>,
}
Expand All @@ -90,18 +91,18 @@ pub(crate) struct MakePostgresServerHandler {
}

impl MakeHandler for MakePostgresServerHandler {
type Handler = Arc<PostgresServerHandler>;
type Handler = PostgresServerHandler;

fn make(&self) -> Self::Handler {
Arc::new(PostgresServerHandler {
PostgresServerHandler {
query_handler: self.query_handler.clone(),
login_verifier: PgLoginVerifier::new(self.user_provider.clone()),
force_tls: self.force_tls,
param_provider: self.param_provider.clone(),

query_ctx: QueryContext::arc(),
session: Session::new(None, Channel::Postgres),
portal_store: Arc::new(MemPortalStore::new()),
query_parser: self.query_parser.clone(),
})
}
}
}
17 changes: 11 additions & 6 deletions src/servers/src/postgres/auth_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::messages::response::ErrorResponse;
use pgwire::messages::startup::Authentication;
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
use session::context::QueryContextRef;
use session::context::UserInfo;
use session::Session;

use super::PostgresServerHandler;
use crate::auth::{Identity, Password, UserProviderRef};
Expand Down Expand Up @@ -112,15 +113,19 @@ impl PgLoginVerifier {
}
}

fn set_query_context_from_client_info<C>(client: &C, query_context: QueryContextRef)
fn set_client_info<C>(client: &C, session: &Session)
where
C: ClientInfo,
{
let ctx = session.context();
if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) {
query_context.set_current_catalog(current_catalog);
ctx.set_current_catalog(current_catalog);
}
if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) {
query_context.set_current_schema(current_schema);
ctx.set_current_schema(current_schema);
}
if let Some(username) = client.metadata().get(super::METADATA_USER) {
session.set_user_info(UserInfo::new(username));
}
}

Expand Down Expand Up @@ -170,7 +175,7 @@ impl StartupHandler for PostgresServerHandler {
))
.await?;
} else {
set_query_context_from_client_info(client, self.query_ctx.clone());
set_client_info(client, &self.session);
auth::finish_authentication(client, self.param_provider.as_ref()).await;
}
}
Expand All @@ -193,7 +198,7 @@ impl StartupHandler for PostgresServerHandler {
)
.await;
}
set_query_context_from_client_info(client, self.query_ctx.clone());
set_client_info(client, &self.session);
auth::finish_authentication(client, self.param_provider.as_ref()).await;
}
_ => {}
Expand Down
Loading

0 comments on commit ab5dfd3

Please sign in to comment.