From 925423a60096258c8b6341b0ae9c1ccf3f0de06c Mon Sep 17 00:00:00 2001 From: Dominik Liebler Date: Fri, 19 Jan 2024 16:06:22 +0100 Subject: [PATCH 01/10] implemented type system for Trino --- Cargo.lock | 68 +++++++++++++ Justfile | 1 + connectorx-cpp/Cargo.toml | 1 + connectorx-python/Cargo.toml | 1 + connectorx/Cargo.toml | 4 +- connectorx/src/lib.rs | 2 + connectorx/src/source_router.rs | 2 + connectorx/src/sources/mod.rs | 2 + connectorx/src/sources/trino/errors.rs | 14 +++ connectorx/src/sources/trino/mod.rs | 13 +++ connectorx/src/sources/trino/typesystem.rs | 112 +++++++++++++++++++++ 11 files changed, 219 insertions(+), 1 deletion(-) create mode 100644 connectorx/src/sources/trino/errors.rs create mode 100644 connectorx/src/sources/trino/mod.rs create mode 100644 connectorx/src/sources/trino/typesystem.rs diff --git a/Cargo.lock b/Cargo.lock index c640e9873f..4008e8233d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1019,6 +1019,7 @@ dependencies = [ "postgres-native-tls", "postgres-openssl", "pprof", + "prusto", "r2d2", "r2d2-oracle", "r2d2_mysql", @@ -1078,6 +1079,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7144d30dcf0fafbce74250a3963025d8d52177934239851c917d29f1df280c2" +[[package]] +name = "convert_case" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" + [[package]] name = "core-foundation" version = "0.9.3" @@ -1523,6 +1530,19 @@ dependencies = [ "serde", ] +[[package]] +name = "derive_more" +version = "0.99.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321" +dependencies = [ + "convert_case", + "proc-macro2", + "quote", + "rustc_version", + "syn 1.0.109", +] + [[package]] name = "derive_utils" version = "0.13.2" @@ -2411,6 +2431,15 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" +[[package]] +name = "iterable" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c151dfd6ab7dff5ca5567d82041bb286f07469ece85c1e2444a6d26d7057a65f" +dependencies = [ + "itertools 0.10.5", +] + [[package]] name = "itertools" version = "0.10.5" @@ -3823,6 +3852,43 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "prusto" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4b88a35eb608a69482012e38b818a77c23bd1f3fe952143217609ad6c43f94" +dependencies = [ + "bigdecimal", + "chrono", + "chrono-tz", + "derive_more", + "futures", + "http", + "iterable", + "lazy_static", + "log", + "prusto-macros", + "regex", + "reqwest", + "serde", + "serde_json", + "thiserror", + "tokio", + "urlencoding", + "uuid 1.4.1", +] + +[[package]] +name = "prusto-macros" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "729a73ec40e80da961c846455ec579c521346392d6f9f5a8c8aadfb5c99f9cf8" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "ptr_meta" version = "0.1.4" @@ -5013,6 +5079,7 @@ dependencies = [ "num_cpus", "parking_lot 0.12.1", "pin-project-lite", + "signal-hook-registry", "socket2 0.5.4", "tokio-macros", "windows-sys", @@ -5283,6 +5350,7 @@ checksum = "79daa5ed5740825c40b389c5e50312b9c86df53fccd33f281df655642b43869d" dependencies = [ "getrandom 0.2.10", "rand 0.8.5", + "serde", ] [[package]] diff --git a/Justfile b/Justfile index 07ac50601d..6a2cb19548 100644 --- a/Justfile +++ b/Justfile @@ -23,6 +23,7 @@ test-feature-gate: cargo c --features src_oracle cargo c --features src_csv cargo c --features src_dummy + cargo c --features src_trino cargo c --features dst_arrow cargo c --features dst_arrow2 diff --git a/connectorx-cpp/Cargo.toml b/connectorx-cpp/Cargo.toml index 05944b4ed0..ea7e1a5faf 100644 --- a/connectorx-cpp/Cargo.toml +++ b/connectorx-cpp/Cargo.toml @@ -34,4 +34,5 @@ srcs = [ "connectorx/src_mssql", "connectorx/src_oracle", "connectorx/src_bigquery", + "connectorx/src_trino", ] diff --git a/connectorx-python/Cargo.toml b/connectorx-python/Cargo.toml index 0b010af1df..aaa5245d00 100644 --- a/connectorx-python/Cargo.toml +++ b/connectorx-python/Cargo.toml @@ -75,5 +75,6 @@ srcs = [ "connectorx/src_mssql", "connectorx/src_oracle", "connectorx/src_bigquery", + "connectorx/src_trino", ] integrated-auth-gssapi = ["connectorx/integrated-auth-gssapi"] diff --git a/connectorx/Cargo.toml b/connectorx/Cargo.toml index e64f2d2e51..244ae4f221 100644 --- a/connectorx/Cargo.toml +++ b/connectorx/Cargo.toml @@ -57,6 +57,7 @@ urlencoding = {version = "2.1", optional = true} uuid = {version = "0.8", optional = true} j4rs = {version = "0.15", optional = true} datafusion = {version = "31", optional = true} +prusto = {version = "0.5.1", optional = true} [lib] crate-type = ["cdylib", "rlib"] @@ -69,7 +70,7 @@ iai = "0.1" pprof = {version = "0.5", features = ["flamegraph"]} [features] -all = ["src_sqlite", "src_postgres", "src_mysql", "src_mssql", "src_oracle", "src_bigquery", "src_csv", "src_dummy", "dst_arrow", "dst_arrow2", "federation", "fed_exec"] +all = ["src_sqlite", "src_postgres", "src_mysql", "src_mssql", "src_oracle", "src_bigquery", "src_csv", "src_dummy", "src_trino", "dst_arrow", "dst_arrow2", "federation", "fed_exec"] branch = [] default = ["fptr"] dst_arrow = ["arrow"] @@ -97,6 +98,7 @@ src_postgres = [ "postgres-openssl", ] src_sqlite = ["rusqlite", "r2d2_sqlite", "fallible-streaming-iterator", "r2d2", "urlencoding"] +src_trino = ["prusto", "uuid"] federation = ["j4rs"] fed_exec = ["datafusion", "tokio"] integrated-auth-gssapi = ["tiberius/integrated-auth-gssapi"] diff --git a/connectorx/src/lib.rs b/connectorx/src/lib.rs index 84b043be89..5b4ce73868 100644 --- a/connectorx/src/lib.rs +++ b/connectorx/src/lib.rs @@ -208,6 +208,8 @@ pub mod prelude { pub use crate::sources::postgres::PostgresSource; #[cfg(feature = "src_sqlite")] pub use crate::sources::sqlite::SQLiteSource; + #[cfg(feature = "src_trino")] + pub use crate::sources::trino::TrinoSource; pub use crate::sources::{PartitionParser, Produce, Source, SourcePartition}; pub use crate::sql::CXQuery; pub use crate::transports::*; diff --git a/connectorx/src/source_router.rs b/connectorx/src/source_router.rs index d307967668..ad3aec489e 100644 --- a/connectorx/src/source_router.rs +++ b/connectorx/src/source_router.rs @@ -14,6 +14,7 @@ pub enum SourceType { Oracle, BigQuery, DuckDB, + Trino, Unknown, } @@ -58,6 +59,7 @@ impl TryFrom<&str> for SourceConn { "oracle" => Ok(SourceConn::new(SourceType::Oracle, url, proto)), "bigquery" => Ok(SourceConn::new(SourceType::BigQuery, url, proto)), "duckdb" => Ok(SourceConn::new(SourceType::DuckDB, url, proto)), + "trino" => Ok(SourceConn::new(SourceType::Trino, url, proto)), _ => Ok(SourceConn::new(SourceType::Unknown, url, proto)), } } diff --git a/connectorx/src/sources/mod.rs b/connectorx/src/sources/mod.rs index 0afb6416ce..dd86b524dc 100644 --- a/connectorx/src/sources/mod.rs +++ b/connectorx/src/sources/mod.rs @@ -17,6 +17,8 @@ pub mod oracle; pub mod postgres; #[cfg(feature = "src_sqlite")] pub mod sqlite; +#[cfg(feature = "src_trino")] +pub mod trino; use crate::data_order::DataOrder; use crate::errors::ConnectorXError; diff --git a/connectorx/src/sources/trino/errors.rs b/connectorx/src/sources/trino/errors.rs new file mode 100644 index 0000000000..6f273c6199 --- /dev/null +++ b/connectorx/src/sources/trino/errors.rs @@ -0,0 +1,14 @@ +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum TrinoSourceError { + #[error("Cannot infer type from null for SQLite")] + InferTypeFromNull, + + #[error(transparent)] + ConnectorXError(#[from] crate::errors::ConnectorXError), + + /// Any other errors that are too trivial to be put here explicitly. + #[error(transparent)] + Other(#[from] anyhow::Error), +} diff --git a/connectorx/src/sources/trino/mod.rs b/connectorx/src/sources/trino/mod.rs new file mode 100644 index 0000000000..d26b85890c --- /dev/null +++ b/connectorx/src/sources/trino/mod.rs @@ -0,0 +1,13 @@ +use crate::prelude::CXQuery; + +use self::typesystem::TrinoTypeSystem; + +mod errors; +mod typesystem; + +pub struct TrinoSource { + origin_query: Option, + queries: Vec>, + names: Vec, + schema: Vec, +} diff --git a/connectorx/src/sources/trino/typesystem.rs b/connectorx/src/sources/trino/typesystem.rs new file mode 100644 index 0000000000..03111ee91b --- /dev/null +++ b/connectorx/src/sources/trino/typesystem.rs @@ -0,0 +1,112 @@ +use super::errors::TrinoSourceError; +use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; +use fehler::{throw, throws}; +use prusto::{PrestoFloat, PrestoInt, PrestoTy}; +use std::convert::TryFrom; +use uuid::Uuid; + +// TODO: implement Tuple, Row, Array and Map +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum TrinoTypeSystem { + Date(bool), + Time(bool), + Timestamp(bool), + Uuid(bool), + Boolean(bool), + Bigint(bool), + Integer(bool), + Smallint(bool), + Tinyint(bool), + Double(bool), + Real(bool), + Varchar(bool), + Char(bool), +} + +impl_typesystem! { + system = TrinoTypeSystem, + mappings = { + { Date => NaiveDate } + { Time => NaiveTime } + { Timestamp => NaiveDateTime } + { Uuid => Uuid } + { Boolean => bool } + { Bigint => i64 } + { Integer => i32 } + { Smallint => i16 } + { Tinyint => i8 } + { Double => f64 } + { Real => f32 } + { Varchar => Box } + { Char => char } + } +} + +impl TryFrom for TrinoTypeSystem { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn try_from(ty: PrestoTy) -> Self { + use TrinoTypeSystem::*; + match ty { + PrestoTy::Date => Date(true), + PrestoTy::Time => Time(true), + PrestoTy::Timestamp => Timestamp(true), + PrestoTy::Uuid => Uuid(true), + PrestoTy::Boolean => Boolean(true), + PrestoTy::PrestoInt(PrestoInt::I64) => Bigint(true), + PrestoTy::PrestoInt(PrestoInt::I32) => Integer(true), + PrestoTy::PrestoInt(PrestoInt::I16) => Smallint(true), + PrestoTy::PrestoInt(PrestoInt::I8) => Tinyint(true), + PrestoTy::PrestoFloat(PrestoFloat::F64) => Double(true), + PrestoTy::PrestoFloat(PrestoFloat::F32) => Real(true), + PrestoTy::Varchar => Varchar(true), + PrestoTy::Char(_) => Char(true), + PrestoTy::Tuple(_) => Varchar(true), + PrestoTy::Row(_) => Varchar(true), + PrestoTy::Array(_) => Varchar(true), + PrestoTy::Map(_, _) => Varchar(true), + PrestoTy::Decimal(_, _) => Double(true), + PrestoTy::IpAddress => Varchar(true), + _ => throw!(TrinoSourceError::InferTypeFromNull), + } + } +} + +impl TryFrom<(Option<&str>, PrestoTy)> for TrinoTypeSystem { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn try_from(types: (Option<&str>, PrestoTy)) -> Self { + use TrinoTypeSystem::*; + match types { + (Some(decl_type), ty) => { + let decl_type = decl_type.to_lowercase(); + match decl_type.as_str() { + "date" => Date(true), + "time" => Time(true), + "timestamp" => Timestamp(true), + "uuid" => Uuid(true), + "boolean" => Boolean(true), + "bigint" => Bigint(true), + "int" | "integer" => Integer(true), + "smallint" => Smallint(true), + "tinyint" => Tinyint(true), + "double" => Double(true), + "real" | "float" => Real(true), + "varchar" | "varbinary" | "json" => Varchar(true), + "char" => Char(true), + "tuple" => Varchar(true), + "row" => Varchar(true), + "array" => Varchar(true), + "map" => Varchar(true), + "decimal" => Double(true), + "ipaddress" => Varchar(true), + _ => TrinoTypeSystem::try_from(ty)?, + } + } + // derive from value type directly if no declare type available + (None, ty) => TrinoTypeSystem::try_from(ty)?, + } + } +} From 95630aabeff4c0c2a9c7a37b6946d924386bccac Mon Sep 17 00:00:00 2001 From: Dominik Liebler Date: Fri, 2 Feb 2024 21:59:45 +0100 Subject: [PATCH 02/10] hard-coded schema --- connectorx/Cargo.toml | 2 +- connectorx/src/sources/trino/errors.rs | 11 + connectorx/src/sources/trino/mod.rs | 430 ++++++++++++++++++++- connectorx/src/sources/trino/typesystem.rs | 9 +- connectorx/src/transports/mod.rs | 5 +- connectorx/src/transports/trino_arrow2.rs | 64 +++ 6 files changed, 508 insertions(+), 13 deletions(-) create mode 100644 connectorx/src/transports/trino_arrow2.rs diff --git a/connectorx/Cargo.toml b/connectorx/Cargo.toml index 244ae4f221..ee54b266b6 100644 --- a/connectorx/Cargo.toml +++ b/connectorx/Cargo.toml @@ -98,7 +98,7 @@ src_postgres = [ "postgres-openssl", ] src_sqlite = ["rusqlite", "r2d2_sqlite", "fallible-streaming-iterator", "r2d2", "urlencoding"] -src_trino = ["prusto", "uuid"] +src_trino = ["prusto", "uuid", "urlencoding", "rust_decimal", "tokio", "num-traits"] federation = ["j4rs"] fed_exec = ["datafusion", "tokio"] integrated-auth-gssapi = ["tiberius/integrated-auth-gssapi"] diff --git a/connectorx/src/sources/trino/errors.rs b/connectorx/src/sources/trino/errors.rs index 6f273c6199..d6862bbb80 100644 --- a/connectorx/src/sources/trino/errors.rs +++ b/connectorx/src/sources/trino/errors.rs @@ -1,3 +1,5 @@ +use std::string::FromUtf8Error; + use thiserror::Error; #[derive(Error, Debug)] @@ -8,6 +10,15 @@ pub enum TrinoSourceError { #[error(transparent)] ConnectorXError(#[from] crate::errors::ConnectorXError), + #[error(transparent)] + PrustoError(prusto::error::Error), + + #[error(transparent)] + UrlParseError(#[from] url::ParseError), + + #[error(transparent)] + TrinoUrlDecodeError(#[from] FromUtf8Error), + /// Any other errors that are too trivial to be put here explicitly. #[error(transparent)] Other(#[from] anyhow::Error), diff --git a/connectorx/src/sources/trino/mod.rs b/connectorx/src/sources/trino/mod.rs index d26b85890c..1f0d86b746 100644 --- a/connectorx/src/sources/trino/mod.rs +++ b/connectorx/src/sources/trino/mod.rs @@ -1,13 +1,435 @@ -use crate::prelude::CXQuery; +use std::{marker::PhantomData, sync::Arc}; -use self::typesystem::TrinoTypeSystem; +use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; +use fehler::{throw, throws}; +use prusto::{auth::Auth, Client, ClientBuilder, DataSet, Presto, Row}; +use serde_json::Value; +use sqlparser::dialect::GenericDialect; +use std::convert::TryFrom; +use tokio::runtime::Runtime; -mod errors; -mod typesystem; +use crate::{ + data_order::DataOrder, + errors::ConnectorXError, + sources::Produce, + sql::{count_query, limit1_query, CXQuery}, +}; + +pub use self::{errors::TrinoSourceError, typesystem::TrinoTypeSystem}; +use urlencoding::decode; + +use super::{PartitionParser, Source, SourcePartition}; + +use anyhow::anyhow; + +pub mod errors; +pub mod typesystem; + +#[throws(TrinoSourceError)] +async fn get_total_rows(client: Arc, query: &CXQuery) -> usize { + let result = client + .get_all::(count_query(query, &GenericDialect {})?.to_string()) + .await + .map_err(TrinoSourceError::PrustoError)?; + + usize::from(result.as_slice()[0]) +} pub struct TrinoSource { + client: Arc, + rt: Arc, origin_query: Option, queries: Vec>, names: Vec, schema: Vec, } + +impl TrinoSource { + #[throws(TrinoSourceError)] + pub fn new(rt: Arc, conn: &str) -> Self { + let decoded_conn = decode(conn)?.into_owned(); + + let url = decoded_conn + .parse::() + .map_err(TrinoSourceError::UrlParseError)?; + + let client = ClientBuilder::new(url.username(), url.host().unwrap().to_owned()) + .port(url.port().unwrap_or(8080)) + .auth(Auth::Basic( + url.username().to_owned(), + url.password().map(|x| x.to_owned()), + )) + .ssl(prusto::ssl::Ssl { root_cert: None }) + .secure(url.scheme() == "trino+https") + .catalog(url.path_segments().unwrap().last().unwrap_or("hive")) + .build() + .map_err(TrinoSourceError::PrustoError)?; + + Self { + client: Arc::new(client), + rt, + origin_query: None, + queries: vec![], + names: vec![], + schema: vec![], + } + } +} + +impl Source for TrinoSource +where + TrinoSourcePartition: SourcePartition, +{ + const DATA_ORDERS: &'static [DataOrder] = &[DataOrder::RowMajor]; + type TypeSystem = TrinoTypeSystem; + type Partition = TrinoSourcePartition; + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn set_data_order(&mut self, data_order: DataOrder) { + if !matches!(data_order, DataOrder::RowMajor) { + throw!(ConnectorXError::UnsupportedDataOrder(data_order)); + } + } + + fn set_queries(&mut self, queries: &[CXQuery]) { + self.queries = queries.iter().map(|q| q.map(Q::to_string)).collect(); + } + + fn set_origin_query(&mut self, query: Option) { + self.origin_query = query; + } + + #[throws(TrinoSourceError)] + fn fetch_metadata(&mut self) { + assert!(!self.queries.is_empty()); + + match &self.origin_query { + Some(q) => { + /*let cxq = CXQuery::Naked(q.clone()); + let cxq = limit1_query(&cxq, &GenericDialect {})?; + let data_set: DataSet<_> = self + .rt + .block_on(self.client.get_all::(cxq.to_string())) + .map_err(TrinoSourceError::PrustoError)?; + + let x = data_set.into_vec().first().unwrap(); + let ncols = x.value().to_vec().len(); + + let mut parser = + TrinoSourceParser::new(self.rt.clone(), self.client.clone(), cxq, ncols)?; + + // produce the first row + for x in 0..ncols { + let x: TrinoTypeSystem = parser.produce()?; + } + + data_set.into_vec().iter().for_each(|row| { + row.value().iter().for_each(|x| { + println!("{:?}", x); + }); + + println!("{:?}", row); + });*/ + + // TODO: remove hard-coded + self.schema = vec![ + TrinoTypeSystem::Integer(true), + TrinoTypeSystem::Double(true), + TrinoTypeSystem::Varchar(true), + ]; + self.names = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + } + None => {} + } + } + + #[throws(TrinoSourceError)] + fn result_rows(&mut self) -> Option { + match &self.origin_query { + Some(q) => { + let cxq = CXQuery::Naked(q.clone()); + let client = self.client.clone(); + let nrows = self.rt.block_on(get_total_rows(client, &cxq))?; + Some(nrows) + } + None => None, + } + } + + fn names(&self) -> Vec { + self.names.clone() + } + + fn schema(&self) -> Vec { + self.schema.clone() + } + + #[throws(TrinoSourceError)] + fn partition(self) -> Vec { + let mut ret = vec![]; + + for query in self.queries { + ret.push( + TrinoSourcePartition::new( + self.client.clone(), + query, + self.schema.clone(), + self.rt.clone(), + ) + .unwrap(), // TODO: handle error + ); + } + ret + } +} + +pub struct TrinoSourcePartition { + client: Arc, + query: CXQuery, + schema: Vec, + rt: Arc, + nrows: usize, + ncols: usize, +} + +impl TrinoSourcePartition { + #[throws(TrinoSourceError)] + pub fn new( + client: Arc, + query: CXQuery, + schema: Vec, + rt: Arc, + ) -> Self { + Self { + client, + query, + schema: schema.clone(), + rt, + nrows: 0, + ncols: schema.len(), + } + } +} + +impl SourcePartition for TrinoSourcePartition { + type TypeSystem = TrinoTypeSystem; + type Parser<'a> = TrinoSourceParser<'a>; + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn result_rows(&mut self) { + self.nrows = self + .rt + .block_on(get_total_rows(self.client.clone(), &self.query))? + } + + #[throws(TrinoSourceError)] + fn parser(&mut self) -> Self::Parser<'_> { + let query = self.query.clone(); + TrinoSourceParser::new(self.rt.clone(), self.client.clone(), query, &self.schema)? + } + + fn nrows(&self) -> usize { + self.nrows + } + + fn ncols(&self) -> usize { + self.ncols + } +} + +pub struct TrinoSourceParser<'a> { + rows: Vec, + nrows: usize, + ncols: usize, + current_col: usize, + current_row: usize, + _phantom: &'a PhantomData>, +} + +impl<'a> TrinoSourceParser<'a> { + #[throws(TrinoSourceError)] + pub fn new( + rt: Arc, + client: Arc, + query: CXQuery, + schema: &[TrinoTypeSystem], + ) -> Self { + let rows = client.get_all::(query.to_string()); + let data = rt.block_on(rows).map_err(TrinoSourceError::PrustoError)?; + let rows = data.clone().into_vec(); + + Self { + rows, + nrows: data.len(), + ncols: schema.len(), + current_col: 0, + current_row: 0, + _phantom: &PhantomData, + } + } + + #[throws(TrinoSourceError)] + fn next_loc(&mut self) -> (usize, usize) { + let ret = (self.current_row, self.current_col); + self.current_row += (self.current_col + 1) / self.ncols; + self.current_col = (self.current_col + 1) % self.ncols; + ret + } +} + +impl<'a> PartitionParser<'a> for TrinoSourceParser<'a> { + type TypeSystem = TrinoTypeSystem; + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn fetch_next(&mut self) -> (usize, bool) { + assert!(self.current_col == 0); + + (self.nrows, true) + } +} + +macro_rules! impl_produce_int { + ($($t: ty,)+) => { + $( + impl<'r, 'a> Produce<'r, $t> for TrinoSourceParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> $t { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::Number(x) => { + if (x.is_i64()) { + <$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))? + } else { + throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + } + } + _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + } + } + } + + impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourceParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> Option<$t> { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::Null => None, + Value::Number(x) => { + if (x.is_i64()) { + Some(<$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))?) + } else { + throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + } + } + _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + } + } + } + )+ + }; +} + +macro_rules! impl_produce_float { + ($($t: ty,)+) => { + $( + impl<'r, 'a> Produce<'r, $t> for TrinoSourceParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> $t { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::Number(x) => { + if (x.is_f64()) { + x.as_f64().unwrap() as $t + } else { + throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + } + } + _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + } + } + } + + impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourceParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> Option<$t> { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::Null => None, + Value::Number(x) => { + if (x.is_f64()) { + Some(x.as_f64().unwrap() as $t) + } else { + throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + } + } + _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + } + } + } + )+ + }; +} + +macro_rules! impl_produce_text { + ($($t: ty,)+) => { + $( + impl<'r, 'a> Produce<'r, $t> for TrinoSourceParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> $t { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::String(x) => { + x.parse().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {})", ridx, cidx))? + } + _ => throw!(anyhow!("Trino cannot parse String at position: ({}, {})", ridx, cidx)) + } + } + } + + impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourceParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> Option<$t> { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::Null => None, + Value::String(x) => { + Some(x.parse().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {})", ridx, cidx))?) + } + _ => throw!(anyhow!("Trino cannot parse String at position: ({}, {})", ridx, cidx)) + } + } + } + )+ + }; +} + +impl_produce_int!(i8, i16, i32, i64,); +impl_produce_float!(f32, f64,); +impl_produce_text!(NaiveDate, NaiveTime, NaiveDateTime, String, bool, char,); diff --git a/connectorx/src/sources/trino/typesystem.rs b/connectorx/src/sources/trino/typesystem.rs index 03111ee91b..d5a492f038 100644 --- a/connectorx/src/sources/trino/typesystem.rs +++ b/connectorx/src/sources/trino/typesystem.rs @@ -3,15 +3,13 @@ use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; use fehler::{throw, throws}; use prusto::{PrestoFloat, PrestoInt, PrestoTy}; use std::convert::TryFrom; -use uuid::Uuid; -// TODO: implement Tuple, Row, Array and Map +// TODO: implement Tuple, Row, Array and Map as well as UUID #[derive(Copy, Clone, Debug, PartialEq)] pub enum TrinoTypeSystem { Date(bool), Time(bool), Timestamp(bool), - Uuid(bool), Boolean(bool), Bigint(bool), Integer(bool), @@ -29,7 +27,6 @@ impl_typesystem! { { Date => NaiveDate } { Time => NaiveTime } { Timestamp => NaiveDateTime } - { Uuid => Uuid } { Boolean => bool } { Bigint => i64 } { Integer => i32 } @@ -37,7 +34,7 @@ impl_typesystem! { { Tinyint => i8 } { Double => f64 } { Real => f32 } - { Varchar => Box } + { Varchar => String } { Char => char } } } @@ -52,7 +49,6 @@ impl TryFrom for TrinoTypeSystem { PrestoTy::Date => Date(true), PrestoTy::Time => Time(true), PrestoTy::Timestamp => Timestamp(true), - PrestoTy::Uuid => Uuid(true), PrestoTy::Boolean => Boolean(true), PrestoTy::PrestoInt(PrestoInt::I64) => Bigint(true), PrestoTy::PrestoInt(PrestoInt::I32) => Integer(true), @@ -86,7 +82,6 @@ impl TryFrom<(Option<&str>, PrestoTy)> for TrinoTypeSystem { "date" => Date(true), "time" => Time(true), "timestamp" => Timestamp(true), - "uuid" => Uuid(true), "boolean" => Boolean(true), "bigint" => Bigint(true), "int" | "integer" => Integer(true), diff --git a/connectorx/src/transports/mod.rs b/connectorx/src/transports/mod.rs index 8be61dc2c0..30a209c255 100644 --- a/connectorx/src/transports/mod.rs +++ b/connectorx/src/transports/mod.rs @@ -44,7 +44,8 @@ mod sqlite_arrow; mod sqlite_arrow2; #[cfg(all(feature = "src_sqlite", feature = "dst_arrow"))] mod sqlite_arrowstream; - +#[cfg(all(feature = "src_trino", feature = "dst_arrow2"))] +mod trino_arrow2; #[cfg(all(feature = "src_bigquery", feature = "dst_arrow"))] pub use bigquery_arrow::{BigQueryArrowTransport, BigQueryArrowTransportError}; #[cfg(all(feature = "src_bigquery", feature = "dst_arrow2"))] @@ -105,3 +106,5 @@ pub use sqlite_arrowstream::{ SQLiteArrowTransport as SQLiteArrowStreamTransport, SQLiteArrowTransportError as SQLiteArrowStreamTransportError, }; +#[cfg(all(feature = "src_trino", feature = "dst_arrow2"))] +pub use trino_arrow2::{TrinoArrow2Transport, TrinoArrow2TransportError}; diff --git a/connectorx/src/transports/trino_arrow2.rs b/connectorx/src/transports/trino_arrow2.rs new file mode 100644 index 0000000000..27290d4a90 --- /dev/null +++ b/connectorx/src/transports/trino_arrow2.rs @@ -0,0 +1,64 @@ +//! Transport from Trino Source to Arrow2 Destination. + +use crate::{ + destinations::arrow2::{ + typesystem::Arrow2TypeSystem, Arrow2Destination, Arrow2DestinationError, + }, + impl_transport, + sources::trino::{TrinoSource, TrinoSourceError, TrinoTypeSystem}, + typesystem::TypeConversion, +}; +use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; +use num_traits::ToPrimitive; +use rust_decimal::Decimal; +use serde_json::{to_string, Value}; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum TrinoArrow2TransportError { + #[error(transparent)] + Source(#[from] TrinoSourceError), + + #[error(transparent)] + Destination(#[from] Arrow2DestinationError), + + #[error(transparent)] + ConnectorX(#[from] crate::errors::ConnectorXError), +} + +/// Convert Trino data types to Arrow2 data types. +pub struct TrinoArrow2Transport(); + +impl_transport!( + name = TrinoArrow2Transport, + error = TrinoArrow2TransportError, + systems = TrinoTypeSystem => Arrow2TypeSystem, + route = TrinoSource => Arrow2Destination, + mappings = { + { Date[NaiveDate] => Date64[NaiveDate] | conversion auto } + { Time[NaiveTime] => Time64[NaiveTime] | conversion auto } + { Timestamp[NaiveDateTime] => Date64[NaiveDateTime] | conversion auto } + { Boolean[bool] => Boolean[bool] | conversion auto } + { Bigint[i32] => Int64[i64] | conversion auto } + { Integer[i32] => Int64[i64] | conversion none } + { Smallint[i16] => Int64[i64] | conversion auto } + { Tinyint[i8] => Int64[i64] | conversion auto } + { Double[f64] => Float64[f64] | conversion auto } + { Real[f32] => Float64[f64] | conversion auto } + { Varchar[String] => LargeUtf8[String] | conversion auto } + { Char[String] => LargeUtf8[String] | conversion none } + } +); + +impl TypeConversion for TrinoArrow2Transport { + fn convert(val: Decimal) -> f64 { + val.to_f64() + .unwrap_or_else(|| panic!("cannot convert decimal {:?} to float64", val)) + } +} + +impl TypeConversion for TrinoArrow2Transport { + fn convert(val: Value) -> String { + to_string(&val).unwrap() + } +} From b33802291d9b30b6a9604877e1299d30c48cfc94 Mon Sep 17 00:00:00 2001 From: Dominik Liebler Date: Fri, 9 Feb 2024 17:12:58 +0100 Subject: [PATCH 03/10] WIP Trino in Python --- README.md | 1 + connectorx-python/Cargo.lock | 450 ++++++++++------- .../connectorx/tests/test_trino.py | 470 ++++++++++++++++++ connectorx-python/src/errors.rs | 3 + connectorx-python/src/pandas/mod.rs | 13 +- .../src/pandas/transports/mod.rs | 2 + .../src/pandas/transports/trino.rs | 54 ++ connectorx/src/sources/trino/errors.rs | 2 +- connectorx/src/sources/trino/mod.rs | 325 ++++++++---- connectorx/src/transports/mod.rs | 11 + connectorx/src/transports/trino_arrow.rs | 64 +++ connectorx/src/transports/trino_arrow2.rs | 2 +- .../src/transports/trino_arrowstream.rs | 64 +++ 13 files changed, 1196 insertions(+), 265 deletions(-) create mode 100644 connectorx-python/connectorx/tests/test_trino.py create mode 100644 connectorx-python/src/pandas/transports/trino.rs create mode 100644 connectorx/src/transports/trino_arrow.rs create mode 100644 connectorx/src/transports/trino_arrowstream.rs diff --git a/README.md b/README.md index 4aeb0c9708..5924e7c9ea 100644 --- a/README.md +++ b/README.md @@ -107,6 +107,7 @@ For more planned data sources, please check out our [discussion](https://github. - [x] Oracle - [x] Big Query - [ ] ODBC (WIP) +- [ ] Trino (WIP) - [ ] ... ## Destinations diff --git a/connectorx-python/Cargo.lock b/connectorx-python/Cargo.lock index 50171df3c3..a877cbce4c 100644 --- a/connectorx-python/Cargo.lock +++ b/connectorx-python/Cargo.lock @@ -130,9 +130,9 @@ checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "arrow" -version = "40.0.0" +version = "46.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6619cab21a0cdd8c9b9f1d9e09bfaa9b1974e5ef809a6566aef0b998caf38ace" +checksum = "04a8801ebb147ad240b2d978d3ab9f73c9ccd4557ba6a03e7800496770ed10e0" dependencies = [ "ahash 0.8.3", "arrow-arith", @@ -152,9 +152,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "40.0.0" +version = "46.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0dc95485623a76e00929bda8caa40c1f838190952365c4f43a7b9ae86d03e94" +checksum = "895263144bd4a69751cbe6a34a53f26626e19770b313a9fa792c415cd0e78f11" dependencies = [ "arrow-array", "arrow-buffer", @@ -167,9 +167,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "40.0.0" +version = "46.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3267847f53d3042473cfd2c769afd8d74a6d7d201fc3a34f5cb84c0282ef47a7" +checksum = "226fdc6c3a4ae154a74c24091d36a90b514f0ed7112f5b8322c1d8f354d8e20d" dependencies = [ "ahash 0.8.3", "arrow-buffer", @@ -178,25 +178,26 @@ dependencies = [ "chrono", "chrono-tz", "half 2.3.1", - "hashbrown 0.13.2", + "hashbrown 0.14.0", "num", ] [[package]] name = "arrow-buffer" -version = "40.0.0" +version = "46.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5f66553e66e120ac4b21570368ee9ebf35ff3f5399f872b0667699e145678f5" +checksum = "fc4843af4dd679c2f35b69c572874da8fde33be53eb549a5fb128e7a4b763510" dependencies = [ + "bytes", "half 2.3.1", "num", ] [[package]] name = "arrow-cast" -version = "40.0.0" +version = "46.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65e6f3579dbf0d97c683d451b2550062b0f0e62a3169bf74238b5f59f44ad6d8" +checksum = "35e8b9990733a9b635f656efda3c9b8308c7a19695c9ec2c7046dd154f9b144b" dependencies = [ "arrow-array", "arrow-buffer", @@ -204,16 +205,17 @@ dependencies = [ "arrow-schema", "arrow-select", "chrono", - "comfy-table 6.2.0", + "comfy-table", + "half 2.3.1", "lexical-core", "num", ] [[package]] name = "arrow-csv" -version = "40.0.0" +version = "46.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "373579c4c1a8f5307d3125b7a89c700fcf8caf85821c77eb4baab3855ae0aba5" +checksum = "646fbb4e11dd0afb8083e883f53117713b8caadb4413b3c9e63e3f535da3683c" dependencies = [ "arrow-array", "arrow-buffer", @@ -230,9 +232,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "40.0.0" +version = "46.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61bc8df9912cca6642665fdf989d6fa0de2570f18a7f709bcf59d29de96d2097" +checksum = "da900f31ff01a0a84da0572209be72b2b6f980f3ea58803635de47913191c188" dependencies = [ "arrow-buffer", "arrow-schema", @@ -252,9 +254,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "40.0.0" +version = "46.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0105dcf5f91daa7182d87b713ee0b32b3bfc88e0c48e7dc3e9d6f1277a07d1ae" +checksum = "2707a8d7ee2d345d045283ece3ae43416175873483e5d96319c929da542a0b1f" dependencies = [ "arrow-array", "arrow-buffer", @@ -266,9 +268,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "40.0.0" +version = "46.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e73134fb5b5ec8770f8cbb214c2c487b2d350081e403ca4eeeb6f8f5e19846ac" +checksum = "5d1b91a63c356d14eedc778b76d66a88f35ac8498426bb0799a769a49a74a8b4" dependencies = [ "arrow-array", "arrow-buffer", @@ -277,7 +279,7 @@ dependencies = [ "arrow-schema", "chrono", "half 2.3.1", - "indexmap 1.9.3", + "indexmap 2.0.0", "lexical-core", "num", "serde", @@ -286,9 +288,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "40.0.0" +version = "46.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89f25bc66e18d4c2aa1fe2f9bb03e2269da60e636213210385ae41a107f9965a" +checksum = "584325c91293abbca7aaaabf8da9fe303245d641f5f4a18a6058dc68009c7ebf" dependencies = [ "arrow-array", "arrow-buffer", @@ -301,9 +303,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "40.0.0" +version = "46.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1095ff85ea4f5ff02d17b30b089de31b51a50be01c6b674f0a0509ab771232f1" +checksum = "0e32afc1329f7b372463b21c6ca502b07cf237e1ed420d87706c1770bb0ebd38" dependencies = [ "ahash 0.8.3", "arrow-array", @@ -311,23 +313,23 @@ dependencies = [ "arrow-data", "arrow-schema", "half 2.3.1", - "hashbrown 0.13.2", + "hashbrown 0.14.0", ] [[package]] name = "arrow-schema" -version = "40.0.0" +version = "46.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25187bbef474151a2e4ddec67b9e34bda5cbfba292dc571392fa3a1f71ff5a82" +checksum = "b104f5daa730f00fde22adc03a12aa5a2ae9ccbbf99cbd53d284119ddc90e03d" dependencies = [ "bitflags 2.4.0", ] [[package]] name = "arrow-select" -version = "40.0.0" +version = "46.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd0d4ee884aec3aa05e41478e3cd312bf609de9babb5d187a43fb45931da4da4" +checksum = "73b3ca55356d1eae07cf48808d8c462cea674393ae6ad1e0b120f40b422eb2b4" dependencies = [ "arrow-array", "arrow-buffer", @@ -338,15 +340,16 @@ dependencies = [ [[package]] name = "arrow-string" -version = "40.0.0" +version = "46.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6d71c3ffe4c07e66ce8fdc6aed5b00e0e60c5144911879b10546f5b72d8fa1c" +checksum = "af1433ce02590cae68da0a18ed3a3ed868ffac2c6f24c533ddd2067f7ee04b4a" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "arrow-select", + "num", "regex", "regex-syntax 0.7.5", ] @@ -914,18 +917,17 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.29" +version = "0.4.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d87d9d13be47a5b7c3907137f1290b0459a7f80efb26be8c52afb11963bccb02" +checksum = "9f13690e35a5e4ace198e7beea2895d29f3a9cc55015fcebe6336bd2010af9eb" dependencies = [ "android-tzdata", "iana-time-zone", "js-sys", "num-traits", "serde", - "time 0.1.45", "wasm-bindgen", - "windows-targets", + "windows-targets 0.52.0", ] [[package]] @@ -985,17 +987,6 @@ dependencies = [ "cc", ] -[[package]] -name = "comfy-table" -version = "6.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e959d788268e3bf9d35ace83e81b124190378e4c91c9067524675e33394b8ba" -dependencies = [ - "strum", - "strum_macros 0.24.3", - "unicode-width", -] - [[package]] name = "comfy-table" version = "7.0.1" @@ -1003,7 +994,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ab77dbd8adecaf3f0db40581631b995f312a8a5ae3aa9993188bb8f23d83a5b" dependencies = [ "crossterm", - "strum", + "strum 0.24.1", "strum_macros 0.24.3", "unicode-width", ] @@ -1044,7 +1035,7 @@ dependencies = [ "futures", "gcp-bigquery-client", "hex", - "itertools", + "itertools 0.10.5", "j4rs", "log", "mysql_common", @@ -1057,6 +1048,7 @@ dependencies = [ "postgres", "postgres-native-tls", "postgres-openssl", + "prusto", "r2d2", "r2d2-oracle", "r2d2_mysql", @@ -1067,7 +1059,7 @@ dependencies = [ "rust_decimal", "rust_decimal_macros", "serde_json", - "sqlparser 0.11.0", + "sqlparser 0.37.0", "thiserror", "tiberius", "tokio", @@ -1095,7 +1087,7 @@ dependencies = [ "env_logger", "fehler", "iai", - "itertools", + "itertools 0.10.5", "lazy_static", "libc", "log", @@ -1111,7 +1103,7 @@ dependencies = [ "rayon", "rust_decimal", "serde_json", - "sqlparser 0.11.0", + "sqlparser 0.37.0", "thiserror", "tokio", "tokio-util 0.6.10", @@ -1148,6 +1140,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7144d30dcf0fafbce74250a3963025d8d52177934239851c917d29f1df280c2" +[[package]] +name = "convert_case" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" + [[package]] name = "core-foundation" version = "0.9.3" @@ -1202,7 +1200,7 @@ dependencies = [ "clap", "criterion-plot", "csv", - "itertools", + "itertools 0.10.5", "lazy_static", "num-traits", "oorandom", @@ -1234,7 +1232,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2673cc8207403546f45f5fd319a974b1e6983ad1a3ee7e6041650013be041876" dependencies = [ "cast", - "itertools", + "itertools 0.10.5", ] [[package]] @@ -1416,9 +1414,9 @@ dependencies = [ [[package]] name = "datafusion" -version = "26.0.0" +version = "31.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9992c267436551d40b52d65289b144712e7b0ebdc62c8c859fd1574e5f73efbb" +checksum = "6a4e4fc25698a14c90b34dda647ba10a5a966dc04b036d22e77fb1048663375d" dependencies = [ "ahash 0.8.3", "arrow", @@ -1435,15 +1433,14 @@ dependencies = [ "datafusion-expr", "datafusion-optimizer", "datafusion-physical-expr", - "datafusion-row", "datafusion-sql", "flate2", "futures", "glob", - "hashbrown 0.13.2", - "indexmap 1.9.3", - "itertools", - "lazy_static", + "half 2.3.1", + "hashbrown 0.14.0", + "indexmap 2.0.0", + "itertools 0.11.0", "log", "num_cpus", "object_store", @@ -1452,11 +1449,9 @@ dependencies = [ "percent-encoding", "pin-project-lite", "rand 0.8.5", - "smallvec", - "sqlparser 0.34.0", + "sqlparser 0.37.0", "tempfile", "tokio", - "tokio-stream", "tokio-util 0.7.8", "url", "uuid 1.4.1", @@ -1466,29 +1461,40 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "26.0.0" +version = "31.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3be97f7a7c720cdbb71e9eeabf814fa6ad8102b9022390f6cac74d3b4af6392" +checksum = "c23ad0229ea4a85bf76b236d8e75edf539881fdb02ce4e2394f9a76de6055206" dependencies = [ "arrow", "arrow-array", + "async-compression", + "bytes", + "bzip2", "chrono", + "flate2", + "futures", "num_cpus", "object_store", "parquet", - "sqlparser 0.34.0", + "sqlparser 0.37.0", + "tokio", + "tokio-util 0.7.8", + "xz2", + "zstd", ] [[package]] name = "datafusion-execution" -version = "26.0.0" +version = "31.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c77c4b14b809b0e4c5bb101b6834504f06cdbb0d3c643400c61d0d844b33264e" +checksum = "9b37d2fc1a213baf34e0a57c85b8e6648f1a95152798fd6738163ee96c19203f" dependencies = [ + "arrow", "dashmap", "datafusion-common", "datafusion-expr", - "hashbrown 0.13.2", + "futures", + "hashbrown 0.14.0", "log", "object_store", "parking_lot 0.12.1", @@ -1499,24 +1505,23 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "26.0.0" +version = "31.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6ec7409bd45cf4fae6395d7d1024c8a97e543cadc88363e405d2aad5330e5e7" +checksum = "d6ea9844395f537730a145e5d87f61fecd37c2bc9d54e1dc89b35590d867345d" dependencies = [ "ahash 0.8.3", "arrow", "datafusion-common", - "lazy_static", - "sqlparser 0.34.0", - "strum", - "strum_macros 0.24.3", + "sqlparser 0.37.0", + "strum 0.25.0", + "strum_macros 0.25.2", ] [[package]] name = "datafusion-optimizer" -version = "26.0.0" +version = "31.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64b537c93f87989c212db92a448a0f5eb4f0995e27199bb7687ae94f8b64a7a8" +checksum = "c8a30e0f79c5d59ba14d3d70f2500e87e0ff70236ad5e47f9444428f054fd2be" dependencies = [ "arrow", "async-trait", @@ -1524,35 +1529,36 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-physical-expr", - "hashbrown 0.13.2", - "itertools", + "hashbrown 0.14.0", + "itertools 0.11.0", "log", "regex-syntax 0.7.5", ] [[package]] name = "datafusion-physical-expr" -version = "26.0.0" +version = "31.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f60ee3f53340fdef36ee54d9e12d446ae2718b1d0196ac581f791d34808ec876" +checksum = "766c567082c9bbdcb784feec8fe40c7049cedaeb3a18d54f563f75fe0dc1932c" dependencies = [ "ahash 0.8.3", "arrow", "arrow-array", "arrow-buffer", "arrow-schema", + "base64 0.21.3", "blake2", "blake3", "chrono", "datafusion-common", "datafusion-expr", - "datafusion-row", "half 2.3.1", - "hashbrown 0.13.2", - "indexmap 1.9.3", - "itertools", - "lazy_static", + "hashbrown 0.14.0", + "hex", + "indexmap 2.0.0", + "itertools 0.11.0", "libc", + "log", "md-5", "paste 1.0.14", "petgraph 0.6.4", @@ -1563,30 +1569,18 @@ dependencies = [ "uuid 1.4.1", ] -[[package]] -name = "datafusion-row" -version = "26.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d58fc64058aa3bcb00077a0d19474a0d584d31dec8c7ac3406868f485f659af9" -dependencies = [ - "arrow", - "datafusion-common", - "paste 1.0.14", - "rand 0.8.5", -] - [[package]] name = "datafusion-sql" -version = "26.0.0" +version = "31.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1531f0314151a34bf6c0a83c7261525688b7c729876f53e7896b8f4ca8f57d07" +checksum = "811fd084cf2d78aa0c76b74320977c7084ad0383690612528b580795764b4dd0" dependencies = [ "arrow", "arrow-schema", "datafusion-common", "datafusion-expr", "log", - "sqlparser 0.34.0", + "sqlparser 0.37.0", ] [[package]] @@ -1607,6 +1601,19 @@ dependencies = [ "serde", ] +[[package]] +name = "derive_more" +version = "0.99.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321" +dependencies = [ + "convert_case", + "proc-macro2", + "quote", + "rustc_version", + "syn 1.0.109", +] + [[package]] name = "derive_utils" version = "0.13.2" @@ -2127,7 +2134,7 @@ dependencies = [ "serde", "serde_json", "thiserror", - "time 0.3.28", + "time", "tokio", "tokio-stream", "url", @@ -2222,15 +2229,6 @@ version = "2.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "74721d007512d0cb3338cd20f0654ac913920061a4c4d0d8708edb3f2a698c0c" -[[package]] -name = "hashbrown" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" -dependencies = [ - "ahash 0.7.6", -] - [[package]] name = "hashbrown" version = "0.12.3" @@ -2262,11 +2260,11 @@ dependencies = [ [[package]] name = "hashlink" -version = "0.7.0" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7249a3129cbc1ffccd74857f81464a323a152173cdb134e0fd81bc803b29facf" +checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7" dependencies = [ - "hashbrown 0.11.2", + "hashbrown 0.14.0", ] [[package]] @@ -2553,6 +2551,15 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" +[[package]] +name = "iterable" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c151dfd6ab7dff5ca5567d82041bb286f07469ece85c1e2444a6d26d7057a65f" +dependencies = [ + "itertools 0.10.5", +] + [[package]] name = "itertools" version = "0.10.5" @@ -2562,6 +2569,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.9" @@ -2755,9 +2771,9 @@ checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" [[package]] name = "libsqlite3-sys" -version = "0.24.2" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "898745e570c7d0453cc1fbc4a701eb6c662ed54e8fec8b7d14be137ebeeb9d14" +checksum = "afc22eff61b133b115c6e8c74e818c628d6d5e7a502afea6f64dee076dd94326" dependencies = [ "cc", "pkg-config", @@ -3022,7 +3038,7 @@ dependencies = [ "smallvec", "subprocess", "thiserror", - "time 0.3.28", + "time", "uuid 1.4.1", ] @@ -3237,15 +3253,16 @@ dependencies = [ [[package]] name = "object_store" -version = "0.5.6" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec9cd6ca25e796a49fa242876d1c4de36a24a6da5258e9f0bc062dbf5e81c53b" +checksum = "f930c88a43b1c3f6e776dfe495b4afab89882dbc81530c632db2ed65451ebcb4" dependencies = [ "async-trait", "bytes", "chrono", "futures", - "itertools", + "humantime", + "itertools 0.11.0", "parking_lot 0.12.1", "percent-encoding", "snafu", @@ -3423,14 +3440,14 @@ dependencies = [ "libc", "redox_syscall 0.3.5", "smallvec", - "windows-targets", + "windows-targets 0.48.5", ] [[package]] name = "parquet" -version = "40.0.0" +version = "46.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6a656fcc17e641657c955742c689732684e096f790ff30865d9f8dcc39f7c4a" +checksum = "1ad2cba786ae07da4d73371a88b9e0f9d3ffac1a9badc83922e0e15814f5c5fa" dependencies = [ "ahash 0.8.3", "arrow-array", @@ -3446,7 +3463,7 @@ dependencies = [ "chrono", "flate2", "futures", - "hashbrown 0.13.2", + "hashbrown 0.14.0", "lz4", "num", "num-bigint", @@ -3669,7 +3686,7 @@ dependencies = [ "arrow2", "bitflags 2.4.0", "chrono", - "comfy-table 7.0.1", + "comfy-table", "either", "hashbrown 0.14.0", "indexmap 2.0.0", @@ -4026,7 +4043,7 @@ checksum = "355f634b43cdd80724ee7848f95770e7e70eefa6dcf14fea676216573b8fd603" dependencies = [ "bytes", "heck 0.3.3", - "itertools", + "itertools 0.10.5", "log", "multimap", "petgraph 0.5.1", @@ -4043,7 +4060,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "600d2f334aa05acb02a755e217ef1ab6dea4d51b58b7846588b747edec04efba" dependencies = [ "anyhow", - "itertools", + "itertools 0.10.5", "proc-macro2", "quote", "syn 1.0.109", @@ -4059,6 +4076,43 @@ dependencies = [ "prost", ] +[[package]] +name = "prusto" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4b88a35eb608a69482012e38b818a77c23bd1f3fe952143217609ad6c43f94" +dependencies = [ + "bigdecimal", + "chrono", + "chrono-tz", + "derive_more", + "futures", + "http", + "iterable", + "lazy_static", + "log", + "prusto-macros", + "regex", + "reqwest", + "serde", + "serde_json", + "thiserror", + "tokio", + "urlencoding", + "uuid 1.4.1", +] + +[[package]] +name = "prusto-macros" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "729a73ec40e80da961c846455ec579c521346392d6f9f5a8c8aadfb5c99f9cf8" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "ptr_meta" version = "0.1.4" @@ -4164,9 +4218,9 @@ dependencies = [ [[package]] name = "r2d2-oracle" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eca5358dca54423e557b30e7b5a6d950d3a442ab4a56cc916965030cead8b02b" +checksum = "e592c29a9d04b2eb9aa5adc8775087200343b486efa8a374cb43a02f4269d67f" dependencies = [ "oracle", "r2d2", @@ -4194,12 +4248,13 @@ dependencies = [ [[package]] name = "r2d2_sqlite" -version = "0.20.0" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fdc8e4da70586127893be32b7adf21326a4c6b1aba907611edf467d13ffe895" +checksum = "99f31323d6161385f385046738df520e0e8694fa74852d35891fc0be08348ddc" dependencies = [ "r2d2", "rusqlite", + "uuid 1.4.1", ] [[package]] @@ -4483,17 +4538,16 @@ dependencies = [ [[package]] name = "rusqlite" -version = "0.27.0" +version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85127183a999f7db96d1a976a309eebbfb6ea3b0b400ddd8340190129de6eb7a" +checksum = "549b9d036d571d42e6e85d1c1425e2ac83491075078ca9a15be021c56b1641f2" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.4.0", "chrono", "fallible-iterator", "fallible-streaming-iterator", "hashlink", "libsqlite3-sys", - "memchr", "smallvec", ] @@ -4932,32 +4986,23 @@ checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" [[package]] name = "sqlparser" -version = "0.11.0" +version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10e1ce16b71375ad72d28d111131069ce0d5f8603f4f86d8acd3456b41b57a51" +checksum = "2eaa1e88e78d2c2460d78b7dc3f0c08dbb606ab4222f9aff36f420d36e307d87" dependencies = [ "log", ] [[package]] name = "sqlparser" -version = "0.34.0" +version = "0.37.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37d3706eefb17039056234df6b566b0014f303f867f2656108334a55b8096f59" +checksum = "37ae05a8250b968a3f7db93155a84d68b2e6cea1583949af5ca5b5170c76c075" dependencies = [ "log", "sqlparser_derive", ] -[[package]] -name = "sqlparser" -version = "0.36.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2eaa1e88e78d2c2460d78b7dc3f0c08dbb606ab4222f9aff36f420d36e307d87" -dependencies = [ - "log", -] - [[package]] name = "sqlparser_derive" version = "0.1.1" @@ -5021,8 +5066,14 @@ name = "strum" version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "063e6045c0e62079840579a7e47a355ae92f60eb74daaf156fb1e84ba164e63f" + +[[package]] +name = "strum" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" dependencies = [ - "strum_macros 0.24.3", + "strum_macros 0.25.2", ] [[package]] @@ -5232,17 +5283,6 @@ dependencies = [ "winauth", ] -[[package]] -name = "time" -version = "0.1.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b797afad3f312d1c66a56d11d0316f916356d11bd158fbc6ca6389ff6bf805a" -dependencies = [ - "libc", - "wasi 0.10.0+wasi-snapshot-preview1", - "winapi", -] - [[package]] name = "time" version = "0.3.28" @@ -5320,6 +5360,7 @@ dependencies = [ "num_cpus", "parking_lot 0.12.1", "pin-project-lite", + "signal-hook-registry", "socket2 0.5.3", "tokio-macros", "windows-sys", @@ -5595,6 +5636,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "79daa5ed5740825c40b389c5e50312b9c86df53fccd33f281df655642b43869d" dependencies = [ "getrandom 0.2.10", + "rand 0.8.5", + "serde", ] [[package]] @@ -5640,12 +5683,6 @@ version = "0.9.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" -[[package]] -name = "wasi" -version = "0.10.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -5816,7 +5853,7 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -5825,7 +5862,7 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -5834,13 +5871,28 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +dependencies = [ + "windows_aarch64_gnullvm 0.52.0", + "windows_aarch64_msvc 0.52.0", + "windows_i686_gnu 0.52.0", + "windows_i686_msvc 0.52.0", + "windows_x86_64_gnu 0.52.0", + "windows_x86_64_gnullvm 0.52.0", + "windows_x86_64_msvc 0.52.0", ] [[package]] @@ -5849,42 +5901,84 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" + [[package]] name = "windows_i686_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" +[[package]] +name = "windows_i686_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" + [[package]] name = "windows_i686_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" +[[package]] +name = "windows_i686_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" + [[package]] name = "winreg" version = "0.50.0" @@ -5932,7 +6026,7 @@ dependencies = [ "http", "hyper", "hyper-rustls 0.23.2", - "itertools", + "itertools 0.10.5", "log", "percent-encoding", "rustls 0.20.9", @@ -5940,7 +6034,7 @@ dependencies = [ "seahash", "serde", "serde_json", - "time 0.3.28", + "time", "tokio", "tower-service", "url", diff --git a/connectorx-python/connectorx/tests/test_trino.py b/connectorx-python/connectorx/tests/test_trino.py new file mode 100644 index 0000000000..9376bf5058 --- /dev/null +++ b/connectorx-python/connectorx/tests/test_trino.py @@ -0,0 +1,470 @@ +import os + +import pandas as pd +import pytest +from pandas.testing import assert_frame_equal + +from .. import read_sql + + +@pytest.fixture(scope="module") # type: ignore +def mysql_url() -> str: + conn = os.environ["MYSQL_URL"] + # conn = os.environ["MARIADB_URL"] + return conn + + +def test_mysql_without_partition(mysql_url: str) -> None: + query = "select * from test_table limit 3" + df = read_sql(mysql_url, query) + expected = pd.DataFrame( + index=range(3), + data={ + "test_int": pd.Series([1, 2, 3], dtype="Int64"), + "test_float": pd.Series([1.1, 2.2, 3.3], dtype="float64"), + "test_enum": pd.Series(["odd", "even", "odd"], dtype="object"), + "test_null": pd.Series([None, None, None], dtype="Int64"), + }, + ) + assert_frame_equal(df, expected, check_names=True) + + +def test_mysql_with_partition(mysql_url: str) -> None: + query = "select * from test_table" + df = read_sql( + mysql_url, + query, + partition_on="test_int", + partition_range=(0, 10), + partition_num=6, + ) + expected = pd.DataFrame( + index=range(6), + data={ + "test_int": pd.Series([1, 2, 3, 4, 5, 6], dtype="Int64"), + "test_float": pd.Series([1.1, 2.2, 3.3, 4.4, 5.5, 6.6], dtype="float64"), + "test_enum": pd.Series( + ["odd", "even", "odd", "even", "odd", "even"], dtype="object" + ), + "test_null": pd.Series([None, None, None, None, None, None], dtype="Int64"), + }, + ) + df.sort_values(by="test_int", inplace=True, ignore_index=True) + assert_frame_equal(df, expected, check_names=True) + + +def test_mysql_without_partition(mysql_url: str) -> None: + query = "SELECT * FROM test_table" + df = read_sql(mysql_url, query) + expected = pd.DataFrame( + index=range(6), + data={ + "test_int": pd.Series([1, 2, 3, 4, 5, 6], dtype="Int64"), + "test_float": pd.Series([1.1, 2.2, 3.3, 4.4, 5.5, 6.6], dtype="float64"), + "test_enum": pd.Series( + ["odd", "even", "odd", "even", "odd", "even"], dtype="object" + ), + "test_null": pd.Series([None, None, None, None, None, None], dtype="Int64"), + }, + ) + assert_frame_equal(df, expected, check_names=True) + + +def test_mysql_limit_without_partition(mysql_url: str) -> None: + query = "SELECT * FROM test_table limit 3" + df = read_sql(mysql_url, query) + expected = pd.DataFrame( + index=range(3), + data={ + "test_int": pd.Series([1, 2, 3], dtype="Int64"), + "test_float": pd.Series([1.1, 2.2, 3.3], dtype="float64"), + "test_enum": pd.Series(["odd", "even", "odd"], dtype="object"), + "test_null": pd.Series([None, None, None], dtype="Int64"), + }, + ) + assert_frame_equal(df, expected, check_names=True) + + +def test_mysql_limit_large_without_partition(mysql_url: str) -> None: + query = "SELECT * FROM test_table limit 10" + df = read_sql(mysql_url, query) + expected = pd.DataFrame( + index=range(6), + data={ + "test_int": pd.Series([1, 2, 3, 4, 5, 6], dtype="Int64"), + "test_float": pd.Series([1.1, 2.2, 3.3, 4.4, 5.5, 6.6], dtype="float64"), + "test_enum": pd.Series( + ["odd", "even", "odd", "even", "odd", "even"], dtype="object" + ), + "test_null": pd.Series([None, None, None, None, None, None], dtype="Int64"), + }, + ) + assert_frame_equal(df, expected, check_names=True) + + +def test_mysql_with_partition(mysql_url: str) -> None: + query = "SELECT * FROM test_table" + df = read_sql( + mysql_url, + query, + partition_on="test_int", + partition_range=(0, 2000), + partition_num=3, + ) + expected = pd.DataFrame( + index=range(6), + data={ + "test_int": pd.Series([1, 2, 3, 4, 5, 6], dtype="Int64"), + "test_float": pd.Series([1.1, 2.2, 3.3, 4.4, 5.5, 6.6], dtype="float64"), + "test_enum": pd.Series( + ["odd", "even", "odd", "even", "odd", "even"], dtype="object" + ), + "test_null": pd.Series([None, None, None, None, None, None], dtype="Int64"), + }, + ) + df.sort_values(by="test_int", inplace=True, ignore_index=True) + assert_frame_equal(df, expected, check_names=True) + + +def test_mysql_limit_with_partition(mysql_url: str) -> None: + query = "SELECT * FROM test_table limit 3" + df = read_sql( + mysql_url, + query, + partition_on="test_int", + partition_range=(0, 2000), + partition_num=3, + ) + expected = pd.DataFrame( + index=range(3), + data={ + "test_int": pd.Series([1, 2, 3], dtype="Int64"), + "test_float": pd.Series([1.1, 2.2, 3.3], dtype="float64"), + "test_enum": pd.Series(["odd", "even", "odd"], dtype="object"), + "test_null": pd.Series([None, None, None], dtype="Int64"), + }, + ) + df.sort_values(by="test_int", inplace=True, ignore_index=True) + assert_frame_equal(df, expected, check_names=True) + + +def test_mysql_limit_large_with_partition(mysql_url: str) -> None: + query = "SELECT * FROM test_table limit 10" + df = read_sql( + mysql_url, + query, + partition_on="test_int", + partition_range=(0, 2000), + partition_num=3, + ) + expected = pd.DataFrame( + index=range(6), + data={ + "test_int": pd.Series([1, 2, 3, 4, 5, 6], dtype="Int64"), + "test_float": pd.Series([1.1, 2.2, 3.3, 4.4, 5.5, 6.6], dtype="float64"), + "test_enum": pd.Series( + ["odd", "even", "odd", "even", "odd", "even"], dtype="object" + ), + "test_null": pd.Series([None, None, None, None, None, None], dtype="Int64"), + }, + ) + df.sort_values(by="test_int", inplace=True, ignore_index=True) + assert_frame_equal(df, expected, check_names=True) + + +def test_mysql_with_partition_without_partition_range(mysql_url: str) -> None: + query = "SELECT * FROM test_table where test_float > 3" + df = read_sql( + mysql_url, + query, + partition_on="test_int", + partition_num=3, + ) + expected = pd.DataFrame( + index=range(4), + data={ + "test_int": pd.Series([3, 4, 5, 6], dtype="Int64"), + "test_float": pd.Series([3.3, 4.4, 5.5, 6.6], dtype="float64"), + "test_enum": pd.Series(["odd", "even", "odd", "even"], dtype="object"), + "test_null": pd.Series([None, None, None, None], dtype="Int64"), + }, + ) + df.sort_values(by="test_int", inplace=True, ignore_index=True) + assert_frame_equal(df, expected, check_names=True) + + +def test_mysql_manual_partition(mysql_url: str) -> None: + queries = [ + "SELECT * FROM test_table WHERE test_int < 2", + "SELECT * FROM test_table WHERE test_int >= 2", + ] + df = read_sql(mysql_url, query=queries) + expected = pd.DataFrame( + index=range(6), + data={ + "test_int": pd.Series([1, 2, 3, 4, 5, 6], dtype="Int64"), + "test_float": pd.Series([1.1, 2.2, 3.3, 4.4, 5.5, 6.6], dtype="float64"), + "test_enum": pd.Series( + ["odd", "even", "odd", "even", "odd", "even"], dtype="object" + ), + "test_null": pd.Series([None, None, None, None, None, None], dtype="Int64"), + }, + ) + df.sort_values(by="test_int", inplace=True, ignore_index=True) + assert_frame_equal(df, expected, check_names=True) + + +def test_mysql_selection_and_projection(mysql_url: str) -> None: + query = "SELECT test_int FROM test_table WHERE test_float < 5" + df = read_sql( + mysql_url, + query, + partition_on="test_int", + partition_num=3, + ) + expected = pd.DataFrame( + index=range(4), + data={ + "test_int": pd.Series([1, 2, 3, 4], dtype="Int64"), + }, + ) + df.sort_values(by="test_int", inplace=True, ignore_index=True) + assert_frame_equal(df, expected, check_names=True) + + +def test_mysql_join(mysql_url: str) -> None: + query = "SELECT T.test_int, T.test_float, S.test_str FROM test_table T INNER JOIN test_table_extra S ON T.test_int = S.test_int" + df = read_sql( + mysql_url, + query, + partition_on="test_int", + partition_num=3, + ) + expected = pd.DataFrame( + index=range(3), + data={ + "test_int": pd.Series([1, 2, 3], dtype="Int64"), + "test_float": pd.Series([1.1, 2.2, 3.3], dtype="float64"), + "test_str": pd.Series( + [ + "Ha好ち😁ðy̆", + "こんにちは", + "русский", + ], + dtype="object", + ), + }, + ) + df.sort_values(by="test_int", inplace=True, ignore_index=True) + assert_frame_equal(df, expected, check_names=True) + + +def test_mysql_aggregate(mysql_url: str) -> None: + query = "select AVG(test_float) as avg_float, SUM(T.test_int) as sum_int, SUM(test_null) as sum_null from test_table as T INNER JOIN test_table_extra as S where T.test_int = S.test_int GROUP BY test_enum ORDER BY sum_int" + df = read_sql(mysql_url, query) + expected = pd.DataFrame( + index=range(2), + data={ + "avg_float": pd.Series([2.2, 2.2], dtype="float64"), + "sum_int": pd.Series([2.0, 4.0], dtype="float64"), + "sum_null": pd.Series([None, None], dtype="float64"), + }, + ) + assert_frame_equal(df, expected, check_names=True) + + +def test_mysql_types_binary(mysql_url: str) -> None: + query = "select * from test_types" + df = read_sql(mysql_url, query, protocol="binary") + expected = pd.DataFrame( + index=range(3), + data={ + "test_timestamp": pd.Series( + ["1970-01-01 00:00:01", "2038-01-19 00:00:00", None], + dtype="datetime64[ns]", + ), + "test_date": pd.Series( + [None, "1970-01-01", "2038-01-19"], dtype="datetime64[ns]" + ), + "test_time": pd.Series(["00:00:00", None, "23:59:59"], dtype="object"), + "test_datetime": pd.Series( + ["1970-01-01 00:00:01", "2038-01-19 00:0:00", None], + dtype="datetime64[ns]", + ), + "test_new_decimal": pd.Series([1.1, None, 3.3], dtype="float"), + "test_decimal": pd.Series([1, 2, None], dtype="float"), + "test_varchar": pd.Series([None, "varchar2", "varchar3"], dtype="object"), + "test_char": pd.Series(["char1", None, "char3"], dtype="object"), + "test_tiny": pd.Series([-128, 127, None], dtype="Int64"), + "test_short": pd.Series([-32768, 32767, None], dtype="Int64"), + "test_int24": pd.Series([-8388608, 8388607, None], dtype="Int64"), + "test_long": pd.Series([-2147483648, 2147483647, None], dtype="Int64"), + "test_longlong": pd.Series( + [-9223372036854775808, 9223372036854775807, None], dtype="Int64" + ), + "test_tiny_unsigned": pd.Series([None, 255, 0], dtype="Int64"), + "test_short_unsigned": pd.Series([None, 65535, 0], dtype="Int64"), + "test_int24_unsigned": pd.Series([None, 16777215, 0], dtype="Int64"), + "test_long_unsigned": pd.Series([None, 4294967295, 0], dtype="Int64"), + "test_longlong_unsigned": pd.Series( + [None, 18446744070000001024.0, 0.0], dtype="float" + ), + "test_long_notnull": pd.Series([1, 2147483647, -2147483648], dtype="int64"), + "test_short_unsigned_notnull": pd.Series([1, 65535, 0], dtype="int64"), + "test_float": pd.Series([None, -1.1e-38, 3.4e38], dtype="float"), + "test_double": pd.Series([-2.2e-308, None, 1.7e308], dtype="float"), + "test_double_notnull": pd.Series([1.2345, -1.1e-3, 1.7e30], dtype="float"), + "test_year": pd.Series([1901, 2155, None], dtype="Int64"), + "test_tinyblob": pd.Series( + [None, b"tinyblob2", b"tinyblob3"], dtype="object" + ), + "test_blob": pd.Series( + [None, b"blobblobblobblob2", b"blobblobblobblob3"], dtype="object" + ), + "test_mediumblob": pd.Series( + [None, b"mediumblob2", b"mediumblob3"], dtype="object" + ), + "test_longblob": pd.Series( + [None, b"longblob2", b"longblob3"], dtype="object" + ), + "test_enum": pd.Series(["apple", None, "mango"], dtype="object"), + "test_json": pd.Series( + ['{"age":1,"name":"piggy"}', '{"age":2,"name":"kitty"}', None], + # mariadb + # [b'{"name": "piggy", "age": 1}', b'{"name": "kitty", "age": 2}', None], + dtype="object", + ), + "test_mediumtext": pd.Series( + [None, b"", b"medium text!!!!"], dtype="object" + ), + }, + ) + assert_frame_equal(df, expected, check_names=True) + + +def test_mysql_types_text(mysql_url: str) -> None: + query = "select * from test_types" + df = read_sql(mysql_url, query, protocol="text") + expected = pd.DataFrame( + index=range(3), + data={ + "test_timestamp": pd.Series( + ["1970-01-01 00:00:01", "2038-01-19 00:00:00", None], + dtype="datetime64[ns]", + ), + "test_date": pd.Series( + [None, "1970-01-01", "2038-01-19"], dtype="datetime64[ns]" + ), + "test_time": pd.Series(["00:00:00", None, "23:59:59"], dtype="object"), + "test_datetime": pd.Series( + ["1970-01-01 00:00:01", "2038-01-19 00:00:00", None], + dtype="datetime64[ns]", + ), + "test_new_decimal": pd.Series([1.1, None, 3.3], dtype="float"), + "test_decimal": pd.Series([1, 2, None], dtype="float"), + "test_varchar": pd.Series([None, "varchar2", "varchar3"], dtype="object"), + "test_char": pd.Series(["char1", None, "char3"], dtype="object"), + "test_tiny": pd.Series([-128, 127, None], dtype="Int64"), + "test_short": pd.Series([-32768, 32767, None], dtype="Int64"), + "test_int24": pd.Series([-8388608, 8388607, None], dtype="Int64"), + "test_long": pd.Series([-2147483648, 2147483647, None], dtype="Int64"), + "test_longlong": pd.Series( + [-9223372036854775808, 9223372036854775807, None], dtype="Int64" + ), + "test_tiny_unsigned": pd.Series([None, 255, 0], dtype="Int64"), + "test_short_unsigned": pd.Series([None, 65535, 0], dtype="Int64"), + "test_int24_unsigned": pd.Series([None, 16777215, 0], dtype="Int64"), + "test_long_unsigned": pd.Series([None, 4294967295, 0], dtype="Int64"), + "test_longlong_unsigned": pd.Series( + [None, 18446744070000001024.0, 0.0], dtype="float" + ), + "test_long_notnull": pd.Series([1, 2147483647, -2147483648], dtype="int64"), + "test_short_unsigned_notnull": pd.Series([1, 65535, 0], dtype="int64"), + "test_float": pd.Series([None, -1.1e-38, 3.4e38], dtype="float"), + "test_double": pd.Series([-2.2e-308, None, 1.7e308], dtype="float"), + "test_double_notnull": pd.Series([1.2345, -1.1e-3, 1.7e30], dtype="float"), + "test_year": pd.Series([1901, 2155, None], dtype="Int64"), + "test_tinyblob": pd.Series( + [None, b"tinyblob2", b"tinyblob3"], dtype="object" + ), + "test_blob": pd.Series( + [None, b"blobblobblobblob2", b"blobblobblobblob3"], dtype="object" + ), + "test_mediumblob": pd.Series( + [None, b"mediumblob2", b"mediumblob3"], dtype="object" + ), + "test_longblob": pd.Series( + [None, b"longblob2", b"longblob3"], dtype="object" + ), + "test_enum": pd.Series(["apple", None, "mango"], dtype="object"), + "test_json": pd.Series( + ['{"age":1,"name":"piggy"}', '{"age":2,"name":"kitty"}', None], + # mariadb + # [b'{"name": "piggy", "age": 1}', b'{"name": "kitty", "age": 2}', None], + dtype="object", + ), + "test_mediumtext": pd.Series( + [None, b"", b"medium text!!!!"], dtype="object" + ), + }, + ) + assert_frame_equal(df, expected, check_names=True) + + +def test_empty_result(mysql_url: str) -> None: + query = "SELECT * FROM test_table where test_int < -100" + df = read_sql(mysql_url, query) + expected = pd.DataFrame( + data={ + "test_int": pd.Series([], dtype="Int64"), + "test_float": pd.Series([], dtype="float64"), + "test_enum": pd.Series([], dtype="object"), + "test_null": pd.Series([], dtype="Int64"), + } + ) + assert_frame_equal(df, expected, check_names=True) + + +def test_empty_result_on_partition(mysql_url: str) -> None: + query = "SELECT * FROM test_table where test_int < -100" + df = read_sql(mysql_url, query, partition_on="test_int", partition_num=3) + expected = pd.DataFrame( + data={ + "test_int": pd.Series([], dtype="Int64"), + "test_float": pd.Series([], dtype="float64"), + "test_enum": pd.Series([], dtype="object"), + "test_null": pd.Series([], dtype="Int64"), + } + ) + assert_frame_equal(df, expected, check_names=True) + + +def test_empty_result_on_some_partition(mysql_url: str) -> None: + query = "SELECT * FROM test_table where test_int = 6" + df = read_sql(mysql_url, query, partition_on="test_int", partition_num=3) + expected = pd.DataFrame( + index=range(1), + data={ + "test_int": pd.Series([6], dtype="Int64"), + "test_float": pd.Series([6.6], dtype="float64"), + "test_enum": pd.Series(["even"], dtype="object"), + "test_null": pd.Series([None], dtype="Int64"), + }, + ) + assert_frame_equal(df, expected, check_names=True) + + +def test_mysql_cte(mysql_url: str) -> None: + query = "with test_cte (test_int, test_enum) as (select test_int, test_enum from test_table where test_float > 2) select test_int, test_enum from test_cte" + df = read_sql(mysql_url, query, partition_on="test_int", partition_num=3) + df.sort_values(by="test_int", inplace=True, ignore_index=True) + expected = pd.DataFrame( + index=range(5), + data={ + "test_int": pd.Series([2, 3, 4, 5, 6], dtype="Int64"), + "test_enum": pd.Series( + ["even", "odd", "even", "odd", "even"], dtype="object" + ), + }, + ) + assert_frame_equal(df, expected, check_names=True) diff --git a/connectorx-python/src/errors.rs b/connectorx-python/src/errors.rs index a8754ef2fe..929023e057 100644 --- a/connectorx-python/src/errors.rs +++ b/connectorx-python/src/errors.rs @@ -42,6 +42,9 @@ pub enum ConnectorXPythonError { #[error(transparent)] BigQuerySourceError(#[from] connectorx::sources::bigquery::BigQuerySourceError), + #[error(transparent)] + TrinoSourceError(#[from] connectorx::sources::trino::TrinoSourceError), + #[error(transparent)] ArrowDestinationError(#[from] connectorx::destinations::arrow::ArrowDestinationError), diff --git a/connectorx-python/src/pandas/mod.rs b/connectorx-python/src/pandas/mod.rs index be2e419286..1172808667 100644 --- a/connectorx-python/src/pandas/mod.rs +++ b/connectorx-python/src/pandas/mod.rs @@ -8,7 +8,7 @@ mod typesystem; pub use self::destination::{PandasBlockInfo, PandasDestination, PandasPartitionDestination}; pub use self::transports::{ BigQueryPandasTransport, MsSQLPandasTransport, MysqlPandasTransport, OraclePandasTransport, - PostgresPandasTransport, SqlitePandasTransport, + PostgresPandasTransport, SqlitePandasTransport, TrinoPandasTransport, }; pub use self::typesystem::{PandasDType, PandasTypeSystem}; use crate::errors::ConnectorXPythonError; @@ -230,6 +230,17 @@ pub fn write_pandas<'a>( ); dispatcher.run()?; } + SourceType::Trino => { + let rt = Arc::new(tokio::runtime::Runtime::new().expect("Failed to create runtime")); + let source = TrinoSource::new(rt, &source_conn.conn[..])?; + let dispatcher = Dispatcher::<_, _, TrinoPandasTransport>::new( + source, + &mut destination, + queries, + origin_query, + ); + dispatcher.run()?; + } _ => unimplemented!("{:?} not implemented!", source_conn.ty), } diff --git a/connectorx-python/src/pandas/transports/mod.rs b/connectorx-python/src/pandas/transports/mod.rs index 9f03abf33e..fbf7952fbf 100644 --- a/connectorx-python/src/pandas/transports/mod.rs +++ b/connectorx-python/src/pandas/transports/mod.rs @@ -4,6 +4,7 @@ mod mysql; mod oracle; mod postgres; mod sqlite; +mod trino; pub use self::postgres::PostgresPandasTransport; pub use bigquery::BigQueryPandasTransport; @@ -11,3 +12,4 @@ pub use mssql::MsSQLPandasTransport; pub use mysql::MysqlPandasTransport; pub use oracle::OraclePandasTransport; pub use sqlite::SqlitePandasTransport; +pub use trino::TrinoPandasTransport; diff --git a/connectorx-python/src/pandas/transports/trino.rs b/connectorx-python/src/pandas/transports/trino.rs new file mode 100644 index 0000000000..fba7a06d50 --- /dev/null +++ b/connectorx-python/src/pandas/transports/trino.rs @@ -0,0 +1,54 @@ +use crate::errors::ConnectorXPythonError; +use crate::pandas::destination::PandasDestination; +use crate::pandas::typesystem::PandasTypeSystem; +use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; +use connectorx::{ + impl_transport, + sources::trino::{TrinoSource, TrinoTypeSystem}, + typesystem::TypeConversion, +}; + +pub struct TrinoPandasTransport<'py>(&'py ()); + +impl_transport!( + name = TrinoPandasTransport<'tp>, + error = ConnectorXPythonError, + systems = TrinoTypeSystem => PandasTypeSystem, + route = TrinoSource => PandasDestination<'tp>, + mappings = { + { Date[NaiveDate] => DateTime[DateTime] | conversion option } + { Time[NaiveTime] => String[String] | conversion option } + { Timestamp[NaiveDateTime] => DateTime[DateTime] | conversion option } + { Boolean[bool] => Bool[bool] | conversion auto } + { Bigint[i32] => I64[i64] | conversion auto } + { Integer[i32] => I64[i64] | conversion none } + { Smallint[i16] => I64[i64] | conversion auto } + { Tinyint[i8] => I64[i64] | conversion auto } + { Double[f64] => F64[f64] | conversion auto } + { Real[f32] => F64[f64] | conversion auto } + { Varchar[String] => String[String] | conversion auto } + { Char[String] => String[String] | conversion none } + } +); + +impl<'py> TypeConversion> for TrinoPandasTransport<'py> { + fn convert(val: NaiveDate) -> DateTime { + DateTime::from_naive_utc_and_offset( + val.and_hms_opt(0, 0, 0) + .unwrap_or_else(|| panic!("and_hms_opt got None from {:?}", val)), + Utc, + ) + } +} + +impl<'py> TypeConversion for TrinoPandasTransport<'py> { + fn convert(val: NaiveTime) -> String { + val.to_string() + } +} + +impl<'py> TypeConversion> for TrinoPandasTransport<'py> { + fn convert(val: NaiveDateTime) -> DateTime { + DateTime::from_naive_utc_and_offset(val, Utc) + } +} diff --git a/connectorx/src/sources/trino/errors.rs b/connectorx/src/sources/trino/errors.rs index d6862bbb80..8b46eff193 100644 --- a/connectorx/src/sources/trino/errors.rs +++ b/connectorx/src/sources/trino/errors.rs @@ -4,7 +4,7 @@ use thiserror::Error; #[derive(Error, Debug)] pub enum TrinoSourceError { - #[error("Cannot infer type from null for SQLite")] + #[error("Cannot infer type from null for Trino")] InferTypeFromNull, #[error(transparent)] diff --git a/connectorx/src/sources/trino/mod.rs b/connectorx/src/sources/trino/mod.rs index 1f0d86b746..6655a0c1f1 100644 --- a/connectorx/src/sources/trino/mod.rs +++ b/connectorx/src/sources/trino/mod.rs @@ -12,7 +12,7 @@ use crate::{ data_order::DataOrder, errors::ConnectorXError, sources::Produce, - sql::{count_query, limit1_query, CXQuery}, + sql::{limit1_query, CXQuery}, }; pub use self::{errors::TrinoSourceError, typesystem::TrinoTypeSystem}; @@ -26,13 +26,10 @@ pub mod errors; pub mod typesystem; #[throws(TrinoSourceError)] -async fn get_total_rows(client: Arc, query: &CXQuery) -> usize { - let result = client - .get_all::(count_query(query, &GenericDialect {})?.to_string()) - .await - .map_err(TrinoSourceError::PrustoError)?; - - usize::from(result.as_slice()[0]) +fn get_total_rows(rt: Arc, client: Arc, query: &CXQuery) -> usize { + rt.block_on(client.get_all::(query.to_string())) + .map_err(TrinoSourceError::PrustoError)? + .len() } pub struct TrinoSource { @@ -104,43 +101,20 @@ where fn fetch_metadata(&mut self) { assert!(!self.queries.is_empty()); - match &self.origin_query { - Some(q) => { - /*let cxq = CXQuery::Naked(q.clone()); - let cxq = limit1_query(&cxq, &GenericDialect {})?; - let data_set: DataSet<_> = self - .rt - .block_on(self.client.get_all::(cxq.to_string())) - .map_err(TrinoSourceError::PrustoError)?; - - let x = data_set.into_vec().first().unwrap(); - let ncols = x.value().to_vec().len(); - - let mut parser = - TrinoSourceParser::new(self.rt.clone(), self.client.clone(), cxq, ncols)?; - - // produce the first row - for x in 0..ncols { - let x: TrinoTypeSystem = parser.produce()?; - } + // TODO: prevent from running the same query multiple times (limit1 + no limit) + let first_query = &self.queries[0]; + let cxq = limit1_query(first_query, &GenericDialect {})?; - data_set.into_vec().iter().for_each(|row| { - row.value().iter().for_each(|x| { - println!("{:?}", x); - }); - - println!("{:?}", row); - });*/ - - // TODO: remove hard-coded - self.schema = vec![ - TrinoTypeSystem::Integer(true), - TrinoTypeSystem::Double(true), - TrinoTypeSystem::Varchar(true), - ]; - self.names = vec!["a".to_string(), "b".to_string(), "c".to_string()]; - } - None => {} + let dataset: DataSet = self + .rt + .block_on(self.client.get_all::(cxq.to_string())) + .map_err(TrinoSourceError::PrustoError)?; + + let schema = dataset.split().0; + + for (name, t) in schema { + self.names.push(name.clone()); + self.schema.push(TrinoTypeSystem::try_from(t.clone())?); } } @@ -149,8 +123,7 @@ where match &self.origin_query { Some(q) => { let cxq = CXQuery::Naked(q.clone()); - let client = self.client.clone(); - let nrows = self.rt.block_on(get_total_rows(client, &cxq))?; + let nrows = get_total_rows(self.rt.clone(), self.client.clone(), &cxq)?; Some(nrows) } None => None, @@ -170,15 +143,12 @@ where let mut ret = vec![]; for query in self.queries { - ret.push( - TrinoSourcePartition::new( - self.client.clone(), - query, - self.schema.clone(), - self.rt.clone(), - ) - .unwrap(), // TODO: handle error - ); + ret.push(TrinoSourcePartition::new( + self.client.clone(), + query, + self.schema.clone(), + self.rt.clone(), + )?); } ret } @@ -190,7 +160,6 @@ pub struct TrinoSourcePartition { schema: Vec, rt: Arc, nrows: usize, - ncols: usize, } impl TrinoSourcePartition { @@ -203,31 +172,32 @@ impl TrinoSourcePartition { ) -> Self { Self { client, - query, - schema: schema.clone(), + query: query.clone(), + schema: schema.to_vec(), rt, nrows: 0, - ncols: schema.len(), } } } impl SourcePartition for TrinoSourcePartition { type TypeSystem = TrinoTypeSystem; - type Parser<'a> = TrinoSourceParser<'a>; + type Parser<'a> = TrinoSourcePartitionParser<'a>; type Error = TrinoSourceError; #[throws(TrinoSourceError)] fn result_rows(&mut self) { - self.nrows = self - .rt - .block_on(get_total_rows(self.client.clone(), &self.query))? + self.nrows = get_total_rows(self.rt.clone(), self.client.clone(), &self.query)?; } #[throws(TrinoSourceError)] fn parser(&mut self) -> Self::Parser<'_> { - let query = self.query.clone(); - TrinoSourceParser::new(self.rt.clone(), self.client.clone(), query, &self.schema)? + TrinoSourcePartitionParser::new( + self.rt.clone(), + self.client.clone(), + self.query.clone(), + &self.schema, + )? } fn nrows(&self) -> usize { @@ -235,20 +205,19 @@ impl SourcePartition for TrinoSourcePartition { } fn ncols(&self) -> usize { - self.ncols + self.schema.len() } } -pub struct TrinoSourceParser<'a> { +pub struct TrinoSourcePartitionParser<'a> { rows: Vec, - nrows: usize, ncols: usize, current_col: usize, current_row: usize, _phantom: &'a PhantomData>, } -impl<'a> TrinoSourceParser<'a> { +impl<'a> TrinoSourcePartitionParser<'a> { #[throws(TrinoSourceError)] pub fn new( rt: Arc, @@ -262,10 +231,9 @@ impl<'a> TrinoSourceParser<'a> { Self { rows, - nrows: data.len(), ncols: schema.len(), - current_col: 0, current_row: 0, + current_col: 0, _phantom: &PhantomData, } } @@ -279,7 +247,7 @@ impl<'a> TrinoSourceParser<'a> { } } -impl<'a> PartitionParser<'a> for TrinoSourceParser<'a> { +impl<'a> PartitionParser<'a> for TrinoSourcePartitionParser<'a> { type TypeSystem = TrinoTypeSystem; type Error = TrinoSourceError; @@ -287,14 +255,15 @@ impl<'a> PartitionParser<'a> for TrinoSourceParser<'a> { fn fetch_next(&mut self) -> (usize, bool) { assert!(self.current_col == 0); - (self.nrows, true) + // results are always fetched in a single batch for Prusto + (self.rows.len(), true) } } macro_rules! impl_produce_int { ($($t: ty,)+) => { $( - impl<'r, 'a> Produce<'r, $t> for TrinoSourceParser<'a> { + impl<'r, 'a> Produce<'r, $t> for TrinoSourcePartitionParser<'a> { type Error = TrinoSourceError; #[throws(TrinoSourceError)] @@ -315,7 +284,7 @@ macro_rules! impl_produce_int { } } - impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourceParser<'a> { + impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourcePartitionParser<'a> { type Error = TrinoSourceError; #[throws(TrinoSourceError)] @@ -343,7 +312,7 @@ macro_rules! impl_produce_int { macro_rules! impl_produce_float { ($($t: ty,)+) => { $( - impl<'r, 'a> Produce<'r, $t> for TrinoSourceParser<'a> { + impl<'r, 'a> Produce<'r, $t> for TrinoSourcePartitionParser<'a> { type Error = TrinoSourceError; #[throws(TrinoSourceError)] @@ -364,7 +333,7 @@ macro_rules! impl_produce_float { } } - impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourceParser<'a> { + impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourcePartitionParser<'a> { type Error = TrinoSourceError; #[throws(TrinoSourceError)] @@ -392,7 +361,7 @@ macro_rules! impl_produce_float { macro_rules! impl_produce_text { ($($t: ty,)+) => { $( - impl<'r, 'a> Produce<'r, $t> for TrinoSourceParser<'a> { + impl<'r, 'a> Produce<'r, $t> for TrinoSourcePartitionParser<'a> { type Error = TrinoSourceError; #[throws(TrinoSourceError)] @@ -402,14 +371,14 @@ macro_rules! impl_produce_text { match value { Value::String(x) => { - x.parse().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {})", ridx, cidx))? + x.parse().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}): {:?}", ridx, cidx, value))? } - _ => throw!(anyhow!("Trino cannot parse String at position: ({}, {})", ridx, cidx)) + _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value)) } } } - impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourceParser<'a> { + impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourcePartitionParser<'a> { type Error = TrinoSourceError; #[throws(TrinoSourceError)] @@ -420,9 +389,83 @@ macro_rules! impl_produce_text { match value { Value::Null => None, Value::String(x) => { - Some(x.parse().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {})", ridx, cidx))?) + Some(x.parse().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}): {:?}", ridx, cidx, value))?) } - _ => throw!(anyhow!("Trino cannot parse String at position: ({}, {})", ridx, cidx)) + _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value)) + } + } + } + )+ + }; +} + +macro_rules! impl_produce_timestamp { + ($($t: ty,)+) => { + $( + impl<'r, 'a> Produce<'r, $t> for TrinoSourcePartitionParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> $t { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::String(x) => NaiveDateTime::parse_from_str(x, "%Y-%m-%d %H:%M:%S%.f").map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}): {:?}", ridx, cidx, value))?, + _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value)) + } + } + } + + impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourcePartitionParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> Option<$t> { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::Null => None, + Value::String(x) => Some(NaiveDateTime::parse_from_str(x, "%Y-%m-%d %H:%M:%S%.f").map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}): {:?}", ridx, cidx, value))?), + _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value)) + } + } + } + )+ + }; +} + +macro_rules! impl_produce_bool { + ($($t: ty,)+) => { + $( + impl<'r, 'a> Produce<'r, $t> for TrinoSourcePartitionParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> $t { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::Bool(x) => *x, + _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value)) + } + } + } + + impl<'r, 'a> Produce<'r, Option<$t>> for TrinoSourcePartitionParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> Option<$t> { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::Null => None, + Value::Bool(x) => Some(*x), + _ => throw!(anyhow!("Trino unknown value at position: ({}, {}): {:?}", ridx, cidx, value)) } } } @@ -430,6 +473,120 @@ macro_rules! impl_produce_text { }; } +impl_produce_bool!(bool,); impl_produce_int!(i8, i16, i32, i64,); impl_produce_float!(f32, f64,); -impl_produce_text!(NaiveDate, NaiveTime, NaiveDateTime, String, bool, char,); +impl_produce_timestamp!(NaiveDateTime,); +impl_produce_text!(String, char,); + +impl<'r, 'a> Produce<'r, NaiveTime> for TrinoSourcePartitionParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> NaiveTime { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::String(x) => NaiveTime::parse_from_str(x, "%H:%M:%S%.f").map_err(|_| { + anyhow!( + "Trino cannot parse String at position: ({}, {}): {:?}", + ridx, + cidx, + value + ) + })?, + _ => throw!(anyhow!( + "Trino unknown value at position: ({}, {}): {:?}", + ridx, + cidx, + value + )), + } + } +} + +impl<'r, 'a> Produce<'r, Option> for TrinoSourcePartitionParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> Option { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::Null => None, + Value::String(x) => { + Some(NaiveTime::parse_from_str(x, "%H:%M:%S%.f").map_err(|_| { + anyhow!( + "Trino cannot parse Time at position: ({}, {}): {:?}", + ridx, + cidx, + value + ) + })?) + } + _ => throw!(anyhow!( + "Trino unknown value at position: ({}, {}): {:?}", + ridx, + cidx, + value + )), + } + } +} + +impl<'r, 'a> Produce<'r, NaiveDate> for TrinoSourcePartitionParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> NaiveDate { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::String(x) => NaiveDate::parse_from_str(x, "%Y-%m-%d").map_err(|_| { + anyhow!( + "Trino cannot parse Date at position: ({}, {}): {:?}", + ridx, + cidx, + value + ) + })?, + _ => throw!(anyhow!( + "Trino unknown value at position: ({}, {}): {:?}", + ridx, + cidx, + value + )), + } + } +} + +impl<'r, 'a> Produce<'r, Option> for TrinoSourcePartitionParser<'a> { + type Error = TrinoSourceError; + + #[throws(TrinoSourceError)] + fn produce(&'r mut self) -> Option { + let (ridx, cidx) = self.next_loc()?; + let value = &self.rows[ridx].value()[cidx]; + + match value { + Value::Null => None, + Value::String(x) => Some(NaiveDate::parse_from_str(x, "%Y-%m-%d").map_err(|_| { + anyhow!( + "Trino cannot parse Date at position: ({}, {}): {:?}", + ridx, + cidx, + value + ) + })?), + _ => throw!(anyhow!( + "Trino unknown value at position: ({}, {}): {:?}", + ridx, + cidx, + value + )), + } + } +} diff --git a/connectorx/src/transports/mod.rs b/connectorx/src/transports/mod.rs index 30a209c255..96f90db44e 100644 --- a/connectorx/src/transports/mod.rs +++ b/connectorx/src/transports/mod.rs @@ -44,8 +44,12 @@ mod sqlite_arrow; mod sqlite_arrow2; #[cfg(all(feature = "src_sqlite", feature = "dst_arrow"))] mod sqlite_arrowstream; +#[cfg(all(feature = "src_trino", feature = "dst_arrow"))] +mod trino_arrow; #[cfg(all(feature = "src_trino", feature = "dst_arrow2"))] mod trino_arrow2; +#[cfg(all(feature = "src_trino", feature = "dst_arrow"))] +mod trino_arrowstream; #[cfg(all(feature = "src_bigquery", feature = "dst_arrow"))] pub use bigquery_arrow::{BigQueryArrowTransport, BigQueryArrowTransportError}; #[cfg(all(feature = "src_bigquery", feature = "dst_arrow2"))] @@ -106,5 +110,12 @@ pub use sqlite_arrowstream::{ SQLiteArrowTransport as SQLiteArrowStreamTransport, SQLiteArrowTransportError as SQLiteArrowStreamTransportError, }; +#[cfg(all(feature = "src_trino", feature = "dst_arrow"))] +pub use trino_arrow::{TrinoArrowTransport, TrinoArrowTransportError}; #[cfg(all(feature = "src_trino", feature = "dst_arrow2"))] pub use trino_arrow2::{TrinoArrow2Transport, TrinoArrow2TransportError}; +#[cfg(all(feature = "src_trino", feature = "dst_arrow"))] +pub use trino_arrowstream::{ + TrinoArrowTransport as TrinoArrowStreamTransport, + TrinoArrowTransportError as TrinoArrowStreamTransportError, +}; diff --git a/connectorx/src/transports/trino_arrow.rs b/connectorx/src/transports/trino_arrow.rs new file mode 100644 index 0000000000..d498fb6150 --- /dev/null +++ b/connectorx/src/transports/trino_arrow.rs @@ -0,0 +1,64 @@ +//! Transport from Trino Source to Arrow Destination. + +use crate::{ + destinations::arrow::{ + typesystem::ArrowTypeSystem, ArrowDestination, ArrowDestinationError, + }, + impl_transport, + sources::trino::{TrinoSource, TrinoSourceError, TrinoTypeSystem}, + typesystem::TypeConversion, +}; +use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; +use num_traits::ToPrimitive; +use rust_decimal::Decimal; +use serde_json::{to_string, Value}; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum TrinoArrowTransportError { + #[error(transparent)] + Source(#[from] TrinoSourceError), + + #[error(transparent)] + Destination(#[from] ArrowDestinationError), + + #[error(transparent)] + ConnectorX(#[from] crate::errors::ConnectorXError), +} + +/// Convert Trino data types to Arrow data types. +pub struct TrinoArrowTransport(); + +impl_transport!( + name = TrinoArrowTransport, + error = TrinoArrowTransportError, + systems = TrinoTypeSystem => ArrowTypeSystem, + route = TrinoSource => ArrowDestination, + mappings = { + { Date[NaiveDate] => Date32[NaiveDate] | conversion auto } + { Time[NaiveTime] => Time64[NaiveTime] | conversion auto } + { Timestamp[NaiveDateTime] => Date64[NaiveDateTime] | conversion auto } + { Boolean[bool] => Boolean[bool] | conversion auto } + { Bigint[i32] => Int64[i64] | conversion auto } + { Integer[i32] => Int64[i64] | conversion none } + { Smallint[i16] => Int64[i64] | conversion auto } + { Tinyint[i8] => Int64[i64] | conversion auto } + { Double[f64] => Float64[f64] | conversion auto } + { Real[f32] => Float64[f64] | conversion auto } + { Varchar[String] => LargeUtf8[String] | conversion auto } + { Char[String] => LargeUtf8[String] | conversion none } + } +); + +impl TypeConversion for TrinoArrowTransport { + fn convert(val: Decimal) -> f64 { + val.to_f64() + .unwrap_or_else(|| panic!("cannot convert decimal {:?} to float64", val)) + } +} + +impl TypeConversion for TrinoArrowTransport { + fn convert(val: Value) -> String { + to_string(&val).unwrap() + } +} diff --git a/connectorx/src/transports/trino_arrow2.rs b/connectorx/src/transports/trino_arrow2.rs index 27290d4a90..bc31fe6460 100644 --- a/connectorx/src/transports/trino_arrow2.rs +++ b/connectorx/src/transports/trino_arrow2.rs @@ -35,7 +35,7 @@ impl_transport!( systems = TrinoTypeSystem => Arrow2TypeSystem, route = TrinoSource => Arrow2Destination, mappings = { - { Date[NaiveDate] => Date64[NaiveDate] | conversion auto } + { Date[NaiveDate] => Date32[NaiveDate] | conversion auto } { Time[NaiveTime] => Time64[NaiveTime] | conversion auto } { Timestamp[NaiveDateTime] => Date64[NaiveDateTime] | conversion auto } { Boolean[bool] => Boolean[bool] | conversion auto } diff --git a/connectorx/src/transports/trino_arrowstream.rs b/connectorx/src/transports/trino_arrowstream.rs new file mode 100644 index 0000000000..f2b9e220c1 --- /dev/null +++ b/connectorx/src/transports/trino_arrowstream.rs @@ -0,0 +1,64 @@ +//! Transport from Trino Source to Arrow Destination. + +use crate::{ + destinations::arrowstream::{ + typesystem::ArrowTypeSystem, ArrowDestination, ArrowDestinationError, + }, + impl_transport, + sources::trino::{TrinoSource, TrinoSourceError, TrinoTypeSystem}, + typesystem::TypeConversion, +}; +use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; +use num_traits::ToPrimitive; +use rust_decimal::Decimal; +use serde_json::{to_string, Value}; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum TrinoArrowTransportError { + #[error(transparent)] + Source(#[from] TrinoSourceError), + + #[error(transparent)] + Destination(#[from] ArrowDestinationError), + + #[error(transparent)] + ConnectorX(#[from] crate::errors::ConnectorXError), +} + +/// Convert Trino data types to Arrow data types. +pub struct TrinoArrowTransport(); + +impl_transport!( + name = TrinoArrowTransport, + error = TrinoArrowTransportError, + systems = TrinoTypeSystem => ArrowTypeSystem, + route = TrinoSource => ArrowDestination, + mappings = { + { Date[NaiveDate] => Date32[NaiveDate] | conversion auto } + { Time[NaiveTime] => Time64[NaiveTime] | conversion auto } + { Timestamp[NaiveDateTime] => Date64[NaiveDateTime] | conversion auto } + { Boolean[bool] => Boolean[bool] | conversion auto } + { Bigint[i32] => Int64[i64] | conversion auto } + { Integer[i32] => Int64[i64] | conversion none } + { Smallint[i16] => Int64[i64] | conversion auto } + { Tinyint[i8] => Int64[i64] | conversion auto } + { Double[f64] => Float64[f64] | conversion auto } + { Real[f32] => Float64[f64] | conversion auto } + { Varchar[String] => LargeUtf8[String] | conversion auto } + { Char[String] => LargeUtf8[String] | conversion none } + } +); + +impl TypeConversion for TrinoArrowTransport { + fn convert(val: Decimal) -> f64 { + val.to_f64() + .unwrap_or_else(|| panic!("cannot convert decimal {:?} to float64", val)) + } +} + +impl TypeConversion for TrinoArrowTransport { + fn convert(val: Value) -> String { + to_string(&val).unwrap() + } +} From 5fca5cafb0967256cb1f6185fbffd9801a204625 Mon Sep 17 00:00:00 2001 From: Dominik Liebler Date: Fri, 15 Mar 2024 13:41:24 +0100 Subject: [PATCH 04/10] update deps for Trino source --- connectorx-python/Cargo.lock | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/connectorx-python/Cargo.lock b/connectorx-python/Cargo.lock index a877cbce4c..58a27784a5 100644 --- a/connectorx-python/Cargo.lock +++ b/connectorx-python/Cargo.lock @@ -1854,6 +1854,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + [[package]] name = "fallible-streaming-iterator" version = "0.1.9" @@ -2771,9 +2777,9 @@ checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" [[package]] name = "libsqlite3-sys" -version = "0.26.0" +version = "0.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afc22eff61b133b115c6e8c74e818c628d6d5e7a502afea6f64dee076dd94326" +checksum = "cf4e226dcd58b4be396f7bd3c20da8fdee2911400705297ba7d2d7cc2c30f716" dependencies = [ "cc", "pkg-config", @@ -3901,7 +3907,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7915b33ed60abc46040cbcaa25ffa1c7ec240668e0477c4f3070786f5916d451" dependencies = [ "bytes", - "fallible-iterator", + "fallible-iterator 0.2.0", "futures-util", "log", "tokio", @@ -3943,7 +3949,7 @@ dependencies = [ "base64 0.21.3", "byteorder", "bytes", - "fallible-iterator", + "fallible-iterator 0.2.0", "hmac", "md-5", "memchr", @@ -3960,7 +3966,7 @@ checksum = "8d2234cdee9408b523530a9b6d2d6b373d1db34f6a8e51dc03ded1828d7fb67c" dependencies = [ "bytes", "chrono", - "fallible-iterator", + "fallible-iterator 0.2.0", "postgres-protocol", "serde", "serde_json", @@ -4248,9 +4254,9 @@ dependencies = [ [[package]] name = "r2d2_sqlite" -version = "0.22.0" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99f31323d6161385f385046738df520e0e8694fa74852d35891fc0be08348ddc" +checksum = "4dc290b669d30e20751e813517bbe13662d020419c5c8818ff10b6e8bb7777f6" dependencies = [ "r2d2", "rusqlite", @@ -4538,13 +4544,13 @@ dependencies = [ [[package]] name = "rusqlite" -version = "0.29.0" +version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "549b9d036d571d42e6e85d1c1425e2ac83491075078ca9a15be021c56b1641f2" +checksum = "a78046161564f5e7cd9008aff3b2990b3850dc8e0349119b98e8f251e099f24d" dependencies = [ "bitflags 2.4.0", "chrono", - "fallible-iterator", + "fallible-iterator 0.3.0", "fallible-streaming-iterator", "hashlink", "libsqlite3-sys", @@ -5408,7 +5414,7 @@ dependencies = [ "async-trait", "byteorder", "bytes", - "fallible-iterator", + "fallible-iterator 0.2.0", "futures-channel", "futures-util", "log", From d3744001745fe5a09878977583d4c5b4334efa39 Mon Sep 17 00:00:00 2001 From: Dominik Liebler Date: Fri, 12 Apr 2024 14:23:28 +0200 Subject: [PATCH 05/10] added tests for Trino and connecting without auth --- Justfile | 1 + connectorx/src/sources/trino/mod.rs | 24 ++++-- connectorx/tests/test_trino.rs | 112 ++++++++++++++++++++++++++++ scripts/trino.sql | 48 ++++++++++++ 4 files changed, 177 insertions(+), 8 deletions(-) create mode 100644 connectorx/tests/test_trino.rs create mode 100644 scripts/trino.sql diff --git a/Justfile b/Justfile index 6a2cb19548..c086d524a2 100644 --- a/Justfile +++ b/Justfile @@ -63,6 +63,7 @@ seed-db-more: ORACLE_URL_SCRIPT=`echo ${ORACLE_URL#oracle://} | sed "s/:/\//"` cat scripts/oracle.sql | sqlplus $ORACLE_URL_SCRIPT mysql --protocol tcp -h$MARIADB_HOST -P$MARIADB_PORT -u$MARIADB_USER -p$MARIADB_PASSWORD $MARIADB_DB < scripts/mysql.sql + trino $TRINO_URL --catalog=$TRINO_CATALOG < scripts/trino.sql # benches flame-tpch conn="POSTGRES_URL": diff --git a/connectorx/src/sources/trino/mod.rs b/connectorx/src/sources/trino/mod.rs index 6655a0c1f1..8a75245a1e 100644 --- a/connectorx/src/sources/trino/mod.rs +++ b/connectorx/src/sources/trino/mod.rs @@ -50,17 +50,25 @@ impl TrinoSource { .parse::() .map_err(TrinoSourceError::UrlParseError)?; - let client = ClientBuilder::new(url.username(), url.host().unwrap().to_owned()) + let username = match url.username() { + "" => "connectorx", + username => username, + }; + + let builder = ClientBuilder::new(username, url.host().unwrap().to_owned()) .port(url.port().unwrap_or(8080)) - .auth(Auth::Basic( - url.username().to_owned(), - url.password().map(|x| x.to_owned()), - )) .ssl(prusto::ssl::Ssl { root_cert: None }) .secure(url.scheme() == "trino+https") - .catalog(url.path_segments().unwrap().last().unwrap_or("hive")) - .build() - .map_err(TrinoSourceError::PrustoError)?; + .catalog(url.path_segments().unwrap().last().unwrap_or("hive")); + + let builder = match url.password() { + None => builder, + Some(password) => { + builder.auth(Auth::Basic(username.to_owned(), Some(password.to_owned()))) + } + }; + + let client = builder.build().map_err(TrinoSourceError::PrustoError)?; Self { client: Arc::new(client), diff --git a/connectorx/tests/test_trino.rs b/connectorx/tests/test_trino.rs new file mode 100644 index 0000000000..8aa7d5f1c4 --- /dev/null +++ b/connectorx/tests/test_trino.rs @@ -0,0 +1,112 @@ +use arrow::{ + array::{Float64Array, Int64Array}, + record_batch::RecordBatch, +}; +use connectorx::{ + destinations::arrow::ArrowDestination, prelude::*, sources::trino::TrinoSource, sql::CXQuery, + transports::TrinoArrowTransport, +}; +use std::{env, sync::Arc}; + +#[test] +fn test_trino() { + let _ = env_logger::builder().is_test(true).try_init(); + + let dburl = env::var("TRINO_URL").unwrap(); + + let queries = [ + CXQuery::naked("select * from test.test_table where test_int <= 2 order by test_int"), + CXQuery::naked("select * from test.test_table where test_int > 2 order by test_int"), + ]; + + let rt = Arc::new(tokio::runtime::Runtime::new().expect("Failed to create runtime")); + let builder = TrinoSource::new(rt, &dburl).unwrap(); + let mut destination = ArrowDestination::new(); + let dispatcher = Dispatcher::<_, _, TrinoArrowTransport>::new( + builder, + &mut destination, + &queries, + Some(String::from( + "select * from test.test_table order by test_int", + )), + ); + dispatcher.run().unwrap(); + + let result = destination.arrow().unwrap(); + verify_arrow_results(result); +} + +#[test] +fn test_trino_text() { + let _ = env_logger::builder().is_test(true).try_init(); + + let dburl = env::var("TRINO_URL").unwrap(); + + let queries = [ + CXQuery::naked("select * from test.test_table where test_int <= 2 order by test_int"), + CXQuery::naked("select * from test.test_table where test_int > 2 order by test_int"), + ]; + + let rt = Arc::new(tokio::runtime::Runtime::new().expect("Failed to create runtime")); + let builder = TrinoSource::new(rt, &dburl).unwrap(); + let mut destination = ArrowDestination::new(); + let dispatcher = + Dispatcher::<_, _, TrinoArrowTransport>::new(builder, &mut destination, &queries, None); + dispatcher.run().unwrap(); + + let result = destination.arrow().unwrap(); + verify_arrow_results(result); +} + +pub fn verify_arrow_results(result: Vec) { + assert!(result.len() == 2); + + for r in result { + match r.num_rows() { + 2 => { + assert!(r + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .eq(&Int64Array::from(vec![1, 2]))); + assert!(r + .column(1) + .as_any() + .downcast_ref::() + .unwrap() + .eq(&Float64Array::from(vec![1.1, 2.2]))); + assert!(r + .column(2) + .as_any() + .downcast_ref::() + .unwrap() + .eq(&Int64Array::from(vec![None, None]))); + } + 4 => { + assert!(r + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .eq(&Int64Array::from(vec![3, 4, 5, 6]))); + assert!(r + .column(1) + .as_any() + .downcast_ref::() + .unwrap() + .eq(&Float64Array::from(vec![3.3, 4.4, 5.5, 6.6]))); + assert!(r + .column(2) + .as_any() + .downcast_ref::() + .unwrap() + .eq(&Int64Array::from(vec![None, None, None, None]))); + } + _ => { + println!("got {} rows in a record batch!", r.num_rows()); + unreachable!() + } + } + } +} diff --git a/scripts/trino.sql b/scripts/trino.sql new file mode 100644 index 0000000000..8f7482c677 --- /dev/null +++ b/scripts/trino.sql @@ -0,0 +1,48 @@ +CREATE SCHEMA IF NOT EXISTS test; + +CREATE TABLE IF NOT EXISTS test.test_table( + test_int INTEGER, + test_float DOUBLE, + test_null INTEGER +); + +DELETE FROM test.test_table; +INSERT INTO test.test_table VALUES (1, 1.1, NULL); +INSERT INTO test.test_table VALUES (2, 2.2, NULL); +INSERT INTO test.test_table VALUES (3, 3.3, NULL); +INSERT INTO test.test_table VALUES (4, 4.4, NULL); +INSERT INTO test.test_table VALUES (5, 5.5, NULL); +INSERT INTO test.test_table VALUES (6, 6.6, NULL); + +DROP TABLE IF EXISTS test.test_table_extra; + +CREATE TABLE IF NOT EXISTS test.test_table_extra( + test_int INTEGER, + test_str VARCHAR(30) +); + +DELETE FROM test.test_table_extra; +INSERT INTO test.test_table_extra VALUES (1, 'Ha好ち😁ðy̆'); +INSERT INTO test.test_table_extra VALUES (2, 'こんにちは'); +INSERT INTO test.test_table_extra VALUES (3, 'русский'); + +DROP TABLE IF EXISTS test.test_types; + +CREATE TABLE IF NOT EXISTS test.test_types( + test_boolean BOOLEAN, + test_int INT, + test_bigint BIGINT, + test_real REAL, + test_double DOUBLE, + test_decimal DECIMAL(15,2), + test_date DATE, + test_time TIME(6), + test_timestamp TIMESTAMP(6), + test_varchar VARCHAR(15), + test_uuid UUID -- TODO: VARBINARY, ROW, ARRAY, MAP +); + +DELETE FROM test.test_types; +INSERT INTO test.test_types (test_boolean, test_int, test_bigint, test_real, test_double, test_decimal, test_date, test_time, test_timestamp, test_varchar, test_uuid) VALUES +(TRUE, 123, 123456789012345, CAST(123.456 AS REAL), CAST(123.4567890123 AS DOUBLE), 1234567890.12, date('2023-01-01'), time '12:00:00', cast(timestamp '2023-01-01 12:00:00.123456' AS timestamp(6)), 'Sample text', UUID()), +(FALSE, 321, 123456789012345, CAST(123.456 AS REAL), CAST(123.4567890123 AS DOUBLE), 1234567890.12, date('2023-01-01'), time '12:00:00', cast(timestamp '2023-01-01 12:00:00.123456' AS timestamp(6)), 'Sample text', UUID()); From 3099395faa2a471a00eda651d1e94c5ece2ec8aa Mon Sep 17 00:00:00 2001 From: Dominik Liebler Date: Fri, 12 Apr 2024 15:10:56 +0200 Subject: [PATCH 06/10] fixed copy/paste for trino tests --- .../connectorx/tests/test_trino.py | 294 ++++-------------- scripts/trino.sql | 5 +- 2 files changed, 71 insertions(+), 228 deletions(-) diff --git a/connectorx-python/connectorx/tests/test_trino.py b/connectorx-python/connectorx/tests/test_trino.py index 9376bf5058..12d6c8d9af 100644 --- a/connectorx-python/connectorx/tests/test_trino.py +++ b/connectorx-python/connectorx/tests/test_trino.py @@ -8,31 +8,29 @@ @pytest.fixture(scope="module") # type: ignore -def mysql_url() -> str: - conn = os.environ["MYSQL_URL"] - # conn = os.environ["MARIADB_URL"] +def trino_url() -> str: + conn = os.environ["TRINO_URL"] return conn -def test_mysql_without_partition(mysql_url: str) -> None: - query = "select * from test_table limit 3" - df = read_sql(mysql_url, query) +def test_trino_without_partition(trino_url: str) -> None: + query = "select * from test.test_table order by test_int limit 3" + df = read_sql(trino_url, query) expected = pd.DataFrame( index=range(3), data={ "test_int": pd.Series([1, 2, 3], dtype="Int64"), "test_float": pd.Series([1.1, 2.2, 3.3], dtype="float64"), - "test_enum": pd.Series(["odd", "even", "odd"], dtype="object"), "test_null": pd.Series([None, None, None], dtype="Int64"), }, ) assert_frame_equal(df, expected, check_names=True) -def test_mysql_with_partition(mysql_url: str) -> None: - query = "select * from test_table" +def test_trino_with_partition(trino_url: str) -> None: + query = "select * from test.test_table order by test_int" df = read_sql( - mysql_url, + trino_url, query, partition_on="test_int", partition_range=(0, 10), @@ -43,9 +41,6 @@ def test_mysql_with_partition(mysql_url: str) -> None: data={ "test_int": pd.Series([1, 2, 3, 4, 5, 6], dtype="Int64"), "test_float": pd.Series([1.1, 2.2, 3.3, 4.4, 5.5, 6.6], dtype="float64"), - "test_enum": pd.Series( - ["odd", "even", "odd", "even", "odd", "even"], dtype="object" - ), "test_null": pd.Series([None, None, None, None, None, None], dtype="Int64"), }, ) @@ -53,59 +48,52 @@ def test_mysql_with_partition(mysql_url: str) -> None: assert_frame_equal(df, expected, check_names=True) -def test_mysql_without_partition(mysql_url: str) -> None: - query = "SELECT * FROM test_table" - df = read_sql(mysql_url, query) +def test_trino_without_partition(trino_url: str) -> None: + query = "SELECT * FROM test.test_table order by test_int" + df = read_sql(trino_url, query) expected = pd.DataFrame( index=range(6), data={ "test_int": pd.Series([1, 2, 3, 4, 5, 6], dtype="Int64"), "test_float": pd.Series([1.1, 2.2, 3.3, 4.4, 5.5, 6.6], dtype="float64"), - "test_enum": pd.Series( - ["odd", "even", "odd", "even", "odd", "even"], dtype="object" - ), "test_null": pd.Series([None, None, None, None, None, None], dtype="Int64"), }, ) assert_frame_equal(df, expected, check_names=True) -def test_mysql_limit_without_partition(mysql_url: str) -> None: - query = "SELECT * FROM test_table limit 3" - df = read_sql(mysql_url, query) +def test_trino_limit_without_partition(trino_url: str) -> None: + query = "SELECT * FROM test.test_table order by test_int limit 3" + df = read_sql(trino_url, query) expected = pd.DataFrame( index=range(3), data={ "test_int": pd.Series([1, 2, 3], dtype="Int64"), "test_float": pd.Series([1.1, 2.2, 3.3], dtype="float64"), - "test_enum": pd.Series(["odd", "even", "odd"], dtype="object"), "test_null": pd.Series([None, None, None], dtype="Int64"), }, ) assert_frame_equal(df, expected, check_names=True) -def test_mysql_limit_large_without_partition(mysql_url: str) -> None: - query = "SELECT * FROM test_table limit 10" - df = read_sql(mysql_url, query) +def test_trino_limit_large_without_partition(trino_url: str) -> None: + query = "SELECT * FROM test.test_table order by test_int limit 10" + df = read_sql(trino_url, query) expected = pd.DataFrame( index=range(6), data={ "test_int": pd.Series([1, 2, 3, 4, 5, 6], dtype="Int64"), "test_float": pd.Series([1.1, 2.2, 3.3, 4.4, 5.5, 6.6], dtype="float64"), - "test_enum": pd.Series( - ["odd", "even", "odd", "even", "odd", "even"], dtype="object" - ), "test_null": pd.Series([None, None, None, None, None, None], dtype="Int64"), }, ) assert_frame_equal(df, expected, check_names=True) -def test_mysql_with_partition(mysql_url: str) -> None: - query = "SELECT * FROM test_table" +def test_trino_with_partition(trino_url: str) -> None: + query = "SELECT * FROM test.test_table order by test_int" df = read_sql( - mysql_url, + trino_url, query, partition_on="test_int", partition_range=(0, 2000), @@ -116,9 +104,6 @@ def test_mysql_with_partition(mysql_url: str) -> None: data={ "test_int": pd.Series([1, 2, 3, 4, 5, 6], dtype="Int64"), "test_float": pd.Series([1.1, 2.2, 3.3, 4.4, 5.5, 6.6], dtype="float64"), - "test_enum": pd.Series( - ["odd", "even", "odd", "even", "odd", "even"], dtype="object" - ), "test_null": pd.Series([None, None, None, None, None, None], dtype="Int64"), }, ) @@ -126,10 +111,10 @@ def test_mysql_with_partition(mysql_url: str) -> None: assert_frame_equal(df, expected, check_names=True) -def test_mysql_limit_with_partition(mysql_url: str) -> None: - query = "SELECT * FROM test_table limit 3" +def test_trino_limit_with_partition(trino_url: str) -> None: + query = "SELECT * FROM test.test_table order by test_int limit 3" df = read_sql( - mysql_url, + trino_url, query, partition_on="test_int", partition_range=(0, 2000), @@ -140,7 +125,6 @@ def test_mysql_limit_with_partition(mysql_url: str) -> None: data={ "test_int": pd.Series([1, 2, 3], dtype="Int64"), "test_float": pd.Series([1.1, 2.2, 3.3], dtype="float64"), - "test_enum": pd.Series(["odd", "even", "odd"], dtype="object"), "test_null": pd.Series([None, None, None], dtype="Int64"), }, ) @@ -148,10 +132,10 @@ def test_mysql_limit_with_partition(mysql_url: str) -> None: assert_frame_equal(df, expected, check_names=True) -def test_mysql_limit_large_with_partition(mysql_url: str) -> None: - query = "SELECT * FROM test_table limit 10" +def test_trino_limit_large_with_partition(trino_url: str) -> None: + query = "SELECT * FROM test.test_table order by test_int limit 10" df = read_sql( - mysql_url, + trino_url, query, partition_on="test_int", partition_range=(0, 2000), @@ -162,9 +146,6 @@ def test_mysql_limit_large_with_partition(mysql_url: str) -> None: data={ "test_int": pd.Series([1, 2, 3, 4, 5, 6], dtype="Int64"), "test_float": pd.Series([1.1, 2.2, 3.3, 4.4, 5.5, 6.6], dtype="float64"), - "test_enum": pd.Series( - ["odd", "even", "odd", "even", "odd", "even"], dtype="object" - ), "test_null": pd.Series([None, None, None, None, None, None], dtype="Int64"), }, ) @@ -172,10 +153,10 @@ def test_mysql_limit_large_with_partition(mysql_url: str) -> None: assert_frame_equal(df, expected, check_names=True) -def test_mysql_with_partition_without_partition_range(mysql_url: str) -> None: - query = "SELECT * FROM test_table where test_float > 3" +def test_trino_with_partition_without_partition_range(trino_url: str) -> None: + query = "SELECT * FROM test.test_table where test_float > 3 order by test_int" df = read_sql( - mysql_url, + trino_url, query, partition_on="test_int", partition_num=3, @@ -185,7 +166,6 @@ def test_mysql_with_partition_without_partition_range(mysql_url: str) -> None: data={ "test_int": pd.Series([3, 4, 5, 6], dtype="Int64"), "test_float": pd.Series([3.3, 4.4, 5.5, 6.6], dtype="float64"), - "test_enum": pd.Series(["odd", "even", "odd", "even"], dtype="object"), "test_null": pd.Series([None, None, None, None], dtype="Int64"), }, ) @@ -193,20 +173,17 @@ def test_mysql_with_partition_without_partition_range(mysql_url: str) -> None: assert_frame_equal(df, expected, check_names=True) -def test_mysql_manual_partition(mysql_url: str) -> None: +def test_trino_manual_partition(trino_url: str) -> None: queries = [ - "SELECT * FROM test_table WHERE test_int < 2", - "SELECT * FROM test_table WHERE test_int >= 2", + "SELECT * FROM test.test_table WHERE test_int < 2 order by test_int", + "SELECT * FROM test.test_table WHERE test_int >= 2 order by test_int", ] - df = read_sql(mysql_url, query=queries) + df = read_sql(trino_url, query=queries) expected = pd.DataFrame( index=range(6), data={ "test_int": pd.Series([1, 2, 3, 4, 5, 6], dtype="Int64"), "test_float": pd.Series([1.1, 2.2, 3.3, 4.4, 5.5, 6.6], dtype="float64"), - "test_enum": pd.Series( - ["odd", "even", "odd", "even", "odd", "even"], dtype="object" - ), "test_null": pd.Series([None, None, None, None, None, None], dtype="Int64"), }, ) @@ -214,10 +191,10 @@ def test_mysql_manual_partition(mysql_url: str) -> None: assert_frame_equal(df, expected, check_names=True) -def test_mysql_selection_and_projection(mysql_url: str) -> None: - query = "SELECT test_int FROM test_table WHERE test_float < 5" +def test_trino_selection_and_projection(trino_url: str) -> None: + query = "SELECT test_int FROM test.test_table WHERE test_float < 5 order by test_int" df = read_sql( - mysql_url, + trino_url, query, partition_on="test_int", partition_num=3, @@ -232,10 +209,10 @@ def test_mysql_selection_and_projection(mysql_url: str) -> None: assert_frame_equal(df, expected, check_names=True) -def test_mysql_join(mysql_url: str) -> None: - query = "SELECT T.test_int, T.test_float, S.test_str FROM test_table T INNER JOIN test_table_extra S ON T.test_int = S.test_int" +def test_trino_join(trino_url: str) -> None: + query = "SELECT T.test_int, T.test_float, S.test_str FROM test_table T INNER JOIN test_table_extra S ON T.test_int = S.test_int order by T.test_int" df = read_sql( - mysql_url, + trino_url, query, partition_on="test_int", partition_num=3, @@ -259,212 +236,77 @@ def test_mysql_join(mysql_url: str) -> None: assert_frame_equal(df, expected, check_names=True) -def test_mysql_aggregate(mysql_url: str) -> None: - query = "select AVG(test_float) as avg_float, SUM(T.test_int) as sum_int, SUM(test_null) as sum_null from test_table as T INNER JOIN test_table_extra as S where T.test_int = S.test_int GROUP BY test_enum ORDER BY sum_int" - df = read_sql(mysql_url, query) - expected = pd.DataFrame( - index=range(2), - data={ - "avg_float": pd.Series([2.2, 2.2], dtype="float64"), - "sum_int": pd.Series([2.0, 4.0], dtype="float64"), - "sum_null": pd.Series([None, None], dtype="float64"), - }, - ) - assert_frame_equal(df, expected, check_names=True) - - -def test_mysql_types_binary(mysql_url: str) -> None: - query = "select * from test_types" - df = read_sql(mysql_url, query, protocol="binary") +def test_trino_aggregate(trino_url: str) -> None: + query = "select AVG(test_float) as avg_float, SUM(T.test_int) as sum_int, SUM(test_null) as sum_null from test.test_table as T" + df = read_sql(trino_url, query) expected = pd.DataFrame( - index=range(3), + index=range(1), data={ - "test_timestamp": pd.Series( - ["1970-01-01 00:00:01", "2038-01-19 00:00:00", None], - dtype="datetime64[ns]", - ), - "test_date": pd.Series( - [None, "1970-01-01", "2038-01-19"], dtype="datetime64[ns]" - ), - "test_time": pd.Series(["00:00:00", None, "23:59:59"], dtype="object"), - "test_datetime": pd.Series( - ["1970-01-01 00:00:01", "2038-01-19 00:0:00", None], - dtype="datetime64[ns]", - ), - "test_new_decimal": pd.Series([1.1, None, 3.3], dtype="float"), - "test_decimal": pd.Series([1, 2, None], dtype="float"), - "test_varchar": pd.Series([None, "varchar2", "varchar3"], dtype="object"), - "test_char": pd.Series(["char1", None, "char3"], dtype="object"), - "test_tiny": pd.Series([-128, 127, None], dtype="Int64"), - "test_short": pd.Series([-32768, 32767, None], dtype="Int64"), - "test_int24": pd.Series([-8388608, 8388607, None], dtype="Int64"), - "test_long": pd.Series([-2147483648, 2147483647, None], dtype="Int64"), - "test_longlong": pd.Series( - [-9223372036854775808, 9223372036854775807, None], dtype="Int64" - ), - "test_tiny_unsigned": pd.Series([None, 255, 0], dtype="Int64"), - "test_short_unsigned": pd.Series([None, 65535, 0], dtype="Int64"), - "test_int24_unsigned": pd.Series([None, 16777215, 0], dtype="Int64"), - "test_long_unsigned": pd.Series([None, 4294967295, 0], dtype="Int64"), - "test_longlong_unsigned": pd.Series( - [None, 18446744070000001024.0, 0.0], dtype="float" - ), - "test_long_notnull": pd.Series([1, 2147483647, -2147483648], dtype="int64"), - "test_short_unsigned_notnull": pd.Series([1, 65535, 0], dtype="int64"), - "test_float": pd.Series([None, -1.1e-38, 3.4e38], dtype="float"), - "test_double": pd.Series([-2.2e-308, None, 1.7e308], dtype="float"), - "test_double_notnull": pd.Series([1.2345, -1.1e-3, 1.7e30], dtype="float"), - "test_year": pd.Series([1901, 2155, None], dtype="Int64"), - "test_tinyblob": pd.Series( - [None, b"tinyblob2", b"tinyblob3"], dtype="object" - ), - "test_blob": pd.Series( - [None, b"blobblobblobblob2", b"blobblobblobblob3"], dtype="object" - ), - "test_mediumblob": pd.Series( - [None, b"mediumblob2", b"mediumblob3"], dtype="object" - ), - "test_longblob": pd.Series( - [None, b"longblob2", b"longblob3"], dtype="object" - ), - "test_enum": pd.Series(["apple", None, "mango"], dtype="object"), - "test_json": pd.Series( - ['{"age":1,"name":"piggy"}', '{"age":2,"name":"kitty"}', None], - # mariadb - # [b'{"name": "piggy", "age": 1}', b'{"name": "kitty", "age": 2}', None], - dtype="object", - ), - "test_mediumtext": pd.Series( - [None, b"", b"medium text!!!!"], dtype="object" - ), + "avg_float": pd.Series([3.85], dtype="float64"), + "sum_int": pd.Series([21], dtype="Int64"), + "sum_null": pd.Series([None], dtype="Int64"), }, ) assert_frame_equal(df, expected, check_names=True) -def test_mysql_types_text(mysql_url: str) -> None: - query = "select * from test_types" - df = read_sql(mysql_url, query, protocol="text") +def test_trino_types_binary(trino_url: str) -> None: + query = "select test_boolean, test_int, test_bigint, test_real, test_double, test_decimal, test_date, test_time, test_timestamp, test_varchar, test_uuid from test.test_types order by test_int" + df = read_sql(trino_url, query) expected = pd.DataFrame( index=range(3), data={ - "test_timestamp": pd.Series( - ["1970-01-01 00:00:01", "2038-01-19 00:00:00", None], - dtype="datetime64[ns]", - ), - "test_date": pd.Series( - [None, "1970-01-01", "2038-01-19"], dtype="datetime64[ns]" - ), - "test_time": pd.Series(["00:00:00", None, "23:59:59"], dtype="object"), - "test_datetime": pd.Series( - ["1970-01-01 00:00:01", "2038-01-19 00:00:00", None], - dtype="datetime64[ns]", - ), - "test_new_decimal": pd.Series([1.1, None, 3.3], dtype="float"), - "test_decimal": pd.Series([1, 2, None], dtype="float"), - "test_varchar": pd.Series([None, "varchar2", "varchar3"], dtype="object"), - "test_char": pd.Series(["char1", None, "char3"], dtype="object"), - "test_tiny": pd.Series([-128, 127, None], dtype="Int64"), - "test_short": pd.Series([-32768, 32767, None], dtype="Int64"), - "test_int24": pd.Series([-8388608, 8388607, None], dtype="Int64"), - "test_long": pd.Series([-2147483648, 2147483647, None], dtype="Int64"), - "test_longlong": pd.Series( - [-9223372036854775808, 9223372036854775807, None], dtype="Int64" - ), - "test_tiny_unsigned": pd.Series([None, 255, 0], dtype="Int64"), - "test_short_unsigned": pd.Series([None, 65535, 0], dtype="Int64"), - "test_int24_unsigned": pd.Series([None, 16777215, 0], dtype="Int64"), - "test_long_unsigned": pd.Series([None, 4294967295, 0], dtype="Int64"), - "test_longlong_unsigned": pd.Series( - [None, 18446744070000001024.0, 0.0], dtype="float" - ), - "test_long_notnull": pd.Series([1, 2147483647, -2147483648], dtype="int64"), - "test_short_unsigned_notnull": pd.Series([1, 65535, 0], dtype="int64"), - "test_float": pd.Series([None, -1.1e-38, 3.4e38], dtype="float"), - "test_double": pd.Series([-2.2e-308, None, 1.7e308], dtype="float"), - "test_double_notnull": pd.Series([1.2345, -1.1e-3, 1.7e30], dtype="float"), - "test_year": pd.Series([1901, 2155, None], dtype="Int64"), - "test_tinyblob": pd.Series( - [None, b"tinyblob2", b"tinyblob3"], dtype="object" - ), - "test_blob": pd.Series( - [None, b"blobblobblobblob2", b"blobblobblobblob3"], dtype="object" - ), - "test_mediumblob": pd.Series( - [None, b"mediumblob2", b"mediumblob3"], dtype="object" - ), - "test_longblob": pd.Series( - [None, b"longblob2", b"longblob3"], dtype="object" - ), - "test_enum": pd.Series(["apple", None, "mango"], dtype="object"), - "test_json": pd.Series( - ['{"age":1,"name":"piggy"}', '{"age":2,"name":"kitty"}', None], - # mariadb - # [b'{"name": "piggy", "age": 1}', b'{"name": "kitty", "age": 2}', None], - dtype="object", - ), - "test_mediumtext": pd.Series( - [None, b"", b"medium text!!!!"], dtype="object" - ), + "test_boolean": pd.Series([True, False, None], dtype="boolean"), + "test_int": pd.Series([123, 321, None], dtype="Int64"), + "test_bigint": pd.Series([1000, 2000, None], dtype="Int64"), + "test_real": pd.Series([123.456, 123.456, None], dtype="float64"), + "test_double": pd.Series([123.4567890123, 123.4567890123, None], dtype="float64"), + "test_decimal": pd.Series([1234567890.12, 1234567890.12, None], dtype="float64"), + "test_date": pd.Series([None, "2023-01-01", "2023-01-01"], dtype="datetime64[ns]"), + "test_time": pd.Series(["12:00:00", "12:00:00", None], dtype="object"), + "test_timestamp": pd.Series(["2023-01-01 12:00:00.123456", "2023-01-01 12:00:00.123456", None], dtype="datetime64[ns]"), + "test_varchar": pd.Series(["Sample text", "Sample text", None], dtype="object"), + "test_uuid": pd.Series(["f4967dbb-33e9-4242-a13a-45b56ce60dba", "1c8b79d0-4508-4974-b728-7651bce4a5a5", None], dtype="object"), }, ) assert_frame_equal(df, expected, check_names=True) -def test_empty_result(mysql_url: str) -> None: - query = "SELECT * FROM test_table where test_int < -100" - df = read_sql(mysql_url, query) +def test_empty_result(trino_url: str) -> None: + query = "SELECT * FROM test.test_table where test_int < -100" + df = read_sql(trino_url, query) expected = pd.DataFrame( data={ "test_int": pd.Series([], dtype="Int64"), "test_float": pd.Series([], dtype="float64"), - "test_enum": pd.Series([], dtype="object"), "test_null": pd.Series([], dtype="Int64"), } ) assert_frame_equal(df, expected, check_names=True) -def test_empty_result_on_partition(mysql_url: str) -> None: - query = "SELECT * FROM test_table where test_int < -100" - df = read_sql(mysql_url, query, partition_on="test_int", partition_num=3) +def test_empty_result_on_partition(trino_url: str) -> None: + query = "SELECT * FROM test.test_table where test_int < -100" + df = read_sql(trino_url, query, partition_on="test_int", partition_num=3) expected = pd.DataFrame( data={ "test_int": pd.Series([], dtype="Int64"), "test_float": pd.Series([], dtype="float64"), - "test_enum": pd.Series([], dtype="object"), "test_null": pd.Series([], dtype="Int64"), } ) assert_frame_equal(df, expected, check_names=True) -def test_empty_result_on_some_partition(mysql_url: str) -> None: +def test_empty_result_on_some_partition(trino_url: str) -> None: query = "SELECT * FROM test_table where test_int = 6" - df = read_sql(mysql_url, query, partition_on="test_int", partition_num=3) + df = read_sql(trino_url, query, partition_on="test_int", partition_num=3) expected = pd.DataFrame( index=range(1), data={ "test_int": pd.Series([6], dtype="Int64"), "test_float": pd.Series([6.6], dtype="float64"), - "test_enum": pd.Series(["even"], dtype="object"), "test_null": pd.Series([None], dtype="Int64"), }, ) assert_frame_equal(df, expected, check_names=True) - - -def test_mysql_cte(mysql_url: str) -> None: - query = "with test_cte (test_int, test_enum) as (select test_int, test_enum from test_table where test_float > 2) select test_int, test_enum from test_cte" - df = read_sql(mysql_url, query, partition_on="test_int", partition_num=3) - df.sort_values(by="test_int", inplace=True, ignore_index=True) - expected = pd.DataFrame( - index=range(5), - data={ - "test_int": pd.Series([2, 3, 4, 5, 6], dtype="Int64"), - "test_enum": pd.Series( - ["even", "odd", "even", "odd", "even"], dtype="object" - ), - }, - ) - assert_frame_equal(df, expected, check_names=True) diff --git a/scripts/trino.sql b/scripts/trino.sql index 8f7482c677..643984dec2 100644 --- a/scripts/trino.sql +++ b/scripts/trino.sql @@ -44,5 +44,6 @@ CREATE TABLE IF NOT EXISTS test.test_types( DELETE FROM test.test_types; INSERT INTO test.test_types (test_boolean, test_int, test_bigint, test_real, test_double, test_decimal, test_date, test_time, test_timestamp, test_varchar, test_uuid) VALUES -(TRUE, 123, 123456789012345, CAST(123.456 AS REAL), CAST(123.4567890123 AS DOUBLE), 1234567890.12, date('2023-01-01'), time '12:00:00', cast(timestamp '2023-01-01 12:00:00.123456' AS timestamp(6)), 'Sample text', UUID()), -(FALSE, 321, 123456789012345, CAST(123.456 AS REAL), CAST(123.4567890123 AS DOUBLE), 1234567890.12, date('2023-01-01'), time '12:00:00', cast(timestamp '2023-01-01 12:00:00.123456' AS timestamp(6)), 'Sample text', UUID()); +(TRUE, 123, 1000, CAST(123.456 AS REAL), CAST(123.4567890123 AS DOUBLE), 1234567890.12, date('2023-01-01'), time '12:00:00', cast(timestamp '2023-01-01 12:00:00.123456' AS timestamp(6)), 'Sample text', CAST('f4967dbb-33e9-4242-a13a-45b56ce60dba' AS UUID)), +(FALSE, 321, 2000, CAST(123.456 AS REAL), CAST(123.4567890123 AS DOUBLE), 1234567890.12, date('2023-01-01'), time '12:00:00', cast(timestamp '2023-01-01 12:00:00.123456' AS timestamp(6)), 'Sample text', CAST('1c8b79d0-4508-4974-b728-7651bce4a5a5' AS UUID)), +(NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL); From 6979c653dd11099a60e9e8d36c04a8a199849107 Mon Sep 17 00:00:00 2001 From: Dominik Liebler Date: Fri, 19 Apr 2024 15:34:23 +0200 Subject: [PATCH 07/10] implemented partitioning for Trino --- Cargo.lock | 9 +-- connectorx-python/Cargo.lock | 1 + .../connectorx/tests/test_trino.py | 9 +-- connectorx-python/src/pandas/get_meta.rs | 14 ++++- connectorx/Cargo.toml | 3 +- connectorx/src/partition.rs | 58 ++++++++++++++++++- connectorx/src/sources/trino/mod.rs | 44 ++++++++++---- connectorx/src/sources/trino/typesystem.rs | 4 +- 8 files changed, 119 insertions(+), 23 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2e196209bc..156a4fb3b7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1015,6 +1015,7 @@ dependencies = [ "rusqlite", "rust_decimal", "rust_decimal_macros", + "serde", "serde_json", "sqlparser 0.37.0", "thiserror", @@ -4523,9 +4524,9 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.195" +version = "1.0.198" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63261df402c67811e9ac6def069e4786148c4563f4b50fd4bf30aa370d626b02" +checksum = "9846a40c979031340571da2545a4e5b7c4163bdae79b301d5f86d03979451fcc" dependencies = [ "serde_derive", ] @@ -4542,9 +4543,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.195" +version = "1.0.198" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46fe8f8603d81ba86327b23a2e9cdf49e1255fb94a4c5f297f6ee0547178ea2c" +checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9" dependencies = [ "proc-macro2", "quote", diff --git a/connectorx-python/Cargo.lock b/connectorx-python/Cargo.lock index b7b7170ebf..4303ca0183 100644 --- a/connectorx-python/Cargo.lock +++ b/connectorx-python/Cargo.lock @@ -1058,6 +1058,7 @@ dependencies = [ "rusqlite", "rust_decimal", "rust_decimal_macros", + "serde", "serde_json", "sqlparser 0.37.0", "thiserror", diff --git a/connectorx-python/connectorx/tests/test_trino.py b/connectorx-python/connectorx/tests/test_trino.py index 12d6c8d9af..5c883fee6c 100644 --- a/connectorx-python/connectorx/tests/test_trino.py +++ b/connectorx-python/connectorx/tests/test_trino.py @@ -154,7 +154,7 @@ def test_trino_limit_large_with_partition(trino_url: str) -> None: def test_trino_with_partition_without_partition_range(trino_url: str) -> None: - query = "SELECT * FROM test.test_table where test_float > 3 order by test_int" + query = "SELECT * FROM test.test_table where test_float > 3" df = read_sql( trino_url, query, @@ -170,6 +170,7 @@ def test_trino_with_partition_without_partition_range(trino_url: str) -> None: }, ) df.sort_values(by="test_int", inplace=True, ignore_index=True) + assert_frame_equal(df, expected, check_names=True) @@ -210,7 +211,7 @@ def test_trino_selection_and_projection(trino_url: str) -> None: def test_trino_join(trino_url: str) -> None: - query = "SELECT T.test_int, T.test_float, S.test_str FROM test_table T INNER JOIN test_table_extra S ON T.test_int = S.test_int order by T.test_int" + query = "SELECT T.test_int, T.test_float, S.test_str FROM test.test_table T INNER JOIN test.test_table_extra S ON T.test_int = S.test_int order by T.test_int" df = read_sql( trino_url, query, @@ -262,7 +263,7 @@ def test_trino_types_binary(trino_url: str) -> None: "test_real": pd.Series([123.456, 123.456, None], dtype="float64"), "test_double": pd.Series([123.4567890123, 123.4567890123, None], dtype="float64"), "test_decimal": pd.Series([1234567890.12, 1234567890.12, None], dtype="float64"), - "test_date": pd.Series([None, "2023-01-01", "2023-01-01"], dtype="datetime64[ns]"), + "test_date": pd.Series(["2023-01-01", "2023-01-01", None], dtype="datetime64[ns]"), "test_time": pd.Series(["12:00:00", "12:00:00", None], dtype="object"), "test_timestamp": pd.Series(["2023-01-01 12:00:00.123456", "2023-01-01 12:00:00.123456", None], dtype="datetime64[ns]"), "test_varchar": pd.Series(["Sample text", "Sample text", None], dtype="object"), @@ -299,7 +300,7 @@ def test_empty_result_on_partition(trino_url: str) -> None: def test_empty_result_on_some_partition(trino_url: str) -> None: - query = "SELECT * FROM test_table where test_int = 6" + query = "SELECT * FROM test.test_table where test_int = 6" df = read_sql(trino_url, query, partition_on="test_int", partition_num=3) expected = pd.DataFrame( index=range(1), diff --git a/connectorx-python/src/pandas/get_meta.rs b/connectorx-python/src/pandas/get_meta.rs index bc5e7de950..7ee648e7d1 100644 --- a/connectorx-python/src/pandas/get_meta.rs +++ b/connectorx-python/src/pandas/get_meta.rs @@ -2,7 +2,7 @@ use super::{ destination::PandasDestination, transports::{ BigQueryPandasTransport, MsSQLPandasTransport, MysqlPandasTransport, OraclePandasTransport, - PostgresPandasTransport, SqlitePandasTransport, + PostgresPandasTransport, SqlitePandasTransport, TrinoPandasTransport, }, }; use crate::errors::ConnectorXPythonError; @@ -18,6 +18,7 @@ use connectorx::{ PostgresSource, SimpleProtocol, }, sqlite::SQLiteSource, + trino::TrinoSource, }, sql::CXQuery, }; @@ -223,6 +224,17 @@ pub fn get_meta<'a>(py: Python<'a>, conn: &str, protocol: &str, query: String) - debug!("Running dispatcher"); dispatcher.get_meta()?; } + SourceType::Trino => { + let rt = Arc::new(tokio::runtime::Runtime::new().expect("Failed to create runtime")); + let source = TrinoSource::new(rt, &source_conn.conn[..])?; + let dispatcher = Dispatcher::<_, _, TrinoPandasTransport>::new( + source, + &mut destination, + queries, + None, + ); + dispatcher.run()?; + } _ => unimplemented!("{:?} not implemented!", source_conn.ty), } diff --git a/connectorx/Cargo.toml b/connectorx/Cargo.toml index 44409c2cb8..f7b200e39e 100644 --- a/connectorx/Cargo.toml +++ b/connectorx/Cargo.toml @@ -58,6 +58,7 @@ uuid = {version = "0.8", optional = true} j4rs = {version = "0.15", optional = true} datafusion = {version = "31", optional = true} prusto = {version = "0.5.1", optional = true} +serde = {optional = true} [lib] crate-type = ["cdylib", "rlib"] @@ -98,7 +99,7 @@ src_postgres = [ "postgres-openssl", ] src_sqlite = ["rusqlite", "r2d2_sqlite", "fallible-streaming-iterator", "r2d2", "urlencoding"] -src_trino = ["prusto", "uuid", "urlencoding", "rust_decimal", "tokio", "num-traits"] +src_trino = ["prusto", "uuid", "urlencoding", "rust_decimal", "tokio", "num-traits", "serde"] federation = ["j4rs"] fed_exec = ["datafusion", "tokio"] integrated-auth-gssapi = ["tiberius/integrated-auth-gssapi"] diff --git a/connectorx/src/partition.rs b/connectorx/src/partition.rs index 370120e2ab..fedd34fe71 100644 --- a/connectorx/src/partition.rs +++ b/connectorx/src/partition.rs @@ -10,6 +10,7 @@ use crate::sources::mysql::{MySQLSourceError, MySQLTypeSystem}; use crate::sources::oracle::{connect_oracle, OracleDialect}; #[cfg(feature = "src_postgres")] use crate::sources::postgres::{rewrite_tls_args, PostgresTypeSystem}; +use crate::sources::trino::TrinoDialect; #[cfg(feature = "src_sqlite")] use crate::sql::get_partition_range_query_sep; use crate::sql::{get_partition_range_query, single_col_partition_query, CXQuery}; @@ -35,7 +36,7 @@ use sqlparser::dialect::PostgreSqlDialect; use sqlparser::dialect::SQLiteDialect; #[cfg(feature = "src_mssql")] use tiberius::Client; -#[cfg(any(feature = "src_bigquery", feature = "src_mssql"))] +#[cfg(any(feature = "src_bigquery", feature = "src_mssql", feature = "src_trino"))] use tokio::{net::TcpStream, runtime::Runtime}; #[cfg(feature = "src_mssql")] use tokio_util::compat::TokioAsyncWriteCompatExt; @@ -100,6 +101,8 @@ pub fn get_col_range(source_conn: &SourceConn, query: &str, col: &str) -> OutRes SourceType::Oracle => oracle_get_partition_range(&source_conn.conn, query, col), #[cfg(feature = "src_bigquery")] SourceType::BigQuery => bigquery_get_partition_range(&source_conn.conn, query, col), + #[cfg(feature = "src_trino")] + SourceType::Trino => trino_get_partition_range(&source_conn.conn, query, col), _ => unimplemented!("{:?} not implemented!", source_conn.ty), } } @@ -137,6 +140,10 @@ pub fn get_part_query( SourceType::BigQuery => { single_col_partition_query(query, col, lower, upper, &BigQueryDialect {})? } + #[cfg(feature = "src_trino")] + SourceType::Trino => { + single_col_partition_query(query, col, lower, upper, &TrinoDialect {})? + } _ => unimplemented!("{:?} not implemented!", source_conn.ty), }; CXQuery::Wrapped(query) @@ -481,3 +488,52 @@ fn bigquery_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64 (min_v, max_v) } + +#[cfg(feature = "src_trino")] +#[throws(ConnectorXOutError)] +fn trino_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64) { + use prusto::{auth::Auth, ClientBuilder}; + + use crate::sources::trino::{TrinoDialect, TrinoPartitionQueryResult}; + + let rt = Runtime::new().expect("Failed to create runtime"); + + let username = match conn.username() { + "" => "connectorx", + username => username, + }; + + let builder = ClientBuilder::new(username, conn.host().unwrap().to_owned()) + .port(conn.port().unwrap_or(8080)) + .ssl(prusto::ssl::Ssl { root_cert: None }) + .secure(conn.scheme() == "trino+https") + .catalog(conn.path_segments().unwrap().last().unwrap_or("hive")); + + let builder = match conn.password() { + None => builder, + Some(password) => builder.auth(Auth::Basic(username.to_owned(), Some(password.to_owned()))), + }; + + let client = builder + .build() + .map_err(|e| anyhow!("Failed to build client: {}", e))?; + + let range_query = get_partition_range_query(query, col, &TrinoDialect {})?; + let query_result = rt.block_on(client.get_all::(range_query)); + + let query_result = match query_result { + Ok(query_result) => Ok(query_result.into_vec()), + Err(e) => match e { + prusto::error::Error::EmptyData => { + Ok(vec![TrinoPartitionQueryResult { _col0: 0, _col1: 0 }]) + } + _ => Err(anyhow!("Failed to get query result: {}", e)), + }, + }?; + + let result = query_result + .first() + .unwrap_or(&TrinoPartitionQueryResult { _col0: 0, _col1: 0 }); + + (result._col0, result._col1) +} diff --git a/connectorx/src/sources/trino/mod.rs b/connectorx/src/sources/trino/mod.rs index 8a75245a1e..4f463fb282 100644 --- a/connectorx/src/sources/trino/mod.rs +++ b/connectorx/src/sources/trino/mod.rs @@ -4,7 +4,7 @@ use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; use fehler::{throw, throws}; use prusto::{auth::Auth, Client, ClientBuilder, DataSet, Presto, Row}; use serde_json::Value; -use sqlparser::dialect::GenericDialect; +use sqlparser::dialect::{Dialect, GenericDialect}; use std::convert::TryFrom; use tokio::runtime::Runtime; @@ -32,6 +32,26 @@ fn get_total_rows(rt: Arc, client: Arc, query: &CXQuery .len() } +#[derive(Presto, Debug)] +pub struct TrinoPartitionQueryResult { + pub _col0: i64, + pub _col1: i64, +} + +#[derive(Debug)] +pub struct TrinoDialect {} + +// implementation copy from AnsiDialect +impl Dialect for TrinoDialect { + fn is_identifier_start(&self, ch: char) -> bool { + ch.is_ascii_lowercase() || ch.is_ascii_uppercase() + } + + fn is_identifier_part(&self, ch: char) -> bool { + ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch.is_ascii_digit() || ch == '_' + } +} + pub struct TrinoSource { client: Arc, rt: Arc, @@ -282,12 +302,12 @@ macro_rules! impl_produce_int { match value { Value::Number(x) => { if (x.is_i64()) { - <$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))? + <$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse i64 at position: ({}, {}) {:?}", ridx, cidx, value))? } else { - throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x)) } } - _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value)) } } } @@ -304,12 +324,12 @@ macro_rules! impl_produce_int { Value::Null => None, Value::Number(x) => { if (x.is_i64()) { - Some(<$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))?) + Some(<$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse i64 at position: ({}, {}) {:?}", ridx, cidx, value))?) } else { - throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x)) } } - _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value)) } } } @@ -333,10 +353,11 @@ macro_rules! impl_produce_float { if (x.is_f64()) { x.as_f64().unwrap() as $t } else { - throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x)) } } - _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + Value::String(x) => x.parse::<$t>().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}) {:?}", ridx, cidx, value))?, + _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value)) } } } @@ -355,10 +376,11 @@ macro_rules! impl_produce_float { if (x.is_f64()) { Some(x.as_f64().unwrap() as $t) } else { - throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x)) } } - _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx)) + Value::String(x) => Some(x.parse::<$t>().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}) {:?}", ridx, cidx, value))?), + _ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value)) } } } diff --git a/connectorx/src/sources/trino/typesystem.rs b/connectorx/src/sources/trino/typesystem.rs index d5a492f038..f739c9a42d 100644 --- a/connectorx/src/sources/trino/typesystem.rs +++ b/connectorx/src/sources/trino/typesystem.rs @@ -1,7 +1,7 @@ use super::errors::TrinoSourceError; use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; use fehler::{throw, throws}; -use prusto::{PrestoFloat, PrestoInt, PrestoTy}; +use prusto::{Presto, PrestoFloat, PrestoInt, PrestoTy}; use std::convert::TryFrom; // TODO: implement Tuple, Row, Array and Map as well as UUID @@ -64,6 +64,7 @@ impl TryFrom for TrinoTypeSystem { PrestoTy::Map(_, _) => Varchar(true), PrestoTy::Decimal(_, _) => Double(true), PrestoTy::IpAddress => Varchar(true), + PrestoTy::Uuid => Varchar(true), _ => throw!(TrinoSourceError::InferTypeFromNull), } } @@ -97,6 +98,7 @@ impl TryFrom<(Option<&str>, PrestoTy)> for TrinoTypeSystem { "map" => Varchar(true), "decimal" => Double(true), "ipaddress" => Varchar(true), + "uuid" => Varchar(true), _ => TrinoTypeSystem::try_from(ty)?, } } From 52da49ac0d2ddee458cc3279a20f5079b8b754f4 Mon Sep 17 00:00:00 2001 From: Dominik Liebler Date: Fri, 19 Apr 2024 16:14:35 +0200 Subject: [PATCH 08/10] fetch results for Trino more efficiently --- connectorx/src/sources/trino/mod.rs | 39 ++++++++++++++++++---- connectorx/src/sources/trino/typesystem.rs | 4 +-- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/connectorx/src/sources/trino/mod.rs b/connectorx/src/sources/trino/mod.rs index 4f463fb282..f0bdfe6d30 100644 --- a/connectorx/src/sources/trino/mod.rs +++ b/connectorx/src/sources/trino/mod.rs @@ -129,7 +129,6 @@ where fn fetch_metadata(&mut self) { assert!(!self.queries.is_empty()); - // TODO: prevent from running the same query multiple times (limit1 + no limit) let first_query = &self.queries[0]; let cxq = limit1_query(first_query, &GenericDialect {})?; @@ -238,6 +237,9 @@ impl SourcePartition for TrinoSourcePartition { } pub struct TrinoSourcePartitionParser<'a> { + rt: Arc, + client: Arc, + next_uri: Option, rows: Vec, ncols: usize, current_col: usize, @@ -253,11 +255,19 @@ impl<'a> TrinoSourcePartitionParser<'a> { query: CXQuery, schema: &[TrinoTypeSystem], ) -> Self { - let rows = client.get_all::(query.to_string()); - let data = rt.block_on(rows).map_err(TrinoSourceError::PrustoError)?; - let rows = data.clone().into_vec(); + let results = rt + .block_on(client.get::(query.to_string())) + .map_err(TrinoSourceError::PrustoError)?; + + let rows = match results.data_set { + Some(x) => x.into_vec(), + _ => vec![], + }; Self { + rt, + client, + next_uri: results.next_uri, rows, ncols: schema.len(), current_row: 0, @@ -283,8 +293,25 @@ impl<'a> PartitionParser<'a> for TrinoSourcePartitionParser<'a> { fn fetch_next(&mut self) -> (usize, bool) { assert!(self.current_col == 0); - // results are always fetched in a single batch for Prusto - (self.rows.len(), true) + match self.next_uri.clone() { + Some(uri) => { + let results = self + .rt + .block_on(self.client.get_next::(&uri)) + .map_err(TrinoSourceError::PrustoError)?; + + self.rows = match results.data_set { + Some(x) => x.into_vec(), + _ => vec![], + }; + + self.current_row = 0; + self.next_uri = results.next_uri; + + (self.rows.len(), false) + } + None => return (self.rows.len(), true), + } } } diff --git a/connectorx/src/sources/trino/typesystem.rs b/connectorx/src/sources/trino/typesystem.rs index f739c9a42d..c21c8cf45d 100644 --- a/connectorx/src/sources/trino/typesystem.rs +++ b/connectorx/src/sources/trino/typesystem.rs @@ -1,10 +1,10 @@ use super::errors::TrinoSourceError; use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; use fehler::{throw, throws}; -use prusto::{Presto, PrestoFloat, PrestoInt, PrestoTy}; +use prusto::{PrestoFloat, PrestoInt, PrestoTy}; use std::convert::TryFrom; -// TODO: implement Tuple, Row, Array and Map as well as UUID +// TODO: implement Tuple, Row, Array and Map #[derive(Copy, Clone, Debug, PartialEq)] pub enum TrinoTypeSystem { Date(bool), From b279cdb3cda11c814f76268d4e0e65232c9098f9 Mon Sep 17 00:00:00 2001 From: Dominik Liebler Date: Mon, 22 Apr 2024 15:42:51 +0200 Subject: [PATCH 09/10] Trino use count_query for get_total_rows --- .../connectorx/tests/test_trino.py | 51 +++++++++++++++++++ connectorx/src/sources/trino/mod.rs | 20 ++++++-- 2 files changed, 68 insertions(+), 3 deletions(-) diff --git a/connectorx-python/connectorx/tests/test_trino.py b/connectorx-python/connectorx/tests/test_trino.py index 5c883fee6c..783d2b99e1 100644 --- a/connectorx-python/connectorx/tests/test_trino.py +++ b/connectorx-python/connectorx/tests/test_trino.py @@ -13,6 +13,9 @@ def trino_url() -> str: return conn +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) def test_trino_without_partition(trino_url: str) -> None: query = "select * from test.test_table order by test_int limit 3" df = read_sql(trino_url, query) @@ -27,6 +30,9 @@ def test_trino_without_partition(trino_url: str) -> None: assert_frame_equal(df, expected, check_names=True) +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) def test_trino_with_partition(trino_url: str) -> None: query = "select * from test.test_table order by test_int" df = read_sql( @@ -48,6 +54,9 @@ def test_trino_with_partition(trino_url: str) -> None: assert_frame_equal(df, expected, check_names=True) +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) def test_trino_without_partition(trino_url: str) -> None: query = "SELECT * FROM test.test_table order by test_int" df = read_sql(trino_url, query) @@ -62,6 +71,9 @@ def test_trino_without_partition(trino_url: str) -> None: assert_frame_equal(df, expected, check_names=True) +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) def test_trino_limit_without_partition(trino_url: str) -> None: query = "SELECT * FROM test.test_table order by test_int limit 3" df = read_sql(trino_url, query) @@ -76,6 +88,9 @@ def test_trino_limit_without_partition(trino_url: str) -> None: assert_frame_equal(df, expected, check_names=True) +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) def test_trino_limit_large_without_partition(trino_url: str) -> None: query = "SELECT * FROM test.test_table order by test_int limit 10" df = read_sql(trino_url, query) @@ -90,6 +105,9 @@ def test_trino_limit_large_without_partition(trino_url: str) -> None: assert_frame_equal(df, expected, check_names=True) +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) def test_trino_with_partition(trino_url: str) -> None: query = "SELECT * FROM test.test_table order by test_int" df = read_sql( @@ -111,6 +129,9 @@ def test_trino_with_partition(trino_url: str) -> None: assert_frame_equal(df, expected, check_names=True) +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) def test_trino_limit_with_partition(trino_url: str) -> None: query = "SELECT * FROM test.test_table order by test_int limit 3" df = read_sql( @@ -132,6 +153,9 @@ def test_trino_limit_with_partition(trino_url: str) -> None: assert_frame_equal(df, expected, check_names=True) +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) def test_trino_limit_large_with_partition(trino_url: str) -> None: query = "SELECT * FROM test.test_table order by test_int limit 10" df = read_sql( @@ -153,6 +177,9 @@ def test_trino_limit_large_with_partition(trino_url: str) -> None: assert_frame_equal(df, expected, check_names=True) +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) def test_trino_with_partition_without_partition_range(trino_url: str) -> None: query = "SELECT * FROM test.test_table where test_float > 3" df = read_sql( @@ -174,6 +201,9 @@ def test_trino_with_partition_without_partition_range(trino_url: str) -> None: assert_frame_equal(df, expected, check_names=True) +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) def test_trino_manual_partition(trino_url: str) -> None: queries = [ "SELECT * FROM test.test_table WHERE test_int < 2 order by test_int", @@ -192,6 +222,9 @@ def test_trino_manual_partition(trino_url: str) -> None: assert_frame_equal(df, expected, check_names=True) +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) def test_trino_selection_and_projection(trino_url: str) -> None: query = "SELECT test_int FROM test.test_table WHERE test_float < 5 order by test_int" df = read_sql( @@ -210,6 +243,9 @@ def test_trino_selection_and_projection(trino_url: str) -> None: assert_frame_equal(df, expected, check_names=True) +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) def test_trino_join(trino_url: str) -> None: query = "SELECT T.test_int, T.test_float, S.test_str FROM test.test_table T INNER JOIN test.test_table_extra S ON T.test_int = S.test_int order by T.test_int" df = read_sql( @@ -237,6 +273,9 @@ def test_trino_join(trino_url: str) -> None: assert_frame_equal(df, expected, check_names=True) +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) def test_trino_aggregate(trino_url: str) -> None: query = "select AVG(test_float) as avg_float, SUM(T.test_int) as sum_int, SUM(test_null) as sum_null from test.test_table as T" df = read_sql(trino_url, query) @@ -251,6 +290,9 @@ def test_trino_aggregate(trino_url: str) -> None: assert_frame_equal(df, expected, check_names=True) +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) def test_trino_types_binary(trino_url: str) -> None: query = "select test_boolean, test_int, test_bigint, test_real, test_double, test_decimal, test_date, test_time, test_timestamp, test_varchar, test_uuid from test.test_types order by test_int" df = read_sql(trino_url, query) @@ -273,6 +315,9 @@ def test_trino_types_binary(trino_url: str) -> None: assert_frame_equal(df, expected, check_names=True) +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) def test_empty_result(trino_url: str) -> None: query = "SELECT * FROM test.test_table where test_int < -100" df = read_sql(trino_url, query) @@ -286,6 +331,9 @@ def test_empty_result(trino_url: str) -> None: assert_frame_equal(df, expected, check_names=True) +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) def test_empty_result_on_partition(trino_url: str) -> None: query = "SELECT * FROM test.test_table where test_int < -100" df = read_sql(trino_url, query, partition_on="test_int", partition_num=3) @@ -299,6 +347,9 @@ def test_empty_result_on_partition(trino_url: str) -> None: assert_frame_equal(df, expected, check_names=True) +@pytest.mark.skipif( + not os.environ.get("TRINO_URL"), reason="Test Trino only when `TRINO_URL` is set" +) def test_empty_result_on_some_partition(trino_url: str) -> None: query = "SELECT * FROM test.test_table where test_int = 6" df = read_sql(trino_url, query, partition_on="test_int", partition_num=3) diff --git a/connectorx/src/sources/trino/mod.rs b/connectorx/src/sources/trino/mod.rs index f0bdfe6d30..072b0ed6f2 100644 --- a/connectorx/src/sources/trino/mod.rs +++ b/connectorx/src/sources/trino/mod.rs @@ -12,7 +12,7 @@ use crate::{ data_order::DataOrder, errors::ConnectorXError, sources::Produce, - sql::{limit1_query, CXQuery}, + sql::{count_query, limit1_query, CXQuery}, }; pub use self::{errors::TrinoSourceError, typesystem::TrinoTypeSystem}; @@ -27,9 +27,23 @@ pub mod typesystem; #[throws(TrinoSourceError)] fn get_total_rows(rt: Arc, client: Arc, query: &CXQuery) -> usize { - rt.block_on(client.get_all::(query.to_string())) + let cquery = count_query(query, &TrinoDialect {})?; + + let row = rt + .block_on(client.get_all::(cquery.to_string())) .map_err(TrinoSourceError::PrustoError)? - .len() + .split() + .1[0] + .clone(); + + let value = row + .value() + .first() + .ok_or_else(|| anyhow!("Trino count dataset is empty"))?; + + value + .as_i64() + .ok_or_else(|| anyhow!("Trino cannot parse i64"))? as usize } #[derive(Presto, Debug)] From cbe16b5f5e0268a93ef189942c05225915c770fb Mon Sep 17 00:00:00 2001 From: Dominik Liebler Date: Mon, 22 Apr 2024 15:58:08 +0200 Subject: [PATCH 10/10] added Trino documentation --- docs/databases.md | 4 +++- docs/databases/trino.md | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 docs/databases/trino.md diff --git a/docs/databases.md b/docs/databases.md index 6eb371fe9e..077f99f2ac 100644 --- a/docs/databases.md +++ b/docs/databases.md @@ -7,4 +7,6 @@ ConnectorX supports retrieving data from Postgres, MsSQL, MySQL, Oracle, SQLite, * [MySQL](./databases/mysql.md) * [Oracle](./databases/oracle.md) * [Postgres](./databases/postgres.md) -* [SQLite](./databases/sqlite.md) \ No newline at end of file +* [SQLite](./databases/sqlite.md) +* [Trino](./databases/trino.md) + diff --git a/docs/databases/trino.md b/docs/databases/trino.md new file mode 100644 index 0000000000..8ea640d380 --- /dev/null +++ b/docs/databases/trino.md @@ -0,0 +1,36 @@ +# Trino + +## Postgres Connection + +```{hint} +Using `trino+http` as connection protocol disables SSL for the connection. Example: `trino+http://host:port/catalog +Notice that basic auth requires SSL for Trino. +``` + +```py +import connectorx as cx +conn = 'trino+https://username:password@server:port/catalog' # connection token +query = "SELECT * FROM table" # query string +cx.read_sql(conn, query) # read data from Trino +``` + +## Trino-Pandas Type Mapping + +| Trino Type | Pandas Type | Comment | +| :--------: | :---------------------: | :-----: | +| BOOLEAN | bool, boolean(nullable) | | +| TINYINT | int64, Int64(nullable) | | +| SMALLINT | int64, Int64(nullable) | | +| INT | int64, Int64(nullable) | | +| BIGINT | int64, Int64(nullable) | | +| REAL | float64 | | +| DOUBLE | float64 | | +| DECIMAL | float64 | | +| VARCHAR | object | | +| CHAR | object | | +| DATE | datetime64[ns] | | +| TIME | object | | +| TIMESTAMP | datetime64[ns] | | +| UUID | object | | +| JSON | object | | +| IPADDRESS | object | |