Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Implement support for external oauth #46

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions snowflake-api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,24 @@ snowflake-api = "0.7.0"

Check [examples](./examples) for working programs using the library.


```rust
use anyhow::Result;
use snowflake_api::{QueryResult, SnowflakeApi};
use snowflake_api::{QueryResult, AuthArgs, PasswordArgs, AuthType, SnowflakeApi, SnowflakeApiBuilder};

async fn run_query(sql: &str) -> Result<QueryResult> {
let mut api = SnowflakeApi::with_password_auth(

let auth = AuthArgs::new(
"ACCOUNT_IDENTIFIER",
Some("WAREHOUSE"),
Some("DATABASE"),
Some("SCHEMA"),
"USERNAME",
Some("ROLE"),
"PASSWORD",
)?;
AuthType::Password(PasswordArgs { password: "password".to_string() })
);

let mut api: SnowflakeApi = SnowflakeApiBuilder::new(auth)
.build()?;
let res = api.exec(sql).await?;

Ok(res)
Expand All @@ -68,7 +71,7 @@ async fn run_query(sql: &str) -> Result<QueryResult> {
Or using environment variables:

```rust
use anyhow::Result;
use anyhow::Result;
use snowflake_api::{QueryResult, SnowflakeApi};

async fn run_query(sql: &str) -> Result<QueryResult> {
Expand Down
46 changes: 27 additions & 19 deletions snowflake-api/examples/run_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use arrow::util::pretty::pretty_format_batches;
use clap::Parser;
use std::fs;

use snowflake_api::{QueryResult, SnowflakeApi};
use snowflake_api::{
AuthArgs, CertificateArgs, PasswordArgs, QueryResult, SnowflakeApi, SnowflakeApiBuilder,
};

#[derive(clap::ValueEnum, Clone, Debug)]
enum Output {
Expand Down Expand Up @@ -67,25 +69,31 @@ async fn main() -> Result<()> {
let mut api = match (&args.private_key, &args.password) {
(Some(pkey), None) => {
let pem = fs::read_to_string(pkey)?;
SnowflakeApi::with_certificate_auth(
&args.account_identifier,
args.warehouse.as_deref(),
args.database.as_deref(),
args.schema.as_deref(),
&args.username,
args.role.as_deref(),
&pem,
)?
SnowflakeApiBuilder::new(AuthArgs {
account_identifier: args.account_identifier,
warehouse: args.warehouse,
database: args.database,
schema: args.schema,
username: args.username,
role: args.role,
auth_type: snowflake_api::AuthType::Certificate(CertificateArgs {
private_key_pem: pem,
}),
})
.build()?
}
(None, Some(pwd)) => SnowflakeApi::with_password_auth(
&args.account_identifier,
args.warehouse.as_deref(),
args.database.as_deref(),
args.schema.as_deref(),
&args.username,
args.role.as_deref(),
pwd,
)?,
(None, Some(pwd)) => SnowflakeApiBuilder::new(AuthArgs {
account_identifier: args.account_identifier,
warehouse: args.warehouse,
database: args.database,
schema: args.schema,
username: args.username,
role: args.role,
auth_type: snowflake_api::AuthType::Password(PasswordArgs {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having two different internal AuthType enums with slightly different purposes, what if SnowflakeApi was just generic over an AuthType trait, that implements functions for building the login flow? I think that would clean the code up a lot.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SnowflakeApi<OAuth> where OAuth impl AuthType, etc.

password: pwd.to_string(),
}),
})
.build()?,
_ => {
panic!("Either private key path or password must be set")
}
Expand Down
123 changes: 41 additions & 82 deletions snowflake-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use reqwest_middleware::ClientWithMiddleware;
use thiserror::Error;

use responses::ExecResponse;
use session::{AuthError, Session};
use session::{AuthError, Session, SessionBuilder};

use crate::connection::QueryType;
use crate::connection::{Connection, ConnectionError};
Expand Down Expand Up @@ -193,11 +193,33 @@ pub struct AuthArgs {
}

impl AuthArgs {
pub fn new(
account_identifier: &str,
warehouse: Option<&str>,
database: Option<&str>,
schema: Option<&str>,
username: &str,
role: Option<&str>,
auth_type: AuthType,
) -> Self {
Self {
account_identifier: account_identifier.to_string(),
warehouse: warehouse.map(str::to_string),
database: database.map(str::to_string),
schema: schema.map(str::to_string),
username: username.to_string(),
role: role.map(str::to_string),
auth_type,
}
}

pub fn from_env() -> Result<AuthArgs, SnowflakeApiError> {
let auth_type = if let Ok(password) = std::env::var("SNOWFLAKE_PASSWORD") {
Ok(AuthType::Password(PasswordArgs { password }))
} else if let Ok(private_key_pem) = std::env::var("SNOWFLAKE_PRIVATE_KEY") {
Ok(AuthType::Certificate(CertificateArgs { private_key_pem }))
} else if let Ok(token) = std::env::var("SNOWFLAKE_OAUTH_TOKEN") {
Ok(AuthType::OAuth(OAuthArgs { token }))
} else {
Err(MissingEnvArgument(
"SNOWFLAKE_PASSWORD or SNOWFLAKE_PRIVATE_KEY".to_owned(),
Expand All @@ -221,6 +243,7 @@ impl AuthArgs {
pub enum AuthType {
Password(PasswordArgs),
Certificate(CertificateArgs),
OAuth(OAuthArgs),
}

pub struct PasswordArgs {
Expand All @@ -231,6 +254,10 @@ pub struct CertificateArgs {
pub private_key_pem: String,
}

pub struct OAuthArgs {
pub token: String,
}

#[must_use]
pub struct SnowflakeApiBuilder {
pub auth: AuthArgs,
Expand All @@ -253,27 +280,20 @@ impl SnowflakeApiBuilder {
None => Arc::new(Connection::new()?),
};

let session = SessionBuilder::new(&self.auth.account_identifier, &self.auth.username)
.warehouse(self.auth.warehouse.as_deref())
.database(self.auth.database.as_deref())
.schema(self.auth.schema.as_deref())
.role(self.auth.role.as_deref());

let session = match self.auth.auth_type {
AuthType::Password(args) => Session::password_auth(
Arc::clone(&connection),
&self.auth.account_identifier,
self.auth.warehouse.as_deref(),
self.auth.database.as_deref(),
self.auth.schema.as_deref(),
&self.auth.username,
self.auth.role.as_deref(),
&args.password,
),
AuthType::Certificate(args) => Session::cert_auth(
Arc::clone(&connection),
&self.auth.account_identifier,
self.auth.warehouse.as_deref(),
self.auth.database.as_deref(),
self.auth.schema.as_deref(),
&self.auth.username,
self.auth.role.as_deref(),
&args.private_key_pem,
),
AuthType::Password(args) => {
session.build_password(Arc::clone(&connection), &args.password)
}
AuthType::Certificate(args) => {
session.build_cert(Arc::clone(&connection), &args.private_key_pem)
}
AuthType::OAuth(args) => session.build_oauth(Arc::clone(&connection), &args.token),
};

let account_identifier = self.auth.account_identifier.to_uppercase();
Expand Down Expand Up @@ -302,67 +322,6 @@ impl SnowflakeApi {
account_identifier,
}
}
/// Initialize object with password auth. Authentication happens on the first request.
pub fn with_password_auth(
account_identifier: &str,
warehouse: Option<&str>,
database: Option<&str>,
schema: Option<&str>,
username: &str,
role: Option<&str>,
password: &str,
) -> Result<Self, SnowflakeApiError> {
let connection = Arc::new(Connection::new()?);

let session = Session::password_auth(
Arc::clone(&connection),
account_identifier,
warehouse,
database,
schema,
username,
role,
password,
);

let account_identifier = account_identifier.to_uppercase();
Ok(Self::new(
Arc::clone(&connection),
session,
account_identifier,
))
}

/// Initialize object with private certificate auth. Authentication happens on the first request.
pub fn with_certificate_auth(
account_identifier: &str,
warehouse: Option<&str>,
database: Option<&str>,
schema: Option<&str>,
username: &str,
role: Option<&str>,
private_key_pem: &str,
) -> Result<Self, SnowflakeApiError> {
let connection = Arc::new(Connection::new()?);

let session = Session::cert_auth(
Arc::clone(&connection),
account_identifier,
warehouse,
database,
schema,
username,
role,
private_key_pem,
);

let account_identifier = account_identifier.to_uppercase();
Ok(Self::new(
Arc::clone(&connection),
session,
account_identifier,
))
}

pub fn from_env() -> Result<Self, SnowflakeApiError> {
SnowflakeApiBuilder::new(AuthArgs::from_env()?).build()
Expand Down
10 changes: 10 additions & 0 deletions snowflake-api/src/requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub struct LoginRequest<T> {
}

pub type PasswordLoginRequest = LoginRequest<PasswordRequestData>;
pub type OAuthLoginRequest = LoginRequest<OAuthRequestData>;
#[cfg(feature = "cert-auth")]
pub type CertLoginRequest = LoginRequest<CertRequestData>;

Expand Down Expand Up @@ -62,6 +63,15 @@ pub struct CertRequestData {
pub token: String,
}

#[derive(Serialize, Debug)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub struct OAuthRequestData {
#[serde(flatten)]
pub login_request_common: LoginRequestCommon,
pub authenticator: String,
pub token: String,
}

#[derive(Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct RenewSessionRequest {
Expand Down
Loading
Loading