diff --git a/Cargo.lock b/Cargo.lock index 2e196209b..156a4fb3b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1015,6 +1015,7 @@ dependencies = [ "rusqlite", "rust_decimal", "rust_decimal_macros", + "serde", "serde_json", "sqlparser 0.37.0", "thiserror", @@ -4523,9 +4524,9 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.195" +version = "1.0.198" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63261df402c67811e9ac6def069e4786148c4563f4b50fd4bf30aa370d626b02" +checksum = "9846a40c979031340571da2545a4e5b7c4163bdae79b301d5f86d03979451fcc" dependencies = [ "serde_derive", ] @@ -4542,9 +4543,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.195" +version = "1.0.198" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46fe8f8603d81ba86327b23a2e9cdf49e1255fb94a4c5f297f6ee0547178ea2c" +checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9" dependencies = [ "proc-macro2", "quote", diff --git a/connectorx-python/Cargo.lock b/connectorx-python/Cargo.lock index b7b7170eb..4303ca018 100644 --- a/connectorx-python/Cargo.lock +++ b/connectorx-python/Cargo.lock @@ -1058,6 +1058,7 @@ dependencies = [ "rusqlite", "rust_decimal", "rust_decimal_macros", + "serde", "serde_json", "sqlparser 0.37.0", "thiserror", diff --git a/connectorx-python/connectorx/tests/test_trino.py b/connectorx-python/connectorx/tests/test_trino.py index 12d6c8d9a..5c883fee6 100644 --- a/connectorx-python/connectorx/tests/test_trino.py +++ b/connectorx-python/connectorx/tests/test_trino.py @@ -154,7 +154,7 @@ def test_trino_limit_large_with_partition(trino_url: str) -> None: def test_trino_with_partition_without_partition_range(trino_url: str) -> None: - query = "SELECT * FROM test.test_table where test_float > 3 order by test_int" + query = "SELECT * FROM test.test_table where test_float > 3" df = read_sql( trino_url, query, @@ -170,6 +170,7 @@ def test_trino_with_partition_without_partition_range(trino_url: str) -> None: }, ) df.sort_values(by="test_int", inplace=True, ignore_index=True) + assert_frame_equal(df, expected, check_names=True) @@ -210,7 +211,7 @@ def test_trino_selection_and_projection(trino_url: str) -> None: def test_trino_join(trino_url: str) -> None: - query = "SELECT T.test_int, T.test_float, S.test_str FROM test_table T INNER JOIN test_table_extra S ON T.test_int = S.test_int order by T.test_int" + query = "SELECT T.test_int, T.test_float, S.test_str FROM test.test_table T INNER JOIN test.test_table_extra S ON T.test_int = S.test_int order by T.test_int" df = read_sql( trino_url, query, @@ -262,7 +263,7 @@ def test_trino_types_binary(trino_url: str) -> None: "test_real": pd.Series([123.456, 123.456, None], dtype="float64"), "test_double": pd.Series([123.4567890123, 123.4567890123, None], dtype="float64"), "test_decimal": pd.Series([1234567890.12, 1234567890.12, None], dtype="float64"), - "test_date": pd.Series([None, "2023-01-01", "2023-01-01"], dtype="datetime64[ns]"), + "test_date": pd.Series(["2023-01-01", "2023-01-01", None], dtype="datetime64[ns]"), "test_time": pd.Series(["12:00:00", "12:00:00", None], dtype="object"), "test_timestamp": pd.Series(["2023-01-01 12:00:00.123456", "2023-01-01 12:00:00.123456", None], dtype="datetime64[ns]"), "test_varchar": pd.Series(["Sample text", "Sample text", None], dtype="object"), @@ -299,7 +300,7 @@ def test_empty_result_on_partition(trino_url: str) -> None: def test_empty_result_on_some_partition(trino_url: str) -> None: - query = "SELECT * FROM test_table where test_int = 6" + query = "SELECT * FROM test.test_table where test_int = 6" df = read_sql(trino_url, query, partition_on="test_int", partition_num=3) expected = pd.DataFrame( index=range(1), diff --git a/connectorx-python/src/pandas/get_meta.rs b/connectorx-python/src/pandas/get_meta.rs index bc5e7de95..7ee648e7d 100644 --- a/connectorx-python/src/pandas/get_meta.rs +++ b/connectorx-python/src/pandas/get_meta.rs @@ -2,7 +2,7 @@ use super::{ destination::PandasDestination, transports::{ BigQueryPandasTransport, MsSQLPandasTransport, MysqlPandasTransport, OraclePandasTransport, - PostgresPandasTransport, SqlitePandasTransport, + PostgresPandasTransport, SqlitePandasTransport, TrinoPandasTransport, }, }; use crate::errors::ConnectorXPythonError; @@ -18,6 +18,7 @@ use connectorx::{ PostgresSource, SimpleProtocol, }, sqlite::SQLiteSource, + trino::TrinoSource, }, sql::CXQuery, }; @@ -223,6 +224,17 @@ pub fn get_meta<'a>(py: Python<'a>, conn: &str, protocol: &str, query: String) - debug!("Running dispatcher"); dispatcher.get_meta()?; } + SourceType::Trino => { + let rt = Arc::new(tokio::runtime::Runtime::new().expect("Failed to create runtime")); + let source = TrinoSource::new(rt, &source_conn.conn[..])?; + let dispatcher = Dispatcher::<_, _, TrinoPandasTransport>::new( + source, + &mut destination, + queries, + None, + ); + dispatcher.run()?; + } _ => unimplemented!("{:?} not implemented!", source_conn.ty), } diff --git a/connectorx/Cargo.toml b/connectorx/Cargo.toml index 44409c2cb..f7b200e39 100644 --- a/connectorx/Cargo.toml +++ b/connectorx/Cargo.toml @@ -58,6 +58,7 @@ uuid = {version = "0.8", optional = true} j4rs = {version = "0.15", optional = true} datafusion = {version = "31", optional = true} prusto = {version = "0.5.1", optional = true} +serde = {optional = true} [lib] crate-type = ["cdylib", "rlib"] @@ -98,7 +99,7 @@ src_postgres = [ "postgres-openssl", ] src_sqlite = ["rusqlite", "r2d2_sqlite", "fallible-streaming-iterator", "r2d2", "urlencoding"] -src_trino = ["prusto", "uuid", "urlencoding", "rust_decimal", "tokio", "num-traits"] +src_trino = ["prusto", "uuid", "urlencoding", "rust_decimal", "tokio", "num-traits", "serde"] federation = ["j4rs"] fed_exec = ["datafusion", "tokio"] integrated-auth-gssapi = ["tiberius/integrated-auth-gssapi"] diff --git a/connectorx/src/partition.rs b/connectorx/src/partition.rs index 370120e2a..fedd34fe7 100644 --- a/connectorx/src/partition.rs +++ b/connectorx/src/partition.rs @@ -10,6 +10,7 @@ use crate::sources::mysql::{MySQLSourceError, MySQLTypeSystem}; use crate::sources::oracle::{connect_oracle, OracleDialect}; #[cfg(feature = "src_postgres")] use crate::sources::postgres::{rewrite_tls_args, PostgresTypeSystem}; +use crate::sources::trino::TrinoDialect; #[cfg(feature = "src_sqlite")] use crate::sql::get_partition_range_query_sep; use crate::sql::{get_partition_range_query, single_col_partition_query, CXQuery}; @@ -35,7 +36,7 @@ use sqlparser::dialect::PostgreSqlDialect; use sqlparser::dialect::SQLiteDialect; #[cfg(feature = "src_mssql")] use tiberius::Client; -#[cfg(any(feature = "src_bigquery", feature = "src_mssql"))] +#[cfg(any(feature = "src_bigquery", feature = "src_mssql", feature = "src_trino"))] use tokio::{net::TcpStream, runtime::Runtime}; #[cfg(feature = "src_mssql")] use tokio_util::compat::TokioAsyncWriteCompatExt; @@ -100,6 +101,8 @@ pub fn get_col_range(source_conn: &SourceConn, query: &str, col: &str) -> OutRes SourceType::Oracle => oracle_get_partition_range(&source_conn.conn, query, col), #[cfg(feature = "src_bigquery")] SourceType::BigQuery => bigquery_get_partition_range(&source_conn.conn, query, col), + #[cfg(feature = "src_trino")] + SourceType::Trino => trino_get_partition_range(&source_conn.conn, query, col), _ => unimplemented!("{:?} not implemented!", source_conn.ty), } } @@ -137,6 +140,10 @@ pub fn get_part_query( SourceType::BigQuery => { single_col_partition_query(query, col, lower, upper, &BigQueryDialect {})? } + #[cfg(feature = "src_trino")] + SourceType::Trino => { + single_col_partition_query(query, col, lower, upper, &TrinoDialect {})? + } _ => unimplemented!("{:?} not implemented!", source_conn.ty), }; CXQuery::Wrapped(query) @@ -481,3 +488,52 @@ fn bigquery_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64 (min_v, max_v) } + +#[cfg(feature = "src_trino")] +#[throws(ConnectorXOutError)] +fn trino_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64) { + use prusto::{auth::Auth, ClientBuilder}; + + use crate::sources::trino::{TrinoDialect, TrinoPartitionQueryResult}; + + let rt = Runtime::new().expect("Failed to create runtime"); + + let username = match conn.username() { + "" => "connectorx", + username => username, + }; + + let builder = ClientBuilder::new(username, conn.host().unwrap().to_owned()) + .port(conn.port().unwrap_or(8080)) + .ssl(prusto::ssl::Ssl { root_cert: None }) + .secure(conn.scheme() == "trino+https") + .catalog(conn.path_segments().unwrap().last().unwrap_or("hive")); + + let builder = match conn.password() { + None => builder, + Some(password) => builder.auth(Auth::Basic(username.to_owned(), Some(password.to_owned()))), + }; + + let client = builder + .build() + .map_err(|e| anyhow!("Failed to build client: {}", e))?; + + let range_query = get_partition_range_query(query, col, &TrinoDialect {})?; + let query_result = rt.block_on(client.get_all::(range_query)); + + let query_result = match query_result { + Ok(query_result) => Ok(query_result.into_vec()), + Err(e) => match e { + prusto::error::Error::EmptyData => { + Ok(vec![TrinoPartitionQueryResult { _col0: 0, _col1: 0 }]) + } + _ => Err(anyhow!("Failed to get query result: {}", e)), + }, + }?; + + let result = query_result + .first() + .unwrap_or(&TrinoPartitionQueryResult { _col0: 0, _col1: 0 }); + + (result._col0, result._col1) +} diff --git a/connectorx/src/sources/trino/mod.rs b/connectorx/src/sources/trino/mod.rs index 8a75245a1..4f463fb28 100644 --- a/connectorx/src/sources/trino/mod.rs +++ b/connectorx/src/sources/trino/mod.rs @@ -4,7 +4,7 @@ use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; use fehler::{throw, throws}; use prusto::{auth::Auth, Client, ClientBuilder, DataSet, Presto, Row}; use serde_json::Value; -use sqlparser::dialect::GenericDialect; +use sqlparser::dialect::{Dialect, GenericDialect}; use std::convert::TryFrom; use tokio::runtime::Runtime; @@ -32,6 +32,26 @@ fn get_total_rows(rt: Arc, client: Arc, query: &CXQuery .len() } +#[derive(Presto, Debug)] +pub struct TrinoPartitionQueryResult { + pub _col0: i64, + pub _col1: i64, +} + +#[derive(Debug)] +pub struct TrinoDialect {} + +// implementation copy from AnsiDialect +impl Dialect for TrinoDialect { + fn is_identifier_start(&self, ch: char) -> bool { + ch.is_ascii_lowercase() || ch.is_ascii_uppercase() + } + + fn is_identifier_part(&self, ch: char) -> bool { + ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch.is_ascii_digit() || ch == '_' + } +} + pub struct TrinoSource { client: Arc, rt: Arc, @@ -282,12 +302,12 @@ macro_rules! impl_produce_int { match value { Value::Number(x) => { if (x.is_i64()) { - <$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))? + <$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse i64 at position: ({}, {}) {:?}", ridx, cidx, value))? } else { - throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x)) } } - _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value)) } } } @@ -304,12 +324,12 @@ macro_rules! impl_produce_int { Value::Null => None, Value::Number(x) => { if (x.is_i64()) { - Some(<$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))?) + Some(<$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse i64 at position: ({}, {}) {:?}", ridx, cidx, value))?) } else { - throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x)) } } - _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value)) } } } @@ -333,10 +353,11 @@ macro_rules! impl_produce_float { if (x.is_f64()) { x.as_f64().unwrap() as $t } else { - throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x)) } } - _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + Value::String(x) => x.parse::<$t>().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}) {:?}", ridx, cidx, value))?, + _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value)) } } } @@ -355,10 +376,11 @@ macro_rules! impl_produce_float { if (x.is_f64()) { Some(x.as_f64().unwrap() as $t) } else { - throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x)) } } - _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + Value::String(x) => Some(x.parse::<$t>().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}) {:?}", ridx, cidx, value))?), + _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value)) } } } diff --git a/connectorx/src/sources/trino/typesystem.rs b/connectorx/src/sources/trino/typesystem.rs index d5a492f03..f739c9a42 100644 --- a/connectorx/src/sources/trino/typesystem.rs +++ b/connectorx/src/sources/trino/typesystem.rs @@ -1,7 +1,7 @@ use super::errors::TrinoSourceError; use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; use fehler::{throw, throws}; -use prusto::{PrestoFloat, PrestoInt, PrestoTy}; +use prusto::{Presto, PrestoFloat, PrestoInt, PrestoTy}; use std::convert::TryFrom; // TODO: implement Tuple, Row, Array and Map as well as UUID @@ -64,6 +64,7 @@ impl TryFrom for TrinoTypeSystem { PrestoTy::Map(_, _) => Varchar(true), PrestoTy::Decimal(_, _) => Double(true), PrestoTy::IpAddress => Varchar(true), + PrestoTy::Uuid => Varchar(true), _ => throw!(TrinoSourceError::InferTypeFromNull), } } @@ -97,6 +98,7 @@ impl TryFrom<(Option<&str>, PrestoTy)> for TrinoTypeSystem { "map" => Varchar(true), "decimal" => Double(true), "ipaddress" => Varchar(true), + "uuid" => Varchar(true), _ => TrinoTypeSystem::try_from(ty)?, } }