Skip to content

Commit

Permalink
Add streaming support for arrow batches
Browse files Browse the repository at this point in the history
  • Loading branch information
sgrebnov committed May 8, 2024
1 parent 38dcb9e commit d77e97b
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 8 deletions.
1 change: 1 addition & 0 deletions snowflake-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ polars = ["dep:polars-core", "dep:polars-io"]
[dependencies]
arrow = "51"
async-trait = "0.1"
async-stream = "0.3.5"
base64 = "0.22"
bytes = "1"
futures = "0.3"
Expand Down
122 changes: 114 additions & 8 deletions snowflake-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ clippy::missing_panics_doc

use std::fmt::{Display, Formatter};
use std::io;
use std::pin::Pin;
use std::sync::Arc;

use async_stream::stream;

use arrow::error::ArrowError;
use arrow::ipc::reader::StreamReader;
use arrow::record_batch::RecordBatch;
Expand All @@ -27,7 +30,7 @@ use regex::Regex;
use reqwest_middleware::ClientWithMiddleware;
use thiserror::Error;

use responses::ExecResponse;
use responses::{ExecResponse, QueryExecResponseData};
use session::{AuthError, Session};

use crate::connection::QueryType;
Expand All @@ -36,6 +39,8 @@ use crate::requests::ExecRequest;
use crate::responses::{ExecResponseRowType, SnowflakeType};
use crate::session::AuthError::MissingEnvArgument;

use futures::{future, Stream, StreamExt};

