diff --git a/backend/windmill-api/src/websocket_triggers.rs b/backend/windmill-api/src/websocket_triggers.rs index 6e4b8b3c7a330..3bdd51c3e9cee 100644 --- a/backend/windmill-api/src/websocket_triggers.rs +++ b/backend/windmill-api/src/websocket_triggers.rs @@ -330,29 +330,49 @@ async fn exists_websocket_trigger( Ok(Json(exists)) } -pub async fn start_websockets(db: DB, rsmq: Option) -> () { +async fn listen_to_unlistened_websockets( + db: &DB, + rsmq: &Option, + killpill_rx: &tokio::sync::broadcast::Receiver<()>, +) -> () { + match sqlx::query_as!( + WebsocketTrigger, + r#"SELECT * + FROM websocket_trigger + WHERE enabled IS TRUE AND (server_id IS NULL OR last_server_ping IS NULL OR last_server_ping < now() - interval '15 seconds')"# + ) + .fetch_all(db) + .await + { + Ok(mut triggers) => { + triggers.shuffle(&mut rand::thread_rng()); + for trigger in triggers { + maybe_listen_to_websocket(trigger, db.clone(), rsmq.clone(), killpill_rx.resubscribe()).await; + } + } + Err(err) => { + tracing::error!("Error fetching websocket triggers: {:?}", err); + } + }; +} + +pub async fn start_websockets( + db: DB, + rsmq: Option, + mut killpill_rx: tokio::sync::broadcast::Receiver<()>, +) -> () { tokio::spawn(async move { + listen_to_unlistened_websockets(&db, &rsmq, &killpill_rx).await; loop { - match sqlx::query_as!( - WebsocketTrigger, - r#"SELECT * - FROM websocket_trigger - WHERE enabled IS TRUE AND (server_id IS NULL OR last_server_ping IS NULL OR last_server_ping < now() - interval '15 seconds')"# - ) - .fetch_all(&db) - .await - { - Ok(mut triggers) => { - triggers.shuffle(&mut rand::thread_rng()); - for trigger in triggers { - maybe_listen_to_websocket(trigger, db.clone(), rsmq.clone()).await; - } + tokio::select! { + biased; + _ = killpill_rx.recv() => { + return; } - Err(err) => { - tracing::error!("Error fetching websocket triggers: {:?}", err); + _ = tokio::time::sleep(tokio::time::Duration::from_secs(15)) => { + listen_to_unlistened_websockets(&db, &rsmq, &killpill_rx).await; } - }; - tokio::time::sleep(tokio::time::Duration::from_secs(15)).await; + } } }); } @@ -361,6 +381,7 @@ async fn maybe_listen_to_websocket( ws_trigger: WebsocketTrigger, db: DB, rsmq: Option, + killpill_rx: tokio::sync::broadcast::Receiver<()>, ) -> () { match sqlx::query_scalar!( "UPDATE websocket_trigger SET server_id = $1, last_server_ping = now() WHERE enabled IS TRUE AND workspace_id = $2 AND path = $3 AND (server_id IS NULL OR last_server_ping IS NULL OR last_server_ping < now() - interval '15 seconds') RETURNING true", @@ -370,7 +391,7 @@ async fn maybe_listen_to_websocket( ).fetch_optional(&db).await { Ok(has_lock) => { if has_lock.flatten().unwrap_or(false) { - tokio::spawn(listen_to_websocket(ws_trigger, db, rsmq)); + tokio::spawn(listen_to_websocket(ws_trigger, db, rsmq, killpill_rx)); } else { tracing::info!("Websocket {} already being listened to", ws_trigger.url); } @@ -453,6 +474,7 @@ async fn listen_to_websocket( ws_trigger: WebsocketTrigger, db: DB, rsmq: Option, + mut killpill_rx: tokio::sync::broadcast::Receiver<()>, ) -> () { async fn update_ping(db: DB, ws_trigger: &WebsocketTrigger, error: Option<&str>) -> Option<()> { match sqlx::query_scalar!( @@ -496,84 +518,109 @@ async fn listen_to_websocket( .collect_vec(); loop { - match connect_async(url).await { - Ok((ws_stream, _)) => { - tracing::info!("Listening to websocket {}", url); - if let None = update_ping(db.clone(), &ws_trigger, None).await { - return; - } - let (_, mut read) = ws_stream.split(); - loop { - tokio::select! { - msg = read.next() => { - if let Some(msg) = msg { - match msg { - Ok(msg) => { + tokio::select! { + biased; + _ = killpill_rx.recv() => { + return; + }, + connection = connect_async(url) => { + match connection { + Ok((ws_stream, _)) => { + tracing::info!("Listening to websocket {}", url); + if let None = update_ping(db.clone(), &ws_trigger, None).await { + return; + } + let mut last_ping = tokio::time::Instant::now(); + let (_, mut read) = ws_stream.split(); + loop { + tokio::select! { + biased; + _ = killpill_rx.recv() => { + return; + } + msg = read.next() => { + if let Some(msg) = msg { + if last_ping.elapsed() > tokio::time::Duration::from_secs(5) { + if let None = update_ping(db.clone(), &ws_trigger, None).await { + return; + } + last_ping = tokio::time::Instant::now(); + } match msg { - tokio_tungstenite::tungstenite::Message::Text(text) => { - let mut should_handle = true; - for filter in &filters { - match filter { - Filter::JsonFilter(JsonFilter { key, value }) => { - let mut deserializer = serde_json::Deserializer::from_str(text.as_str()); - should_handle = match is_value_superset(&mut deserializer, key, &value) { - Ok(filter_match) => { - filter_match - }, - Err(err) => { - tracing::warn!("Error deserializing filter for websocket {}: {:?}", url, err); - false + Ok(msg) => { + match msg { + tokio_tungstenite::tungstenite::Message::Text(text) => { + let mut should_handle = true; + for filter in &filters { + match filter { + Filter::JsonFilter(JsonFilter { key, value }) => { + let mut deserializer = serde_json::Deserializer::from_str(text.as_str()); + should_handle = match is_value_superset(&mut deserializer, key, &value) { + Ok(filter_match) => { + filter_match + }, + Err(err) => { + tracing::warn!("Error deserializing filter for websocket {}: {:?}", url, err); + false + } + }; } - }; + } + if !should_handle { + break; + } } - } - if !should_handle { - break; - } - } - if should_handle { - let db_ = db.clone(); - let rsmq_ = rsmq.clone(); - let ws_trigger_ = ws_trigger.clone(); - tokio::spawn(async move { - let url = ws_trigger_.url.clone(); - if let Err(err) = run_job(db_, rsmq_, ws_trigger_, text).await { - tracing::error!("Error running job on websocket {}: {:?}", url, err); - }; - }); + if should_handle { + let db_ = db.clone(); + let rsmq_ = rsmq.clone(); + let ws_trigger_ = ws_trigger.clone(); + tokio::spawn(async move { + let url = ws_trigger_.url.clone(); + if let Err(err) = run_job(db_, rsmq_, ws_trigger_, text).await { + tracing::error!("Error running job on websocket {}: {:?}", url, err); + }; + }); + } + }, + _ => {} } }, - _ => {} + Err(err) => { + tracing::error!("Error reading from websocket {}: {:?}", url, err); + } } - }, - Err(err) => { - tracing::error!("Error reading from websocket {}: {:?}", url, err); + } else { + tracing::error!("Websocket {} closed", url); + if let None = + update_ping(db.clone(), &ws_trigger, Some("Websocket cloesd")).await + { + return; + } + tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; + break; } - } - } else { - tracing::error!("Websocket {} closed, reconnecting in 5s...", url); - break; - } - }, - _ = tokio::time::sleep(tokio::time::Duration::from_secs(5)) => { - if let None = update_ping(db.clone(), &ws_trigger, None).await { - return; + }, + _ = tokio::time::sleep(tokio::time::Duration::from_secs(5)) => { + if let None = update_ping(db.clone(), &ws_trigger, None).await { + return; + } + last_ping = tokio::time::Instant::now(); + }, } - }, + } + } + Err(err) => { + tracing::error!("Error connecting to websocket {}: {:?}", url, err); + if let None = + update_ping(db.clone(), &ws_trigger, Some(err.to_string().as_str())).await + { + return; + } + tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; } - } - } - Err(err) => { - tracing::error!("Error connecting to websocket {}: {:?}", url, err); - if let None = - update_ping(db.clone(), &ws_trigger, Some(err.to_string().as_str())).await - { - return; } } } - - tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; } }