diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index db186ebdc..1bad88c65 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -131,6 +131,23 @@ jobs: - run: cargo build -p musli --no-default-features --features ${{matrix.base}},simdutf8 - run: cargo build -p musli --no-default-features --features ${{matrix.base}},parse-full + crate_features: + needs: [rustfmt, clippy] + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + crate: + - musli-axum + env: + RUSTFLAGS: -D warnings + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - run: cargo build -p ${{matrix.crate}} --no-default-features + - run: cargo build -p ${{matrix.crate}} --no-default-features --features alloc + - run: cargo build -p ${{matrix.crate}} --no-default-features --features std + recursive: runs-on: ubuntu-latest steps: diff --git a/crates/musli-axum/Cargo.toml b/crates/musli-axum/Cargo.toml new file mode 100644 index 000000000..6e2c991e3 --- /dev/null +++ b/crates/musli-axum/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "musli-axum" +version = "0.0.122" +authors = ["John-John Tedro "] +edition = "2021" +description = """ +Types for integrating Müsli with axum. +""" +documentation = "https://docs.rs/musli" +readme = "README.md" +homepage = "https://github.com/udoprog/musli" +repository = "https://github.com/udoprog/musli" +license = "MIT OR Apache-2.0" +keywords = ["framework", "http", "web"] +categories = ["asynchronous", "network-programming", "web-programming::http-server"] + +[features] +default = ["alloc", "std", "ws", "json"] +alloc = ["musli/alloc"] +std = ["musli/std"] +api = [] +json = ["musli/json", "axum/json", "dep:bytes", "dep:mime"] +ws = ["api", "axum/ws", "dep:rand", "tokio/time", "dep:tokio-stream"] + +[dependencies] +musli = { path = "../musli", version = "0.0.122", default-features = false } +axum = { version = "0.7.5", default-features = false, optional = true } +bytes = { version = "1.6.0", optional = true } +mime = { version = "0.3.17", default-features = false, optional = true } +rand = { version = "0.8.5", default-features = false, optional = true, features = ["small_rng"] } +tracing = { version = "0.1.40", default-features = false } +tokio = { version = "1.37.0", default-features = false, features = ["time"], optional = true } +tokio-stream = { version = "0.1.15", default-features = false, optional = true } diff --git a/crates/musli-axum/README.md b/crates/musli-axum/README.md new file mode 100644 index 000000000..e54f5c6a4 --- /dev/null +++ b/crates/musli-axum/README.md @@ -0,0 +1,11 @@ +# musli-axum + +[github](https://github.com/udoprog/musli) +[crates.io](https://crates.io/crates/musli-axum) +[docs.rs](https://docs.rs/musli-axum) +[build status](https://github.com/udoprog/musli/actions?query=branch%3Amain) + +This crate provides a set of utilities for working with [Axum] and [Müsli]. + +[Axum]: https://github.com/tokio-rs/axum +[Müsli]: https://github.com/udoprog/musli diff --git a/crates/musli-axum/src/api.rs b/crates/musli-axum/src/api.rs new file mode 100644 index 000000000..5dc8b0606 --- /dev/null +++ b/crates/musli-axum/src/api.rs @@ -0,0 +1,45 @@ +//! Shared traits for defining API types. + +use musli::mode::Binary; +use musli::{Decode, Encode}; + +/// A marker indicating a decodable type. +pub trait Marker: 'static { + /// The type that can be decoded. + type Type<'de>: Decode<'de, Binary>; +} + +/// Trait governing requests. +pub trait Request: Encode { + /// The kind of the request. + const KIND: &'static str; + + /// Type acting as a token for the response. + type Marker: Marker; +} + +/// A broadcast type marker. +pub trait Broadcast: Marker { + /// The kind of the broadcast being subscribed to. + const KIND: &'static str; +} + +#[derive(Debug, Clone, Copy, Encode, Decode)] +pub struct RequestHeader<'a> { + pub index: u32, + pub serial: u32, + /// The kind of the request. + pub kind: &'a str, +} + +#[derive(Debug, Clone, Encode, Decode)] +pub struct ResponseHeader<'de> { + pub index: u32, + pub serial: u32, + /// The response is a broadcast. + #[musli(default, skip_encoding_if = Option::is_none)] + pub broadcast: Option<&'de str>, + /// An error message in the response. + #[musli(default, skip_encoding_if = Option::is_none)] + pub error: Option<&'de str>, +} diff --git a/crates/musli-axum/src/json.rs b/crates/musli-axum/src/json.rs new file mode 100644 index 000000000..092ef9088 --- /dev/null +++ b/crates/musli-axum/src/json.rs @@ -0,0 +1,155 @@ +use alloc::boxed::Box; +use alloc::string::{String, ToString}; + +use axum::async_trait; +use axum::extract::rejection::BytesRejection; +use axum::extract::{FromRequest, Request}; +use axum::http::header::{self, HeaderValue}; +use axum::http::{HeaderMap, StatusCode}; +use axum::response::{IntoResponse, Response}; +use bytes::{BufMut, Bytes, BytesMut}; +use musli::de::DecodeOwned; +use musli::json::Encoding; +use musli::mode::Text; +use musli::Encode; + +const ENCODING: Encoding = Encoding::new(); + +/// A rejection from the JSON extractor. +pub enum JsonRejection { + ContentType, + Report(String), + BytesRejection(BytesRejection), +} + +impl From for JsonRejection { + #[inline] + fn from(rejection: BytesRejection) -> Self { + JsonRejection::BytesRejection(rejection) + } +} + +impl IntoResponse for JsonRejection { + fn into_response(self) -> Response { + let status; + let body; + + match self { + JsonRejection::ContentType => { + status = StatusCode::UNSUPPORTED_MEDIA_TYPE; + body = String::from("Expected request with `Content-Type: application/json`"); + } + JsonRejection::Report(report) => { + status = StatusCode::BAD_REQUEST; + body = report; + } + JsonRejection::BytesRejection(rejection) => { + return rejection.into_response(); + } + } + + ( + status, + [( + header::CONTENT_TYPE, + HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()), + )], + body, + ) + .into_response() + } +} + +/// Encode the given value as JSON. +pub struct Json(pub T); + +#[async_trait] +impl FromRequest for Json +where + T: DecodeOwned, + S: Send + Sync, +{ + type Rejection = JsonRejection; + + async fn from_request(req: Request, state: &S) -> Result { + if !json_content_type(req.headers()) { + return Err(JsonRejection::ContentType); + } + + let bytes = Bytes::from_request(req, state).await?; + Self::from_bytes(&bytes) + } +} + +fn json_content_type(headers: &HeaderMap) -> bool { + let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) { + content_type + } else { + return false; + }; + + let content_type = if let Ok(content_type) = content_type.to_str() { + content_type + } else { + return false; + }; + + let mime = if let Ok(mime) = content_type.parse::() { + mime + } else { + return false; + }; + + let is_json_content_type = mime.type_() == "application" + && (mime.subtype() == "json" || mime.suffix().map_or(false, |name| name == "json")); + + is_json_content_type +} + +impl IntoResponse for Json +where + T: Encode, +{ + fn into_response(self) -> Response { + // Use a small initial capacity of 128 bytes like serde_json::to_vec + // https://docs.rs/serde_json/1.0.82/src/serde_json/ser.rs.html#2189 + let mut buf = BytesMut::with_capacity(128).writer(); + + match ENCODING.to_writer(&mut buf, &self.0) { + Ok(()) => ( + [( + header::CONTENT_TYPE, + HeaderValue::from_static(mime::APPLICATION_JSON.as_ref()), + )], + buf.into_inner().freeze(), + ) + .into_response(), + Err(err) => ( + StatusCode::INTERNAL_SERVER_ERROR, + [( + header::CONTENT_TYPE, + HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()), + )], + err.to_string(), + ) + .into_response(), + } + } +} + +impl Json +where + T: DecodeOwned, +{ + fn from_bytes(bytes: &[u8]) -> Result { + let cx = musli::context::SystemContext::new(); + + if let Ok(value) = ENCODING.from_slice_with(&cx, bytes) { + return Ok(Json(value)); + } + + let report = cx.report(); + let report = report.to_string(); + Err(JsonRejection::Report(report)) + } +} diff --git a/crates/musli-axum/src/lib.rs b/crates/musli-axum/src/lib.rs new file mode 100644 index 000000000..4ff5e141d --- /dev/null +++ b/crates/musli-axum/src/lib.rs @@ -0,0 +1,27 @@ +//! [github](https://github.com/udoprog/musli) +//! [crates.io](https://crates.io/crates/musli-axum) +//! [docs.rs](https://docs.rs/musli-axum) +//! +//! This crate provides a set of utilities for working with [Axum] and [Müsli]. +//! +//! [Axum]: https://github.com/tokio-rs/axum +//! [Müsli]: https://github.com/udoprog/musli + +#![no_std] + +#[cfg(feature = "std")] +extern crate std; + +#[cfg(feature = "alloc")] +extern crate alloc; + +#[cfg(all(feature = "json", feature = "alloc"))] +mod json; +#[cfg(all(feature = "json", feature = "alloc"))] +pub use self::json::Json; + +#[cfg(feature = "api")] +pub mod api; + +#[cfg(all(feature = "ws", feature = "api", feature = "alloc"))] +pub mod ws; diff --git a/crates/musli-axum/src/ws.rs b/crates/musli-axum/src/ws.rs new file mode 100644 index 000000000..69ede047e --- /dev/null +++ b/crates/musli-axum/src/ws.rs @@ -0,0 +1,374 @@ +use core::fmt::{self, Write}; + +use alloc::borrow::Cow; +use alloc::boxed::Box; +use alloc::string::String; +use alloc::vec::Vec; + +use axum::extract::ws::{CloseFrame, Message, WebSocket}; +use musli::mode::Binary; +use musli::reader::SliceReader; +use musli::{Decode, Encode}; +use rand::prelude::*; +use rand::rngs::SmallRng; +use tokio::time::Duration; +use tokio_stream::StreamExt; + +use crate::api; + +const MAX_CAPACITY: usize = 1048576; + +enum OneOf { + Error(Error), + Handler(E), +} + +impl From for OneOf { + #[inline] + fn from(error: Error) -> Self { + Self::Error(error) + } +} + +impl From for OneOf { + #[inline] + fn from(error: musli::storage::Error) -> Self { + Self::Error(Error::from(error)) + } +} + +impl fmt::Display for OneOf +where + E: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + OneOf::Error(error) => error.fmt(f), + OneOf::Handler(error) => error.fmt(f), + } + } +} + +#[derive(Debug)] +enum ErrorKind { + Axum { error: axum::Error }, + Musli { error: musli::storage::Error }, + UnknownRequest { kind: Box }, + FormatError, +} + +#[derive(Debug)] +pub struct Error { + kind: ErrorKind, +} + +impl Error { + const fn new(kind: ErrorKind) -> Self { + Self { kind } + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.kind { + ErrorKind::Axum { .. } => write!(f, "Error in axum"), + ErrorKind::Musli { .. } => write!(f, "Error in musli"), + ErrorKind::UnknownRequest { kind } => { + write!(f, "Unknown request kind: {kind}") + } + ErrorKind::FormatError => write!(f, "Error formatting error response"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match &self.kind { + ErrorKind::Axum { error } => Some(error), + ErrorKind::Musli { error } => Some(error), + _ => None, + } + } +} + +impl From for Error { + #[inline] + fn from(error: axum::Error) -> Self { + Self::new(ErrorKind::Axum { error }) + } +} + +impl From for Error { + #[inline] + fn from(error: musli::storage::Error) -> Self { + Self::new(ErrorKind::Musli { error }) + } +} + +type Result = core::result::Result; + +/// A handler for incoming requests. +#[axum::async_trait] +pub trait Handler: Send + Sync { + /// Error returned by handler. + type Error: 'static + Send + Sync; + + async fn handle( + &mut self, + incoming: &mut Incoming<'_>, + outgoing: &mut Outgoing<'_>, + kind: &str, + ) -> Result<(), Self::Error>; +} + +/// The server side of the websocket connection. +pub struct Server<'a, E> { + buf: Buf, + error: String, + socket: WebSocket, + handler: Box + 'a>, +} + +impl<'a, E> Server<'a, E> +where + E: 'static + Send + Sync + fmt::Display, +{ + /// Construct a new server with the specified handler. + pub fn new(socket: WebSocket, handler: H) -> Self + where + H: 'a + Handler, + { + Self { + buf: Buf::default(), + error: String::new(), + socket, + handler: Box::new(handler), + } + } + + /// Run the server. + pub async fn run(&mut self) -> Result<(), Error> { + tracing::trace!("Accepted"); + + const CLOSE_NORMAL: u16 = 1000; + const CLOSE_PROTOCOL_ERROR: u16 = 1002; + const CLOSE_TIMEOUT: Duration = Duration::from_secs(30); + const PING_TIMEOUT: Duration = Duration::from_secs(10); + + let mut last_ping = None::; + let mut rng = SmallRng::seed_from_u64(0x404241112); + let mut close_interval = tokio::time::interval(CLOSE_TIMEOUT); + close_interval.reset(); + + let mut ping_interval = tokio::time::interval(PING_TIMEOUT); + ping_interval.reset(); + + let close_here = loop { + tokio::select! { + _ = close_interval.tick() => { + break Some((CLOSE_NORMAL, "connection timed out")); + } + _ = ping_interval.tick() => { + let payload = rng.gen::(); + last_ping = Some(payload); + let data = payload.to_ne_bytes().into_iter().collect::>(); + tracing::trace!(data = ?&data[..], "Sending ping"); + self.socket.send(Message::Ping(data)).await?; + ping_interval.reset(); + } + message = self.socket.next() => { + let Some(message) = message else { + break None; + }; + + match message? { + Message::Text(_) => break Some((CLOSE_PROTOCOL_ERROR, "unsupported message")), + Message::Binary(bytes) => { + let mut reader = SliceReader::new(&bytes); + + let header = match musli::storage::decode(&mut reader) { + Ok(header) => header, + Err(error) => { + tracing::warn!(?error, "Failed to decode request header"); + break Some((CLOSE_PROTOCOL_ERROR, "invalid request")); + } + }; + + match self.handle_request(reader, header).await { + Ok(()) => { + self.flush().await?; + }, + Err(error) => { + if write!(self.error, "{error}").is_err() { + return Err(Error::new(ErrorKind::FormatError)); + } + + self.buf.buffer.clear(); + + self.buf.write(api::ResponseHeader { + index: header.index, + serial: header.serial, + broadcast: None, + error: Some(self.error.as_str()), + })?; + + self.flush().await?; + } + } + }, + Message::Ping(payload) => { + self.socket.send(Message::Pong(payload)).await?; + continue; + }, + Message::Pong(data) => { + tracing::trace!(data = ?&data[..], "Pong"); + + let Some(expected) = last_ping else { + continue; + }; + + if expected.to_ne_bytes()[..] != data[..] { + continue; + } + + close_interval.reset(); + ping_interval.reset(); + last_ping = None; + }, + Message::Close(_) => break None, + } + } + } + }; + + if let Some((code, reason)) = close_here { + tracing::trace!(code, reason, "Closing websocket with reason"); + + self.socket + .send(Message::Close(Some(CloseFrame { + code, + reason: Cow::Borrowed(reason), + }))) + .await?; + } else { + tracing::trace!("Closing websocket"); + }; + + Ok(()) + } + + async fn flush(&mut self) -> Result<()> { + self.socket + .send(Message::Binary(self.buf.buffer.to_vec())) + .await?; + self.error.clear(); + self.buf.buffer.clear(); + self.buf.buffer.shrink_to(MAX_CAPACITY); + Ok(()) + } + + async fn handle_request( + &mut self, + reader: SliceReader<'_>, + header: api::RequestHeader<'_>, + ) -> Result<(), OneOf> { + tracing::trace!("Got request: {header:?}"); + + self.buf.write(api::ResponseHeader { + index: header.index, + serial: header.serial, + broadcast: None, + error: None, + })?; + + let mut incoming = Incoming { + error: None, + reader, + }; + + let mut outgoing = Outgoing { + error: None, + written: false, + buf: &mut self.buf, + }; + + if let Err(error) = self + .handler + .handle(&mut incoming, &mut outgoing, header.kind) + .await + { + return Err(OneOf::Handler(error)); + } + + if let Some(error) = incoming.error.take() { + return Err(OneOf::Error(Error::new(ErrorKind::Musli { error }))); + } + + if !outgoing.written { + return Err(OneOf::Error(Error::new(ErrorKind::UnknownRequest { + kind: header.kind.into(), + }))); + } + + Ok(()) + } +} + +/// An incoming request. +pub struct Incoming<'de> { + error: Option, + reader: SliceReader<'de>, +} + +impl<'de> Incoming<'de> { + /// Read a request. + pub fn read(&mut self) -> Option + where + T: Decode<'de, Binary>, + { + match musli::storage::decode(&mut self.reader) { + Ok(value) => Some(value), + Err(error) => { + self.error = Some(error); + None + } + } + } +} + +/// Handler for an outgoing buffer. +pub struct Outgoing<'a> { + error: Option, + written: bool, + buf: &'a mut Buf, +} + +impl Outgoing<'_> { + /// Write a response. + pub fn write(&mut self, value: T) + where + T: Encode, + { + if let Err(error) = self.buf.write(value) { + self.error = Some(error); + } else { + self.written = true; + } + } +} + +#[derive(Default)] +pub struct Buf { + buffer: Vec, +} + +impl Buf { + fn write(&mut self, value: T) -> Result<(), musli::storage::Error> + where + T: Encode, + { + musli::storage::to_writer(&mut self.buffer, &value)?; + Ok(()) + } +} diff --git a/crates/musli-yew/Cargo.toml b/crates/musli-yew/Cargo.toml new file mode 100644 index 000000000..eacdbe494 --- /dev/null +++ b/crates/musli-yew/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "musli-yew" +version = "0.0.122" +authors = ["John-John Tedro "] +edition = "2021" +description = """ +Types for integrating Müsli with yew. +""" +documentation = "https://docs.rs/musli" +readme = "README.md" +homepage = "https://github.com/udoprog/musli" +repository = "https://github.com/udoprog/musli" +license = "MIT OR Apache-2.0" +keywords = ["javascript", "web", "webasm"] +categories = ["gui", "wasm", "web-programming"] + +[dependencies] +musli-axum = { version = "0.0.122", path = "../musli-axum", default-features = false, features = ["api"] } + +log = "0.4.21" +slab = "0.4.9" +wasm-bindgen = "0.2.92" +web-sys = { version = "0.3.69", features = ["WebSocket", "MessageEvent", "Performance"] } +yew = "0.21.0" +musli = { version = "0.0.122", path = "../musli", features = ["storage"] } +gloo = { version = "0.11.0", features = ["timers"], default-features = false } diff --git a/crates/musli-yew/README.md b/crates/musli-yew/README.md new file mode 100644 index 000000000..9d1470d60 --- /dev/null +++ b/crates/musli-yew/README.md @@ -0,0 +1,11 @@ +# musli-yew + +[github](https://github.com/udoprog/musli) +[crates.io](https://crates.io/crates/musli-yew) +[docs.rs](https://docs.rs/musli-yew) +[build status](https://github.com/udoprog/musli/actions?query=branch%3Amain) + +This crate provides a set of utilities for working with [Yew] and [Müsli]. + +[Yew]: https://yew.rs +[Müsli]: https://github.com/udoprog/musli diff --git a/crates/musli-yew/src/lib.rs b/crates/musli-yew/src/lib.rs new file mode 100644 index 000000000..da9667545 --- /dev/null +++ b/crates/musli-yew/src/lib.rs @@ -0,0 +1,10 @@ +//! [github](https://github.com/udoprog/musli) +//! [crates.io](https://crates.io/crates/musli-yew) +//! [docs.rs](https://docs.rs/musli-yew) +//! +//! This crate provides a set of utilities for working with [Yew] and [Müsli]. +//! +//! [Yew]: https://yew.rs +//! [Müsli]: https://github.com/udoprog/musli + +pub mod ws; diff --git a/crates/musli-yew/src/ws.rs b/crates/musli-yew/src/ws.rs new file mode 100644 index 000000000..91856347c --- /dev/null +++ b/crates/musli-yew/src/ws.rs @@ -0,0 +1,727 @@ +use core::fmt; +use std::cell::{Cell, RefCell}; +use std::collections::{hash_map, HashMap}; +use std::marker::PhantomData; +use std::mem::take; +use std::rc::Rc; + +use gloo::timers::callback::Timeout; +use musli_axum::api; +use slab::Slab; +use wasm_bindgen::closure::Closure; +use wasm_bindgen::{JsCast, JsValue}; +use web_sys::js_sys::{ArrayBuffer, Uint8Array}; +use web_sys::{window, BinaryType, CloseEvent, ErrorEvent, MessageEvent, WebSocket}; +use yew::html::ImplicitClone; +use yew::{Component, Context}; + +const MAX_CAPACITY: usize = 1048576; + +/// The state of the connection. +/// +/// A listener for state changes can be set up through +/// [`Handle::state_changes`]. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[non_exhaustive] +pub enum State { + /// The connection is open. + Open, + /// The connection is closed. + Closed, +} + +/// Error type for the WebSocket service. +#[derive(Debug)] +pub struct Error { + kind: ErrorKind, +} + +#[derive(Debug)] +enum ErrorKind { + Message(String), + Storage(musli::storage::Error), + Overflow(usize, usize), +} + +impl Error { + #[inline] + fn new(kind: ErrorKind) -> Self { + Self { kind } + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.kind { + ErrorKind::Message(message) => write!(f, "{message}"), + ErrorKind::Storage(error) => write!(f, "Encoding error: {error}"), + ErrorKind::Overflow(at, len) => { + write!(f, "Internal packet overflow, {at} not in range 0-{len}") + } + } + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match &self.kind { + ErrorKind::Storage(error) => Some(error), + _ => None, + } + } +} + +impl From for Error { + fn from(error: musli::storage::Error) -> Self { + Self::new(ErrorKind::Storage(error)) + } +} + +impl From for Error { + fn from(error: JsValue) -> Self { + Self::new(ErrorKind::Message(format!("{error:?}"))) + } +} + +impl From<&str> for Error { + fn from(error: &str) -> Self { + Self::new(ErrorKind::Message(error.to_string())) + } +} + +type Result = core::result::Result; + +const INITIAL_TIMEOUT: u32 = 250; +const MAX_TIMEOUT: u32 = 16000; + +struct ClientRequest<'a> { + header: api::RequestHeader<'a>, + body: Vec, +} + +enum MsgKind { + Reconnect, + Open, + Close(CloseEvent), + Message(MessageEvent), + Error(ErrorEvent), + ClientRequest(ClientRequest<'static>), +} + +/// A message passed into the WebSocket service. +pub struct Msg { + kind: MsgKind, +} + +impl Msg { + #[inline] + const fn new(kind: MsgKind) -> Self { + Self { kind } + } +} + +/// The WebSocket service. +pub struct Service { + shared: Rc, + socket: Option, + opened: Option, + state: State, + buffer: Vec>, + output: Vec, + timeout: u32, + on_open: Closure, + on_close: Closure, + on_message: Closure, + on_error: Closure, + _timeout: Option, + _ping_timeout: Option, + _marker: PhantomData, +} + +impl Service +where + C: Component, + C::Message: From + From, +{ + /// Construct a new websocket service, and return it and an associated + /// handle to it. + pub fn new(ctx: &Context) -> (Self, Handle) { + let link = ctx.link().clone(); + + let shared = Rc::new(Shared { + serial: Cell::new(0), + onmessage: Box::new(move |request| { + link.send_message(Msg::new(MsgKind::ClientRequest(request))) + }), + requests: RefCell::new(Slab::new()), + broadcasts: RefCell::new(HashMap::new()), + state_changes: RefCell::new(Slab::new()), + }); + + let on_open = { + let link = ctx.link().clone(); + + let cb: Box = Box::new(move || { + link.send_message(Msg::new(MsgKind::Open)); + }); + + Closure::wrap(cb) + }; + + let on_close = { + let link = ctx.link().clone(); + + let cb: Box = Box::new(move |e: CloseEvent| { + link.send_message(Msg::new(MsgKind::Close(e))); + }); + + Closure::wrap(cb) + }; + + let on_message = { + let link = ctx.link().clone(); + + let cb: Box = Box::new(move |e: MessageEvent| { + link.send_message(Msg::new(MsgKind::Message(e))); + }); + + Closure::wrap(cb) + }; + + let on_error = { + let link = ctx.link().clone(); + + let cb: Box = Box::new(move |e: ErrorEvent| { + link.send_message(Msg::new(MsgKind::Error(e))); + }); + + Closure::wrap(cb) + }; + + let this = Self { + shared: shared.clone(), + socket: None, + opened: None, + state: State::Closed, + buffer: Vec::new(), + output: Vec::new(), + timeout: INITIAL_TIMEOUT, + on_open, + on_close, + on_message, + on_error, + _timeout: None, + _ping_timeout: None, + _marker: PhantomData, + }; + + let handle = Handle { shared }; + + (this, handle) + } + + /// Send a client message. + fn send_client_request(&mut self, request: ClientRequest<'_>) -> Result<()> { + let Some(socket) = &self.socket else { + return Err("Socket is not connected".into()); + }; + + self.output.clear(); + musli::storage::to_writer(&mut self.output, &request.header)?; + self.output.extend_from_slice(request.body.as_slice()); + socket.send_with_u8_array(&self.output)?; + self.output.shrink_to(MAX_CAPACITY); + Ok(()) + } + + fn message(&mut self, e: MessageEvent) -> Result<()> { + let Ok(array_buffer) = e.data().dyn_into::() else { + return Err("Expected message as ArrayBuffer".into()); + }; + + let body = Rc::from(Uint8Array::new(&array_buffer).to_vec()); + let mut reader = musli::reader::SliceReader::new(&body); + + let header: api::ResponseHeader<'_> = musli::storage::decode(&mut reader)?; + + match header.broadcast { + Some(kind) => { + let broadcasts = self.shared.broadcasts.borrow(); + let at = body.len() - reader.remaining(); + + if let Some(broadcasts) = broadcasts.get(kind) { + let mut it = broadcasts.iter(); + + let last = it.next_back(); + let raw = RawPacket { + body: body.clone(), + at, + }; + + for (_, callback) in it { + (callback)(raw.clone()); + } + + if let Some((_, callback)) = last { + (callback)(raw); + } + } + } + None => { + log::trace!( + "Got response: index={}, serial={}", + header.index, + header.serial + ); + + let requests = self.shared.requests.borrow(); + + let Some(pending) = requests.get(header.index as usize) else { + return Err("Header index out of bound".into()); + }; + + if pending.serial == header.serial { + if let Some(error) = header.error { + (pending.callback)(Err(Error::from(error))); + } else { + let at = body.len() - reader.remaining(); + let raw = RawPacket { body, at }; + (pending.callback)(Ok(raw)); + } + } + } + } + + Ok(()) + } + + fn set_open(&mut self) { + log::trace!("Set open"); + self.opened = Some(Opened { at: now() }); + self.emit_state_change(State::Open); + } + + fn is_open_for_a_while(&self) -> bool { + let Some(opened) = self.opened else { + return false; + }; + + let Some(at) = opened.at else { + return false; + }; + + let Some(now) = now() else { + return false; + }; + + (now - at) >= 250.0 + } + + fn set_closed(&mut self, ctx: &Context) { + log::trace!( + "Set closed timeout={}, opened={:?}", + self.timeout, + self.opened + ); + + if !self.is_open_for_a_while() { + if self.timeout < MAX_TIMEOUT { + self.timeout *= 2; + } + } else { + self.timeout = INITIAL_TIMEOUT; + } + + self.opened = None; + self.reconnect(ctx); + self.emit_state_change(State::Closed); + } + + fn emit_state_change(&mut self, state: State) { + if self.state != state { + let callbacks = self.shared.state_changes.borrow(); + + for (_, callback) in callbacks.iter() { + callback(state); + } + + self.state = state; + } + } + + /// Handle an update message. + pub fn update(&mut self, ctx: &Context, message: Msg) { + match message.kind { + MsgKind::Reconnect => { + log::trace!("Reconnect"); + + if let Err(error) = self.inner_connect() { + ctx.link().send_message(error); + self.inner_reconnect(ctx); + } + } + MsgKind::Open => { + log::trace!("Open"); + self.set_open(); + + let buffer = take(&mut self.buffer); + + for request in buffer { + if let Err(error) = self.send_client_request(request) { + ctx.link().send_message(error); + } + } + } + MsgKind::Close(e) => { + log::trace!("Close: {} ({})", e.code(), e.reason()); + self.set_closed(ctx); + } + MsgKind::Message(e) => { + if let Err(error) = self.message(e) { + ctx.link().send_message(error); + } + } + MsgKind::Error(e) => { + log::error!("{}", e.message()); + self.set_closed(ctx); + } + MsgKind::ClientRequest(request) => { + if self.opened.is_none() { + self.buffer.push(request); + return; + } + + if let Err(error) = self.send_client_request(request) { + ctx.link().send_message(error); + } + } + } + } + + pub(crate) fn reconnect(&mut self, ctx: &Context) { + if let Some(old) = self.socket.take() { + if let Err(error) = old.close() { + ctx.link().send_message(Error::from(error)); + } + } + + let link = ctx.link().clone(); + + self._timeout = Some(Timeout::new(self.timeout, move || { + link.send_message(Msg::new(MsgKind::Reconnect)); + })); + } + + /// Attempt to establish a websocket connection. + pub fn connect(&mut self, ctx: &Context) { + if let Err(error) = self.inner_connect() { + ctx.link().send_message(error); + self.inner_reconnect(ctx); + } + } + + fn inner_connect(&mut self) -> Result<()> { + let window = window().ok_or("No window")?; + let port = window.location().port()?; + let url = format!("ws://127.0.0.1:{port}/ws"); + let ws = WebSocket::new(&url)?; + + ws.set_binary_type(BinaryType::Arraybuffer); + ws.set_onopen(Some(self.on_open.as_ref().unchecked_ref())); + ws.set_onclose(Some(self.on_close.as_ref().unchecked_ref())); + ws.set_onmessage(Some(self.on_message.as_ref().unchecked_ref())); + ws.set_onerror(Some(self.on_error.as_ref().unchecked_ref())); + + if let Some(old) = self.socket.replace(ws) { + old.close()?; + } + + Ok(()) + } + + fn inner_reconnect(&mut self, ctx: &Context) { + let link = ctx.link().clone(); + + self._timeout = Some(Timeout::new(1000, move || { + link.send_message(Msg::new(MsgKind::Reconnect)); + })); + } +} + +/// The handle for a pending request. Dropping this handle cancels the request. +pub struct Request { + inner: Option<(Rc, u32)>, + _marker: PhantomData, +} + +impl Request { + /// An empty request handler. + #[inline] + pub fn empty() -> Self { + Self::default() + } +} + +impl Default for Request { + #[inline] + fn default() -> Self { + Self { + inner: None, + _marker: PhantomData, + } + } +} + +impl Drop for Request { + #[inline] + fn drop(&mut self) { + if let Some((shared, index)) = self.inner.take() { + shared.requests.borrow_mut().try_remove(index as usize); + } + } +} + +/// The handle for a pending request. Dropping this handle cancels the request. +pub struct Listener { + kind: &'static str, + index: usize, + shared: Rc, + _marker: PhantomData, +} + +impl Drop for Listener { + #[inline] + fn drop(&mut self) { + let mut broadcast = self.shared.broadcasts.borrow_mut(); + + if let hash_map::Entry::Occupied(mut e) = broadcast.entry(self.kind) { + e.get_mut().try_remove(self.index); + + if e.get().is_empty() { + e.remove(); + } + } + } +} + +/// The handle for state change listening. Dropping this handle cancels the request. +pub struct StateListener { + index: usize, + shared: Rc, +} + +impl Drop for StateListener { + #[inline] + fn drop(&mut self) { + self.shared + .state_changes + .borrow_mut() + .try_remove(self.index); + } +} + +#[derive(Clone)] +struct RawPacket { + body: Rc<[u8]>, + at: usize, +} + +/// A packet of data. +pub struct Packet +where + T: api::Marker, +{ + raw: RawPacket, + _marker: PhantomData, +} + +impl Packet +where + T: api::Marker, +{ + /// Handle a broadcast packet. + pub fn decode(&self, ctx: &Context, f: F) + where + F: FnOnce(T::Type<'_>), + C: Component, + C::Message: From, + { + let Some(bytes) = self.raw.body.get(self.raw.at..) else { + ctx.link() + .send_message(C::Message::from(Error::new(ErrorKind::Overflow( + self.raw.at, + self.raw.body.len(), + )))); + return; + }; + + match musli::storage::from_slice(bytes) { + Ok(value) => { + f(value); + } + Err(error) => { + ctx.link() + .send_message(C::Message::from(Error::from(error))); + } + } + } +} + +/// A handle to the WebSocket service. +#[derive(Clone)] +pub struct Handle { + shared: Rc, +} + +impl Handle { + /// Send a request of type `T`. + /// + /// Returns a handle for the request. + /// + /// If the handle is dropped, the request is cancelled. + pub fn request(&self, ctx: &Context, request: T) -> Request + where + C: Component, + C::Message: From> + From, + T: api::Request, + { + let body = match musli::storage::to_vec(&request) { + Ok(body) => body, + Err(error) => { + ctx.link() + .send_message(C::Message::from(Error::from(error))); + return Request::default(); + } + }; + + let mut requests = self.shared.requests.borrow_mut(); + let serial = self.shared.serial.get(); + self.shared.serial.set(serial.wrapping_add(1)); + + let link = ctx.link().clone(); + + let pending = Pending { + serial, + callback: Box::new(move |result| { + let raw = match result { + Ok(raw) => raw, + Err(error) => { + link.send_message(C::Message::from(error)); + return; + } + }; + + link.send_message(C::Message::from(Packet { + raw, + _marker: PhantomData, + })); + }), + }; + + let index = requests.insert(pending) as u32; + + (self.shared.onmessage)(ClientRequest { + header: api::RequestHeader { + index, + serial, + kind: T::KIND, + }, + body, + }); + + Request { + inner: Some((self.shared.clone(), index)), + _marker: PhantomData, + } + } + + /// List for broadcasts of type `T`. + /// + /// Returns a handle for the broadcasts. + /// + /// If the handle is dropped, the listener is cancelled. + pub fn listen(&self, ctx: &Context) -> Listener + where + C: Component, + C::Message: From> + From, + T: api::Broadcast, + { + let mut broadcasts = self.shared.broadcasts.borrow_mut(); + + let slots = broadcasts.entry(T::KIND).or_default(); + let link = ctx.link().clone(); + + let index = slots.insert(Box::new(move |raw| { + link.send_message(C::Message::from(Packet { + raw, + _marker: PhantomData, + })); + })); + + Listener { + kind: T::KIND, + index, + shared: self.shared.clone(), + _marker: PhantomData, + } + } + + /// Listen for state changes to the underlying connection. + pub fn state_changes(&self, ctx: &Context) -> StateListener + where + C: Component, + C::Message: From, + { + let link = ctx.link().clone(); + let mut state = self.shared.state_changes.borrow_mut(); + + let index = state.insert(Box::new(move |state| { + link.send_message(C::Message::from(state)) + })); + + StateListener { + index, + shared: self.shared.clone(), + } + } +} + +impl ImplicitClone for Handle { + #[inline] + fn implicit_clone(&self) -> Self { + self.clone() + } +} + +impl PartialEq for Handle { + #[inline] + fn eq(&self, _: &Self) -> bool { + true + } +} + +fn now() -> Option { + Some(window()?.performance()?.now()) +} + +struct Pending { + serial: u32, + callback: Box)>, +} + +type Broadcasts = HashMap<&'static str, Slab>>; +type OnMessageCallback = dyn Fn(ClientRequest<'static>); +type StateCallback = dyn Fn(State); + +struct Shared { + serial: Cell, + onmessage: Box, + requests: RefCell>, + broadcasts: RefCell, + state_changes: RefCell>>, +} + +#[derive(Debug, Clone, Copy)] +struct Opened { + at: Option, +} diff --git a/crates/musli/Cargo.toml b/crates/musli/Cargo.toml index df109cb36..63c9e4fdd 100644 --- a/crates/musli/Cargo.toml +++ b/crates/musli/Cargo.toml @@ -12,7 +12,7 @@ readme = "README.md" homepage = "https://github.com/udoprog/musli" repository = "https://github.com/udoprog/musli" license = "MIT OR Apache-2.0" -keywords = ["no_std", "serialization"] +keywords = ["binary", "json", "no_std", "serialization"] categories = ["encoding", "no-std", "no-std::no-alloc"] [package.metadata.docs.rs]