From 3ce5587faae3912ceedae4644732fa9704eb6d76 Mon Sep 17 00:00:00 2001 From: HugoCasa Date: Wed, 25 Sep 2024 18:11:11 +0200 Subject: [PATCH] feat: add return_last_result annotation to sql (#4443) --- backend/windmill-common/src/worker.rs | 16 ++ .../windmill-worker/src/bigquery_executor.rs | 148 ++++++++------ backend/windmill-worker/src/mssql_executor.rs | 52 +++-- backend/windmill-worker/src/mysql_executor.rs | 94 ++++++--- backend/windmill-worker/src/pg_executor.rs | 108 ++++++---- .../windmill-worker/src/snowflake_executor.rs | 189 ++++++++++-------- frontend/src/lib/script_helpers.ts | 3 +- 7 files changed, 366 insertions(+), 244 deletions(-) diff --git a/backend/windmill-common/src/worker.rs b/backend/windmill-common/src/worker.rs index bf355f92e1328..9464eccdcc466 100644 --- a/backend/windmill-common/src/worker.rs +++ b/backend/windmill-common/src/worker.rs @@ -327,6 +327,22 @@ pub fn get_annotation(inner_content: &str) -> Annotations { Annotations { npm_mode, nodejs_mode, native_mode, nobundling } } +pub struct SqlAnnotations { + pub return_last_result: bool, +} + +pub fn get_sql_annotations(inner_content: &str) -> SqlAnnotations { + let annotations = inner_content + .lines() + .take_while(|x| x.starts_with("--")) + .map(|x| x.to_string().replace("--", "").trim().to_string()) + .collect_vec(); + + let return_last_result: bool = annotations.contains(&"return_last_result".to_string()); + + SqlAnnotations { return_last_result } +} + pub async fn load_cache(bin_path: &str, _remote_path: &str) -> (bool, String) { if tokio::fs::metadata(&bin_path).await.is_ok() { (true, format!("loaded from local cache: {}\n", bin_path)) diff --git a/backend/windmill-worker/src/bigquery_executor.rs b/backend/windmill-worker/src/bigquery_executor.rs index d9194c747532b..ea240b87c9baa 100644 --- a/backend/windmill-worker/src/bigquery_executor.rs +++ b/backend/windmill-worker/src/bigquery_executor.rs @@ -5,6 +5,7 @@ use futures::{FutureExt, TryFutureExt}; use serde_json::{json, value::RawValue, Value}; use windmill_common::error::to_anyhow; use windmill_common::jobs::QueuedJob; +use windmill_common::worker::get_sql_annotations; use windmill_common::{error::Error, worker::to_raw_value}; use windmill_parser_sql::{ parse_bigquery_sig, parse_db_resource, parse_sql_blocks, parse_sql_statement_named_params, @@ -69,6 +70,7 @@ fn do_bigquery_inner<'a>( token: &'a str, timeout_ms: i32, column_order: Option<&'a mut Option>>, + skip_collect: bool, ) -> windmill_common::error::Result>>> { let param_names = parse_sql_statement_named_params(query, '@'); @@ -106,76 +108,80 @@ fn do_bigquery_inner<'a>( match response.error_for_status_ref() { Ok(_) => { - let result = response.json::().await.map_err(|e| { - Error::ExecutionErr(format!( - "BigQuery API response could not be parsed: {}", - e.to_string() - )) - })?; - - if !result.jobComplete { - return Err(Error::ExecutionErr( - "BigQuery API did not answer query in time".to_string(), - )); - } + if skip_collect { + return Ok(to_raw_value(&Value::Array(vec![]))); + } else { + let result = response.json::().await.map_err(|e| { + Error::ExecutionErr(format!( + "BigQuery API response could not be parsed: {}", + e.to_string() + )) + })?; + + if !result.jobComplete { + return Err(Error::ExecutionErr( + "BigQuery API did not answer query in time".to_string(), + )); + } - if result.rows.is_none() || result.rows.as_ref().unwrap().len() == 0 { - return Ok(serde_json::from_str("[]").unwrap()); - } + if result.rows.is_none() || result.rows.as_ref().unwrap().len() == 0 { + return Ok(serde_json::from_str("[]").unwrap()); + } - if result.schema.is_none() { - return Err(Error::ExecutionErr( - "Incomplete response from BigQuery API".to_string(), + if result.schema.is_none() { + return Err(Error::ExecutionErr( + "Incomplete response from BigQuery API".to_string(), + )); + } + + if result + .totalRows + .unwrap_or(json!("")) + .as_str() + .unwrap_or("") + .parse::() + .unwrap_or(0) + > 10000 + { + return Err(Error::ExecutionErr( + "More than 10000 rows were requested, use LIMIT 10000 to limit the number of rows".to_string(), )); - } + } - if result - .totalRows - .unwrap_or(json!("")) - .as_str() - .unwrap_or("") - .parse::() - .unwrap_or(0) - > 10000 - { - return Err(Error::ExecutionErr( - "More than 10000 rows were requested, use LIMIT 10000 to limit the number of rows".to_string(), - )); - } + if let Some(column_order) = column_order { + *column_order = Some( + result + .schema + .as_ref() + .unwrap() + .fields + .iter() + .map(|x| x.name.clone()) + .collect::>(), + ); + } - if let Some(column_order) = column_order { - *column_order = Some( - result - .schema - .as_ref() - .unwrap() - .fields - .iter() - .map(|x| x.name.clone()) - .collect::>(), - ); - } + let rows = result + .rows + .unwrap() + .iter() + .map(|row| { + let mut row_map = serde_json::Map::new(); + row.f + .iter() + .zip(result.schema.as_ref().unwrap().fields.iter()) + .for_each(|(field, schema)| { + row_map.insert( + schema.name.clone(), + parse_val(&field.v, &schema.r#type, &schema), + ); + }); + Value::from(row_map) + }) + .collect::>(); - let rows = result - .rows - .unwrap() - .iter() - .map(|row| { - let mut row_map = serde_json::Map::new(); - row.f - .iter() - .zip(result.schema.as_ref().unwrap().fields.iter()) - .for_each(|(field, schema)| { - row_map.insert( - schema.name.clone(), - parse_val(&field.v, &schema.r#type, &schema), - ); - }); - Value::from(row_map) - }) - .collect::>(); - - return Ok(to_raw_value(&rows)); + Ok(to_raw_value(&rows)) + } } Err(e) => match response.json::().await { Ok(bq_err) => Err(Error::ExecutionErr(format!( @@ -230,6 +236,8 @@ pub async fn do_bigquery( return Err(Error::BadRequest("Missing database argument".to_string())); }; + let annotations = get_sql_annotations(query); + let service_account = CustomServiceAccount::from_json(&database) .map_err(|e| Error::ExecutionErr(e.to_string()))?; @@ -306,7 +314,8 @@ pub async fn do_bigquery( let result_f = if queries.len() > 1 { let futures = queries .iter() - .map(|x| { + .enumerate() + .map(|(i, x)| { do_bigquery_inner( x, &statement_values, @@ -314,17 +323,23 @@ pub async fn do_bigquery( token.as_str(), timeout_ms, None, + annotations.return_last_result && i < queries.len() - 1, ) }) .collect::>>()?; let f = async { let mut res: Vec> = vec![]; + for fut in futures { let r = fut.await?; res.push(r); } - Ok(to_raw_value(&res)) + if annotations.return_last_result && res.len() > 0 { + Ok(res.pop().unwrap()) + } else { + Ok(to_raw_value(&res)) + } }; f.boxed() @@ -336,6 +351,7 @@ pub async fn do_bigquery( token.as_str(), timeout_ms, Some(column_order), + false, )? }; diff --git a/backend/windmill-worker/src/mssql_executor.rs b/backend/windmill-worker/src/mssql_executor.rs index beeb62c2b0c1c..2bffb2ee6136d 100644 --- a/backend/windmill-worker/src/mssql_executor.rs +++ b/backend/windmill-worker/src/mssql_executor.rs @@ -1,6 +1,5 @@ use base64::{engine::general_purpose, Engine as _}; use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; -use futures::TryFutureExt; use regex::Regex; use serde::Deserialize; use serde_json::value::RawValue; @@ -10,7 +9,7 @@ use tokio::net::TcpStream; use tokio_util::compat::TokioAsyncWriteCompatExt; use uuid::Uuid; use windmill_common::error::{self, Error}; -use windmill_common::worker::to_raw_value; +use windmill_common::worker::{get_sql_annotations, to_raw_value}; use windmill_common::{error::to_anyhow, jobs::QueuedJob}; use windmill_parser_sql::{parse_db_resource, parse_mssql_sig}; use windmill_queue::{append_logs, CanceledBy}; @@ -67,6 +66,8 @@ pub async fn do_mssql( return Err(Error::BadRequest("Missing database argument".to_string())); }; + let annotations = get_sql_annotations(query); + let mut config = Config::new(); config.host(database.host); @@ -124,37 +125,44 @@ pub async fn do_mssql( // polled to the end before querying again. Using streams allows // fetching data in an asynchronous manner, if needed. let stream = prepared_query.query(&mut client).await.map_err(to_anyhow)?; - stream - .into_results() - .await - .map_err(to_anyhow)? - .into_iter() - .map(|rows| { - let result = rows - .into_iter() - .map(|row| row_to_json(row)) - .collect::>, Error>>(); - result - }) - .collect::>>, Error>>() + + let results = stream.into_results().await.map_err(to_anyhow)?; + let len = results.len(); + let mut json_results = vec![]; + for (i, statement_result) in results.into_iter().enumerate() { + if annotations.return_last_result && i < len - 1 { + continue; + } + let mut json_rows = vec![]; + for row in statement_result { + let row = row_to_json(row)?; + json_rows.push(row); + } + json_results.push(json_rows); + } + + if annotations.return_last_result && json_results.len() > 0 { + Ok(to_raw_value(&json_results.pop().unwrap())) + } else { + Ok(to_raw_value(&json_results)) + } }; - let rows = run_future_with_polling_update_job_poller( + let raw_result = run_future_with_polling_update_job_poller( job.id, job.timeout, db, mem_peak, canceled_by, - result_f.map_err(to_anyhow), + result_f, worker_name, &job.workspace_id, ) .await?; - let r = to_raw_value(&rows); - *mem_peak = (r.get().len() / 1000) as i32; + *mem_peak = (raw_result.get().len() / 1000) as i32; - return Ok(to_raw_value(&rows)); + Ok(raw_result) } fn json_value_to_sql<'a>( @@ -221,7 +229,7 @@ fn json_value_to_sql<'a>( Ok(()) } -fn row_to_json(row: Row) -> Result, Error> { +fn row_to_json(row: Row) -> Result { let cols = row .columns() .iter() @@ -231,7 +239,7 @@ fn row_to_json(row: Row) -> Result, Error> { for (col, val) in cols.iter().zip(row.into_iter()) { map.insert(col.name().to_string(), sql_to_json_value(val)?); } - Ok(map) + Ok(Value::Object(map)) } fn value_or_null( diff --git a/backend/windmill-worker/src/mysql_executor.rs b/backend/windmill-worker/src/mysql_executor.rs index 8c138351e4681..9d16c6bde2bc7 100644 --- a/backend/windmill-worker/src/mysql_executor.rs +++ b/backend/windmill-worker/src/mysql_executor.rs @@ -13,6 +13,7 @@ use tokio::sync::Mutex; use windmill_common::{ error::{to_anyhow, Error}, jobs::QueuedJob, + worker::{get_sql_annotations, to_raw_value}, }; use windmill_parser_sql::{ parse_db_resource, parse_mysql_sig, parse_sql_blocks, parse_sql_statement_named_params, @@ -40,7 +41,8 @@ pub fn do_mysql_inner<'a>( all_statement_values: &Params, conn: Arc>, column_order: Option<&'a mut Option>>, -) -> windmill_common::error::Result>>> { + skip_collect: bool, +) -> windmill_common::error::Result>>> { let param_names = parse_sql_statement_named_params(query, ':') .into_iter() .map(|x| x.into_bytes()) @@ -58,31 +60,42 @@ pub fn do_mysql_inner<'a>( }; let result_f = async move { - let rows: Vec = conn - .lock() - .await - .exec(query, statement_values) - .await - .map_err(to_anyhow)?; - - if let Some(column_order) = column_order { - *column_order = Some( - rows.first() - .map(|x| { - x.columns() - .iter() - .map(|x| x.name_str().to_string()) - .collect::>() - }) - .unwrap_or_default(), - ); - } + if skip_collect { + conn.lock() + .await + .exec_drop(query, statement_values) + .await + .map_err(to_anyhow)?; + + Ok(to_raw_value(&Value::Array(vec![]))) + } else { + let rows: Vec = conn + .lock() + .await + .exec(query, statement_values) + .await + .map_err(to_anyhow)?; + + if let Some(column_order) = column_order { + *column_order = Some( + rows.first() + .map(|x| { + x.columns() + .iter() + .map(|x| x.name_str().to_string()) + .collect::>() + }) + .unwrap_or_default(), + ); + } - Ok(rows - .into_iter() - .map(|x| convert_row_to_value(x)) - .collect::>()) - as Result, anyhow::Error> + Ok(to_raw_value( + &rows + .into_iter() + .map(|x| convert_row_to_value(x)) + .collect::>(), + )) + } }; Ok(result_f.boxed()) @@ -133,6 +146,8 @@ pub async fn do_mysql( return Err(Error::BadRequest("Missing database argument".to_string())); }; + let annotations = get_sql_annotations(query); + let opts = OptsBuilder::default() .db_name(Some(database.database)) .user(database.user) @@ -235,21 +250,40 @@ pub async fn do_mysql( let result_f = if queries.len() > 1 { let futures = queries .iter() - .map(|x| do_mysql_inner(x, &statement_values, conn_a.clone(), None)) + .enumerate() + .map(|(i, x)| { + do_mysql_inner( + x, + &statement_values, + conn_a.clone(), + None, + annotations.return_last_result && i < queries.len() - 1, + ) + }) .collect::>>()?; let f = async { - let mut res: Vec = vec![]; + let mut res: Vec> = vec![]; for fut in futures { let r = fut.await?; - res.push(serde_json::to_value(r).map_err(to_anyhow)?); + res.push(r); + } + if annotations.return_last_result && res.len() > 0 { + Ok(res.pop().unwrap()) + } else { + Ok(to_raw_value(&res)) } - Ok(res) }; f.boxed() } else { - do_mysql_inner(query, &statement_values, conn_a.clone(), Some(column_order))? + do_mysql_inner( + query, + &statement_values, + conn_a.clone(), + Some(column_order), + false, + )? }; let result = run_future_with_polling_update_job_poller( diff --git a/backend/windmill-worker/src/pg_executor.rs b/backend/windmill-worker/src/pg_executor.rs index b55f550428169..000042d3e62a4 100644 --- a/backend/windmill-worker/src/pg_executor.rs +++ b/backend/windmill-worker/src/pg_executor.rs @@ -30,7 +30,7 @@ use tokio_postgres::{ }; use uuid::Uuid; use windmill_common::error::{self, Error}; -use windmill_common::worker::{to_raw_value, CLOUD_HOSTED}; +use windmill_common::worker::{get_sql_annotations, to_raw_value, CLOUD_HOSTED}; use windmill_common::{error::to_anyhow, jobs::QueuedJob}; use windmill_parser::{Arg, Typ}; use windmill_parser_sql::{ @@ -68,7 +68,8 @@ fn do_postgresql_inner<'a>( client: &'a Client, column_order: Option<&'a mut Option>>, siz: &'a AtomicUsize, -) -> error::Result>>> { + skip_collect: bool, +) -> error::Result>>> { let mut query_params = vec![]; let arg_indices = parse_pg_statement_arg_indices(&query); @@ -93,51 +94,60 @@ fn do_postgresql_inner<'a>( let result_f = async move { // Now we can execute a simple statement that just returns its parameter. - let rows = client - .query_raw(&query, query_params) - .await - .map_err(to_anyhow)?; - - let rows = rows.try_collect::>().await.map_err(to_anyhow)?; - - if let Some(column_order) = column_order { - *column_order = Some( - rows.first() - .map(|x| { - x.columns() - .iter() - .map(|x| x.name().to_string()) - .collect::>() - }) - .unwrap_or_default(), - ); - } let mut res: Vec = vec![]; - for row in rows.into_iter() { - let r = postgres_row_to_json_value(row); - if let Ok(v) = r.as_ref() { - let size = sizeof_val(v); - siz.fetch_add(size, Ordering::Relaxed); + + if skip_collect { + client + .execute_raw(&query, query_params) + .await + .map_err(to_anyhow)?; + } else { + let rows = client + .query_raw(&query, query_params) + .await + .map_err(to_anyhow)?; + + let rows = rows.try_collect::>().await.map_err(to_anyhow)?; + + if let Some(column_order) = column_order { + *column_order = Some( + rows.first() + .map(|x| { + x.columns() + .iter() + .map(|x| x.name().to_string()) + .collect::>() + }) + .unwrap_or_default(), + ); } - if *CLOUD_HOSTED { - let siz = siz.load(Ordering::Relaxed); - if siz > MAX_RESULT_SIZE * 4 { - return Err(anyhow::anyhow!( - "Query result too large for cloud (size = {} > {})", - siz, - MAX_RESULT_SIZE & 4 - )); + + for row in rows.into_iter() { + let r = postgres_row_to_json_value(row); + if let Ok(v) = r.as_ref() { + let size = sizeof_val(v); + siz.fetch_add(size, Ordering::Relaxed); + } + if *CLOUD_HOSTED { + let siz = siz.load(Ordering::Relaxed); + if siz > MAX_RESULT_SIZE * 4 { + return Err(anyhow::anyhow!( + "Query result too large for cloud (size = {} > {})", + siz, + MAX_RESULT_SIZE & 4 + )); + } + } + if let Ok(v) = r { + res.push(v); + } else { + return Err(to_anyhow(r.err().unwrap())); } - } - if let Ok(v) = r { - res.push(v); - } else { - return Err(to_anyhow(r.err().unwrap())); } } - Ok(res) + Ok(to_raw_value(&res)) }; Ok(result_f.boxed()) @@ -178,6 +188,9 @@ pub async fn do_postgresql( } else { return Err(Error::BadRequest("Missing database argument".to_string())); }; + + let annotations = get_sql_annotations(query); + let sslmode = match database.sslmode.as_deref() { Some("allow") => "prefer".to_string(), Some("verify-ca") | Some("verify-full") => "require".to_string(), @@ -291,24 +304,30 @@ pub async fn do_postgresql( let result_f = if queries.len() > 1 { let futures = queries .iter() - .map(|x| { + .enumerate() + .map(|(i, x)| { do_postgresql_inner( x.to_string(), ¶m_idx_to_arg_and_value, client, None, &size, + annotations.return_last_result && i < queries.len() - 1, ) }) .collect::>>()?; let f = async { - let mut res: Vec = vec![]; + let mut res: Vec> = vec![]; for fut in futures { let r = fut.await?; - res.push(serde_json::to_value(r).map_err(to_anyhow)?); + res.push(r); + } + if annotations.return_last_result && res.len() > 0 { + Ok(res.pop().unwrap()) + } else { + Ok(to_raw_value(&res)) } - Ok(res) }; f.boxed() @@ -319,6 +338,7 @@ pub async fn do_postgresql( client, Some(column_order), &size, + false, )? }; diff --git a/backend/windmill-worker/src/snowflake_executor.rs b/backend/windmill-worker/src/snowflake_executor.rs index aac2210356035..9c9bec7c9f993 100644 --- a/backend/windmill-worker/src/snowflake_executor.rs +++ b/backend/windmill-worker/src/snowflake_executor.rs @@ -9,6 +9,7 @@ use serde_json::{json, value::RawValue, Value}; use sha2::{Digest, Sha256}; use std::collections::HashMap; use windmill_common::error::to_anyhow; +use windmill_common::worker::get_sql_annotations; use windmill_common::jobs::QueuedJob; use windmill_common::{error::Error, worker::to_raw_value}; @@ -74,34 +75,41 @@ struct SnowflakeError { } trait SnowflakeResponseExt { - async fn get_snowflake_response Deserialize<'a>>( + async fn parse_snowflake_response Deserialize<'a>>( self, ) -> windmill_common::error::Result; } +async fn handle_snowflake_result( + result: Result, +) -> windmill_common::error::Result { + match result { + Ok(response) => match response.error_for_status_ref() { + Ok(_) => Ok(response), + Err(e) => { + let resp = response.text().await.unwrap_or("".to_string()); + match serde_json::from_str::(&resp) { + Ok(sf_err) => return Err(Error::ExecutionErr(sf_err.message)), + Err(_) => return Err(Error::ExecutionErr(e.to_string())), + } + } + }, + Err(e) => Err(Error::ExecutionErr(format!( + "Could not send request: {:?}", + e + ))), + } +} + impl SnowflakeResponseExt for Result { - async fn get_snowflake_response Deserialize<'a>>( + async fn parse_snowflake_response Deserialize<'a>>( self, ) -> windmill_common::error::Result { - match self { - Ok(response) => match response.error_for_status_ref() { - Ok(_) => response - .json::() - .await - .map_err(|e| Error::ExecutionErr(e.to_string())), - Err(e) => { - let resp = response.text().await.unwrap_or("".to_string()); - match serde_json::from_str::(&resp) { - Ok(sf_err) => return Err(Error::ExecutionErr(sf_err.message)), - Err(_) => return Err(Error::ExecutionErr(e.to_string())), - } - } - }, - Err(e) => Err(Error::ExecutionErr(format!( - "Could not send request: {:?}", - e - ))), - } + let response = handle_snowflake_result(self).await?; + response + .json::() + .await + .map_err(|e| Error::ExecutionErr(e.to_string())) } } @@ -112,6 +120,7 @@ fn do_snowflake_inner<'a>( account_identifier: &'a str, token: &'a str, column_order: Option<&'a mut Option>>, + skip_collect: bool, ) -> windmill_common::error::Result>>> { body.insert("statement".to_string(), json!(query)); @@ -135,7 +144,7 @@ fn do_snowflake_inner<'a>( } let result_f = async move { - let response = HTTP_CLIENT + let result = HTTP_CLIENT .post(format!( "https://{}.snowflakecomputing.com/api/v2/statements/", account_identifier.to_uppercase() @@ -144,67 +153,76 @@ fn do_snowflake_inner<'a>( .header("X-Snowflake-Authorization-Token-Type", "KEYPAIR_JWT") .json(&body) .send() - .await - .get_snowflake_response::() - .await?; - - if response.resultSetMetaData.numRows > 10000 { - return Err(Error::ExecutionErr( - "More than 10000 rows were requested, use LIMIT 10000 to limit the number of rows" - .to_string(), - )); - } - if let Some(column_order) = column_order { - *column_order = Some( - response - .resultSetMetaData - .rowType - .iter() - .map(|x| x.name.clone()) - .collect::>(), - ); - } - - let mut rows = response.data; - - if response.resultSetMetaData.partitionInfo.len() > 1 { - for idx in 1..response.resultSetMetaData.partitionInfo.len() { - let url = format!( - "https://{}.snowflakecomputing.com/api/v2/statements/{}", - account_identifier.to_uppercase(), - response.statementHandle + .await; + + if skip_collect { + handle_snowflake_result(result).await?; + Ok(to_raw_value(&Value::Array(vec![]))) + } else { + let response = result + .parse_snowflake_response::() + .await?; + + if response.resultSetMetaData.numRows > 10000 { + return Err(Error::ExecutionErr( + "More than 10000 rows were requested, use LIMIT 10000 to limit the number of rows" + .to_string(), + )); + } + if let Some(column_order) = column_order { + *column_order = Some( + response + .resultSetMetaData + .rowType + .iter() + .map(|x| x.name.clone()) + .collect::>(), ); - let response = HTTP_CLIENT - .get(url) - .bearer_auth(token) - .header("X-Snowflake-Authorization-Token-Type", "KEYPAIR_JWT") - .query(&[("partition", idx.to_string())]) - .send() - .await - .get_snowflake_response::() - .await?; - - rows.extend(response.data); } - } - let rows = to_raw_value( - &rows - .iter() - .map(|row| { - let mut row_map = serde_json::Map::new(); - row.iter() - .zip(response.resultSetMetaData.rowType.iter()) - .for_each(|(val, row_type)| { - row_map - .insert(row_type.name.clone(), parse_val(&val, &row_type.r#type)); - }); - row_map - }) - .collect::>(), - ); + let mut rows = response.data; + + if response.resultSetMetaData.partitionInfo.len() > 1 { + for idx in 1..response.resultSetMetaData.partitionInfo.len() { + let url = format!( + "https://{}.snowflakecomputing.com/api/v2/statements/{}", + account_identifier.to_uppercase(), + response.statementHandle + ); + let response = HTTP_CLIENT + .get(url) + .bearer_auth(token) + .header("X-Snowflake-Authorization-Token-Type", "KEYPAIR_JWT") + .query(&[("partition", idx.to_string())]) + .send() + .await + .parse_snowflake_response::() + .await?; + + rows.extend(response.data); + } + } - Ok(rows) + let rows = to_raw_value( + &rows + .iter() + .map(|row| { + let mut row_map = serde_json::Map::new(); + row.iter() + .zip(response.resultSetMetaData.rowType.iter()) + .for_each(|(val, row_type)| { + row_map.insert( + row_type.name.clone(), + parse_val(&val, &row_type.r#type), + ); + }); + row_map + }) + .collect::>(), + ); + + Ok(rows) + } }; Ok(result_f.boxed()) @@ -246,6 +264,8 @@ pub async fn do_snowflake( return Err(Error::BadRequest("Missing database argument".to_string())); }; + let annotations = get_sql_annotations(query); + let qualified_username = format!( "{}.{}", database.account_identifier.split('.').next().unwrap_or(""), // get first part of account identifier @@ -315,7 +335,8 @@ pub async fn do_snowflake( let result_f = if queries.len() > 1 { let futures = queries .iter() - .map(|x| { + .enumerate() + .map(|(i, x)| { do_snowflake_inner( x, &snowflake_args, @@ -323,6 +344,7 @@ pub async fn do_snowflake( &database.account_identifier, &token, None, + annotations.return_last_result && i < queries.len() - 1, ) }) .collect::>>()?; @@ -333,7 +355,11 @@ pub async fn do_snowflake( let r = fut.await?; res.push(r); } - Ok(to_raw_value(&res)) + if annotations.return_last_result && res.len() > 0 { + Ok(res.pop().unwrap()) + } else { + Ok(to_raw_value(&res)) + } }; f.boxed() @@ -345,6 +371,7 @@ pub async fn do_snowflake( &database.account_identifier, &token, Some(column_order), + false, )? }; let r = run_future_with_polling_update_job_poller( diff --git a/frontend/src/lib/script_helpers.ts b/frontend/src/lib/script_helpers.ts index 458499d3d2dc5..3dcade90c0a2f 100644 --- a/frontend/src/lib/script_helpers.ts +++ b/frontend/src/lib/script_helpers.ts @@ -215,7 +215,8 @@ INSERT INTO demo VALUES (?, ?); UPDATE demo SET col2 = ? WHERE col2 = ?; ` -export const MSSQL_INIT_CODE = `-- to pin the database use '-- database f/your/path' +export const MSSQL_INIT_CODE = `-- return_last_result +-- to pin the database use '-- database f/your/path' -- @p1 name1 (varchar) = default arg -- @p2 name2 (int) -- @p3 name3 (int)