Skip to content

Commit

Permalink
allow selection of runtimes
Browse files Browse the repository at this point in the history
  • Loading branch information
andrusha committed Aug 18, 2023
1 parent d6489cb commit 24aafd8
Show file tree
Hide file tree
Showing 8 changed files with 312 additions and 30 deletions.
21 changes: 20 additions & 1 deletion snowflake-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ keywords = ["snowflake", "database", "api"]
categories = ["database", "api-bindings"]
readme = "README.md"
license = "Apache-2.0"
exclude = ["Cargo.lock"]

[features]
default = ["tokio"]
tokio = ["dep:tokio"]
naive-runtime = ["futures-executor", "futures-channel", "futures-util"]
async-std = ["dep:async-std"]

[dependencies]
thiserror = "1"
Expand All @@ -25,10 +32,22 @@ base64 = "0.21"
regex = "1"
object_store = { version = "0.6", features = ["aws"] }
async-trait = "0.1"
retry-policies = "0.2"
tokio = { version = "1", features = ["time"], optional = true }
futures-executor = { version = "0.3", optional = true }
futures-channel = { version = "0.3", optional = true }
futures-util = { version = "0.3", default-features = false, optional = true }
async-std = { version = "1.9", optional = true }

[dev-dependencies]
anyhow = "1"
pretty_env_logger = "0.5.0"
clap = { version = "4", features = ["derive"] }
arrow = { version = "42", features = ["prettyprint"] }
tokio = { version = "1", features=["macros", "rt-multi-thread"] }
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
smol = "1"
futures = "0.3"

[package.metadata.docs.rs]
features = ["naive-runtime", "async-std", "tokio"]
rustdoc-args = ["--cfg", "docsrs"]
4 changes: 2 additions & 2 deletions snowflake-api/examples/filetransfer.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use anyhow::Result;
use arrow::util::pretty::pretty_format_batches;
use clap::Parser;
use snowflake_api::{QueryResult, SnowflakeApi};
use snowflake_api::{QueryResult, SnowflakeApi, DefaultRuntime};
use std::fs;

