diff --git a/snowflake-api/Cargo.toml b/snowflake-api/Cargo.toml index 0f8193a..8cf96f8 100644 --- a/snowflake-api/Cargo.toml +++ b/snowflake-api/Cargo.toml @@ -10,6 +10,13 @@ keywords = ["snowflake", "database", "api"] categories = ["database", "api-bindings"] readme = "README.md" license = "Apache-2.0" +exclude = ["Cargo.lock"] + +[features] +default = ["tokio"] +tokio = ["dep:tokio"] +naive-runtime = ["futures-executor", "futures-channel", "futures-util"] +async-std = ["dep:async-std"] [dependencies] thiserror = "1" @@ -25,10 +32,22 @@ base64 = "0.21" regex = "1" object_store = { version = "0.6", features = ["aws"] } async-trait = "0.1" +retry-policies = "0.2" +tokio = { version = "1", features = ["time"], optional = true } +futures-executor = { version = "0.3", optional = true } +futures-channel = { version = "0.3", optional = true } +futures-util = { version = "0.3", default-features = false, optional = true } +async-std = { version = "1.9", optional = true } [dev-dependencies] anyhow = "1" pretty_env_logger = "0.5.0" clap = { version = "4", features = ["derive"] } arrow = { version = "42", features = ["prettyprint"] } -tokio = { version = "1", features=["macros", "rt-multi-thread"] } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } +smol = "1" +futures = "0.3" + +[package.metadata.docs.rs] +features = ["naive-runtime", "async-std", "tokio"] +rustdoc-args = ["--cfg", "docsrs"] diff --git a/snowflake-api/examples/filetransfer.rs b/snowflake-api/examples/filetransfer.rs index 49f5b3c..46b9b83 100644 --- a/snowflake-api/examples/filetransfer.rs +++ b/snowflake-api/examples/filetransfer.rs @@ -1,7 +1,7 @@ use anyhow::Result; use arrow::util::pretty::pretty_format_batches; use clap::Parser; -use snowflake_api::{QueryResult, SnowflakeApi}; +use snowflake_api::{QueryResult, SnowflakeApi, DefaultRuntime}; use std::fs; extern crate snowflake_api; @@ -51,7 +51,7 @@ async fn main() -> Result<()> { let args = Args::parse(); - let mut api = match (&args.private_key, &args.password) { + let mut api: SnowflakeApi = match (&args.private_key, &args.password) { (Some(pkey), None) => { let pem = fs::read(pkey)?; SnowflakeApi::with_certificate_auth( diff --git a/snowflake-api/examples/run_sql.rs b/snowflake-api/examples/run_sql.rs index c0aab66..d7c37f7 100644 --- a/snowflake-api/examples/run_sql.rs +++ b/snowflake-api/examples/run_sql.rs @@ -5,7 +5,7 @@ use arrow::util::pretty::pretty_format_batches; use clap::Parser; use std::fs; -use snowflake_api::{QueryResult, SnowflakeApi}; +use snowflake_api::{QueryResult, SnowflakeApi, DefaultRuntime}; #[derive(clap::ValueEnum, Clone, Debug)] enum Output { @@ -64,7 +64,7 @@ async fn main() -> Result<()> { let args = Args::parse(); - let mut api = match (&args.private_key, &args.password) { + let mut api: SnowflakeApi = match (&args.private_key, &args.password) { (Some(pkey), None) => { let pem = fs::read(pkey)?; SnowflakeApi::with_certificate_auth( diff --git a/snowflake-api/examples/runtime_smol.rs b/snowflake-api/examples/runtime_smol.rs new file mode 100644 index 0000000..79134a4 --- /dev/null +++ b/snowflake-api/examples/runtime_smol.rs @@ -0,0 +1,86 @@ +extern crate snowflake_api; + +use std::process; +use std::time::{Duration, Instant}; + +use anyhow::Result; +use clap::Parser; +use futures::future::{self, FutureExt}; + +use snowflake_api::{AsyncRuntime, SnowflakeApi}; + +pub struct SmolRuntime; + +impl AsyncRuntime for SmolRuntime { + type Delay = future::Map; + + fn delay_for(duration: Duration) -> Self::Delay { + FutureExt::map(smol::Timer::after(duration), |_| ()) + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + #[arg(long)] + password: String, + + /// in Snowflake format, uppercase + #[arg(short, long)] + account_identifier: String, + + /// Database name + #[arg(short, long)] + database: String, + + /// Schema name + #[arg(long)] + schema: String, + + /// Warehouse + #[arg(short, long)] + warehouse: String, + + /// username to whom the private key belongs to + #[arg(short, long)] + username: String, + + /// role which user will assume + #[arg(short, long)] + role: String, + + /// sql statement to execute and print result from + #[arg(long)] + sql: String, +} + +fn main() -> Result<()> { + pretty_env_logger::init(); + + let args = Args::parse(); + + let mut api: SnowflakeApi = SnowflakeApi::with_password_auth( + &args.account_identifier, + &args.warehouse, + Some(&args.database), + Some(&args.schema), + &args.username, + Some(&args.role), + &args.password, + )?; + + smol::block_on(async { + let res = api.exec_json(&args.sql).await; + match res { + Ok(r) => { + println!("{}", r.to_string()); + } + Err(e) => { + eprintln!("Error querying API: {:?}", e); + process::exit(1); + } + } + }); + + Ok(()) +} diff --git a/snowflake-api/src/connection.rs b/snowflake-api/src/connection.rs index 39ec913..d67daa7 100644 --- a/snowflake-api/src/connection.rs +++ b/snowflake-api/src/connection.rs @@ -1,9 +1,13 @@ +use std::marker::PhantomData; + +use reqwest::{Client, ClientBuilder, header}; use reqwest::header::{HeaderMap, HeaderValue}; -use reqwest::{header, Client, ClientBuilder}; use thiserror::Error; use url::Url; use uuid::Uuid; +use crate::{AsyncRuntime, DefaultRuntime}; + #[derive(Error, Debug)] pub enum ConnectionError { #[error(transparent)] @@ -53,12 +57,15 @@ impl QueryType { /// Connection pool /// Minimal session will have at least 2 requests - login and query -pub struct Connection { +pub struct Connection { // no need for Arc as it's already inside the reqwest client client: Client, + _runtime: PhantomData, } -impl Connection { +impl Connection + where + R: AsyncRuntime { pub fn new() -> Result { // use builder to fail safely, unlike client new let client = ClientBuilder::new() @@ -68,21 +75,24 @@ impl Connection { .connection_verbose(true) .build()?; - Ok(Connection { client }) + Ok(Connection { + client, + _runtime: PhantomData, + }) } /// Perform request of given query type with extra body or parameters // todo: implement retry logic // todo: implement soft error handling // todo: is there better way to not repeat myself? - pub async fn request( + pub async fn request( &self, query_type: QueryType, account_identifier: &str, extra_get_params: &[(&str, &str)], auth: Option<&str>, body: impl serde::Serialize, - ) -> Result { + ) -> Result { let context = query_type.query_context(); // todo: increment subsequent request ids (on retry?) @@ -128,6 +138,6 @@ impl Connection { .send() .await?; - Ok(resp.json::().await?) + Ok(resp.json::().await?) } } diff --git a/snowflake-api/src/lib.rs b/snowflake-api/src/lib.rs index ff18aeb..f2f6ee0 100644 --- a/snowflake-api/src/lib.rs +++ b/snowflake-api/src/lib.rs @@ -1,8 +1,9 @@ #![doc( - issue_tracker_base_url = "https://github.com/mycelial/snowflake-rs/issues", - test(no_crate_inject) +issue_tracker_base_url = "https://github.com/mycelial/snowflake-rs/issues", +test(no_crate_inject) )] -#![doc = include_str ! ("../README.md")] +#![doc = include_str!("../README.md")] +#![cfg_attr(docsrs, feature(doc_cfg))] use std::io; use std::path::Path; @@ -18,11 +19,12 @@ use object_store::ObjectStore; use regex::Regex; use thiserror::Error; -use crate::connection::{Connection, ConnectionError}; use put_response::{PutResponse, S3PutResponse}; use query_response::QueryResponse; +pub use runtime::{AsyncRuntime, DefaultRuntime}; use session::{AuthError, Session}; +use crate::connection::{Connection, ConnectionError}; use crate::connection::QueryType; mod auth_response; @@ -30,6 +32,7 @@ mod connection; mod error_response; mod put_response; mod query_response; +mod runtime; mod session; #[derive(Error, Debug)] @@ -81,14 +84,16 @@ pub enum QueryResult { } /// Snowflake API, keeps connection pool and manages session for you -pub struct SnowflakeApi { - connection: Arc, - session: Session, +pub struct SnowflakeApi { + connection: Arc>, + session: Session, account_identifier: String, sequence_id: u64, } -impl SnowflakeApi { +impl SnowflakeApi + where + R: AsyncRuntime { /// Initialize object with password auth. Authentication happens on the first request. pub fn with_password_auth( account_identifier: &str, @@ -271,11 +276,11 @@ impl SnowflakeApi { } } - async fn run_sql( + async fn run_sql( &mut self, sql: &str, query_type: QueryType, - ) -> Result { + ) -> Result { log::debug!("Executing: {}", sql); let token = self.session.get_token().await?; @@ -293,7 +298,7 @@ impl SnowflakeApi { let resp = self .connection - .request::(query_type, &self.account_identifier, &[], Some(&auth), body) + .request::(query_type, &self.account_identifier, &[], Some(&auth), body) .await?; Ok(resp) diff --git a/snowflake-api/src/runtime.rs b/snowflake-api/src/runtime.rs new file mode 100644 index 0000000..4a821d3 --- /dev/null +++ b/snowflake-api/src/runtime.rs @@ -0,0 +1,158 @@ +//! Taken pretty much verbatim from the +//! [rdkafka](https://github.com/fede1024/rust-rdkafka/blob/8afecbc5ab2c775b8928f6129bbd38777eed11d7/src/util.rs) +//! utils crate. Removes the support for `spawn` and adds ` [async-std] implementation as well. + +use std::future::Future; +#[cfg(feature = "async-std")] +use std::pin::Pin; +#[cfg(feature = "naive-runtime")] +use std::thread; +use std::time::Duration; + +#[cfg(feature = "naive-runtime")] +use futures_channel::oneshot; +#[cfg(feature = "naive-runtime")] +use futures_util::future::{FutureExt, Map}; + +/// An abstraction over asynchronous runtimes. +/// +/// There are several asynchronous runtimes available for Rust. By default +/// snowflake-api uses Tokio, via the [`TokioRuntime`], but it has pluggable +/// support for any runtime that can satisfy this trait. +/// +/// For an example of using the [smol] runtime, see the [runtime_smol] example. +/// +/// [async-std]: https://docs.rs/async-std +/// [tokio]: https://docs.rs/tokio +/// [futures_executor]: https://docs.rs/futures_executor +/// [runtime_smol]: https://github.com/mycelial/snowflake-rs/blob/main/snowflake-api/examples/runtime_smol.rs +pub trait AsyncRuntime: Send + Sync + 'static { + /// The type of the future returned by + /// [`delay_for`](AsyncRuntime::delay_for). + type Delay: Future + Send; + + /// Constructs a future that will resolve after `duration` has elapsed. + fn delay_for(duration: Duration) -> Self::Delay; +} + +/// The default [`AsyncRuntime`] used when one is not explicitly specified. +/// +/// This is defined to be the [`TokioRuntime`] when the `tokio` feature is +/// enabled, or the [`NaiveRuntime`] if the `naive-runtime` feature is enabled, +/// or [`AsyncStdRuntime`] if the `async-std` feature is enabled. +/// +/// If neither of the features are enabled, this is +/// defined to be `()`, which is not a valid `AsyncRuntime` and will cause +/// compilation errors if used as one. You will need to explicitly specify a +/// custom async runtime wherever one is required. +#[cfg(not(any(feature = "tokio", feature = "naive-runtime", feature = "async-std")))] +pub type DefaultRuntime = (); + +/// The default [`AsyncRuntime`] used when one is not explicitly specified. +/// +/// This is defined to be the [`TokioRuntime`] when the `tokio` feature is +/// enabled, or the [`NaiveRuntime`] if the `naive-runtime` feature is enabled, +/// or [`AsyncStdRuntime`] if the `async-std` feature is enabled. +/// +/// If neither of the features are enabled, this is +/// defined to be `()`, which is not a valid `AsyncRuntime` and will cause +/// compilation errors if used as one. You will need to explicitly specify a +/// custom async runtime wherever one is required. +#[cfg(all( + not(feature = "tokio"), + not(feature = "async-std"), + feature = "naive-runtime" +))] +pub type DefaultRuntime = NaiveRuntime; + +/// The default [`AsyncRuntime`] used when one is not explicitly specified. +/// +/// This is defined to be the [`TokioRuntime`] when the `tokio` feature is +/// enabled, or the [`NaiveRuntime`] if the `naive-runtime` feature is enabled, +/// or [`AsyncStdRuntime`] if the `async-std` feature is enabled. +/// +/// If neither of the features are enabled, this is +/// defined to be `()`, which is not a valid `AsyncRuntime` and will cause +/// compilation errors if used as one. You will need to explicitly specify a +/// custom async runtime wherever one is required. +#[cfg(all( + feature = "tokio", + not(feature = "async-std"), + not(feature = "naive-runtime") +))] +pub type DefaultRuntime = TokioRuntime; + +/// The default [`AsyncRuntime`] used when one is not explicitly specified. +/// +/// This is defined to be the [`TokioRuntime`] when the `tokio` feature is +/// enabled, or the [`NaiveRuntime`] if the `naive-runtime` feature is enabled, +/// or [`AsyncStdRuntime`] if the `async-std` feature is enabled. +/// +/// If neither of the features are enabled, this is +/// defined to be `()`, which is not a valid `AsyncRuntime` and will cause +/// compilation errors if used as one. You will need to explicitly specify a +/// custom async runtime wherever one is required. +#[cfg(all( + not(feature = "tokio"), + feature = "async-std", + not(feature = "naive-runtime") +))] +pub type DefaultRuntime = AsyncStdRuntime; + +/// An [`AsyncRuntime`] implementation backed by the executor in the +/// [`futures_executor`](futures_executor) crate. +/// +/// This runtime should not be used when performance is a concern, as it makes +/// heavy use of threads to compensate for the lack of a timer in the futures +/// executor. +#[cfg(feature = "naive-runtime")] +#[cfg_attr(docsrs, doc(cfg(feature = "naive-runtime")))] +pub struct NaiveRuntime; + +#[cfg(feature = "naive-runtime")] +#[cfg_attr(docsrs, doc(cfg(feature = "naive-runtime")))] +impl AsyncRuntime for NaiveRuntime { + type Delay = Map, fn(Result<(), oneshot::Canceled>)>; + + fn delay_for(duration: Duration) -> Self::Delay { + let (tx, rx) = oneshot::channel(); + thread::spawn(move || { + thread::sleep(duration); + tx.send(()) + }); + rx.map(|_| ()) + } +} + +/// An [`AsyncRuntime`] implementation backed by [Tokio](tokio). +/// +/// This runtime is used by default throughout the crate, unless the `tokio` +/// feature is disabled. +#[cfg(feature = "tokio")] +#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))] +pub struct TokioRuntime; + +#[cfg(feature = "tokio")] +#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))] +impl AsyncRuntime for TokioRuntime { + type Delay = tokio::time::Sleep; + + fn delay_for(duration: Duration) -> Self::Delay { + tokio::time::sleep(duration) + } +} + +/// An [`AsyncRuntime`] implementation backed by [async-std]. +#[cfg(feature = "async-std")] +#[cfg_attr(docsrs, doc(cfg(feature = "async-std")))] +pub struct AsyncStdRuntime; + +#[cfg(feature = "async-std")] +#[cfg_attr(docsrs, doc(cfg(feature = "async-std")))] +impl AsyncRuntime for AsyncStdRuntime { + type Delay = Pin + Send>>; + + fn delay_for(duration: Duration) -> Self::Delay { + Box::pin(async_std::task::sleep(duration)) + } +} diff --git a/snowflake-api/src/session.rs b/snowflake-api/src/session.rs index d342249..adab9f4 100644 --- a/snowflake-api/src/session.rs +++ b/snowflake-api/src/session.rs @@ -1,11 +1,13 @@ -use snowflake_jwt::generate_jwt_token; use std::sync::Arc; use std::time::{Duration, Instant}; + +use snowflake_jwt::generate_jwt_token; use thiserror::Error; +use crate::{AsyncRuntime, connection}; use crate::auth_response::AuthResponse; -use crate::connection; use crate::connection::{Connection, QueryType}; +use crate::DefaultRuntime; #[derive(Error, Debug)] pub enum AuthError { @@ -65,8 +67,8 @@ enum AuthType { /// the configuration state and temporary objects (tables, procedures, etc). // todo: split warehouse-database-schema and username-role-key into its own structs // todo: close session after object is dropped -pub struct Session { - connection: Arc, +pub struct Session { + connection: Arc>, auth_token_cached: Option, auth_type: AuthType, @@ -83,12 +85,14 @@ pub struct Session { } // todo: make builder -impl Session { +impl Session + where + R: AsyncRuntime { /// Authenticate using private certificate and JWT // fixme: add builder or introduce structs #[allow(clippy::too_many_arguments)] pub fn cert_auth( - connection: Arc, + connection: Arc>, account_identifier: &str, warehouse: &str, database: Option<&str>, @@ -127,7 +131,7 @@ impl Session { // fixme: add builder or introduce structs #[allow(clippy::too_many_arguments)] pub fn password_auth( - connection: Arc, + connection: Arc>, account_identifier: &str, warehouse: &str, database: Option<&str>,