pub mod connection;
#[cfg(feature = "polars")]
mod polars;
Expand Down Expand Up @@ -98,6 +103,8 @@ pub enum SnowflakeApiError {
GlobError(#[from] glob::GlobError),
}

const MAX_CHUNK_DOWNLOAD_WORKERS: usize = 10;

/// Even if Arrow is specified as a return type non-select queries
/// will return Json array of arrays: `[[42, "answer"], [43, "non-answer"]]`.
pub struct JsonResult {
Expand Down Expand Up @@ -144,24 +151,38 @@ pub enum QueryResult {
Empty,
}

pub type BytesStream = Pin<Box<dyn Stream<Item = Result<bytes::Bytes, SnowflakeApiError>> + Send>>;
pub type RecordBatchStream = Pin<Box<dyn Stream<Item = Result<RecordBatch, ArrowError>> + Send>>;

/// Raw query result
/// Can be transformed into [`QueryResult`]
pub enum RawQueryResult {
/// Arrow IPC chunks
/// see: <https://arrow.apache.org/docs/format/Columnar.html#serialization-and-interprocess-communication-ipc>
Bytes(Vec<Bytes>),
Stream(BytesStream),
/// Json payload is deserialized,
/// as it's already a part of REST response
Json(JsonResult),
Empty,
}

impl RawQueryResult {
pub fn deserialize_arrow(self) -> Result<QueryResult, ArrowError> {
pub async fn deserialize_arrow(self) -> Result<QueryResult, ArrowError> {
match self {
RawQueryResult::Bytes(bytes) => {
Self::flat_bytes_to_batches(bytes).map(QueryResult::Arrow)
}
RawQueryResult::Stream(bytes_stream) => {
let arrow_records_stream = Self::to_record_batches_stream(bytes_stream);
let arrow_records = arrow_records_stream
.collect::<Vec<Result<RecordBatch, ArrowError>>>()
.await;

return Ok(QueryResult::Arrow(
arrow_records.into_iter().map(Result::unwrap).collect(),
));
}
RawQueryResult::Json(j) => Ok(QueryResult::Json(j)),
RawQueryResult::Empty => Ok(QueryResult::Empty),
}
Expand All @@ -176,6 +197,24 @@ impl RawQueryResult {
Ok(res)
}

fn to_record_batches_stream(bytes_stream: BytesStream) -> RecordBatchStream {
let batch_stream = bytes_stream.flat_map(|bytes_result| match bytes_result {
Ok(bytes) => match Self::bytes_to_batches(bytes) {
Ok(batches) => futures::stream::iter(batches.into_iter().map(Ok)).boxed(),
Err(e) => futures::stream::once(async move { Err(ArrowError::from(e)) }).boxed(),
},
Err(e) => futures::stream::once(async move {
Err(ArrowError::ParseError(format!(
"Unable to parse RecordBatch due to error in bytes stream: {}",
e.to_string()
)))
})
.boxed(),
});

Box::pin(batch_stream)
}

fn bytes_to_batches(bytes: Bytes) -> Result<Vec<RecordBatch>, ArrowError> {
let record_batches = StreamReader::try_new_unbuffered(bytes.reader(), None)?;
record_batches.into_iter().collect()
Expand Down Expand Up @@ -380,10 +419,23 @@ impl SnowflakeApi {
/// If statement is PUT, then file will be uploaded to the Snowflake-managed storage
pub async fn exec(&self, sql: &str) -> Result<QueryResult, SnowflakeApiError> {
let raw = self.exec_raw(sql).await?;
let res = raw.deserialize_arrow()?;
let res = raw.deserialize_arrow().await?;
Ok(res)
}

// Executes a single query against API and returns a stream of RecordBatches
pub async fn exec_streamed(&self, sql: &str) -> Result<RecordBatchStream, SnowflakeApiError> {
let raw = self.exec_arrow_raw(sql, true).await?;
match raw {
RawQueryResult::Empty => Ok(Box::pin(futures::stream::empty())),
RawQueryResult::Stream(bytes_stream) => {
let arrow_stream = RawQueryResult::to_record_batches_stream(bytes_stream);
Ok(arrow_stream)
}
_ => Err(SnowflakeApiError::UnexpectedResponse),
}
}

/// Executes a single query against API.
/// If statement is PUT, then file will be uploaded to the Snowflake-managed storage
/// Returns raw bytes in the Arrow response
Expand All @@ -395,7 +447,7 @@ impl SnowflakeApi {
log::info!("Detected PUT query");
self.exec_put(sql).await.map(|()| RawQueryResult::Empty)
} else {
self.exec_arrow_raw(sql).await
self.exec_arrow_raw(sql, false).await
}
}

Expand Down Expand Up @@ -429,8 +481,12 @@ impl SnowflakeApi {
.await
}

async fn exec_arrow_raw(&self, sql: &str) -> Result<RawQueryResult, SnowflakeApiError> {
let resp = self
async fn exec_arrow_raw(
&self,
sql: &str,
enable_streaming: bool,
) -> Result<RawQueryResult, SnowflakeApiError> {
let mut resp = self
.run_sql::<ExecResponse>(sql, QueryType::ArrowQuery)
.await?;
log::debug!("Got query response: {:?}", resp);
Expand Down Expand Up @@ -459,14 +515,19 @@ impl SnowflakeApi {
value,
schema: resp.data.rowtype.into_iter().map(Into::into).collect(),
}))
} else if let Some(base64) = resp.data.rowset_base64 {
// fixme: is it possible to give streaming interface?
} else if resp.data.rowset_base64.is_some() {
if enable_streaming {
return Ok(self.chunks_to_bytes_stream(&resp.data));
}

let mut chunks = try_join_all(resp.data.chunks.iter().map(|chunk| {
self.connection
.get_chunk(&chunk.url, &resp.data.chunk_headers)
}))
.await?;

let base64 = resp.data.rowset_base64.unwrap_or_default();

// fixme: should base64 chunk go first?
// fixme: if response is chunked is it both base64 + chunks or just chunks?
if !base64.is_empty() {
Expand Down Expand Up @@ -510,4 +571,49 @@ impl SnowflakeApi {

Ok(resp)
}

fn chunks_to_bytes_stream(&self, data: &QueryExecResponseData) -> RawQueryResult {
let chunk_urls = data
.chunks
.iter()
.map(|chunk| chunk.url.clone())
.collect::<Vec<String>>();
let chunk_headers = data.chunk_headers.clone();
let connection = self.connection.clone();
let base64 = data.rowset_base64.clone().unwrap_or_default();

let stream = stream! {

let chunks_iter = chunk_urls.chunks(MAX_CHUNK_DOWNLOAD_WORKERS);

for chunk in chunks_iter {
let futures_batch = chunk.iter().map(|chunk_url| {
let headers = chunk_headers.clone();
let connection_clone = connection.clone();
async move {
connection_clone.get_chunk(chunk_url, &headers).await.map_err(SnowflakeApiError::from)
}
}).collect::<Vec<_>>();

let results = future::join_all(futures_batch).await;
for result in results {
yield result;
}
}

if !base64.is_empty() {
log::debug!("Got base64 encoded response");
match base64::engine::general_purpose::STANDARD.decode(&base64) {
Ok(bytes) => {
yield Ok(Bytes::from(bytes));
}
Err(e) => {
yield Err(SnowflakeApiError::from(e));
}
}
}
};

RawQueryResult::Stream(Box::pin(stream))
}
}
1 change: 1 addition & 0 deletions snowflake-api/src/polars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ impl RawQueryResult {
RawQueryResult::Bytes(bytes) => dataframe_from_bytes(bytes),
RawQueryResult::Json(json) => dataframe_from_json(&json),
RawQueryResult::Empty => Ok(DataFrame::empty()),
RawQueryResult::Stream(_) => todo!(),
}
}
}
Expand Down

0 comments on commit d77e97b

Please sign in to comment.