Skip to content

Commit

Permalink
implemented partitioning for Trino
Browse files Browse the repository at this point in the history
  • Loading branch information
domnikl committed Apr 19, 2024
1 parent aaab7de commit 6979c65
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 23 deletions.
9 changes: 5 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions connectorx-python/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 5 additions & 4 deletions connectorx-python/connectorx/tests/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_trino_limit_large_with_partition(trino_url: str) -> None:


def test_trino_with_partition_without_partition_range(trino_url: str) -> None:
query = "SELECT * FROM test.test_table where test_float > 3 order by test_int"
query = "SELECT * FROM test.test_table where test_float > 3"
df = read_sql(
trino_url,
query,
Expand All @@ -170,6 +170,7 @@ def test_trino_with_partition_without_partition_range(trino_url: str) -> None:
},
)
df.sort_values(by="test_int", inplace=True, ignore_index=True)

assert_frame_equal(df, expected, check_names=True)


Expand Down Expand Up @@ -210,7 +211,7 @@ def test_trino_selection_and_projection(trino_url: str) -> None:


def test_trino_join(trino_url: str) -> None:
query = "SELECT T.test_int, T.test_float, S.test_str FROM test_table T INNER JOIN test_table_extra S ON T.test_int = S.test_int order by T.test_int"
query = "SELECT T.test_int, T.test_float, S.test_str FROM test.test_table T INNER JOIN test.test_table_extra S ON T.test_int = S.test_int order by T.test_int"
df = read_sql(
trino_url,
query,
Expand Down Expand Up @@ -262,7 +263,7 @@ def test_trino_types_binary(trino_url: str) -> None:
"test_real": pd.Series([123.456, 123.456, None], dtype="float64"),
"test_double": pd.Series([123.4567890123, 123.4567890123, None], dtype="float64"),
"test_decimal": pd.Series([1234567890.12, 1234567890.12, None], dtype="float64"),
"test_date": pd.Series([None, "2023-01-01", "2023-01-01"], dtype="datetime64[ns]"),
"test_date": pd.Series(["2023-01-01", "2023-01-01", None], dtype="datetime64[ns]"),
"test_time": pd.Series(["12:00:00", "12:00:00", None], dtype="object"),
"test_timestamp": pd.Series(["2023-01-01 12:00:00.123456", "2023-01-01 12:00:00.123456", None], dtype="datetime64[ns]"),
"test_varchar": pd.Series(["Sample text", "Sample text", None], dtype="object"),
Expand Down Expand Up @@ -299,7 +300,7 @@ def test_empty_result_on_partition(trino_url: str) -> None:


def test_empty_result_on_some_partition(trino_url: str) -> None:
query = "SELECT * FROM test_table where test_int = 6"
query = "SELECT * FROM test.test_table where test_int = 6"
df = read_sql(trino_url, query, partition_on="test_int", partition_num=3)
expected = pd.DataFrame(
index=range(1),
Expand Down
14 changes: 13 additions & 1 deletion connectorx-python/src/pandas/get_meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::{
destination::PandasDestination,
transports::{
BigQueryPandasTransport, MsSQLPandasTransport, MysqlPandasTransport, OraclePandasTransport,
PostgresPandasTransport, SqlitePandasTransport,
PostgresPandasTransport, SqlitePandasTransport, TrinoPandasTransport,
},
};
use crate::errors::ConnectorXPythonError;
Expand All @@ -18,6 +18,7 @@ use connectorx::{
PostgresSource, SimpleProtocol,
},
sqlite::SQLiteSource,
trino::TrinoSource,
},
sql::CXQuery,
};
Expand Down Expand Up @@ -223,6 +224,17 @@ pub fn get_meta<'a>(py: Python<'a>, conn: &str, protocol: &str, query: String) -
debug!("Running dispatcher");
dispatcher.get_meta()?;
}
SourceType::Trino => {
let rt = Arc::new(tokio::runtime::Runtime::new().expect("Failed to create runtime"));
let source = TrinoSource::new(rt, &source_conn.conn[..])?;
let dispatcher = Dispatcher::<_, _, TrinoPandasTransport>::new(
source,
&mut destination,
queries,
None,
);
dispatcher.run()?;
}
_ => unimplemented!("{:?} not implemented!", source_conn.ty),
}

