Skip to content

Commit

Permalink
snowflake-api: GET query support
Browse files Browse the repository at this point in the history
  • Loading branch information
andrusha committed Aug 28, 2023
1 parent 1238765 commit 1508971
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 11 deletions.
7 changes: 4 additions & 3 deletions snowflake-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ serde_json = "1"
serde = { version = "1", features = ["derive"] }
url = "2"
uuid = { version = "1.4", features = ["v4"] }
arrow = "42"
arrow = "45"
base64 = "0.21"
regex = "1"
object_store = { version = "0.6", features = ["aws"] }
object_store = { version = "0.7", features = ["aws"] }
async-trait = "0.1"

[dev-dependencies]
anyhow = "1"
pretty_env_logger = "0.5.0"
clap = { version = "4", features = ["derive"] }
arrow = { version = "42", features = ["prettyprint"] }
tokio = { version = "1", features=["macros", "rt-multi-thread"] }
parquet = { version = "45", features = ["arrow", "snap"] }
arrow = { version = "45", features = ["prettyprint"] }
36 changes: 36 additions & 0 deletions snowflake-api/examples/filetransfer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use arrow::util::pretty::pretty_format_batches;
use clap::Parser;
use snowflake_api::{QueryResult, SnowflakeApi};
use std::fs;
use std::fs::File;
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;

extern crate snowflake_api;

Expand Down Expand Up @@ -43,6 +45,9 @@ struct Args {

#[arg(long)]
csv_path: String,

#[arg(long)]
output_path: String,
}

#[tokio::main]
Expand Down Expand Up @@ -111,7 +116,38 @@ async fn main() -> Result<()> {
}
}

log::info!("Copy table contents into a stage");
api.exec(
"COPY INTO @%OSCAR_AGE_MALE/output/ FROM OSCAR_AGE_MALE FILE_FORMAT = (TYPE = parquet COMPRESSION = NONE) HEADER = TRUE OVERWRITE = TRUE SINGLE = TRUE;"
).await?;

log::info!("Downloading Parquet files");
api.exec(&format!(
"GET @%OSCAR_AGE_MALE/output/ file://{}",
&args.output_path
))
.await?;

log::info!("Closing Snowflake session");
api.close_session().await?;

log::info!("Reading downloaded files");
let parquet_dir = format!("{}output", &args.output_path);
let paths = fs::read_dir(&parquet_dir).unwrap();

for path in paths {
let path = path?.path();
log::info!("Reading {:?}", path);
let file = File::open(path)?;

let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
let reader = builder.build()?;
let mut batches = Vec::default();
for batch in reader {
batches.push(batch?);
}
println!("{}", pretty_format_batches(batches.as_slice()).unwrap());
}

Ok(())
}
95 changes: 91 additions & 4 deletions snowflake-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
issue_tracker_base_url = "https://github.com/mycelial/snowflake-rs/issues",
test(no_crate_inject)
)]
#![doc = include_str ! ("../README.md")]
#![doc = include_str!("../README.md")]

use std::io;
use std::path::Path;
Expand Down Expand Up @@ -169,24 +169,109 @@ impl SnowflakeApi {
/// Execute a single query against API.
/// If statement is PUT, then file will be uploaded to the Snowflake-managed storage
pub async fn exec(&mut self, sql: &str) -> Result<QueryResult, SnowflakeApiError> {
// fixme: can go without regex? but needs different accept-mime for those still
let put_re = Regex::new(r"(?i)^(?:/\*.*\*/\s*)*put\s+").unwrap();
let get_re = Regex::new(r"(?i)^(?:/\*.*\*/\s*)*get\s+").unwrap();

// put commands go through a different flow and result is side-effect
// put/get commands go through a different flow and result is side-effect
if put_re.is_match(sql) {
log::info!("Detected PUT query");

self.exec_put(sql).await.map(|_| QueryResult::Empty)
} else if get_re.is_match(sql) {
log::info!("Detected GET query");

self.exec_get(sql).await.map(|_| QueryResult::Empty)
} else {
self.exec_arrow(sql).await
}
}

async fn exec_get(&mut self, sql: &str) -> Result<(), SnowflakeApiError> {
let resp = self
.run_sql::<ExecResponse>(sql, QueryType::JsonQuery)
.await?;
log::debug!("Got GET response: {:?}", resp);

match resp {
ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::PutGet(pg) => self.get(pg).await,
ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
e.data.error_code,
e.message.unwrap_or_default(),
)),
}
}

