Skip to content

Commit

Permalink
feat: add return_last_result annotation to sql (#4443)
Browse files Browse the repository at this point in the history
  • Loading branch information
HugoCasa authored Sep 25, 2024
1 parent fd58e7e commit 3ce5587
Show file tree
Hide file tree
Showing 7 changed files with 366 additions and 244 deletions.
16 changes: 16 additions & 0 deletions backend/windmill-common/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,22 @@ pub fn get_annotation(inner_content: &str) -> Annotations {
Annotations { npm_mode, nodejs_mode, native_mode, nobundling }
}

pub struct SqlAnnotations {
pub return_last_result: bool,
}

pub fn get_sql_annotations(inner_content: &str) -> SqlAnnotations {
let annotations = inner_content
.lines()
.take_while(|x| x.starts_with("--"))
.map(|x| x.to_string().replace("--", "").trim().to_string())
.collect_vec();

let return_last_result: bool = annotations.contains(&"return_last_result".to_string());

SqlAnnotations { return_last_result }
}

pub async fn load_cache(bin_path: &str, _remote_path: &str) -> (bool, String) {
if tokio::fs::metadata(&bin_path).await.is_ok() {
(true, format!("loaded from local cache: {}\n", bin_path))
Expand Down
148 changes: 82 additions & 66 deletions backend/windmill-worker/src/bigquery_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use futures::{FutureExt, TryFutureExt};
use serde_json::{json, value::RawValue, Value};
use windmill_common::error::to_anyhow;
use windmill_common::jobs::QueuedJob;
use windmill_common::worker::get_sql_annotations;
use windmill_common::{error::Error, worker::to_raw_value};
use windmill_parser_sql::{
parse_bigquery_sig, parse_db_resource, parse_sql_blocks, parse_sql_statement_named_params,
Expand Down Expand Up @@ -69,6 +70,7 @@ fn do_bigquery_inner<'a>(
token: &'a str,
timeout_ms: i32,
column_order: Option<&'a mut Option<Vec<String>>>,
skip_collect: bool,
) -> windmill_common::error::Result<BoxFuture<'a, windmill_common::error::Result<Box<RawValue>>>> {
let param_names = parse_sql_statement_named_params(query, '@');

Expand Down Expand Up @@ -106,76 +108,80 @@ fn do_bigquery_inner<'a>(

match response.error_for_status_ref() {
Ok(_) => {
let result = response.json::<BigqueryResponse>().await.map_err(|e| {
Error::ExecutionErr(format!(
"BigQuery API response could not be parsed: {}",
e.to_string()
))
})?;

if !result.jobComplete {
return Err(Error::ExecutionErr(
"BigQuery API did not answer query in time".to_string(),
));
}
if skip_collect {
return Ok(to_raw_value(&Value::Array(vec![])));
} else {
let result = response.json::<BigqueryResponse>().await.map_err(|e| {
Error::ExecutionErr(format!(
"BigQuery API response could not be parsed: {}",
e.to_string()
))
})?;

if !result.jobComplete {
return Err(Error::ExecutionErr(
"BigQuery API did not answer query in time".to_string(),
));
}

if result.rows.is_none() || result.rows.as_ref().unwrap().len() == 0 {
return Ok(serde_json::from_str("[]").unwrap());
}
if result.rows.is_none() || result.rows.as_ref().unwrap().len() == 0 {
return Ok(serde_json::from_str("[]").unwrap());
}

if result.schema.is_none() {
return Err(Error::ExecutionErr(
"Incomplete response from BigQuery API".to_string(),
if result.schema.is_none() {
return Err(Error::ExecutionErr(
"Incomplete response from BigQuery API".to_string(),
));
}

if result
.totalRows
.unwrap_or(json!(""))
.as_str()
.unwrap_or("")
.parse::<i64>()
.unwrap_or(0)
> 10000
{
return Err(Error::ExecutionErr(
"More than 10000 rows were requested, use LIMIT 10000 to limit the number of rows".to_string(),
));
}
}

if result
.totalRows
.unwrap_or(json!(""))
.as_str()
.unwrap_or("")
.parse::<i64>()
.unwrap_or(0)
> 10000
{
return Err(Error::ExecutionErr(
"More than 10000 rows were requested, use LIMIT 10000 to limit the number of rows".to_string(),
));
}
if let Some(column_order) = column_order {
*column_order = Some(
result
.schema
.as_ref()
.unwrap()
.fields
.iter()
.map(|x| x.name.clone())
.collect::<Vec<String>>(),
);
}

if let Some(column_order) = column_order {
*column_order = Some(
result
.schema
.as_ref()
.unwrap()
.fields
.iter()
.map(|x| x.name.clone())
.collect::<Vec<String>>(),
);
}
let rows = result
.rows
.unwrap()
.iter()
.map(|row| {
let mut row_map = serde_json::Map::new();
row.f
.iter()
.zip(result.schema.as_ref().unwrap().fields.iter())
.for_each(|(field, schema)| {
row_map.insert(
schema.name.clone(),
parse_val(&field.v, &schema.r#type, &schema),
);
});
Value::from(row_map)
})
.collect::<Vec<_>>();

let rows = result
.rows
.unwrap()
.iter()
.map(|row| {
let mut row_map = serde_json::Map::new();
row.f
.iter()
.zip(result.schema.as_ref().unwrap().fields.iter())
.for_each(|(field, schema)| {
row_map.insert(
schema.name.clone(),
parse_val(&field.v, &schema.r#type, &schema),
);
});
Value::from(row_map)
})
.collect::<Vec<_>>();

return Ok(to_raw_value(&rows));
Ok(to_raw_value(&rows))
}
}
Err(e) => match response.json::<BigqueryErrorResponse>().await {
Ok(bq_err) => Err(Error::ExecutionErr(format!(
Expand Down Expand Up @@ -230,6 +236,8 @@ pub async fn do_bigquery(
return Err(Error::BadRequest("Missing database argument".to_string()));
};

let annotations = get_sql_annotations(query);

let service_account = CustomServiceAccount::from_json(&database)
.map_err(|e| Error::ExecutionErr(e.to_string()))?;

Expand Down Expand Up @@ -306,25 +314,32 @@ pub async fn do_bigquery(
let result_f = if queries.len() > 1 {
let futures = queries
.iter()
.map(|x| {
.enumerate()
.map(|(i, x)| {
do_bigquery_inner(
x,
&statement_values,
&project_id,
token.as_str(),
timeout_ms,
None,
annotations.return_last_result && i < queries.len() - 1,
)
})
.collect::<windmill_common::error::Result<Vec<_>>>()?;

let f = async {
let mut res: Vec<Box<RawValue>> = vec![];

for fut in futures {
let r = fut.await?;
res.push(r);
}
Ok(to_raw_value(&res))
if annotations.return_last_result && res.len() > 0 {
Ok(res.pop().unwrap())
} else {
Ok(to_raw_value(&res))
}
};

f.boxed()
Expand All @@ -336,6 +351,7 @@ pub async fn do_bigquery(
token.as_str(),
timeout_ms,
Some(column_order),
false,
)?
};

Expand Down
52 changes: 30 additions & 22 deletions backend/windmill-worker/src/mssql_executor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use base64::{engine::general_purpose, Engine as _};
use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
use futures::TryFutureExt;
use regex::Regex;
use serde::Deserialize;
use serde_json::value::RawValue;
Expand All @@ -10,7 +9,7 @@ use tokio::net::TcpStream;
use tokio_util::compat::TokioAsyncWriteCompatExt;
use uuid::Uuid;
use windmill_common::error::{self, Error};
use windmill_common::worker::to_raw_value;
use windmill_common::worker::{get_sql_annotations, to_raw_value};
use windmill_common::{error::to_anyhow, jobs::QueuedJob};
use windmill_parser_sql::{parse_db_resource, parse_mssql_sig};
use windmill_queue::{append_logs, CanceledBy};
Expand Down Expand Up @@ -67,6 +66,8 @@ pub async fn do_mssql(
return Err(Error::BadRequest("Missing database argument".to_string()));
};

let annotations = get_sql_annotations(query);

let mut config = Config::new();

config.host(database.host);
Expand Down Expand Up @@ -124,37 +125,44 @@ pub async fn do_mssql(
// polled to the end before querying again. Using streams allows
// fetching data in an asynchronous manner, if needed.
let stream = prepared_query.query(&mut client).await.map_err(to_anyhow)?;
stream
.into_results()
.await
.map_err(to_anyhow)?
.into_iter()
.map(|rows| {
let result = rows
.into_iter()
.map(|row| row_to_json(row))
.collect::<Result<Vec<Map<String, Value>>, Error>>();
result
})
.collect::<Result<Vec<Vec<Map<String, Value>>>, Error>>()

let results = stream.into_results().await.map_err(to_anyhow)?;
let len = results.len();
let mut json_results = vec![];
for (i, statement_result) in results.into_iter().enumerate() {
if annotations.return_last_result && i < len - 1 {
continue;
}
let mut json_rows = vec![];
for row in statement_result {
let row = row_to_json(row)?;
json_rows.push(row);
}
json_results.push(json_rows);
}

if annotations.return_last_result && json_results.len() > 0 {
Ok(to_raw_value(&json_results.pop().unwrap()))
} else {
Ok(to_raw_value(&json_results))
}
};

let rows = run_future_with_polling_update_job_poller(
let raw_result = run_future_with_polling_update_job_poller(
job.id,
job.timeout,
db,
mem_peak,
canceled_by,
result_f.map_err(to_anyhow),
result_f,
worker_name,
&job.workspace_id,
)
.await?;

let r = to_raw_value(&rows);
*mem_peak = (r.get().len() / 1000) as i32;
*mem_peak = (raw_result.get().len() / 1000) as i32;

return Ok(to_raw_value(&rows));
Ok(raw_result)
}

fn json_value_to_sql<'a>(
Expand Down Expand Up @@ -221,7 +229,7 @@ fn json_value_to_sql<'a>(
Ok(())
}

fn row_to_json(row: Row) -> Result<Map<String, Value>, Error> {
fn row_to_json(row: Row) -> Result<Value, Error> {
let cols = row
.columns()
.iter()
Expand All @@ -231,7 +239,7 @@ fn row_to_json(row: Row) -> Result<Map<String, Value>, Error> {
for (col, val) in cols.iter().zip(row.into_iter()) {
map.insert(col.name().to_string(), sql_to_json_value(val)?);
}
Ok(map)
Ok(Value::Object(map))
}

fn value_or_null<T>(
Expand Down
Loading

0 comments on commit 3ce5587

Please sign in to comment.