Skip to content

Commit

Permalink
Merge pull request #567 from mikotoIO/develop
Browse files Browse the repository at this point in the history
feat: properly implement client -> server websocket commands
  • Loading branch information
TheCactusBlue authored Sep 16, 2024
2 parents 30927d7 + b67a487 commit 40c9547
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 22 deletions.
4 changes: 2 additions & 2 deletions apps/superego/src/functions/pubsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use fred::prelude::PubsubInterface;
use serde::Serialize;
use serde_json::json;

use crate::{db::redis, error::Result};
use crate::{db::redis, error::Error};

pub async fn emit_event<T: Serialize>(op: &str, data: T, channel: &str) -> Result<()> {
pub async fn emit_event<T: Serialize>(op: &str, data: T, channel: &str) -> Result<(), Error> {
let o = json!({
"op": op,
"data": data,
Expand Down
18 changes: 16 additions & 2 deletions apps/superego/src/routes/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::OnceLock;
use std::sync::{Arc, OnceLock};

use aide::{
axum::{routing::get_with, IntoApiResponse},
Expand All @@ -9,9 +9,12 @@ use channels::voice;
use router::AppRouter;
use schemars::JsonSchema;
use serde::Serialize;
use tokio::sync::RwLock;
use tower_http::cors::CorsLayer;
use ws::state::State;

use crate::functions::pubsub::emit_event;

pub mod account;
pub mod bots;
pub mod channels;
Expand Down Expand Up @@ -39,6 +42,11 @@ pub async fn index() -> Json<&'static IndexResponse> {
}))
}

#[derive(Serialize, Deserialize, JsonSchema)]
pub struct Ping {
pub message: String,
}

pub fn router() -> Router {
let router = AppRouter::<State>::new()
.on_http(|router| {
Expand Down Expand Up @@ -71,7 +79,13 @@ pub fn router() -> Router {
"/spaces/:space_id/members",
spaces::members::router(),
)
.nest("roles", "/spaces/:space_id/roles", spaces::roles::router());
.nest("roles", "/spaces/:space_id/roles", spaces::roles::router())
.ws_command("ping", |ping: Ping, state: Arc<RwLock<State>>| async move {
let reader = state.read().await;
emit_event("pong", ping, &format!("conn:{}", reader.conn_id)).await?;
Ok(())
})
.ws_event("pong", |ping: Ping, _| async move { Some(ping) });

router
.build(OpenApi {
Expand Down
25 changes: 21 additions & 4 deletions apps/superego/src/routes/router.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
use std::{future::Future, sync::Arc};

use aide::{
axum::{routing::ApiMethodRouter, ApiRouter, IntoApiResponse},
openapi::OpenApi,
scalar::Scalar,
};
use axum::{Extension, Json, Router};
use schemars::JsonSchema;
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::RwLock;

use crate::error::Error;

use super::ws::{self, schema::WebSocketRouter, WebSocketState};

Expand Down Expand Up @@ -52,15 +59,25 @@ impl<W: WebSocketState> AppRouter<W> {

pub fn ws_event<T, R, F, Fut>(mut self, name: &str, filter: F) -> Self
where
T: serde::Serialize + serde::de::DeserializeOwned,
R: schemars::JsonSchema + serde::Serialize,
Fut: std::future::Future<Output = Option<R>> + Send + Sync,
F: Fn(T, std::sync::Arc<tokio::sync::RwLock<W>>) -> Fut + 'static + Send + Sync + Copy,
T: Serialize + DeserializeOwned,
R: JsonSchema + Serialize,
F: Fn(T, Arc<RwLock<W>>) -> Fut + 'static + Send + Sync + Copy,
Fut: Future<Output = Option<R>> + Send + Sync,
{
self.ws = self.ws.event(name, filter);
self
}

pub fn ws_command<T, F, Fut>(mut self, name: &str, func: F) -> Self
where
T: JsonSchema + DeserializeOwned,
F: Fn(T, Arc<RwLock<W>>) -> Fut + 'static + Send + Sync + Copy,
Fut: Future<Output = Result<(), Error>> + Send,
{
self.ws = self.ws.command(name, func);
self
}

pub fn build(self, mut api: OpenApi) -> Router {
let Self { http, ws } = self;

Expand Down
27 changes: 17 additions & 10 deletions apps/superego/src/routes/ws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,11 @@ async fn handle_socket<S: WebSocketState>(
};

let state = Arc::new(RwLock::new(state));
let s2c_state = state.clone();
let c2s_state = state.clone();

let mut server_to_client: JoinHandle<Result<(), Error>> = tokio::spawn(async move {
let state = s2c_state;
while let Ok(msg) = redis.message_rx().recv().await {
let msg = if let Some(x) = msg.value.as_string() {
x
Expand Down Expand Up @@ -89,22 +92,26 @@ async fn handle_socket<S: WebSocketState>(
Err(Error::WebSocketTerminated)
});

let mut client_to_server = tokio::spawn(async move {
let mut client_to_server: JoinHandle<Result<(), Error>> = tokio::spawn(async move {
let state = c2s_state;

while let Some(msg) = reader.next().await {
let msg = if let Ok(ws::Message::Text(msg)) = msg {
msg
} else {
return; // client disconnected
return Err(Error::WebSocketTerminated); // client disconnected
};
emit_event(
"spaces.onDelete",
ObjectWithId { id: Uuid::new_v4() },
"all",
)
.await
.unwrap();
dbg!(msg.clone());
let msg: Operation = serde_json::from_str(&msg)?;
let filter = if let Some(filter) = router.command_filters.get(&msg.op) {
filter
} else {
continue; // no filter for this event
};

filter(msg.data, state.clone()).await?;
}

Err(Error::WebSocketTerminated)
});

tokio::select! {
Expand Down
37 changes: 33 additions & 4 deletions apps/superego/src/routes/ws/schema.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
use std::{collections::BTreeMap, future::Future, pin::Pin, sync::Arc};

use jsonwebtoken::errors::Error;
use schemars::{schema::Schema, JsonSchema};
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::RwLock;

use crate::error::Error;

pub struct WebSocketRouter<S> {
pub commands: BTreeMap<String, Schema>,
pub events: BTreeMap<String, Schema>,

pub command_filters: BTreeMap<
String,
Box<
dyn Fn(
serde_json::Value,
Arc<RwLock<S>>,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>
+ Send
+ Sync,
>,
>,

pub event_filters: BTreeMap<
String,
Box<
Expand Down Expand Up @@ -36,25 +49,38 @@ impl<S: 'static + Send + Sync> WebSocketRouter<S> {
Self {
commands: BTreeMap::new(),
events: BTreeMap::new(),

command_filters: BTreeMap::new(),
event_filters: BTreeMap::new(),
}
}

pub fn command<T>(mut self, name: &str) -> Self
pub fn command<T, F, Fut>(mut self, name: &str, func: F) -> Self
where
T: JsonSchema,
T: JsonSchema + DeserializeOwned,
F: Fn(T, Arc<RwLock<S>>) -> Fut + 'static + Send + Sync + Copy,
Fut: Future<Output = Result<(), Error>> + Send,
{
let schema = aide::gen::in_context(|ctx| ctx.schema.subschema_for::<T>());
self.commands.insert(name.to_string(), schema);
self.command_filters.insert(
name.to_string(),
Box::new(move |value, state| {
Box::pin(async move {
let parsed = serde_json::from_value(value)?;
func(parsed, state).await
})
}),
);
self
}

pub fn event<T, R, F, Fut>(mut self, name: &str, filter: F) -> Self
where
T: Serialize + DeserializeOwned,
R: JsonSchema + Serialize,
Fut: Future<Output = Option<R>> + Send + Sync,
F: Fn(T, Arc<RwLock<S>>) -> Fut + 'static + Send + Sync + Copy,
Fut: Future<Output = Option<R>> + Send + Sync,
{
let schema = aide::gen::in_context(|ctx| ctx.schema.subschema_for::<R>());
self.events.insert(name.to_string(), schema);
Expand All @@ -78,6 +104,7 @@ impl<S: 'static + Send + Sync> WebSocketRouter<S> {
pub fn merge(mut self, other: Self) -> Self {
self.commands.extend(other.commands);
self.events.extend(other.events);
self.command_filters.extend(other.command_filters);
self.event_filters.extend(other.event_filters);
self
}
Expand All @@ -86,6 +113,8 @@ impl<S: 'static + Send + Sync> WebSocketRouter<S> {
let prefix = prefix.trim_end_matches('/');
self.commands.extend(prefix_map(prefix, other.commands));
self.events.extend(prefix_map(prefix, other.events));
self.command_filters
.extend(prefix_map(prefix, other.command_filters));
self.event_filters
.extend(prefix_map(prefix, other.event_filters));

Expand Down

0 comments on commit 40c9547

Please sign in to comment.