Expand Down
3 changes: 2 additions & 1 deletion connectorx/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ uuid = {version = "0.8", optional = true}
j4rs = {version = "0.15", optional = true}
datafusion = {version = "31", optional = true}
prusto = {version = "0.5.1", optional = true}
serde = {optional = true}

[lib]
crate-type = ["cdylib", "rlib"]
Expand Down Expand Up @@ -98,7 +99,7 @@ src_postgres = [
"postgres-openssl",
]
src_sqlite = ["rusqlite", "r2d2_sqlite", "fallible-streaming-iterator", "r2d2", "urlencoding"]
src_trino = ["prusto", "uuid", "urlencoding", "rust_decimal", "tokio", "num-traits"]
src_trino = ["prusto", "uuid", "urlencoding", "rust_decimal", "tokio", "num-traits", "serde"]
federation = ["j4rs"]
fed_exec = ["datafusion", "tokio"]
integrated-auth-gssapi = ["tiberius/integrated-auth-gssapi"]
Expand Down
58 changes: 57 additions & 1 deletion connectorx/src/partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::sources::mysql::{MySQLSourceError, MySQLTypeSystem};
use crate::sources::oracle::{connect_oracle, OracleDialect};
#[cfg(feature = "src_postgres")]
use crate::sources::postgres::{rewrite_tls_args, PostgresTypeSystem};
use crate::sources::trino::TrinoDialect;
#[cfg(feature = "src_sqlite")]
use crate::sql::get_partition_range_query_sep;
use crate::sql::{get_partition_range_query, single_col_partition_query, CXQuery};
Expand All @@ -35,7 +36,7 @@ use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::dialect::SQLiteDialect;
#[cfg(feature = "src_mssql")]
use tiberius::Client;
#[cfg(any(feature = "src_bigquery", feature = "src_mssql"))]
#[cfg(any(feature = "src_bigquery", feature = "src_mssql", feature = "src_trino"))]
use tokio::{net::TcpStream, runtime::Runtime};
#[cfg(feature = "src_mssql")]
use tokio_util::compat::TokioAsyncWriteCompatExt;
Expand Down Expand Up @@ -100,6 +101,8 @@ pub fn get_col_range(source_conn: &SourceConn, query: &str, col: &str) -> OutRes
SourceType::Oracle => oracle_get_partition_range(&source_conn.conn, query, col),
#[cfg(feature = "src_bigquery")]
SourceType::BigQuery => bigquery_get_partition_range(&source_conn.conn, query, col),
#[cfg(feature = "src_trino")]
SourceType::Trino => trino_get_partition_range(&source_conn.conn, query, col),
_ => unimplemented!("{:?} not implemented!", source_conn.ty),
}
}
Expand Down Expand Up @@ -137,6 +140,10 @@ pub fn get_part_query(
SourceType::BigQuery => {
single_col_partition_query(query, col, lower, upper, &BigQueryDialect {})?
}
#[cfg(feature = "src_trino")]
SourceType::Trino => {
single_col_partition_query(query, col, lower, upper, &TrinoDialect {})?
}
_ => unimplemented!("{:?} not implemented!", source_conn.ty),
};
CXQuery::Wrapped(query)
Expand Down Expand Up @@ -481,3 +488,52 @@ fn bigquery_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64

(min_v, max_v)
}

#[cfg(feature = "src_trino")]
#[throws(ConnectorXOutError)]
fn trino_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64) {
use prusto::{auth::Auth, ClientBuilder};

use crate::sources::trino::{TrinoDialect, TrinoPartitionQueryResult};

let rt = Runtime::new().expect("Failed to create runtime");

let username = match conn.username() {
"" => "connectorx",
username => username,
};

let builder = ClientBuilder::new(username, conn.host().unwrap().to_owned())
.port(conn.port().unwrap_or(8080))
.ssl(prusto::ssl::Ssl { root_cert: None })
.secure(conn.scheme() == "trino+https")
.catalog(conn.path_segments().unwrap().last().unwrap_or("hive"));

let builder = match conn.password() {
None => builder,
Some(password) => builder.auth(Auth::Basic(username.to_owned(), Some(password.to_owned()))),
};

let client = builder
.build()
.map_err(|e| anyhow!("Failed to build client: {}", e))?;

let range_query = get_partition_range_query(query, col, &TrinoDialect {})?;
let query_result = rt.block_on(client.get_all::<TrinoPartitionQueryResult>(range_query));

let query_result = match query_result {
Ok(query_result) => Ok(query_result.into_vec()),
Err(e) => match e {
prusto::error::Error::EmptyData => {
Ok(vec![TrinoPartitionQueryResult { _col0: 0, _col1: 0 }])
}
_ => Err(anyhow!("Failed to get query result: {}", e)),
},
}?;

let result = query_result
.first()
.unwrap_or(&TrinoPartitionQueryResult { _col0: 0, _col1: 0 });

(result._col0, result._col1)
}
44 changes: 33 additions & 11 deletions connectorx/src/sources/trino/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
use fehler::{throw, throws};
use prusto::{auth::Auth, Client, ClientBuilder, DataSet, Presto, Row};
use serde_json::Value;
use sqlparser::dialect::GenericDialect;
use sqlparser::dialect::{Dialect, GenericDialect};
use std::convert::TryFrom;
use tokio::runtime::Runtime;