async fn get(&self, resp: PutGetExecResponse) -> Result<(), SnowflakeApiError> {
match resp.data.stage_info {
PutGetStageInfo::Aws(info) => {
self.get_from_s3(
resp.data
.local_location
.ok_or(SnowflakeApiError::BrokenResponse)?,
&resp.data.src_locations,
info,
)
.await
}
PutGetStageInfo::Azure(_) => Err(SnowflakeApiError::Unimplemented(
"GET local file requests for Azure".to_string(),
)),
PutGetStageInfo::Gcs(_) => Err(SnowflakeApiError::Unimplemented(
"GET local file requests for GCS".to_string(),
)),
}
}

// fixme: refactor s3 put/get into a single function?
async fn get_from_s3(
&self,
local_location: String,
src_locations: &[String],
info: AwsPutGetStageInfo,
) -> Result<(), SnowflakeApiError> {
// todo: use path parser?
let (bucket_name, bucket_path) = info
.location
.split_once('/')
.ok_or(SnowflakeApiError::InvalidBucketPath(info.location.clone()))?;

let s3 = AmazonS3Builder::new()
.with_region(info.region)
.with_bucket_name(bucket_name)
.with_access_key_id(info.creds.aws_key_id)
.with_secret_access_key(info.creds.aws_secret_key)
.with_token(info.creds.aws_token)
.build()?;

// todo: implement parallelism for small files
// todo: security vulnerability, external system tells you which local files to upload
for src_path in src_locations {
let dest_path = format!("{}{}", local_location, src_path);
let dest_path = object_store::path::Path::parse(dest_path)?;

let src_path = format!("{}{}", bucket_path, src_path);
let src_path = object_store::path::Path::parse(src_path)?;

// fixme: can we stream the thing or multipart?
let bytes = s3.get(&src_path).await?;
LocalFileSystem::new()
.put(&dest_path, bytes.bytes().await?)
.await?;
}

Ok(())
}

async fn exec_put(&mut self, sql: &str) -> Result<(), SnowflakeApiError> {
let resp = self
.run_sql::<ExecResponse>(sql, QueryType::JsonQuery)
.await?;
// fixme: don't log secrets maybe?
log::debug!("Got PUT response: {:?}", resp);

// fixme: support PUT for external stage
match resp {
ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::PutGet(pg) => self.put(pg).await,
Expand Down Expand Up @@ -227,21 +312,23 @@ impl SnowflakeApi {
.with_token(info.creds.aws_token)
.build()?;

// todo: implement parallelism for small files
// todo: security vulnerability, external system tells you which local files to upload
for src_path in src_locations {
let path = Path::new(src_path);
let filename = path
.file_name()
.ok_or(SnowflakeApiError::InvalidLocalPath(src_path.clone()))?;

// fixme: nicer way to join paths?
// fixme: unwrap
let dest_path = format!("{}{}", bucket_path, filename.to_str().unwrap());
let dest_path = object_store::path::Path::parse(dest_path)?;

let src_path = object_store::path::Path::parse(src_path)?;

// fixme: can we stream the thing or multipart?
let fs = LocalFileSystem::new().get(&src_path).await?;

s3.put(&dest_path, fs.bytes().await?).await?;
}

Expand Down Expand Up @@ -276,7 +363,7 @@ impl SnowflakeApi {
return Err(SnowflakeApiError::ApiError(
e.data.error_code,
e.message.unwrap_or_default(),
))
));
}
};

Expand Down
9 changes: 5 additions & 4 deletions snowflake-api/src/responses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,18 +198,19 @@ pub struct PutGetResponseData {
// file upload parallelism
pub parallel: i32,
// file size threshold, small ones are should be uploaded with given parallelism
pub threshold: i64,
pub threshold: Option<i64>,
// doesn't need compression if source is already compressed
pub auto_compress: bool,
pub auto_compress: Option<bool>,
pub overwrite: bool,
// maps to one of the predefined compression algos
// todo: support different compression formats?
pub source_compression: String,
pub source_compression: Option<String>,
pub stage_info: PutGetStageInfo,
pub encryption_material: EncryptionMaterialVariant,
// GCS specific. If you request multiple files?
// might return a [ null ] for AWS responses
#[serde(default)]
pub presigned_urls: Vec<String>,
pub presigned_urls: Vec<Option<String>>,
#[serde(default)]
pub parameters: Vec<NameValueParameter>,
pub statement_type_id: Option<i64>,
Expand Down

0 comments on commit 1508971

Please sign in to comment.