diff --git a/apps/superego/src/functions/pubsub.rs b/apps/superego/src/functions/pubsub.rs index 6bcde116..5aac40ae 100644 --- a/apps/superego/src/functions/pubsub.rs +++ b/apps/superego/src/functions/pubsub.rs @@ -2,9 +2,9 @@ use fred::prelude::PubsubInterface; use serde::Serialize; use serde_json::json; -use crate::{db::redis, error::Result}; +use crate::{db::redis, error::Error}; -pub async fn emit_event(op: &str, data: T, channel: &str) -> Result<()> { +pub async fn emit_event(op: &str, data: T, channel: &str) -> Result<(), Error> { let o = json!({ "op": op, "data": data, diff --git a/apps/superego/src/routes/mod.rs b/apps/superego/src/routes/mod.rs index 3950e4e3..d09961b3 100644 --- a/apps/superego/src/routes/mod.rs +++ b/apps/superego/src/routes/mod.rs @@ -1,4 +1,4 @@ -use std::sync::OnceLock; +use std::sync::{Arc, OnceLock}; use aide::{ axum::{routing::get_with, IntoApiResponse}, @@ -9,9 +9,12 @@ use channels::voice; use router::AppRouter; use schemars::JsonSchema; use serde::Serialize; +use tokio::sync::RwLock; use tower_http::cors::CorsLayer; use ws::state::State; +use crate::functions::pubsub::emit_event; + pub mod account; pub mod bots; pub mod channels; @@ -39,6 +42,11 @@ pub async fn index() -> Json<&'static IndexResponse> { })) } +#[derive(Serialize, Deserialize, JsonSchema)] +pub struct Ping { + pub message: String, +} + pub fn router() -> Router { let router = AppRouter::::new() .on_http(|router| { @@ -71,7 +79,13 @@ pub fn router() -> Router { "/spaces/:space_id/members", spaces::members::router(), ) - .nest("roles", "/spaces/:space_id/roles", spaces::roles::router()); + .nest("roles", "/spaces/:space_id/roles", spaces::roles::router()) + .ws_command("ping", |ping: Ping, state: Arc>| async move { + let reader = state.read().await; + emit_event("pong", ping, &format!("conn:{}", reader.conn_id)).await?; + Ok(()) + }) + .ws_event("pong", |ping: Ping, _| async move { Some(ping) }); router .build(OpenApi { diff --git a/apps/superego/src/routes/router.rs b/apps/superego/src/routes/router.rs index fa982ecd..1c87d1b3 100644 --- a/apps/superego/src/routes/router.rs +++ b/apps/superego/src/routes/router.rs @@ -1,9 +1,16 @@ +use std::{future::Future, sync::Arc}; + use aide::{ axum::{routing::ApiMethodRouter, ApiRouter, IntoApiResponse}, openapi::OpenApi, scalar::Scalar, }; use axum::{Extension, Json, Router}; +use schemars::JsonSchema; +use serde::{de::DeserializeOwned, Serialize}; +use tokio::sync::RwLock; + +use crate::error::Error; use super::ws::{self, schema::WebSocketRouter, WebSocketState}; @@ -52,15 +59,25 @@ impl AppRouter { pub fn ws_event(mut self, name: &str, filter: F) -> Self where - T: serde::Serialize + serde::de::DeserializeOwned, - R: schemars::JsonSchema + serde::Serialize, - Fut: std::future::Future> + Send + Sync, - F: Fn(T, std::sync::Arc>) -> Fut + 'static + Send + Sync + Copy, + T: Serialize + DeserializeOwned, + R: JsonSchema + Serialize, + F: Fn(T, Arc>) -> Fut + 'static + Send + Sync + Copy, + Fut: Future> + Send + Sync, { self.ws = self.ws.event(name, filter); self } + pub fn ws_command(mut self, name: &str, func: F) -> Self + where + T: JsonSchema + DeserializeOwned, + F: Fn(T, Arc>) -> Fut + 'static + Send + Sync + Copy, + Fut: Future> + Send, + { + self.ws = self.ws.command(name, func); + self + } + pub fn build(self, mut api: OpenApi) -> Router { let Self { http, ws } = self; diff --git a/apps/superego/src/routes/ws/mod.rs b/apps/superego/src/routes/ws/mod.rs index 44d9b45b..e61e0cc7 100644 --- a/apps/superego/src/routes/ws/mod.rs +++ b/apps/superego/src/routes/ws/mod.rs @@ -53,8 +53,11 @@ async fn handle_socket( }; let state = Arc::new(RwLock::new(state)); + let s2c_state = state.clone(); + let c2s_state = state.clone(); let mut server_to_client: JoinHandle> = tokio::spawn(async move { + let state = s2c_state; while let Ok(msg) = redis.message_rx().recv().await { let msg = if let Some(x) = msg.value.as_string() { x @@ -89,22 +92,26 @@ async fn handle_socket( Err(Error::WebSocketTerminated) }); - let mut client_to_server = tokio::spawn(async move { + let mut client_to_server: JoinHandle> = tokio::spawn(async move { + let state = c2s_state; + while let Some(msg) = reader.next().await { let msg = if let Ok(ws::Message::Text(msg)) = msg { msg } else { - return; // client disconnected + return Err(Error::WebSocketTerminated); // client disconnected }; - emit_event( - "spaces.onDelete", - ObjectWithId { id: Uuid::new_v4() }, - "all", - ) - .await - .unwrap(); - dbg!(msg.clone()); + let msg: Operation = serde_json::from_str(&msg)?; + let filter = if let Some(filter) = router.command_filters.get(&msg.op) { + filter + } else { + continue; // no filter for this event + }; + + filter(msg.data, state.clone()).await?; } + + Err(Error::WebSocketTerminated) }); tokio::select! { diff --git a/apps/superego/src/routes/ws/schema.rs b/apps/superego/src/routes/ws/schema.rs index b9dd224f..067e667b 100644 --- a/apps/superego/src/routes/ws/schema.rs +++ b/apps/superego/src/routes/ws/schema.rs @@ -1,14 +1,27 @@ use std::{collections::BTreeMap, future::Future, pin::Pin, sync::Arc}; -use jsonwebtoken::errors::Error; use schemars::{schema::Schema, JsonSchema}; use serde::{de::DeserializeOwned, Serialize}; use tokio::sync::RwLock; +use crate::error::Error; + pub struct WebSocketRouter { pub commands: BTreeMap, pub events: BTreeMap, + pub command_filters: BTreeMap< + String, + Box< + dyn Fn( + serde_json::Value, + Arc>, + ) -> Pin> + Send>> + + Send + + Sync, + >, + >, + pub event_filters: BTreeMap< String, Box< @@ -36,16 +49,29 @@ impl WebSocketRouter { Self { commands: BTreeMap::new(), events: BTreeMap::new(), + + command_filters: BTreeMap::new(), event_filters: BTreeMap::new(), } } - pub fn command(mut self, name: &str) -> Self + pub fn command(mut self, name: &str, func: F) -> Self where - T: JsonSchema, + T: JsonSchema + DeserializeOwned, + F: Fn(T, Arc>) -> Fut + 'static + Send + Sync + Copy, + Fut: Future> + Send, { let schema = aide::gen::in_context(|ctx| ctx.schema.subschema_for::()); self.commands.insert(name.to_string(), schema); + self.command_filters.insert( + name.to_string(), + Box::new(move |value, state| { + Box::pin(async move { + let parsed = serde_json::from_value(value)?; + func(parsed, state).await + }) + }), + ); self } @@ -53,8 +79,8 @@ impl WebSocketRouter { where T: Serialize + DeserializeOwned, R: JsonSchema + Serialize, - Fut: Future> + Send + Sync, F: Fn(T, Arc>) -> Fut + 'static + Send + Sync + Copy, + Fut: Future> + Send + Sync, { let schema = aide::gen::in_context(|ctx| ctx.schema.subschema_for::()); self.events.insert(name.to_string(), schema); @@ -78,6 +104,7 @@ impl WebSocketRouter { pub fn merge(mut self, other: Self) -> Self { self.commands.extend(other.commands); self.events.extend(other.events); + self.command_filters.extend(other.command_filters); self.event_filters.extend(other.event_filters); self } @@ -86,6 +113,8 @@ impl WebSocketRouter { let prefix = prefix.trim_end_matches('/'); self.commands.extend(prefix_map(prefix, other.commands)); self.events.extend(prefix_map(prefix, other.events)); + self.command_filters + .extend(prefix_map(prefix, other.command_filters)); self.event_filters .extend(prefix_map(prefix, other.event_filters));