Expand Down Expand Up @@ -32,6 +32,26 @@ fn get_total_rows(rt: Arc<Runtime>, client: Arc<Client>, query: &CXQuery<String>
.len()
}

#[derive(Presto, Debug)]
pub struct TrinoPartitionQueryResult {
pub _col0: i64,
pub _col1: i64,
}

#[derive(Debug)]
pub struct TrinoDialect {}

// implementation copy from AnsiDialect
impl Dialect for TrinoDialect {
fn is_identifier_start(&self, ch: char) -> bool {
ch.is_ascii_lowercase() || ch.is_ascii_uppercase()
}

fn is_identifier_part(&self, ch: char) -> bool {
ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch.is_ascii_digit() || ch == '_'
}
}

pub struct TrinoSource {
client: Arc<Client>,
rt: Arc<Runtime>,
Expand Down Expand Up @@ -282,12 +302,12 @@ macro_rules! impl_produce_int {
match value {
Value::Number(x) => {
if (x.is_i64()) {
<$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))?
<$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse i64 at position: ({}, {}) {:?}", ridx, cidx, value))?
} else {
throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))
throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x))
}
}
_ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))
_ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value))
}
}
}
Expand All @@ -304,12 +324,12 @@ macro_rules! impl_produce_int {
Value::Null => None,
Value::Number(x) => {
if (x.is_i64()) {
Some(<$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))?)
Some(<$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse i64 at position: ({}, {}) {:?}", ridx, cidx, value))?)
} else {
throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))
throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x))
}
}
_ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))
_ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value))
}
}
}
Expand All @@ -333,10 +353,11 @@ macro_rules! impl_produce_float {
if (x.is_f64()) {
x.as_f64().unwrap() as $t
} else {
throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))
throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x))
}
}
_ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))
Value::String(x) => x.parse::<$t>().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}) {:?}", ridx, cidx, value))?,
_ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value))
}
}
}
Expand All @@ -355,10 +376,11 @@ macro_rules! impl_produce_float {
if (x.is_f64()) {
Some(x.as_f64().unwrap() as $t)
} else {
throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))
throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x))
}
}
_ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))
Value::String(x) => Some(x.parse::<$t>().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}) {:?}", ridx, cidx, value))?),
_ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value))
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion connectorx/src/sources/trino/typesystem.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::errors::TrinoSourceError;
use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
use fehler::{throw, throws};
use prusto::{PrestoFloat, PrestoInt, PrestoTy};
use prusto::{Presto, PrestoFloat, PrestoInt, PrestoTy};
use std::convert::TryFrom;

// TODO: implement Tuple, Row, Array and Map as well as UUID
Expand Down Expand Up @@ -64,6 +64,7 @@ impl TryFrom<PrestoTy> for TrinoTypeSystem {
PrestoTy::Map(_, _) => Varchar(true),
PrestoTy::Decimal(_, _) => Double(true),
PrestoTy::IpAddress => Varchar(true),
PrestoTy::Uuid => Varchar(true),
_ => throw!(TrinoSourceError::InferTypeFromNull),
}
}
Expand Down Expand Up @@ -97,6 +98,7 @@ impl TryFrom<(Option<&str>, PrestoTy)> for TrinoTypeSystem {
"map" => Varchar(true),
"decimal" => Double(true),
"ipaddress" => Varchar(true),
"uuid" => Varchar(true),
_ => TrinoTypeSystem::try_from(ty)?,
}
}
Expand Down

0 comments on commit 6979c65

Please sign in to comment.