extern crate snowflake_api;
Expand Down Expand Up @@ -51,7 +51,7 @@ async fn main() -> Result<()> {

let args = Args::parse();

let mut api = match (&args.private_key, &args.password) {
let mut api: SnowflakeApi<DefaultRuntime> = match (&args.private_key, &args.password) {
(Some(pkey), None) => {
let pem = fs::read(pkey)?;
SnowflakeApi::with_certificate_auth(
Expand Down
4 changes: 2 additions & 2 deletions snowflake-api/examples/run_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use arrow::util::pretty::pretty_format_batches;
use clap::Parser;
use std::fs;

use snowflake_api::{QueryResult, SnowflakeApi};
use snowflake_api::{QueryResult, SnowflakeApi, DefaultRuntime};

#[derive(clap::ValueEnum, Clone, Debug)]
enum Output {
Expand Down Expand Up @@ -64,7 +64,7 @@ async fn main() -> Result<()> {

let args = Args::parse();

let mut api = match (&args.private_key, &args.password) {
let mut api: SnowflakeApi<DefaultRuntime> = match (&args.private_key, &args.password) {
(Some(pkey), None) => {
let pem = fs::read(pkey)?;
SnowflakeApi::with_certificate_auth(
Expand Down
86 changes: 86 additions & 0 deletions snowflake-api/examples/runtime_smol.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
extern crate snowflake_api;

use std::process;
use std::time::{Duration, Instant};

use anyhow::Result;
use clap::Parser;
use futures::future::{self, FutureExt};

use snowflake_api::{AsyncRuntime, SnowflakeApi};

pub struct SmolRuntime;

impl AsyncRuntime for SmolRuntime {
type Delay = future::Map<smol::Timer, fn(Instant)>;

fn delay_for(duration: Duration) -> Self::Delay {
FutureExt::map(smol::Timer::after(duration), |_| ())
}
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(long)]
password: String,

/// <account_identifier> in Snowflake format, uppercase
#[arg(short, long)]
account_identifier: String,

/// Database name
#[arg(short, long)]
database: String,

/// Schema name
#[arg(long)]
schema: String,

/// Warehouse
#[arg(short, long)]
warehouse: String,

/// username to whom the private key belongs to
#[arg(short, long)]
username: String,

/// role which user will assume
#[arg(short, long)]
role: String,

/// sql statement to execute and print result from
#[arg(long)]
sql: String,
}

fn main() -> Result<()> {
pretty_env_logger::init();

let args = Args::parse();

let mut api: SnowflakeApi<SmolRuntime> = SnowflakeApi::with_password_auth(
&args.account_identifier,
&args.warehouse,
Some(&args.database),
Some(&args.schema),
&args.username,
Some(&args.role),
&args.password,
)?;

smol::block_on(async {
let res = api.exec_json(&args.sql).await;
match res {
Ok(r) => {
println!("{}", r.to_string());
}
Err(e) => {
eprintln!("Error querying API: {:?}", e);
process::exit(1);
}
}
});

Ok(())
}
24 changes: 17 additions & 7 deletions snowflake-api/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use std::marker::PhantomData;

use reqwest::{Client, ClientBuilder, header};
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::{header, Client, ClientBuilder};
use thiserror::Error;
use url::Url;
use uuid::Uuid;

use crate::{AsyncRuntime, DefaultRuntime};

#[derive(Error, Debug)]
pub enum ConnectionError {
#[error(transparent)]
Expand Down Expand Up @@ -53,12 +57,15 @@ impl QueryType {

/// Connection pool
/// Minimal session will have at least 2 requests - login and query
pub struct Connection {
pub struct Connection<R = DefaultRuntime> {
// no need for Arc as it's already inside the reqwest client
client: Client,
_runtime: PhantomData<R>,
}

impl Connection {
impl<R> Connection<R>
where
R: AsyncRuntime {
pub fn new() -> Result<Self, ConnectionError> {
// use builder to fail safely, unlike client new
let client = ClientBuilder::new()
Expand All @@ -68,21 +75,24 @@ impl Connection {
.connection_verbose(true)
.build()?;

Ok(Connection { client })
Ok(Connection {
client,
_runtime: PhantomData,
})
}

/// Perform request of given query type with extra body or parameters
// todo: implement retry logic
// todo: implement soft error handling
// todo: is there better way to not repeat myself?
pub async fn request<R: serde::de::DeserializeOwned>(
pub async fn request<DS: serde::de::DeserializeOwned>(
&self,
query_type: QueryType,
account_identifier: &str,
extra_get_params: &[(&str, &str)],
auth: Option<&str>,
body: impl serde::Serialize,
) -> Result<R, ConnectionError> {
) -> Result<DS, ConnectionError> {
let context = query_type.query_context();

// todo: increment subsequent request ids (on retry?)
Expand Down Expand Up @@ -128,6 +138,6 @@ impl Connection {
.send()
.await?;

Ok(resp.json::<R>().await?)
Ok(resp.json::<DS>().await?)
}
}
27 changes: 16 additions & 11 deletions snowflake-api/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#![doc(
issue_tracker_base_url = "https://github.com/mycelial/snowflake-rs/issues",
test(no_crate_inject)
issue_tracker_base_url = "https://github.com/mycelial/snowflake-rs/issues",
test(no_crate_inject)
)]
#![doc = include_str ! ("../README.md")]
#![doc = include_str!("../README.md")]
#![cfg_attr(docsrs, feature(doc_cfg))]

use std::io;
use std::path::Path;
Expand All @@ -18,18 +19,20 @@ use object_store::ObjectStore;
use regex::Regex;
use thiserror::Error;

use crate::connection::{Connection, ConnectionError};
use put_response::{PutResponse, S3PutResponse};
use query_response::QueryResponse;
pub use runtime::{AsyncRuntime, DefaultRuntime};
use session::{AuthError, Session};

use crate::connection::{Connection, ConnectionError};
use crate::connection::QueryType;

mod auth_response;
mod connection;
mod error_response;
mod put_response;
mod query_response;
mod runtime;
mod session;

#[derive(Error, Debug)]
Expand Down Expand Up @@ -81,14 +84,16 @@ pub enum QueryResult {
}

/// Snowflake API, keeps connection pool and manages session for you
pub struct SnowflakeApi {
connection: Arc<Connection>,
session: Session,
pub struct SnowflakeApi<R = DefaultRuntime> {
connection: Arc<Connection<R>>,
session: Session<R>,
account_identifier: String,
sequence_id: u64,
}

impl SnowflakeApi {
impl<R> SnowflakeApi<R>
where
R: AsyncRuntime {
/// Initialize object with password auth. Authentication happens on the first request.
pub fn with_password_auth(
account_identifier: &str,
Expand Down Expand Up @@ -271,11 +276,11 @@ impl SnowflakeApi {
}
}

async fn run_sql<R: serde::de::DeserializeOwned>(
async fn run_sql<DS: serde::de::DeserializeOwned>(
&mut self,
sql: &str,
query_type: QueryType,
) -> Result<R, SnowflakeApiError> {
) -> Result<DS, SnowflakeApiError> {
log::debug!("Executing: {}", sql);

let token = self.session.get_token().await?;
Expand All @@ -293,7 +298,7 @@ impl SnowflakeApi {

let resp = self
.connection
.request::<R>(query_type, &self.account_identifier, &[], Some(&auth), body)
.request::<DS>(query_type, &self.account_identifier, &[], Some(&auth), body)
.await?;

Ok(resp)
Expand Down
Loading

0 comments on commit 24aafd8

Please sign in to comment.