diff --git a/snowflake-api/Cargo.toml b/snowflake-api/Cargo.toml index 3afdaa4..d5c57db 100644 --- a/snowflake-api/Cargo.toml +++ b/snowflake-api/Cargo.toml @@ -21,6 +21,7 @@ polars = ["dep:polars-core", "dep:polars-io"] [dependencies] arrow = "51" async-trait = "0.1" +async-stream = "0.3.5" base64 = "0.22" bytes = "1" futures = "0.3" diff --git a/snowflake-api/src/lib.rs b/snowflake-api/src/lib.rs index 1fa7b36..a1b4d68 100644 --- a/snowflake-api/src/lib.rs +++ b/snowflake-api/src/lib.rs @@ -15,8 +15,11 @@ clippy::missing_panics_doc use std::fmt::{Display, Formatter}; use std::io; +use std::pin::Pin; use std::sync::Arc; +use async_stream::stream; + use arrow::error::ArrowError; use arrow::ipc::reader::StreamReader; use arrow::record_batch::RecordBatch; @@ -27,7 +30,7 @@ use regex::Regex; use reqwest_middleware::ClientWithMiddleware; use thiserror::Error; -use responses::ExecResponse; +use responses::{ExecResponse, QueryExecResponseData}; use session::{AuthError, Session}; use crate::connection::QueryType; @@ -36,6 +39,8 @@ use crate::requests::ExecRequest; use crate::responses::{ExecResponseRowType, SnowflakeType}; use crate::session::AuthError::MissingEnvArgument; +use futures::{future, Stream, StreamExt}; + pub mod connection; #[cfg(feature = "polars")] mod polars; @@ -98,6 +103,8 @@ pub enum SnowflakeApiError { GlobError(#[from] glob::GlobError), } +const MAX_CHUNK_DOWNLOAD_WORKERS: usize = 10; + /// Even if Arrow is specified as a return type non-select queries /// will return Json array of arrays: `[[42, "answer"], [43, "non-answer"]]`. pub struct JsonResult { @@ -144,12 +151,16 @@ pub enum QueryResult { Empty, } +pub type BytesStream = Pin> + Send>>; +pub type RecordBatchStream = Pin> + Send>>; + /// Raw query result /// Can be transformed into [`QueryResult`] pub enum RawQueryResult { /// Arrow IPC chunks /// see: Bytes(Vec), + Stream(BytesStream), /// Json payload is deserialized, /// as it's already a part of REST response Json(JsonResult), @@ -157,11 +168,21 @@ pub enum RawQueryResult { } impl RawQueryResult { - pub fn deserialize_arrow(self) -> Result { + pub async fn deserialize_arrow(self) -> Result { match self { RawQueryResult::Bytes(bytes) => { Self::flat_bytes_to_batches(bytes).map(QueryResult::Arrow) } + RawQueryResult::Stream(bytes_stream) => { + let arrow_records_stream = Self::to_record_batches_stream(bytes_stream); + let arrow_records = arrow_records_stream + .collect::>>() + .await; + + return Ok(QueryResult::Arrow( + arrow_records.into_iter().map(Result::unwrap).collect(), + )); + } RawQueryResult::Json(j) => Ok(QueryResult::Json(j)), RawQueryResult::Empty => Ok(QueryResult::Empty), } @@ -176,6 +197,24 @@ impl RawQueryResult { Ok(res) } + fn to_record_batches_stream(bytes_stream: BytesStream) -> RecordBatchStream { + let batch_stream = bytes_stream.flat_map(|bytes_result| match bytes_result { + Ok(bytes) => match Self::bytes_to_batches(bytes) { + Ok(batches) => futures::stream::iter(batches.into_iter().map(Ok)).boxed(), + Err(e) => futures::stream::once(async move { Err(ArrowError::from(e)) }).boxed(), + }, + Err(e) => futures::stream::once(async move { + Err(ArrowError::ParseError(format!( + "Unable to parse RecordBatch due to error in bytes stream: {}", + e.to_string() + ))) + }) + .boxed(), + }); + + Box::pin(batch_stream) + } + fn bytes_to_batches(bytes: Bytes) -> Result, ArrowError> { let record_batches = StreamReader::try_new_unbuffered(bytes.reader(), None)?; record_batches.into_iter().collect() @@ -380,10 +419,23 @@ impl SnowflakeApi { /// If statement is PUT, then file will be uploaded to the Snowflake-managed storage pub async fn exec(&self, sql: &str) -> Result { let raw = self.exec_raw(sql).await?; - let res = raw.deserialize_arrow()?; + let res = raw.deserialize_arrow().await?; Ok(res) } + // Executes a single query against API and returns a stream of RecordBatches + pub async fn exec_streamed(&self, sql: &str) -> Result { + let raw = self.exec_arrow_raw(sql, true).await?; + match raw { + RawQueryResult::Empty => Ok(Box::pin(futures::stream::empty())), + RawQueryResult::Stream(bytes_stream) => { + let arrow_stream = RawQueryResult::to_record_batches_stream(bytes_stream); + Ok(arrow_stream) + } + _ => Err(SnowflakeApiError::UnexpectedResponse), + } + } + /// Executes a single query against API. /// If statement is PUT, then file will be uploaded to the Snowflake-managed storage /// Returns raw bytes in the Arrow response @@ -395,7 +447,7 @@ impl SnowflakeApi { log::info!("Detected PUT query"); self.exec_put(sql).await.map(|()| RawQueryResult::Empty) } else { - self.exec_arrow_raw(sql).await + self.exec_arrow_raw(sql, false).await } } @@ -429,8 +481,12 @@ impl SnowflakeApi { .await } - async fn exec_arrow_raw(&self, sql: &str) -> Result { - let resp = self + async fn exec_arrow_raw( + &self, + sql: &str, + enable_streaming: bool, + ) -> Result { + let mut resp = self .run_sql::(sql, QueryType::ArrowQuery) .await?; log::debug!("Got query response: {:?}", resp); @@ -459,14 +515,19 @@ impl SnowflakeApi { value, schema: resp.data.rowtype.into_iter().map(Into::into).collect(), })) - } else if let Some(base64) = resp.data.rowset_base64 { - // fixme: is it possible to give streaming interface? + } else if resp.data.rowset_base64.is_some() { + if enable_streaming { + return Ok(self.chunks_to_bytes_stream(&resp.data)); + } + let mut chunks = try_join_all(resp.data.chunks.iter().map(|chunk| { self.connection .get_chunk(&chunk.url, &resp.data.chunk_headers) })) .await?; + let base64 = resp.data.rowset_base64.unwrap_or_default(); + // fixme: should base64 chunk go first? // fixme: if response is chunked is it both base64 + chunks or just chunks? if !base64.is_empty() { @@ -510,4 +571,49 @@ impl SnowflakeApi { Ok(resp) } + + fn chunks_to_bytes_stream(&self, data: &QueryExecResponseData) -> RawQueryResult { + let chunk_urls = data + .chunks + .iter() + .map(|chunk| chunk.url.clone()) + .collect::>(); + let chunk_headers = data.chunk_headers.clone(); + let connection = self.connection.clone(); + let base64 = data.rowset_base64.clone().unwrap_or_default(); + + let stream = stream! { + + let chunks_iter = chunk_urls.chunks(MAX_CHUNK_DOWNLOAD_WORKERS); + + for chunk in chunks_iter { + let futures_batch = chunk.iter().map(|chunk_url| { + let headers = chunk_headers.clone(); + let connection_clone = connection.clone(); + async move { + connection_clone.get_chunk(chunk_url, &headers).await.map_err(SnowflakeApiError::from) + } + }).collect::>(); + + let results = future::join_all(futures_batch).await; + for result in results { + yield result; + } + } + + if !base64.is_empty() { + log::debug!("Got base64 encoded response"); + match base64::engine::general_purpose::STANDARD.decode(&base64) { + Ok(bytes) => { + yield Ok(Bytes::from(bytes)); + } + Err(e) => { + yield Err(SnowflakeApiError::from(e)); + } + } + } + }; + + RawQueryResult::Stream(Box::pin(stream)) + } } diff --git a/snowflake-api/src/polars.rs b/snowflake-api/src/polars.rs index c7243b7..74462c7 100644 --- a/snowflake-api/src/polars.rs +++ b/snowflake-api/src/polars.rs @@ -26,6 +26,7 @@ impl RawQueryResult { RawQueryResult::Bytes(bytes) => dataframe_from_bytes(bytes), RawQueryResult::Json(json) => dataframe_from_json(&json), RawQueryResult::Empty => Ok(DataFrame::empty()), + RawQueryResult::Stream(_) => todo!(), } } }