Skip to content

Commit

Permalink
Add support for async query response
Browse files Browse the repository at this point in the history
  • Loading branch information
sgrebnov committed May 6, 2024
1 parent 38dcb9e commit c8c6e73
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 19 deletions.
51 changes: 35 additions & 16 deletions snowflake-api/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ pub enum ConnectionError {
/// Container for query parameters
/// This API has different endpoints and MIME types for different requests
struct QueryContext {
path: &'static str,
path: String,
accept_mime: &'static str,
method: reqwest::Method
}

pub enum QueryType {
Expand All @@ -39,30 +40,40 @@ pub enum QueryType {
CloseSession,
JsonQuery,
ArrowQuery,
ArrowQueryResult(String),
}

impl QueryType {
const fn query_context(&self) -> QueryContext {
fn query_context(&self) -> QueryContext {
match self {
Self::LoginRequest => QueryContext {
path: "session/v1/login-request",
path: "session/v1/login-request".to_string(),
accept_mime: "application/json",
method: reqwest::Method::POST,
},
Self::TokenRequest => QueryContext {
path: "/session/token-request",
path: "/session/token-request".to_string(),
accept_mime: "application/snowflake",
method: reqwest::Method::POST,
},
Self::CloseSession => QueryContext {
path: "session",
path: "session".to_string(),
accept_mime: "application/snowflake",
method: reqwest::Method::POST,
},
Self::JsonQuery => QueryContext {
path: "queries/v1/query-request",
path: "queries/v1/query-request".to_string(),
accept_mime: "application/json",
method: reqwest::Method::POST,
},
Self::ArrowQuery => QueryContext {
path: "queries/v1/query-request",
path: "queries/v1/query-request".to_string(),
accept_mime: "application/snowflake",
method: reqwest::Method::POST,
},
Self::ArrowQueryResult(query_result_url) => QueryContext {
path: query_result_url.to_string(),
accept_mime: "application/snowflake",
method: reqwest::Method::GET,
},
}
}
Expand Down Expand Up @@ -163,14 +174,22 @@ impl Connection {
}

// todo: persist client to use connection polling
let resp = self
.client
.post(url)
.headers(headers)
.json(&body)
.send()
.await?;

let resp = match context.method {
reqwest::Method::POST => self
.client
.post(url)
.headers(headers)
.json(&body)
.send()
.await?,
reqwest::Method::GET => self
.client
.get(url)
.headers(headers)
.send()
.await?,
_ => panic!("Unsupported method"),
};
Ok(resp.json::<R>().await?)
}

Expand Down
40 changes: 38 additions & 2 deletions snowflake-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ impl SnowflakeApi {

match resp {
ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::QueryAsync(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::PutGet(pg) => put::put(pg).await,
ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
e.data.error_code,
Expand All @@ -430,14 +431,21 @@ impl SnowflakeApi {
}

async fn exec_arrow_raw(&self, sql: &str) -> Result<RawQueryResult, SnowflakeApiError> {
let resp = self
let mut resp = self
.run_sql::<ExecResponse>(sql, QueryType::ArrowQuery)
.await?;
log::debug!("Got query response: {:?}", resp);

if let ExecResponse::QueryAsync(data) = &resp {
log::debug!("Got async exec response");
resp = self.get_async_exec_result(&data.data.get_result_url).await?;
log::debug!("Got result for async exec: {:?}", resp);
}

let resp = match resp {
// processable response
ExecResponse::Query(qr) => Ok(qr),
ExecResponse::QueryAsync(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::PutGet(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
e.data.error_code,
Expand Down Expand Up @@ -504,10 +512,38 @@ impl SnowflakeApi {
&self.account_identifier,
&[],
Some(&parts.session_token_auth_header),
body,
Some(body),
)
.await?;

Ok(resp)
}

pub async fn get_async_exec_result(&self, query_result_url: &String) -> Result<ExecResponse, SnowflakeApiError>{
log::debug!("Getting async exec result: {}", query_result_url);

let mut delay = 1; // Initial delay of 1 second

loop {
let parts = self.session.get_token().await?;
let resp = self
.connection
.request::<ExecResponse>(
QueryType::ArrowQueryResult(query_result_url.to_string()),
&self.account_identifier,
&[],
Some(&parts.session_token_auth_header),
serde_json::Value::default()
)
.await?;

if let ExecResponse::QueryAsync(_) = &resp {
// simple exponential retry with a maximum wait time of 5 seconds
tokio::time::sleep(tokio::time::Duration::from_secs(delay)).await;
delay = (delay * 2).min(5); // cap delay to 5 seconds
} else {
return Ok(resp);
}
};
}
}
11 changes: 10 additions & 1 deletion snowflake-api/src/responses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use serde::Deserialize;
#[serde(untagged)]
pub enum ExecResponse {
Query(QueryExecResponse),
QueryAsync(QueryAsyncExecResponse),
PutGet(PutGetExecResponse),
Error(ExecErrorResponse),
}
Expand Down Expand Up @@ -34,6 +35,7 @@ pub struct BaseRestResponse<D> {

pub type PutGetExecResponse = BaseRestResponse<PutGetResponseData>;
pub type QueryExecResponse = BaseRestResponse<QueryExecResponseData>;
pub type QueryAsyncExecResponse = BaseRestResponse<QueryAsyncExecResponseData>;
pub type ExecErrorResponse = BaseRestResponse<ExecErrorResponseData>;
pub type AuthErrorResponse = BaseRestResponse<AuthErrorResponseData>;
pub type AuthenticatorResponse = BaseRestResponse<AuthenticatorResponseData>;
Expand All @@ -54,7 +56,7 @@ pub struct ExecErrorResponseData {
pub pos: Option<i64>,

// fixme: only valid for exec query response error? present in any exec query response?
pub query_id: String,
pub query_id: Option<String>,
pub sql_state: String,
}

Expand Down Expand Up @@ -151,6 +153,13 @@ pub struct QueryExecResponseData {
// `sendResultTime`, `queryResultFormat`, `queryContext` also exist
}

#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct QueryAsyncExecResponseData {
pub query_id: String,
pub get_result_url: String,
}

#[derive(Deserialize, Debug)]
pub struct ExecResponseRowType {
pub name: String,
Expand Down

0 comments on commit c8c6e73

Please sign in to comment.